mirror of
https://github.com/langgenius/dify.git
synced 2026-03-31 18:00:55 -04:00
test: add tests for api/services retention, enterprise, plugin (#32648)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
This commit is contained in:
@@ -0,0 +1,182 @@
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from models import Tenant
|
||||
from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant(flask_req_ctx):
|
||||
with session_factory.create_session() as session:
|
||||
t = Tenant(name="plugin_it_tenant")
|
||||
session.add(t)
|
||||
session.commit()
|
||||
tenant_id = t.id
|
||||
|
||||
yield tenant_id
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
session.execute(delete(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id))
|
||||
session.execute(
|
||||
delete(TenantPluginAutoUpgradeStrategy).where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
)
|
||||
session.execute(delete(Tenant).where(Tenant.id == tenant_id))
|
||||
session.commit()
|
||||
|
||||
|
||||
class TestPluginPermissionLifecycle:
|
||||
def test_get_returns_none_for_new_tenant(self, tenant):
|
||||
assert PluginPermissionService.get_permission(tenant) is None
|
||||
|
||||
def test_change_creates_row(self, tenant):
|
||||
result = PluginPermissionService.change_permission(
|
||||
tenant,
|
||||
TenantPluginPermission.InstallPermission.ADMINS,
|
||||
TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
perm = PluginPermissionService.get_permission(tenant)
|
||||
assert perm is not None
|
||||
assert perm.install_permission == TenantPluginPermission.InstallPermission.ADMINS
|
||||
assert perm.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE
|
||||
|
||||
def test_change_updates_existing_row(self, tenant):
|
||||
PluginPermissionService.change_permission(
|
||||
tenant,
|
||||
TenantPluginPermission.InstallPermission.ADMINS,
|
||||
TenantPluginPermission.DebugPermission.NOBODY,
|
||||
)
|
||||
PluginPermissionService.change_permission(
|
||||
tenant,
|
||||
TenantPluginPermission.InstallPermission.EVERYONE,
|
||||
TenantPluginPermission.DebugPermission.ADMINS,
|
||||
)
|
||||
perm = PluginPermissionService.get_permission(tenant)
|
||||
assert perm is not None
|
||||
assert perm.install_permission == TenantPluginPermission.InstallPermission.EVERYONE
|
||||
assert perm.debug_permission == TenantPluginPermission.DebugPermission.ADMINS
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
count = session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant).count()
|
||||
assert count == 1
|
||||
|
||||
|
||||
class TestPluginAutoUpgradeLifecycle:
|
||||
def test_get_returns_none_for_new_tenant(self, tenant):
|
||||
assert PluginAutoUpgradeService.get_strategy(tenant) is None
|
||||
|
||||
def test_change_creates_row(self, tenant):
|
||||
result = PluginAutoUpgradeService.change_strategy(
|
||||
tenant,
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
|
||||
upgrade_time_of_day=3,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
|
||||
exclude_plugins=[],
|
||||
include_plugins=[],
|
||||
)
|
||||
assert result is True
|
||||
|
||||
strategy = PluginAutoUpgradeService.get_strategy(tenant)
|
||||
assert strategy is not None
|
||||
assert strategy.strategy_setting == TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST
|
||||
assert strategy.upgrade_time_of_day == 3
|
||||
|
||||
def test_change_updates_existing_row(self, tenant):
|
||||
PluginAutoUpgradeService.change_strategy(
|
||||
tenant,
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
|
||||
upgrade_time_of_day=0,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
|
||||
exclude_plugins=[],
|
||||
include_plugins=[],
|
||||
)
|
||||
PluginAutoUpgradeService.change_strategy(
|
||||
tenant,
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
|
||||
upgrade_time_of_day=12,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL,
|
||||
exclude_plugins=[],
|
||||
include_plugins=["plugin-a"],
|
||||
)
|
||||
|
||||
strategy = PluginAutoUpgradeService.get_strategy(tenant)
|
||||
assert strategy is not None
|
||||
assert strategy.strategy_setting == TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST
|
||||
assert strategy.upgrade_time_of_day == 12
|
||||
assert strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL
|
||||
assert strategy.include_plugins == ["plugin-a"]
|
||||
|
||||
def test_exclude_plugin_creates_strategy_when_none_exists(self, tenant):
|
||||
PluginAutoUpgradeService.exclude_plugin(tenant, "my-plugin")
|
||||
|
||||
strategy = PluginAutoUpgradeService.get_strategy(tenant)
|
||||
assert strategy is not None
|
||||
assert strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE
|
||||
assert "my-plugin" in strategy.exclude_plugins
|
||||
|
||||
def test_exclude_plugin_appends_in_exclude_mode(self, tenant):
|
||||
PluginAutoUpgradeService.change_strategy(
|
||||
tenant,
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
|
||||
upgrade_time_of_day=0,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
|
||||
exclude_plugins=["existing"],
|
||||
include_plugins=[],
|
||||
)
|
||||
PluginAutoUpgradeService.exclude_plugin(tenant, "new-plugin")
|
||||
|
||||
strategy = PluginAutoUpgradeService.get_strategy(tenant)
|
||||
assert strategy is not None
|
||||
assert "existing" in strategy.exclude_plugins
|
||||
assert "new-plugin" in strategy.exclude_plugins
|
||||
|
||||
def test_exclude_plugin_dedup_in_exclude_mode(self, tenant):
|
||||
PluginAutoUpgradeService.change_strategy(
|
||||
tenant,
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
|
||||
upgrade_time_of_day=0,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
|
||||
exclude_plugins=["same-plugin"],
|
||||
include_plugins=[],
|
||||
)
|
||||
PluginAutoUpgradeService.exclude_plugin(tenant, "same-plugin")
|
||||
|
||||
strategy = PluginAutoUpgradeService.get_strategy(tenant)
|
||||
assert strategy is not None
|
||||
assert strategy.exclude_plugins.count("same-plugin") == 1
|
||||
|
||||
def test_exclude_from_partial_mode_removes_from_include(self, tenant):
|
||||
PluginAutoUpgradeService.change_strategy(
|
||||
tenant,
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
|
||||
upgrade_time_of_day=0,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL,
|
||||
exclude_plugins=[],
|
||||
include_plugins=["p1", "p2"],
|
||||
)
|
||||
PluginAutoUpgradeService.exclude_plugin(tenant, "p1")
|
||||
|
||||
strategy = PluginAutoUpgradeService.get_strategy(tenant)
|
||||
assert strategy is not None
|
||||
assert "p1" not in strategy.include_plugins
|
||||
assert "p2" in strategy.include_plugins
|
||||
|
||||
def test_exclude_from_all_mode_switches_to_exclude(self, tenant):
|
||||
PluginAutoUpgradeService.change_strategy(
|
||||
tenant,
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
|
||||
upgrade_time_of_day=0,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
|
||||
exclude_plugins=[],
|
||||
include_plugins=[],
|
||||
)
|
||||
PluginAutoUpgradeService.exclude_plugin(tenant, "excluded-plugin")
|
||||
|
||||
strategy = PluginAutoUpgradeService.get_strategy(tenant)
|
||||
assert strategy is not None
|
||||
assert strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE
|
||||
assert "excluded-plugin" in strategy.exclude_plugins
|
||||
@@ -0,0 +1,348 @@
|
||||
import datetime
|
||||
import math
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from models import Tenant
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.model import (
|
||||
App,
|
||||
Conversation,
|
||||
Message,
|
||||
MessageAnnotation,
|
||||
MessageFeedback,
|
||||
)
|
||||
from services.retention.conversation.messages_clean_policy import BillingDisabledPolicy
|
||||
from services.retention.conversation.messages_clean_service import MessagesCleanService
|
||||
|
||||
_NOW = datetime.datetime(2026, 1, 15, 12, 0, 0, tzinfo=datetime.UTC)
|
||||
_OLD = _NOW - datetime.timedelta(days=60)
|
||||
_VERY_OLD = _NOW - datetime.timedelta(days=90)
|
||||
_RECENT = _NOW - datetime.timedelta(days=5)
|
||||
|
||||
_WINDOW_START = _VERY_OLD - datetime.timedelta(hours=1)
|
||||
_WINDOW_END = _RECENT + datetime.timedelta(hours=1)
|
||||
|
||||
_DEFAULT_BATCH_SIZE = 100
|
||||
_PAGINATION_MESSAGE_COUNT = 25
|
||||
_PAGINATION_BATCH_SIZE = 8
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_and_app(flask_req_ctx):
|
||||
"""Creates a Tenant, App and Conversation for the test and cleans up after."""
|
||||
with session_factory.create_session() as session:
|
||||
tenant = Tenant(name="retention_it_tenant")
|
||||
session.add(tenant)
|
||||
session.flush()
|
||||
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name="Retention IT App",
|
||||
mode="chat",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
)
|
||||
session.add(app)
|
||||
session.flush()
|
||||
|
||||
conv = Conversation(
|
||||
app_id=app.id,
|
||||
mode="chat",
|
||||
name="test_conv",
|
||||
status="normal",
|
||||
from_source="console",
|
||||
_inputs={},
|
||||
)
|
||||
session.add(conv)
|
||||
session.commit()
|
||||
|
||||
tenant_id = tenant.id
|
||||
app_id = app.id
|
||||
conv_id = conv.id
|
||||
|
||||
yield {"tenant_id": tenant_id, "app_id": app_id, "conversation_id": conv_id}
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
session.execute(delete(Conversation).where(Conversation.id == conv_id))
|
||||
session.execute(delete(App).where(App.id == app_id))
|
||||
session.execute(delete(Tenant).where(Tenant.id == tenant_id))
|
||||
session.commit()
|
||||
|
||||
|
||||
def _make_message(app_id: str, conversation_id: str, created_at: datetime.datetime) -> Message:
|
||||
return Message(
|
||||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
query="test",
|
||||
message=[{"text": "hello"}],
|
||||
answer="world",
|
||||
message_tokens=1,
|
||||
message_unit_price=0,
|
||||
answer_tokens=1,
|
||||
answer_unit_price=0,
|
||||
from_source="console",
|
||||
currency="USD",
|
||||
_inputs={},
|
||||
created_at=created_at,
|
||||
)
|
||||
|
||||
|
||||
class TestMessagesCleanServiceIntegration:
|
||||
@pytest.fixture
|
||||
def seed_messages(self, tenant_and_app):
|
||||
"""Seeds one message at each of _VERY_OLD, _OLD, and _RECENT.
|
||||
Yields a semantic mapping keyed by age label.
|
||||
"""
|
||||
data = tenant_and_app
|
||||
app_id = data["app_id"]
|
||||
conv_id = data["conversation_id"]
|
||||
# Ordered tuple of (label, timestamp) for deterministic seeding
|
||||
timestamps = [
|
||||
("very_old", _VERY_OLD),
|
||||
("old", _OLD),
|
||||
("recent", _RECENT),
|
||||
]
|
||||
msg_ids: dict[str, str] = {}
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
for label, ts in timestamps:
|
||||
msg = _make_message(app_id, conv_id, ts)
|
||||
session.add(msg)
|
||||
session.flush()
|
||||
msg_ids[label] = msg.id
|
||||
session.commit()
|
||||
|
||||
yield {"msg_ids": msg_ids, **data}
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
session.execute(
|
||||
delete(Message)
|
||||
.where(Message.id.in_(list(msg_ids.values())))
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
@pytest.fixture
|
||||
def paginated_seed_messages(self, tenant_and_app):
|
||||
"""Seeds multiple messages separated by 1-second increments starting at _OLD."""
|
||||
data = tenant_and_app
|
||||
app_id = data["app_id"]
|
||||
conv_id = data["conversation_id"]
|
||||
msg_ids: list[str] = []
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
for i in range(_PAGINATION_MESSAGE_COUNT):
|
||||
ts = _OLD + datetime.timedelta(seconds=i)
|
||||
msg = _make_message(app_id, conv_id, ts)
|
||||
session.add(msg)
|
||||
session.flush()
|
||||
msg_ids.append(msg.id)
|
||||
session.commit()
|
||||
|
||||
yield {"msg_ids": msg_ids, **data}
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
session.execute(delete(Message).where(Message.id.in_(msg_ids)).execution_options(synchronize_session=False))
|
||||
session.commit()
|
||||
|
||||
@pytest.fixture
|
||||
def cascade_test_data(self, tenant_and_app):
|
||||
"""Seeds one Message with an associated Feedback and Annotation."""
|
||||
data = tenant_and_app
|
||||
app_id = data["app_id"]
|
||||
conv_id = data["conversation_id"]
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
msg = _make_message(app_id, conv_id, _OLD)
|
||||
session.add(msg)
|
||||
session.flush()
|
||||
|
||||
feedback = MessageFeedback(
|
||||
app_id=app_id,
|
||||
conversation_id=conv_id,
|
||||
message_id=msg.id,
|
||||
rating=FeedbackRating.LIKE,
|
||||
from_source=FeedbackFromSource.USER,
|
||||
)
|
||||
annotation = MessageAnnotation(
|
||||
app_id=app_id,
|
||||
conversation_id=conv_id,
|
||||
message_id=msg.id,
|
||||
question="q",
|
||||
content="a",
|
||||
account_id=str(uuid.uuid4()),
|
||||
)
|
||||
session.add_all([feedback, annotation])
|
||||
session.commit()
|
||||
|
||||
msg_id = msg.id
|
||||
fb_id = feedback.id
|
||||
ann_id = annotation.id
|
||||
|
||||
yield {"msg_id": msg_id, "fb_id": fb_id, "ann_id": ann_id, **data}
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
session.execute(delete(MessageAnnotation).where(MessageAnnotation.id == ann_id))
|
||||
session.execute(delete(MessageFeedback).where(MessageFeedback.id == fb_id))
|
||||
session.execute(delete(Message).where(Message.id == msg_id))
|
||||
session.commit()
|
||||
|
||||
def test_dry_run_does_not_delete(self, seed_messages):
|
||||
"""Dry-run must count eligible rows without deleting any of them."""
|
||||
data = seed_messages
|
||||
msg_ids = data["msg_ids"]
|
||||
all_ids = list(msg_ids.values())
|
||||
|
||||
svc = MessagesCleanService.from_time_range(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=_WINDOW_START,
|
||||
end_before=_WINDOW_END,
|
||||
batch_size=_DEFAULT_BATCH_SIZE,
|
||||
dry_run=True,
|
||||
)
|
||||
stats = svc.run()
|
||||
|
||||
assert stats["filtered_messages"] == len(all_ids)
|
||||
assert stats["total_deleted"] == 0
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
remaining = session.query(Message).where(Message.id.in_(all_ids)).count()
|
||||
assert remaining == len(all_ids)
|
||||
|
||||
def test_billing_disabled_deletes_all_in_range(self, seed_messages):
|
||||
"""All 3 seeded messages fall within the window and must be deleted."""
|
||||
data = seed_messages
|
||||
msg_ids = data["msg_ids"]
|
||||
all_ids = list(msg_ids.values())
|
||||
|
||||
svc = MessagesCleanService.from_time_range(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=_WINDOW_START,
|
||||
end_before=_WINDOW_END,
|
||||
batch_size=_DEFAULT_BATCH_SIZE,
|
||||
dry_run=False,
|
||||
)
|
||||
stats = svc.run()
|
||||
|
||||
assert stats["total_deleted"] == len(all_ids)
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
remaining = session.query(Message).where(Message.id.in_(all_ids)).count()
|
||||
assert remaining == 0
|
||||
|
||||
def test_start_from_filters_correctly(self, seed_messages):
|
||||
"""Only the message at _OLD falls within the narrow ±1 h window."""
|
||||
data = seed_messages
|
||||
msg_ids = data["msg_ids"]
|
||||
|
||||
start = _OLD - datetime.timedelta(hours=1)
|
||||
end = _OLD + datetime.timedelta(hours=1)
|
||||
|
||||
svc = MessagesCleanService.from_time_range(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=start,
|
||||
end_before=end,
|
||||
batch_size=_DEFAULT_BATCH_SIZE,
|
||||
)
|
||||
stats = svc.run()
|
||||
|
||||
assert stats["total_deleted"] == 1
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
all_ids = list(msg_ids.values())
|
||||
remaining_ids = {r[0] for r in session.query(Message.id).where(Message.id.in_(all_ids)).all()}
|
||||
|
||||
assert msg_ids["old"] not in remaining_ids
|
||||
assert msg_ids["very_old"] in remaining_ids
|
||||
assert msg_ids["recent"] in remaining_ids
|
||||
|
||||
def test_cursor_pagination_across_batches(self, paginated_seed_messages):
|
||||
"""Messages must be deleted across multiple batches."""
|
||||
data = paginated_seed_messages
|
||||
msg_ids = data["msg_ids"]
|
||||
|
||||
# _OLD is the earliest; the last one is _OLD + (_PAGINATION_MESSAGE_COUNT - 1) s.
|
||||
pagination_window_start = _OLD - datetime.timedelta(seconds=1)
|
||||
pagination_window_end = _OLD + datetime.timedelta(seconds=_PAGINATION_MESSAGE_COUNT)
|
||||
|
||||
svc = MessagesCleanService.from_time_range(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=pagination_window_start,
|
||||
end_before=pagination_window_end,
|
||||
batch_size=_PAGINATION_BATCH_SIZE,
|
||||
)
|
||||
stats = svc.run()
|
||||
|
||||
assert stats["total_deleted"] == _PAGINATION_MESSAGE_COUNT
|
||||
expected_batches = math.ceil(_PAGINATION_MESSAGE_COUNT / _PAGINATION_BATCH_SIZE)
|
||||
assert stats["batches"] >= expected_batches
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
remaining = session.query(Message).where(Message.id.in_(msg_ids)).count()
|
||||
assert remaining == 0
|
||||
|
||||
def test_no_messages_in_range_returns_empty_stats(self, seed_messages):
|
||||
"""A window entirely in the future must yield zero matches."""
|
||||
far_future = _NOW + datetime.timedelta(days=365)
|
||||
even_further = far_future + datetime.timedelta(days=1)
|
||||
|
||||
svc = MessagesCleanService.from_time_range(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=far_future,
|
||||
end_before=even_further,
|
||||
batch_size=_DEFAULT_BATCH_SIZE,
|
||||
)
|
||||
stats = svc.run()
|
||||
|
||||
assert stats["total_messages"] == 0
|
||||
assert stats["total_deleted"] == 0
|
||||
|
||||
def test_relation_cascade_deletes(self, cascade_test_data):
|
||||
"""Deleting a Message must cascade to its Feedback and Annotation rows."""
|
||||
data = cascade_test_data
|
||||
msg_id = data["msg_id"]
|
||||
fb_id = data["fb_id"]
|
||||
ann_id = data["ann_id"]
|
||||
|
||||
svc = MessagesCleanService.from_time_range(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=_OLD - datetime.timedelta(hours=1),
|
||||
end_before=_OLD + datetime.timedelta(hours=1),
|
||||
batch_size=_DEFAULT_BATCH_SIZE,
|
||||
)
|
||||
stats = svc.run()
|
||||
|
||||
assert stats["total_deleted"] == 1
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
assert session.query(Message).where(Message.id == msg_id).count() == 0
|
||||
assert session.query(MessageFeedback).where(MessageFeedback.id == fb_id).count() == 0
|
||||
assert session.query(MessageAnnotation).where(MessageAnnotation.id == ann_id).count() == 0
|
||||
|
||||
def test_factory_from_time_range_validation(self):
|
||||
with pytest.raises(ValueError, match="start_from"):
|
||||
MessagesCleanService.from_time_range(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=_NOW,
|
||||
end_before=_OLD,
|
||||
)
|
||||
|
||||
def test_factory_from_days_validation(self):
|
||||
with pytest.raises(ValueError, match="days"):
|
||||
MessagesCleanService.from_days(
|
||||
policy=BillingDisabledPolicy(),
|
||||
days=-1,
|
||||
)
|
||||
|
||||
def test_factory_batch_size_validation(self):
|
||||
with pytest.raises(ValueError, match="batch_size"):
|
||||
MessagesCleanService.from_time_range(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=_OLD,
|
||||
end_before=_NOW,
|
||||
batch_size=0,
|
||||
)
|
||||
@@ -0,0 +1,177 @@
|
||||
import datetime
|
||||
import io
|
||||
import json
|
||||
import uuid
|
||||
import zipfile
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from services.retention.workflow_run.archive_paid_plan_workflow_run import (
|
||||
ArchiveSummary,
|
||||
WorkflowRunArchiver,
|
||||
)
|
||||
from services.retention.workflow_run.constants import ARCHIVE_SCHEMA_VERSION
|
||||
|
||||
|
||||
class TestWorkflowRunArchiverInit:
|
||||
def test_start_from_without_end_before_raises(self):
|
||||
with pytest.raises(ValueError, match="start_from and end_before must be provided together"):
|
||||
WorkflowRunArchiver(start_from=datetime.datetime(2025, 1, 1))
|
||||
|
||||
def test_end_before_without_start_from_raises(self):
|
||||
with pytest.raises(ValueError, match="start_from and end_before must be provided together"):
|
||||
WorkflowRunArchiver(end_before=datetime.datetime(2025, 1, 1))
|
||||
|
||||
def test_start_equals_end_raises(self):
|
||||
ts = datetime.datetime(2025, 1, 1)
|
||||
with pytest.raises(ValueError, match="start_from must be earlier than end_before"):
|
||||
WorkflowRunArchiver(start_from=ts, end_before=ts)
|
||||
|
||||
def test_start_after_end_raises(self):
|
||||
with pytest.raises(ValueError, match="start_from must be earlier than end_before"):
|
||||
WorkflowRunArchiver(
|
||||
start_from=datetime.datetime(2025, 6, 1),
|
||||
end_before=datetime.datetime(2025, 1, 1),
|
||||
)
|
||||
|
||||
def test_workers_zero_raises(self):
|
||||
with pytest.raises(ValueError, match="workers must be at least 1"):
|
||||
WorkflowRunArchiver(workers=0)
|
||||
|
||||
def test_valid_init_defaults(self):
|
||||
archiver = WorkflowRunArchiver(days=30, batch_size=50)
|
||||
assert archiver.days == 30
|
||||
assert archiver.batch_size == 50
|
||||
assert archiver.dry_run is False
|
||||
assert archiver.delete_after_archive is False
|
||||
assert archiver.start_from is None
|
||||
|
||||
def test_valid_init_with_time_range(self):
|
||||
start = datetime.datetime(2025, 1, 1)
|
||||
end = datetime.datetime(2025, 6, 1)
|
||||
archiver = WorkflowRunArchiver(start_from=start, end_before=end, workers=2)
|
||||
assert archiver.start_from is not None
|
||||
assert archiver.end_before is not None
|
||||
assert archiver.workers == 2
|
||||
|
||||
|
||||
class TestBuildArchiveBundle:
|
||||
def test_bundle_contains_manifest_and_all_tables(self):
|
||||
archiver = WorkflowRunArchiver(days=90)
|
||||
|
||||
manifest_data = json.dumps({"schema_version": ARCHIVE_SCHEMA_VERSION}).encode("utf-8")
|
||||
table_payloads = dict.fromkeys(archiver.ARCHIVED_TABLES, b"")
|
||||
|
||||
bundle_bytes = archiver._build_archive_bundle(manifest_data, table_payloads)
|
||||
|
||||
with zipfile.ZipFile(io.BytesIO(bundle_bytes), "r") as zf:
|
||||
names = set(zf.namelist())
|
||||
assert "manifest.json" in names
|
||||
for table in archiver.ARCHIVED_TABLES:
|
||||
assert f"{table}.jsonl" in names, f"Missing {table}.jsonl in bundle"
|
||||
|
||||
def test_bundle_missing_table_payload_raises(self):
|
||||
archiver = WorkflowRunArchiver(days=90)
|
||||
manifest_data = b"{}"
|
||||
incomplete_payloads = {archiver.ARCHIVED_TABLES[0]: b"data"}
|
||||
|
||||
with pytest.raises(ValueError, match="Missing archive payload"):
|
||||
archiver._build_archive_bundle(manifest_data, incomplete_payloads)
|
||||
|
||||
|
||||
class TestGenerateManifest:
|
||||
def test_manifest_structure(self):
|
||||
archiver = WorkflowRunArchiver(days=90)
|
||||
from services.retention.workflow_run.archive_paid_plan_workflow_run import TableStats
|
||||
|
||||
run = MagicMock()
|
||||
run.id = str(uuid.uuid4())
|
||||
run.tenant_id = str(uuid.uuid4())
|
||||
run.app_id = str(uuid.uuid4())
|
||||
run.workflow_id = str(uuid.uuid4())
|
||||
run.created_at = datetime.datetime(2025, 3, 15, 10, 0, 0)
|
||||
|
||||
stats = [
|
||||
TableStats(table_name="workflow_runs", row_count=1, checksum="abc123", size_bytes=512),
|
||||
TableStats(table_name="workflow_app_logs", row_count=2, checksum="def456", size_bytes=1024),
|
||||
]
|
||||
|
||||
manifest = archiver._generate_manifest(run, stats)
|
||||
|
||||
assert manifest["schema_version"] == ARCHIVE_SCHEMA_VERSION
|
||||
assert manifest["workflow_run_id"] == run.id
|
||||
assert manifest["tenant_id"] == run.tenant_id
|
||||
assert manifest["app_id"] == run.app_id
|
||||
assert "tables" in manifest
|
||||
assert manifest["tables"]["workflow_runs"]["row_count"] == 1
|
||||
assert manifest["tables"]["workflow_runs"]["checksum"] == "abc123"
|
||||
assert manifest["tables"]["workflow_app_logs"]["row_count"] == 2
|
||||
|
||||
|
||||
class TestFilterPaidTenants:
|
||||
def test_all_tenants_paid_when_billing_disabled(self):
|
||||
archiver = WorkflowRunArchiver(days=90)
|
||||
tenant_ids = {"t1", "t2", "t3"}
|
||||
|
||||
with patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg:
|
||||
cfg.BILLING_ENABLED = False
|
||||
result = archiver._filter_paid_tenants(tenant_ids)
|
||||
|
||||
assert result == tenant_ids
|
||||
|
||||
def test_empty_tenants_returns_empty(self):
|
||||
archiver = WorkflowRunArchiver(days=90)
|
||||
|
||||
with patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg:
|
||||
cfg.BILLING_ENABLED = True
|
||||
result = archiver._filter_paid_tenants(set())
|
||||
|
||||
assert result == set()
|
||||
|
||||
def test_only_paid_plans_returned(self):
|
||||
archiver = WorkflowRunArchiver(days=90)
|
||||
|
||||
mock_bulk = {
|
||||
"t1": {"plan": "professional"},
|
||||
"t2": {"plan": "sandbox"},
|
||||
"t3": {"plan": "team"},
|
||||
}
|
||||
|
||||
with (
|
||||
patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg,
|
||||
patch("services.retention.workflow_run.archive_paid_plan_workflow_run.BillingService") as billing,
|
||||
):
|
||||
cfg.BILLING_ENABLED = True
|
||||
billing.get_plan_bulk_with_cache.return_value = mock_bulk
|
||||
result = archiver._filter_paid_tenants({"t1", "t2", "t3"})
|
||||
|
||||
assert "t1" in result
|
||||
assert "t3" in result
|
||||
assert "t2" not in result
|
||||
|
||||
def test_billing_api_failure_returns_empty(self):
|
||||
archiver = WorkflowRunArchiver(days=90)
|
||||
|
||||
with (
|
||||
patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg,
|
||||
patch("services.retention.workflow_run.archive_paid_plan_workflow_run.BillingService") as billing,
|
||||
):
|
||||
cfg.BILLING_ENABLED = True
|
||||
billing.get_plan_bulk_with_cache.side_effect = RuntimeError("API down")
|
||||
result = archiver._filter_paid_tenants({"t1"})
|
||||
|
||||
assert result == set()
|
||||
|
||||
|
||||
class TestDryRunArchive:
|
||||
@patch("services.retention.workflow_run.archive_paid_plan_workflow_run.get_archive_storage")
|
||||
def test_dry_run_does_not_call_storage(self, mock_get_storage, flask_req_ctx):
|
||||
archiver = WorkflowRunArchiver(days=90, dry_run=True)
|
||||
|
||||
with patch.object(archiver, "_get_runs_batch", return_value=[]):
|
||||
summary = archiver.run()
|
||||
|
||||
mock_get_storage.assert_not_called()
|
||||
assert isinstance(summary, ArchiveSummary)
|
||||
assert summary.runs_failed == 0
|
||||
@@ -5,6 +5,7 @@ Covers:
|
||||
- License status caching (get_cached_license_status)
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -15,9 +16,178 @@ from services.enterprise.enterprise_service import (
|
||||
VALID_LICENSE_CACHE_TTL,
|
||||
DefaultWorkspaceJoinResult,
|
||||
EnterpriseService,
|
||||
WebAppSettings,
|
||||
WorkspacePermission,
|
||||
try_join_default_workspace,
|
||||
)
|
||||
|
||||
MODULE = "services.enterprise.enterprise_service"
|
||||
|
||||
|
||||
class TestEnterpriseServiceInfo:
|
||||
def test_get_info_delegates(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {"version": "1.0"}
|
||||
result = EnterpriseService.get_info()
|
||||
|
||||
req.send_request.assert_called_once_with("GET", "/info")
|
||||
assert result == {"version": "1.0"}
|
||||
|
||||
def test_get_workspace_info_delegates(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {"name": "ws"}
|
||||
result = EnterpriseService.get_workspace_info("tenant-1")
|
||||
|
||||
req.send_request.assert_called_once_with("GET", "/workspace/tenant-1/info")
|
||||
assert result == {"name": "ws"}
|
||||
|
||||
|
||||
class TestSsoSettingsLastUpdateTime:
|
||||
def test_app_sso_parses_valid_timestamp(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = "2025-01-15T10:30:00+00:00"
|
||||
result = EnterpriseService.get_app_sso_settings_last_update_time()
|
||||
|
||||
assert isinstance(result, datetime)
|
||||
assert result.year == 2025
|
||||
|
||||
def test_app_sso_raises_on_empty(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = ""
|
||||
with pytest.raises(ValueError, match="No data found"):
|
||||
EnterpriseService.get_app_sso_settings_last_update_time()
|
||||
|
||||
def test_app_sso_raises_on_invalid_format(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = "not-a-date"
|
||||
with pytest.raises(ValueError, match="Invalid date format"):
|
||||
EnterpriseService.get_app_sso_settings_last_update_time()
|
||||
|
||||
def test_workspace_sso_parses_valid_timestamp(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = "2025-06-01T00:00:00+00:00"
|
||||
result = EnterpriseService.get_workspace_sso_settings_last_update_time()
|
||||
|
||||
assert isinstance(result, datetime)
|
||||
|
||||
def test_workspace_sso_raises_on_empty(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = None
|
||||
with pytest.raises(ValueError, match="No data found"):
|
||||
EnterpriseService.get_workspace_sso_settings_last_update_time()
|
||||
|
||||
|
||||
class TestWorkspacePermissionService:
|
||||
def test_raises_on_empty_workspace_id(self):
|
||||
with pytest.raises(ValueError, match="workspace_id must be provided"):
|
||||
EnterpriseService.WorkspacePermissionService.get_permission("")
|
||||
|
||||
def test_raises_on_missing_data(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = None
|
||||
with pytest.raises(ValueError, match="No data found"):
|
||||
EnterpriseService.WorkspacePermissionService.get_permission("ws-1")
|
||||
|
||||
def test_raises_on_missing_permission_key(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {"other": "data"}
|
||||
with pytest.raises(ValueError, match="No data found"):
|
||||
EnterpriseService.WorkspacePermissionService.get_permission("ws-1")
|
||||
|
||||
def test_returns_parsed_permission(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {
|
||||
"permission": {
|
||||
"workspaceId": "ws-1",
|
||||
"allowMemberInvite": True,
|
||||
"allowOwnerTransfer": False,
|
||||
}
|
||||
}
|
||||
result = EnterpriseService.WorkspacePermissionService.get_permission("ws-1")
|
||||
|
||||
assert isinstance(result, WorkspacePermission)
|
||||
assert result.workspace_id == "ws-1"
|
||||
assert result.allow_member_invite is True
|
||||
assert result.allow_owner_transfer is False
|
||||
|
||||
|
||||
class TestWebAppAuth:
|
||||
def test_is_user_allowed_returns_result_field(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {"result": True}
|
||||
assert EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp("u1", "a1") is True
|
||||
|
||||
def test_is_user_allowed_defaults_false(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {}
|
||||
assert EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp("u1", "a1") is False
|
||||
|
||||
def test_batch_is_user_allowed_returns_empty_for_no_apps(self):
|
||||
assert EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps("u1", []) == {}
|
||||
|
||||
def test_batch_is_user_allowed_raises_on_empty_response(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = None
|
||||
with pytest.raises(ValueError, match="No data found"):
|
||||
EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps("u1", ["a1"])
|
||||
|
||||
def test_get_app_access_mode_raises_on_empty_app_id(self):
|
||||
with pytest.raises(ValueError, match="app_id must be provided"):
|
||||
EnterpriseService.WebAppAuth.get_app_access_mode_by_id("")
|
||||
|
||||
def test_get_app_access_mode_returns_settings(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {"accessMode": "public"}
|
||||
result = EnterpriseService.WebAppAuth.get_app_access_mode_by_id("a1")
|
||||
|
||||
assert isinstance(result, WebAppSettings)
|
||||
assert result.access_mode == "public"
|
||||
|
||||
def test_batch_get_returns_empty_for_no_apps(self):
|
||||
assert EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id([]) == {}
|
||||
|
||||
def test_batch_get_maps_access_modes(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {"accessModes": {"a1": "public", "a2": "private"}}
|
||||
result = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(["a1", "a2"])
|
||||
|
||||
assert result["a1"].access_mode == "public"
|
||||
assert result["a2"].access_mode == "private"
|
||||
|
||||
def test_batch_get_raises_on_invalid_format(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {"accessModes": "not-a-dict"}
|
||||
with pytest.raises(ValueError, match="Invalid data format"):
|
||||
EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(["a1"])
|
||||
|
||||
def test_update_access_mode_raises_on_empty_app_id(self):
|
||||
with pytest.raises(ValueError, match="app_id must be provided"):
|
||||
EnterpriseService.WebAppAuth.update_app_access_mode("", "public")
|
||||
|
||||
def test_update_access_mode_raises_on_invalid_mode(self):
|
||||
with pytest.raises(ValueError, match="access_mode must be"):
|
||||
EnterpriseService.WebAppAuth.update_app_access_mode("a1", "invalid")
|
||||
|
||||
def test_update_access_mode_delegates_and_returns(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {"result": True}
|
||||
result = EnterpriseService.WebAppAuth.update_app_access_mode("a1", "public")
|
||||
|
||||
assert result is True
|
||||
req.send_request.assert_called_once_with(
|
||||
"POST", "/webapp/access-mode", json={"appId": "a1", "accessMode": "public"}
|
||||
)
|
||||
|
||||
def test_cleanup_webapp_raises_on_empty_app_id(self):
|
||||
with pytest.raises(ValueError, match="app_id must be provided"):
|
||||
EnterpriseService.WebAppAuth.cleanup_webapp("")
|
||||
|
||||
def test_cleanup_webapp_delegates(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
EnterpriseService.WebAppAuth.cleanup_webapp("a1")
|
||||
|
||||
req.send_request.assert_called_once_with("DELETE", "/webapp/clean", params={"appId": "a1"})
|
||||
|
||||
|
||||
class TestJoinDefaultWorkspace:
|
||||
def test_join_default_workspace_success(self):
|
||||
|
||||
@@ -7,14 +7,20 @@ This module covers the pre-uninstall plugin hook behavior:
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from httpx import HTTPStatusError
|
||||
|
||||
from configs import dify_config
|
||||
from services.enterprise.plugin_manager_service import (
|
||||
CheckCredentialPolicyComplianceRequest,
|
||||
CredentialPolicyViolationError,
|
||||
PluginCredentialType,
|
||||
PluginManagerService,
|
||||
PreUninstallPluginRequest,
|
||||
)
|
||||
|
||||
MODULE = "services.enterprise.plugin_manager_service"
|
||||
|
||||
|
||||
class TestTryPreUninstallPlugin:
|
||||
def test_try_pre_uninstall_plugin_success(self):
|
||||
@@ -88,3 +94,46 @@ class TestTryPreUninstallPlugin:
|
||||
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
|
||||
)
|
||||
mock_logger.exception.assert_called_once()
|
||||
|
||||
|
||||
class TestCheckCredentialPolicyCompliance:
|
||||
def _request(self, cred_type=PluginCredentialType.MODEL):
|
||||
return CheckCredentialPolicyComplianceRequest(
|
||||
dify_credential_id="cred-1", provider="openai", credential_type=cred_type
|
||||
)
|
||||
|
||||
def test_passes_when_result_true(self):
|
||||
with patch(f"{MODULE}.EnterprisePluginManagerRequest") as req:
|
||||
req.send_request.return_value = {"result": True}
|
||||
PluginManagerService.check_credential_policy_compliance(self._request())
|
||||
|
||||
req.send_request.assert_called_once()
|
||||
|
||||
def test_raises_violation_when_result_false(self):
|
||||
with patch(f"{MODULE}.EnterprisePluginManagerRequest") as req:
|
||||
req.send_request.return_value = {"result": False}
|
||||
with pytest.raises(CredentialPolicyViolationError, match="Credentials not available"):
|
||||
PluginManagerService.check_credential_policy_compliance(self._request())
|
||||
|
||||
def test_raises_violation_on_invalid_response_format(self):
|
||||
with patch(f"{MODULE}.EnterprisePluginManagerRequest") as req:
|
||||
req.send_request.return_value = "not-a-dict"
|
||||
with pytest.raises(CredentialPolicyViolationError, match="error occurred"):
|
||||
PluginManagerService.check_credential_policy_compliance(self._request())
|
||||
|
||||
def test_raises_violation_on_api_exception(self):
|
||||
with patch(f"{MODULE}.EnterprisePluginManagerRequest") as req:
|
||||
req.send_request.side_effect = ConnectionError("network fail")
|
||||
with pytest.raises(CredentialPolicyViolationError, match="error occurred"):
|
||||
PluginManagerService.check_credential_policy_compliance(self._request())
|
||||
|
||||
def test_model_dump_serializes_credential_type_as_number(self):
|
||||
body = self._request(PluginCredentialType.TOOL)
|
||||
data = body.model_dump()
|
||||
|
||||
assert data["credential_type"] == 1
|
||||
assert data["dify_credential_id"] == "cred-1"
|
||||
|
||||
def test_model_credential_type_values(self):
|
||||
assert PluginCredentialType.MODEL.to_number() == 0
|
||||
assert PluginCredentialType.TOOL.to_number() == 1
|
||||
|
||||
@@ -0,0 +1,183 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from models.account import TenantPluginAutoUpgradeStrategy
|
||||
|
||||
MODULE = "services.plugin.plugin_auto_upgrade_service"
|
||||
|
||||
|
||||
def _patched_session():
|
||||
"""Patch Session(db.engine) to return a mock session as context manager."""
|
||||
session = MagicMock()
|
||||
session_cls = MagicMock()
|
||||
session_cls.return_value.__enter__ = MagicMock(return_value=session)
|
||||
session_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
patcher = patch(f"{MODULE}.Session", session_cls)
|
||||
db_patcher = patch(f"{MODULE}.db")
|
||||
return patcher, db_patcher, session
|
||||
|
||||
|
||||
class TestGetStrategy:
|
||||
def test_returns_strategy_when_found(self):
|
||||
p1, p2, session = _patched_session()
|
||||
strategy = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = strategy
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
result = PluginAutoUpgradeService.get_strategy("t1")
|
||||
|
||||
assert result is strategy
|
||||
|
||||
def test_returns_none_when_not_found(self):
|
||||
p1, p2, session = _patched_session()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
result = PluginAutoUpgradeService.get_strategy("t1")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestChangeStrategy:
|
||||
def test_creates_new_strategy(self):
|
||||
p1, p2, session = _patched_session()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.return_value = MagicMock()
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
result = PluginAutoUpgradeService.change_strategy(
|
||||
"t1",
|
||||
TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
|
||||
3,
|
||||
TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
|
||||
[],
|
||||
[],
|
||||
)
|
||||
|
||||
assert result is True
|
||||
session.add.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_updates_existing_strategy(self):
|
||||
p1, p2, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
result = PluginAutoUpgradeService.change_strategy(
|
||||
"t1",
|
||||
TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
|
||||
5,
|
||||
TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL,
|
||||
["p1"],
|
||||
["p2"],
|
||||
)
|
||||
|
||||
assert result is True
|
||||
assert existing.strategy_setting == TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST
|
||||
assert existing.upgrade_time_of_day == 5
|
||||
assert existing.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL
|
||||
assert existing.exclude_plugins == ["p1"]
|
||||
assert existing.include_plugins == ["p2"]
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
class TestExcludePlugin:
|
||||
def test_creates_default_strategy_when_none_exists(self):
|
||||
p1, p2, session = _patched_session()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with (
|
||||
p1,
|
||||
p2,
|
||||
patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls,
|
||||
patch(f"{MODULE}.PluginAutoUpgradeService.change_strategy") as cs,
|
||||
):
|
||||
strat_cls.StrategySetting.FIX_ONLY = "fix_only"
|
||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||
cs.return_value = True
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
result = PluginAutoUpgradeService.exclude_plugin("t1", "plugin-1")
|
||||
|
||||
assert result is True
|
||||
cs.assert_called_once()
|
||||
|
||||
def test_appends_to_exclude_list_in_exclude_mode(self):
|
||||
p1, p2, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
existing.upgrade_mode = "exclude"
|
||||
existing.exclude_plugins = ["p-existing"]
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||
strat_cls.UpgradeMode.PARTIAL = "partial"
|
||||
strat_cls.UpgradeMode.ALL = "all"
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
result = PluginAutoUpgradeService.exclude_plugin("t1", "p-new")
|
||||
|
||||
assert result is True
|
||||
assert existing.exclude_plugins == ["p-existing", "p-new"]
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_removes_from_include_list_in_partial_mode(self):
|
||||
p1, p2, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
existing.upgrade_mode = "partial"
|
||||
existing.include_plugins = ["p1", "p2"]
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||
strat_cls.UpgradeMode.PARTIAL = "partial"
|
||||
strat_cls.UpgradeMode.ALL = "all"
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
result = PluginAutoUpgradeService.exclude_plugin("t1", "p1")
|
||||
|
||||
assert result is True
|
||||
assert existing.include_plugins == ["p2"]
|
||||
|
||||
def test_switches_to_exclude_mode_from_all(self):
|
||||
p1, p2, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
existing.upgrade_mode = "all"
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||
strat_cls.UpgradeMode.PARTIAL = "partial"
|
||||
strat_cls.UpgradeMode.ALL = "all"
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
result = PluginAutoUpgradeService.exclude_plugin("t1", "p1")
|
||||
|
||||
assert result is True
|
||||
assert existing.upgrade_mode == "exclude"
|
||||
assert existing.exclude_plugins == ["p1"]
|
||||
|
||||
def test_no_duplicate_in_exclude_list(self):
|
||||
p1, p2, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
existing.upgrade_mode = "exclude"
|
||||
existing.exclude_plugins = ["p1"]
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||
strat_cls.UpgradeMode.PARTIAL = "partial"
|
||||
strat_cls.UpgradeMode.ALL = "all"
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
PluginAutoUpgradeService.exclude_plugin("t1", "p1")
|
||||
|
||||
assert existing.exclude_plugins == ["p1"]
|
||||
@@ -0,0 +1,75 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from models.account import TenantPluginPermission
|
||||
|
||||
MODULE = "services.plugin.plugin_permission_service"
|
||||
|
||||
|
||||
def _patched_session():
|
||||
"""Patch Session(db.engine) to return a mock session as context manager."""
|
||||
session = MagicMock()
|
||||
session_cls = MagicMock()
|
||||
session_cls.return_value.__enter__ = MagicMock(return_value=session)
|
||||
session_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
patcher = patch(f"{MODULE}.Session", session_cls)
|
||||
db_patcher = patch(f"{MODULE}.db")
|
||||
return patcher, db_patcher, session
|
||||
|
||||
|
||||
class TestGetPermission:
|
||||
def test_returns_permission_when_found(self):
|
||||
p1, p2, session = _patched_session()
|
||||
permission = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = permission
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
|
||||
result = PluginPermissionService.get_permission("t1")
|
||||
|
||||
assert result is permission
|
||||
|
||||
def test_returns_none_when_not_found(self):
|
||||
p1, p2, session = _patched_session()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
|
||||
result = PluginPermissionService.get_permission("t1")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestChangePermission:
|
||||
def test_creates_new_permission_when_not_exists(self):
|
||||
p1, p2, session = _patched_session()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginPermission") as perm_cls:
|
||||
perm_cls.return_value = MagicMock()
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
|
||||
result = PluginPermissionService.change_permission(
|
||||
"t1", TenantPluginPermission.InstallPermission.EVERYONE, TenantPluginPermission.DebugPermission.EVERYONE
|
||||
)
|
||||
|
||||
session.add.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_updates_existing_permission(self):
|
||||
p1, p2, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
|
||||
result = PluginPermissionService.change_permission(
|
||||
"t1", TenantPluginPermission.InstallPermission.ADMINS, TenantPluginPermission.DebugPermission.ADMINS
|
||||
)
|
||||
|
||||
assert existing.install_permission == TenantPluginPermission.InstallPermission.ADMINS
|
||||
assert existing.debug_permission == TenantPluginPermission.DebugPermission.ADMINS
|
||||
session.commit.assert_called_once()
|
||||
session.add.assert_not_called()
|
||||
0
api/tests/unit_tests/services/retention/__init__.py
Normal file
0
api/tests/unit_tests/services/retention/__init__.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from services.retention.conversation.messages_clean_policy import (
|
||||
BillingDisabledPolicy,
|
||||
BillingSandboxPolicy,
|
||||
SimpleMessage,
|
||||
create_message_clean_policy,
|
||||
)
|
||||
|
||||
MODULE = "services.retention.conversation.messages_clean_policy"
|
||||
|
||||
|
||||
def _msg(msg_id: str, app_id: str, days_ago: int = 0) -> SimpleMessage:
|
||||
return SimpleMessage(
|
||||
id=msg_id,
|
||||
app_id=app_id,
|
||||
created_at=datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=days_ago),
|
||||
)
|
||||
|
||||
|
||||
class TestBillingDisabledPolicy:
|
||||
def test_returns_all_message_ids(self):
|
||||
policy = BillingDisabledPolicy()
|
||||
msgs = [_msg("m1", "app1"), _msg("m2", "app2"), _msg("m3", "app1")]
|
||||
|
||||
result = policy.filter_message_ids(msgs, {"app1": "t1", "app2": "t2"})
|
||||
|
||||
assert set(result) == {"m1", "m2", "m3"}
|
||||
|
||||
def test_empty_messages_returns_empty(self):
|
||||
assert BillingDisabledPolicy().filter_message_ids([], {}) == []
|
||||
|
||||
|
||||
class TestBillingSandboxPolicy:
|
||||
def _policy(self, plans, *, graceful_days=21, whitelist=None, now=1_000_000_000):
|
||||
return BillingSandboxPolicy(
|
||||
plan_provider=lambda _ids: plans,
|
||||
graceful_period_days=graceful_days,
|
||||
tenant_whitelist=whitelist,
|
||||
current_timestamp=now,
|
||||
)
|
||||
|
||||
def test_empty_messages_returns_empty(self):
|
||||
policy = self._policy({})
|
||||
assert policy.filter_message_ids([], {"app1": "t1"}) == []
|
||||
|
||||
def test_empty_app_to_tenant_returns_empty(self):
|
||||
policy = self._policy({})
|
||||
assert policy.filter_message_ids([_msg("m1", "app1")], {}) == []
|
||||
|
||||
def test_empty_plans_returns_empty(self):
|
||||
policy = self._policy({})
|
||||
msgs = [_msg("m1", "app1")]
|
||||
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == []
|
||||
|
||||
def test_non_sandbox_tenant_skipped(self):
|
||||
plans = {"t1": {"plan": "professional", "expiration_date": 0}}
|
||||
policy = self._policy(plans)
|
||||
msgs = [_msg("m1", "app1")]
|
||||
|
||||
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == []
|
||||
|
||||
def test_sandbox_no_previous_subscription_deletes(self):
|
||||
plans = {"t1": {"plan": "sandbox", "expiration_date": -1}}
|
||||
policy = self._policy(plans)
|
||||
msgs = [_msg("m1", "app1")]
|
||||
|
||||
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == ["m1"]
|
||||
|
||||
def test_sandbox_expired_beyond_grace_period_deletes(self):
|
||||
now = 1_000_000_000
|
||||
expired_long_ago = now - (22 * 24 * 60 * 60) # 22 days ago > 21 day grace
|
||||
plans = {"t1": {"plan": "sandbox", "expiration_date": expired_long_ago}}
|
||||
policy = self._policy(plans, now=now)
|
||||
msgs = [_msg("m1", "app1")]
|
||||
|
||||
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == ["m1"]
|
||||
|
||||
def test_sandbox_within_grace_period_kept(self):
|
||||
now = 1_000_000_000
|
||||
expired_recently = now - (10 * 24 * 60 * 60) # 10 days ago < 21 day grace
|
||||
plans = {"t1": {"plan": "sandbox", "expiration_date": expired_recently}}
|
||||
policy = self._policy(plans, now=now)
|
||||
msgs = [_msg("m1", "app1")]
|
||||
|
||||
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == []
|
||||
|
||||
def test_whitelisted_tenant_skipped(self):
|
||||
plans = {"t1": {"plan": "sandbox", "expiration_date": -1}}
|
||||
policy = self._policy(plans, whitelist=["t1"])
|
||||
msgs = [_msg("m1", "app1")]
|
||||
|
||||
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == []
|
||||
|
||||
def test_message_without_tenant_mapping_skipped(self):
|
||||
plans = {"t1": {"plan": "sandbox", "expiration_date": -1}}
|
||||
policy = self._policy(plans)
|
||||
msgs = [_msg("m1", "unmapped_app")]
|
||||
|
||||
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == []
|
||||
|
||||
def test_mixed_tenants_only_sandbox_deleted(self):
|
||||
plans = {
|
||||
"t_sandbox": {"plan": "sandbox", "expiration_date": -1},
|
||||
"t_pro": {"plan": "professional", "expiration_date": 0},
|
||||
}
|
||||
policy = self._policy(plans)
|
||||
msgs = [_msg("m1", "app_sandbox"), _msg("m2", "app_pro")]
|
||||
app_map = {"app_sandbox": "t_sandbox", "app_pro": "t_pro"}
|
||||
|
||||
result = policy.filter_message_ids(msgs, app_map)
|
||||
|
||||
assert result == ["m1"]
|
||||
|
||||
|
||||
class TestCreateMessageCleanPolicy:
|
||||
def test_billing_disabled_returns_disabled_policy(self):
|
||||
with patch(f"{MODULE}.dify_config") as cfg:
|
||||
cfg.BILLING_ENABLED = False
|
||||
policy = create_message_clean_policy()
|
||||
|
||||
assert isinstance(policy, BillingDisabledPolicy)
|
||||
|
||||
def test_billing_enabled_returns_sandbox_policy(self):
|
||||
with (
|
||||
patch(f"{MODULE}.dify_config") as cfg,
|
||||
patch(f"{MODULE}.BillingService") as bs,
|
||||
):
|
||||
cfg.BILLING_ENABLED = True
|
||||
bs.get_expired_subscription_cleanup_whitelist.return_value = ["wl1"]
|
||||
bs.get_plan_bulk_with_cache = MagicMock()
|
||||
policy = create_message_clean_policy(graceful_period_days=30)
|
||||
|
||||
assert isinstance(policy, BillingSandboxPolicy)
|
||||
@@ -0,0 +1,598 @@
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolParameter, ToolProviderType
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
MODULE = "services.tools.tools_transform_service"
|
||||
|
||||
|
||||
class TestToolTransformService:
|
||||
"""Test cases for ToolTransformService.convert_tool_entity_to_api_entity method"""
|
||||
|
||||
def test_convert_tool_with_parameter_override(self):
|
||||
"""Test that runtime parameters correctly override base parameters"""
|
||||
# Create mock base parameters
|
||||
base_param1 = Mock(spec=ToolParameter)
|
||||
base_param1.name = "param1"
|
||||
base_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param1.type = "string"
|
||||
base_param1.label = "Base Param 1"
|
||||
|
||||
base_param2 = Mock(spec=ToolParameter)
|
||||
base_param2.name = "param2"
|
||||
base_param2.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param2.type = "string"
|
||||
base_param2.label = "Base Param 2"
|
||||
|
||||
# Create mock runtime parameters that override base parameters
|
||||
runtime_param1 = Mock(spec=ToolParameter)
|
||||
runtime_param1.name = "param1"
|
||||
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param1.type = "string"
|
||||
runtime_param1.label = "Runtime Param 1" # Different label to verify override
|
||||
|
||||
# Create mock tool
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = [base_param1, base_param2]
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = [runtime_param1]
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.author == "test_author"
|
||||
assert result.name == "test_tool"
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 2
|
||||
|
||||
# Find the overridden parameter
|
||||
overridden_param = next((p for p in result.parameters if p.name == "param1"), None)
|
||||
assert overridden_param is not None
|
||||
assert overridden_param.label == "Runtime Param 1" # Should be runtime version
|
||||
|
||||
# Find the non-overridden parameter
|
||||
original_param = next((p for p in result.parameters if p.name == "param2"), None)
|
||||
assert original_param is not None
|
||||
assert original_param.label == "Base Param 2" # Should be base version
|
||||
|
||||
def test_convert_tool_with_additional_runtime_parameters(self):
|
||||
"""Test that additional runtime parameters are added to the final list"""
|
||||
# Create mock base parameters
|
||||
base_param1 = Mock(spec=ToolParameter)
|
||||
base_param1.name = "param1"
|
||||
base_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param1.type = "string"
|
||||
base_param1.label = "Base Param 1"
|
||||
|
||||
# Create mock runtime parameters - one that overrides and one that's new
|
||||
runtime_param1 = Mock(spec=ToolParameter)
|
||||
runtime_param1.name = "param1"
|
||||
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param1.type = "string"
|
||||
runtime_param1.label = "Runtime Param 1"
|
||||
|
||||
runtime_param2 = Mock(spec=ToolParameter)
|
||||
runtime_param2.name = "runtime_only"
|
||||
runtime_param2.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param2.type = "string"
|
||||
runtime_param2.label = "Runtime Only Param"
|
||||
|
||||
# Create mock tool
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = [base_param1]
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2]
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 2
|
||||
|
||||
# Check that both parameters are present
|
||||
param_names = [p.name for p in result.parameters]
|
||||
assert "param1" in param_names
|
||||
assert "runtime_only" in param_names
|
||||
|
||||
# Verify the overridden parameter has runtime version
|
||||
overridden_param = next((p for p in result.parameters if p.name == "param1"), None)
|
||||
assert overridden_param is not None
|
||||
assert overridden_param.label == "Runtime Param 1"
|
||||
|
||||
# Verify the new runtime parameter is included
|
||||
new_param = next((p for p in result.parameters if p.name == "runtime_only"), None)
|
||||
assert new_param is not None
|
||||
assert new_param.label == "Runtime Only Param"
|
||||
|
||||
def test_convert_tool_with_non_form_runtime_parameters(self):
|
||||
"""Test that non-FORM runtime parameters are not added as new parameters"""
|
||||
# Create mock base parameters
|
||||
base_param1 = Mock(spec=ToolParameter)
|
||||
base_param1.name = "param1"
|
||||
base_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param1.type = "string"
|
||||
base_param1.label = "Base Param 1"
|
||||
|
||||
# Create mock runtime parameters with different forms
|
||||
runtime_param1 = Mock(spec=ToolParameter)
|
||||
runtime_param1.name = "param1"
|
||||
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param1.type = "string"
|
||||
runtime_param1.label = "Runtime Param 1"
|
||||
|
||||
runtime_param2 = Mock(spec=ToolParameter)
|
||||
runtime_param2.name = "llm_param"
|
||||
runtime_param2.form = ToolParameter.ToolParameterForm.LLM
|
||||
runtime_param2.type = "string"
|
||||
runtime_param2.label = "LLM Param"
|
||||
|
||||
# Create mock tool
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = [base_param1]
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2]
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 1 # Only the FORM parameter should be present
|
||||
|
||||
# Check that only the FORM parameter is present
|
||||
param_names = [p.name for p in result.parameters]
|
||||
assert "param1" in param_names
|
||||
assert "llm_param" not in param_names
|
||||
|
||||
def test_convert_tool_with_empty_parameters(self):
|
||||
"""Test conversion with empty base and runtime parameters"""
|
||||
# Create mock tool with no parameters
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = []
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = []
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 0
|
||||
|
||||
def test_convert_tool_with_none_parameters(self):
|
||||
"""Test conversion when base parameters is None"""
|
||||
# Create mock tool with None parameters
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = None
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = []
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 0
|
||||
|
||||
def test_convert_tool_parameter_order_preserved(self):
|
||||
"""Test that parameter order is preserved correctly"""
|
||||
# Create mock base parameters in specific order
|
||||
base_param1 = Mock(spec=ToolParameter)
|
||||
base_param1.name = "param1"
|
||||
base_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param1.type = "string"
|
||||
base_param1.label = "Base Param 1"
|
||||
|
||||
base_param2 = Mock(spec=ToolParameter)
|
||||
base_param2.name = "param2"
|
||||
base_param2.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param2.type = "string"
|
||||
base_param2.label = "Base Param 2"
|
||||
|
||||
base_param3 = Mock(spec=ToolParameter)
|
||||
base_param3.name = "param3"
|
||||
base_param3.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param3.type = "string"
|
||||
base_param3.label = "Base Param 3"
|
||||
|
||||
# Create runtime parameter that overrides middle parameter
|
||||
runtime_param2 = Mock(spec=ToolParameter)
|
||||
runtime_param2.name = "param2"
|
||||
runtime_param2.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param2.type = "string"
|
||||
runtime_param2.label = "Runtime Param 2"
|
||||
|
||||
# Create new runtime parameter
|
||||
runtime_param4 = Mock(spec=ToolParameter)
|
||||
runtime_param4.name = "param4"
|
||||
runtime_param4.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param4.type = "string"
|
||||
runtime_param4.label = "Runtime Param 4"
|
||||
|
||||
# Create mock tool
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = [base_param1, base_param2, base_param3]
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = [runtime_param2, runtime_param4]
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 4
|
||||
|
||||
# Check that order is maintained: base parameters first, then new runtime parameters
|
||||
param_names = [p.name for p in result.parameters]
|
||||
assert param_names == ["param1", "param2", "param3", "param4"]
|
||||
|
||||
# Verify that param2 was overridden with runtime version
|
||||
param2 = result.parameters[1]
|
||||
assert param2.name == "param2"
|
||||
assert param2.label == "Runtime Param 2"
|
||||
|
||||
|
||||
class TestWorkflowProviderToUserProvider:
|
||||
"""Test cases for ToolTransformService.workflow_provider_to_user_provider method"""
|
||||
|
||||
def test_workflow_provider_to_user_provider_with_workflow_app_id(self):
|
||||
"""Test that workflow_provider_to_user_provider correctly sets workflow_app_id."""
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
# Create mock workflow tool provider controller
|
||||
workflow_app_id = "app_123"
|
||||
provider_id = "provider_123"
|
||||
mock_controller = Mock(spec=WorkflowToolProviderController)
|
||||
mock_controller.provider_id = provider_id
|
||||
mock_controller.entity = Mock()
|
||||
mock_controller.entity.identity = Mock()
|
||||
mock_controller.entity.identity.author = "test_author"
|
||||
mock_controller.entity.identity.name = "test_workflow_tool"
|
||||
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
|
||||
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
|
||||
mock_controller.entity.identity.icon_dark = None
|
||||
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=mock_controller,
|
||||
labels=["label1", "label2"],
|
||||
workflow_app_id=workflow_app_id,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == provider_id
|
||||
assert result.author == "test_author"
|
||||
assert result.name == "test_workflow_tool"
|
||||
assert result.type == ToolProviderType.WORKFLOW
|
||||
assert result.workflow_app_id == workflow_app_id
|
||||
assert result.labels == ["label1", "label2"]
|
||||
assert result.is_team_authorization is True
|
||||
assert result.plugin_id is None
|
||||
assert result.plugin_unique_identifier is None
|
||||
assert result.tools == []
|
||||
|
||||
def test_workflow_provider_to_user_provider_without_workflow_app_id(self):
|
||||
"""Test that workflow_provider_to_user_provider works when workflow_app_id is not provided."""
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
# Create mock workflow tool provider controller
|
||||
provider_id = "provider_123"
|
||||
mock_controller = Mock(spec=WorkflowToolProviderController)
|
||||
mock_controller.provider_id = provider_id
|
||||
mock_controller.entity = Mock()
|
||||
mock_controller.entity.identity = Mock()
|
||||
mock_controller.entity.identity.author = "test_author"
|
||||
mock_controller.entity.identity.name = "test_workflow_tool"
|
||||
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
|
||||
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
|
||||
mock_controller.entity.identity.icon_dark = None
|
||||
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
|
||||
|
||||
# Call the method without workflow_app_id
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=mock_controller,
|
||||
labels=["label1"],
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == provider_id
|
||||
assert result.workflow_app_id is None
|
||||
assert result.labels == ["label1"]
|
||||
|
||||
def test_workflow_provider_to_user_provider_workflow_app_id_none(self):
|
||||
"""Test that workflow_provider_to_user_provider handles None workflow_app_id explicitly."""
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
# Create mock workflow tool provider controller
|
||||
provider_id = "provider_123"
|
||||
mock_controller = Mock(spec=WorkflowToolProviderController)
|
||||
mock_controller.provider_id = provider_id
|
||||
mock_controller.entity = Mock()
|
||||
mock_controller.entity.identity = Mock()
|
||||
mock_controller.entity.identity.author = "test_author"
|
||||
mock_controller.entity.identity.name = "test_workflow_tool"
|
||||
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
|
||||
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
|
||||
mock_controller.entity.identity.icon_dark = None
|
||||
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
|
||||
|
||||
# Call the method with explicit None values
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=mock_controller,
|
||||
labels=None,
|
||||
workflow_app_id=None,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == provider_id
|
||||
assert result.workflow_app_id is None
|
||||
assert result.labels == []
|
||||
|
||||
def test_workflow_provider_to_user_provider_preserves_other_fields(self):
|
||||
"""Test that workflow_provider_to_user_provider preserves all other entity fields."""
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
# Create mock workflow tool provider controller with various fields
|
||||
workflow_app_id = "app_456"
|
||||
provider_id = "provider_456"
|
||||
mock_controller = Mock(spec=WorkflowToolProviderController)
|
||||
mock_controller.provider_id = provider_id
|
||||
mock_controller.entity = Mock()
|
||||
mock_controller.entity.identity = Mock()
|
||||
mock_controller.entity.identity.author = "another_author"
|
||||
mock_controller.entity.identity.name = "another_workflow_tool"
|
||||
mock_controller.entity.identity.description = I18nObject(
|
||||
en_US="Another description", zh_Hans="Another description"
|
||||
)
|
||||
mock_controller.entity.identity.icon = {"type": "emoji", "content": "⚙️"}
|
||||
mock_controller.entity.identity.icon_dark = {"type": "emoji", "content": "🔧"}
|
||||
mock_controller.entity.identity.label = I18nObject(
|
||||
en_US="Another Workflow Tool", zh_Hans="Another Workflow Tool"
|
||||
)
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=mock_controller,
|
||||
labels=["automation", "workflow"],
|
||||
workflow_app_id=workflow_app_id,
|
||||
)
|
||||
|
||||
# Verify all fields are preserved correctly
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == provider_id
|
||||
assert result.author == "another_author"
|
||||
assert result.name == "another_workflow_tool"
|
||||
assert result.description.en_US == "Another description"
|
||||
assert result.description.zh_Hans == "Another description"
|
||||
assert result.icon == {"type": "emoji", "content": "⚙️"}
|
||||
assert result.icon_dark == {"type": "emoji", "content": "🔧"}
|
||||
assert result.label.en_US == "Another Workflow Tool"
|
||||
assert result.label.zh_Hans == "Another Workflow Tool"
|
||||
assert result.type == ToolProviderType.WORKFLOW
|
||||
assert result.workflow_app_id == workflow_app_id
|
||||
assert result.labels == ["automation", "workflow"]
|
||||
assert result.masked_credentials == {}
|
||||
assert result.is_team_authorization is True
|
||||
assert result.allow_delete is True
|
||||
assert result.plugin_id is None
|
||||
assert result.plugin_unique_identifier is None
|
||||
assert result.tools == []
|
||||
|
||||
|
||||
class TestGetToolProviderIconUrl:
|
||||
def test_builtin_provider_returns_console_url(self):
|
||||
with patch(f"{MODULE}.dify_config") as cfg:
|
||||
cfg.CONSOLE_API_URL = "https://app.dify.ai"
|
||||
url = ToolTransformService.get_tool_provider_icon_url("builtin", "google", "icon.png")
|
||||
|
||||
assert "/builtin/google/icon" in url
|
||||
assert url.startswith("https://app.dify.ai/console/api/workspaces/current/tool-provider")
|
||||
|
||||
def test_builtin_provider_with_no_console_url(self):
|
||||
with patch(f"{MODULE}.dify_config") as cfg:
|
||||
cfg.CONSOLE_API_URL = None
|
||||
url = ToolTransformService.get_tool_provider_icon_url("builtin", "slack", "icon.png")
|
||||
|
||||
assert "/builtin/slack/icon" in url
|
||||
|
||||
def test_api_provider_parses_json_icon(self):
|
||||
icon_json = '{"background": "#fff", "content": "A"}'
|
||||
result = ToolTransformService.get_tool_provider_icon_url("api", "my-api", icon_json)
|
||||
assert result == {"background": "#fff", "content": "A"}
|
||||
|
||||
def test_api_provider_returns_dict_icon_directly(self):
|
||||
icon = {"background": "#000", "content": "B"}
|
||||
result = ToolTransformService.get_tool_provider_icon_url("api", "my-api", icon)
|
||||
assert result == icon
|
||||
|
||||
def test_api_provider_returns_fallback_on_invalid_json(self):
|
||||
result = ToolTransformService.get_tool_provider_icon_url("api", "my-api", "not-json")
|
||||
assert result == {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
def test_workflow_provider_behaves_like_api(self):
|
||||
icon = {"background": "#123", "content": "W"}
|
||||
assert ToolTransformService.get_tool_provider_icon_url("workflow", "wf", icon) == icon
|
||||
|
||||
def test_mcp_returns_icon_as_is(self):
|
||||
assert ToolTransformService.get_tool_provider_icon_url("mcp", "srv", "icon-value") == "icon-value"
|
||||
|
||||
def test_unknown_type_returns_empty(self):
|
||||
assert ToolTransformService.get_tool_provider_icon_url("unknown", "x", "i") == ""
|
||||
|
||||
|
||||
class TestRepackProvider:
|
||||
def test_repacks_dict_provider_icon(self):
|
||||
provider = {"type": "builtin", "name": "google", "icon": "old"}
|
||||
with patch.object(ToolTransformService, "get_tool_provider_icon_url", return_value="/new-url") as mock_fn:
|
||||
ToolTransformService.repack_provider("t1", provider)
|
||||
|
||||
assert provider["icon"] == "/new-url"
|
||||
mock_fn.assert_called_once_with(provider_type="builtin", provider_name="google", icon="old")
|
||||
|
||||
def test_repacks_tool_provider_api_entity_without_plugin(self):
|
||||
entity = MagicMock(spec=ToolProviderApiEntity)
|
||||
entity.plugin_id = None
|
||||
entity.type = ToolProviderType.BUILT_IN
|
||||
entity.name = "slack"
|
||||
entity.icon = "icon.svg"
|
||||
entity.icon_dark = "dark.svg"
|
||||
|
||||
with patch.object(ToolTransformService, "get_tool_provider_icon_url", return_value="/url"):
|
||||
ToolTransformService.repack_provider("t1", entity)
|
||||
|
||||
assert entity.icon == "/url"
|
||||
assert entity.icon_dark == "/url"
|
||||
|
||||
|
||||
class TestConvertMcpSchemaToParameter:
|
||||
def test_simple_object_schema(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"},
|
||||
"count": {"type": "integer", "description": "Result count"},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
params = ToolTransformService.convert_mcp_schema_to_parameter(schema)
|
||||
|
||||
assert len(params) == 2
|
||||
query_param = next(p for p in params if p.name == "query")
|
||||
count_param = next(p for p in params if p.name == "count")
|
||||
assert query_param.required is True
|
||||
assert count_param.required is False
|
||||
assert count_param.type.value == "number"
|
||||
|
||||
def test_float_maps_to_number(self):
|
||||
schema = {"type": "object", "properties": {"rate": {"type": "float"}}, "required": []}
|
||||
assert ToolTransformService.convert_mcp_schema_to_parameter(schema)[0].type.value == "number"
|
||||
|
||||
def test_array_type_attaches_input_schema(self):
|
||||
prop = {"type": "array", "description": "Items", "items": {"type": "string"}}
|
||||
schema = {"type": "object", "properties": {"items": prop}, "required": []}
|
||||
param = ToolTransformService.convert_mcp_schema_to_parameter(schema)[0]
|
||||
assert param.input_schema is not None
|
||||
|
||||
def test_non_object_schema_returns_empty(self):
|
||||
assert ToolTransformService.convert_mcp_schema_to_parameter({"type": "string"}) == []
|
||||
|
||||
def test_missing_properties_returns_empty(self):
|
||||
assert ToolTransformService.convert_mcp_schema_to_parameter({"type": "object"}) == []
|
||||
|
||||
def test_list_type_uses_first_element(self):
|
||||
schema = {"type": "object", "properties": {"f": {"type": ["string", "null"]}}, "required": []}
|
||||
assert ToolTransformService.convert_mcp_schema_to_parameter(schema)[0].type.value == "string"
|
||||
|
||||
def test_missing_description_defaults_empty(self):
|
||||
schema = {"type": "object", "properties": {"f": {"type": "string"}}, "required": []}
|
||||
assert ToolTransformService.convert_mcp_schema_to_parameter(schema)[0].llm_description == ""
|
||||
|
||||
|
||||
class TestApiProviderToController:
|
||||
def test_api_key_header_auth(self):
|
||||
db_provider = MagicMock()
|
||||
db_provider.credentials = {"auth_type": "api_key_header"}
|
||||
with patch(f"{MODULE}.ApiToolProviderController") as ctrl_cls:
|
||||
ctrl_cls.from_db.return_value = MagicMock()
|
||||
ToolTransformService.api_provider_to_controller(db_provider)
|
||||
ctrl_cls.from_db.assert_called_once_with(db_provider=db_provider, auth_type=ApiProviderAuthType.API_KEY_HEADER)
|
||||
|
||||
def test_api_key_query_auth(self):
|
||||
db_provider = MagicMock()
|
||||
db_provider.credentials = {"auth_type": "api_key_query"}
|
||||
with patch(f"{MODULE}.ApiToolProviderController") as ctrl_cls:
|
||||
ctrl_cls.from_db.return_value = MagicMock()
|
||||
ToolTransformService.api_provider_to_controller(db_provider)
|
||||
ctrl_cls.from_db.assert_called_once_with(db_provider=db_provider, auth_type=ApiProviderAuthType.API_KEY_QUERY)
|
||||
|
||||
def test_legacy_api_key_maps_to_header(self):
|
||||
db_provider = MagicMock()
|
||||
db_provider.credentials = {"auth_type": "api_key"}
|
||||
with patch(f"{MODULE}.ApiToolProviderController") as ctrl_cls:
|
||||
ctrl_cls.from_db.return_value = MagicMock()
|
||||
ToolTransformService.api_provider_to_controller(db_provider)
|
||||
ctrl_cls.from_db.assert_called_once_with(db_provider=db_provider, auth_type=ApiProviderAuthType.API_KEY_HEADER)
|
||||
|
||||
def test_unknown_auth_defaults_to_none(self):
|
||||
db_provider = MagicMock()
|
||||
db_provider.credentials = {"auth_type": "something_else"}
|
||||
with patch(f"{MODULE}.ApiToolProviderController") as ctrl_cls:
|
||||
ctrl_cls.from_db.return_value = MagicMock()
|
||||
ToolTransformService.api_provider_to_controller(db_provider)
|
||||
ctrl_cls.from_db.assert_called_once_with(db_provider=db_provider, auth_type=ApiProviderAuthType.NONE)
|
||||
Reference in New Issue
Block a user