test: add tests for api/services retention, enterprise, plugin (#32648)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
This commit is contained in:
Dev Sharma
2026-03-31 08:46:42 +05:30
committed by GitHub
parent 7e4754392d
commit 2de818530b
12 changed files with 1917 additions and 0 deletions

View File

@@ -0,0 +1,182 @@
import pytest
from sqlalchemy import delete
from core.db.session_factory import session_factory
from models import Tenant
from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
from services.plugin.plugin_permission_service import PluginPermissionService
@pytest.fixture
def tenant(flask_req_ctx):
with session_factory.create_session() as session:
t = Tenant(name="plugin_it_tenant")
session.add(t)
session.commit()
tenant_id = t.id
yield tenant_id
with session_factory.create_session() as session:
session.execute(delete(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id))
session.execute(
delete(TenantPluginAutoUpgradeStrategy).where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
)
session.execute(delete(Tenant).where(Tenant.id == tenant_id))
session.commit()
class TestPluginPermissionLifecycle:
def test_get_returns_none_for_new_tenant(self, tenant):
assert PluginPermissionService.get_permission(tenant) is None
def test_change_creates_row(self, tenant):
result = PluginPermissionService.change_permission(
tenant,
TenantPluginPermission.InstallPermission.ADMINS,
TenantPluginPermission.DebugPermission.EVERYONE,
)
assert result is True
perm = PluginPermissionService.get_permission(tenant)
assert perm is not None
assert perm.install_permission == TenantPluginPermission.InstallPermission.ADMINS
assert perm.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE
def test_change_updates_existing_row(self, tenant):
PluginPermissionService.change_permission(
tenant,
TenantPluginPermission.InstallPermission.ADMINS,
TenantPluginPermission.DebugPermission.NOBODY,
)
PluginPermissionService.change_permission(
tenant,
TenantPluginPermission.InstallPermission.EVERYONE,
TenantPluginPermission.DebugPermission.ADMINS,
)
perm = PluginPermissionService.get_permission(tenant)
assert perm is not None
assert perm.install_permission == TenantPluginPermission.InstallPermission.EVERYONE
assert perm.debug_permission == TenantPluginPermission.DebugPermission.ADMINS
with session_factory.create_session() as session:
count = session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant).count()
assert count == 1
class TestPluginAutoUpgradeLifecycle:
def test_get_returns_none_for_new_tenant(self, tenant):
assert PluginAutoUpgradeService.get_strategy(tenant) is None
def test_change_creates_row(self, tenant):
result = PluginAutoUpgradeService.change_strategy(
tenant,
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
upgrade_time_of_day=3,
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
exclude_plugins=[],
include_plugins=[],
)
assert result is True
strategy = PluginAutoUpgradeService.get_strategy(tenant)
assert strategy is not None
assert strategy.strategy_setting == TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST
assert strategy.upgrade_time_of_day == 3
def test_change_updates_existing_row(self, tenant):
PluginAutoUpgradeService.change_strategy(
tenant,
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
upgrade_time_of_day=0,
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
exclude_plugins=[],
include_plugins=[],
)
PluginAutoUpgradeService.change_strategy(
tenant,
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
upgrade_time_of_day=12,
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL,
exclude_plugins=[],
include_plugins=["plugin-a"],
)
strategy = PluginAutoUpgradeService.get_strategy(tenant)
assert strategy is not None
assert strategy.strategy_setting == TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST
assert strategy.upgrade_time_of_day == 12
assert strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL
assert strategy.include_plugins == ["plugin-a"]
def test_exclude_plugin_creates_strategy_when_none_exists(self, tenant):
PluginAutoUpgradeService.exclude_plugin(tenant, "my-plugin")
strategy = PluginAutoUpgradeService.get_strategy(tenant)
assert strategy is not None
assert strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE
assert "my-plugin" in strategy.exclude_plugins
def test_exclude_plugin_appends_in_exclude_mode(self, tenant):
PluginAutoUpgradeService.change_strategy(
tenant,
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
upgrade_time_of_day=0,
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
exclude_plugins=["existing"],
include_plugins=[],
)
PluginAutoUpgradeService.exclude_plugin(tenant, "new-plugin")
strategy = PluginAutoUpgradeService.get_strategy(tenant)
assert strategy is not None
assert "existing" in strategy.exclude_plugins
assert "new-plugin" in strategy.exclude_plugins
def test_exclude_plugin_dedup_in_exclude_mode(self, tenant):
PluginAutoUpgradeService.change_strategy(
tenant,
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
upgrade_time_of_day=0,
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
exclude_plugins=["same-plugin"],
include_plugins=[],
)
PluginAutoUpgradeService.exclude_plugin(tenant, "same-plugin")
strategy = PluginAutoUpgradeService.get_strategy(tenant)
assert strategy is not None
assert strategy.exclude_plugins.count("same-plugin") == 1
def test_exclude_from_partial_mode_removes_from_include(self, tenant):
PluginAutoUpgradeService.change_strategy(
tenant,
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
upgrade_time_of_day=0,
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL,
exclude_plugins=[],
include_plugins=["p1", "p2"],
)
PluginAutoUpgradeService.exclude_plugin(tenant, "p1")
strategy = PluginAutoUpgradeService.get_strategy(tenant)
assert strategy is not None
assert "p1" not in strategy.include_plugins
assert "p2" in strategy.include_plugins
def test_exclude_from_all_mode_switches_to_exclude(self, tenant):
PluginAutoUpgradeService.change_strategy(
tenant,
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
upgrade_time_of_day=0,
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
exclude_plugins=[],
include_plugins=[],
)
PluginAutoUpgradeService.exclude_plugin(tenant, "excluded-plugin")
strategy = PluginAutoUpgradeService.get_strategy(tenant)
assert strategy is not None
assert strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE
assert "excluded-plugin" in strategy.exclude_plugins

View File

@@ -0,0 +1,348 @@
import datetime
import math
import uuid
import pytest
from sqlalchemy import delete
from core.db.session_factory import session_factory
from models import Tenant
from models.enums import FeedbackFromSource, FeedbackRating
from models.model import (
App,
Conversation,
Message,
MessageAnnotation,
MessageFeedback,
)
from services.retention.conversation.messages_clean_policy import BillingDisabledPolicy
from services.retention.conversation.messages_clean_service import MessagesCleanService
_NOW = datetime.datetime(2026, 1, 15, 12, 0, 0, tzinfo=datetime.UTC)
_OLD = _NOW - datetime.timedelta(days=60)
_VERY_OLD = _NOW - datetime.timedelta(days=90)
_RECENT = _NOW - datetime.timedelta(days=5)
_WINDOW_START = _VERY_OLD - datetime.timedelta(hours=1)
_WINDOW_END = _RECENT + datetime.timedelta(hours=1)
_DEFAULT_BATCH_SIZE = 100
_PAGINATION_MESSAGE_COUNT = 25
_PAGINATION_BATCH_SIZE = 8
@pytest.fixture
def tenant_and_app(flask_req_ctx):
"""Creates a Tenant, App and Conversation for the test and cleans up after."""
with session_factory.create_session() as session:
tenant = Tenant(name="retention_it_tenant")
session.add(tenant)
session.flush()
app = App(
tenant_id=tenant.id,
name="Retention IT App",
mode="chat",
enable_site=True,
enable_api=True,
)
session.add(app)
session.flush()
conv = Conversation(
app_id=app.id,
mode="chat",
name="test_conv",
status="normal",
from_source="console",
_inputs={},
)
session.add(conv)
session.commit()
tenant_id = tenant.id
app_id = app.id
conv_id = conv.id
yield {"tenant_id": tenant_id, "app_id": app_id, "conversation_id": conv_id}
with session_factory.create_session() as session:
session.execute(delete(Conversation).where(Conversation.id == conv_id))
session.execute(delete(App).where(App.id == app_id))
session.execute(delete(Tenant).where(Tenant.id == tenant_id))
session.commit()
def _make_message(app_id: str, conversation_id: str, created_at: datetime.datetime) -> Message:
return Message(
app_id=app_id,
conversation_id=conversation_id,
query="test",
message=[{"text": "hello"}],
answer="world",
message_tokens=1,
message_unit_price=0,
answer_tokens=1,
answer_unit_price=0,
from_source="console",
currency="USD",
_inputs={},
created_at=created_at,
)
class TestMessagesCleanServiceIntegration:
@pytest.fixture
def seed_messages(self, tenant_and_app):
"""Seeds one message at each of _VERY_OLD, _OLD, and _RECENT.
Yields a semantic mapping keyed by age label.
"""
data = tenant_and_app
app_id = data["app_id"]
conv_id = data["conversation_id"]
# Ordered tuple of (label, timestamp) for deterministic seeding
timestamps = [
("very_old", _VERY_OLD),
("old", _OLD),
("recent", _RECENT),
]
msg_ids: dict[str, str] = {}
with session_factory.create_session() as session:
for label, ts in timestamps:
msg = _make_message(app_id, conv_id, ts)
session.add(msg)
session.flush()
msg_ids[label] = msg.id
session.commit()
yield {"msg_ids": msg_ids, **data}
with session_factory.create_session() as session:
session.execute(
delete(Message)
.where(Message.id.in_(list(msg_ids.values())))
.execution_options(synchronize_session=False)
)
session.commit()
@pytest.fixture
def paginated_seed_messages(self, tenant_and_app):
"""Seeds multiple messages separated by 1-second increments starting at _OLD."""
data = tenant_and_app
app_id = data["app_id"]
conv_id = data["conversation_id"]
msg_ids: list[str] = []
with session_factory.create_session() as session:
for i in range(_PAGINATION_MESSAGE_COUNT):
ts = _OLD + datetime.timedelta(seconds=i)
msg = _make_message(app_id, conv_id, ts)
session.add(msg)
session.flush()
msg_ids.append(msg.id)
session.commit()
yield {"msg_ids": msg_ids, **data}
with session_factory.create_session() as session:
session.execute(delete(Message).where(Message.id.in_(msg_ids)).execution_options(synchronize_session=False))
session.commit()
@pytest.fixture
def cascade_test_data(self, tenant_and_app):
"""Seeds one Message with an associated Feedback and Annotation."""
data = tenant_and_app
app_id = data["app_id"]
conv_id = data["conversation_id"]
with session_factory.create_session() as session:
msg = _make_message(app_id, conv_id, _OLD)
session.add(msg)
session.flush()
feedback = MessageFeedback(
app_id=app_id,
conversation_id=conv_id,
message_id=msg.id,
rating=FeedbackRating.LIKE,
from_source=FeedbackFromSource.USER,
)
annotation = MessageAnnotation(
app_id=app_id,
conversation_id=conv_id,
message_id=msg.id,
question="q",
content="a",
account_id=str(uuid.uuid4()),
)
session.add_all([feedback, annotation])
session.commit()
msg_id = msg.id
fb_id = feedback.id
ann_id = annotation.id
yield {"msg_id": msg_id, "fb_id": fb_id, "ann_id": ann_id, **data}
with session_factory.create_session() as session:
session.execute(delete(MessageAnnotation).where(MessageAnnotation.id == ann_id))
session.execute(delete(MessageFeedback).where(MessageFeedback.id == fb_id))
session.execute(delete(Message).where(Message.id == msg_id))
session.commit()
def test_dry_run_does_not_delete(self, seed_messages):
"""Dry-run must count eligible rows without deleting any of them."""
data = seed_messages
msg_ids = data["msg_ids"]
all_ids = list(msg_ids.values())
svc = MessagesCleanService.from_time_range(
policy=BillingDisabledPolicy(),
start_from=_WINDOW_START,
end_before=_WINDOW_END,
batch_size=_DEFAULT_BATCH_SIZE,
dry_run=True,
)
stats = svc.run()
assert stats["filtered_messages"] == len(all_ids)
assert stats["total_deleted"] == 0
with session_factory.create_session() as session:
remaining = session.query(Message).where(Message.id.in_(all_ids)).count()
assert remaining == len(all_ids)
def test_billing_disabled_deletes_all_in_range(self, seed_messages):
"""All 3 seeded messages fall within the window and must be deleted."""
data = seed_messages
msg_ids = data["msg_ids"]
all_ids = list(msg_ids.values())
svc = MessagesCleanService.from_time_range(
policy=BillingDisabledPolicy(),
start_from=_WINDOW_START,
end_before=_WINDOW_END,
batch_size=_DEFAULT_BATCH_SIZE,
dry_run=False,
)
stats = svc.run()
assert stats["total_deleted"] == len(all_ids)
with session_factory.create_session() as session:
remaining = session.query(Message).where(Message.id.in_(all_ids)).count()
assert remaining == 0
def test_start_from_filters_correctly(self, seed_messages):
"""Only the message at _OLD falls within the narrow ±1 h window."""
data = seed_messages
msg_ids = data["msg_ids"]
start = _OLD - datetime.timedelta(hours=1)
end = _OLD + datetime.timedelta(hours=1)
svc = MessagesCleanService.from_time_range(
policy=BillingDisabledPolicy(),
start_from=start,
end_before=end,
batch_size=_DEFAULT_BATCH_SIZE,
)
stats = svc.run()
assert stats["total_deleted"] == 1
with session_factory.create_session() as session:
all_ids = list(msg_ids.values())
remaining_ids = {r[0] for r in session.query(Message.id).where(Message.id.in_(all_ids)).all()}
assert msg_ids["old"] not in remaining_ids
assert msg_ids["very_old"] in remaining_ids
assert msg_ids["recent"] in remaining_ids
def test_cursor_pagination_across_batches(self, paginated_seed_messages):
"""Messages must be deleted across multiple batches."""
data = paginated_seed_messages
msg_ids = data["msg_ids"]
# _OLD is the earliest; the last one is _OLD + (_PAGINATION_MESSAGE_COUNT - 1) s.
pagination_window_start = _OLD - datetime.timedelta(seconds=1)
pagination_window_end = _OLD + datetime.timedelta(seconds=_PAGINATION_MESSAGE_COUNT)
svc = MessagesCleanService.from_time_range(
policy=BillingDisabledPolicy(),
start_from=pagination_window_start,
end_before=pagination_window_end,
batch_size=_PAGINATION_BATCH_SIZE,
)
stats = svc.run()
assert stats["total_deleted"] == _PAGINATION_MESSAGE_COUNT
expected_batches = math.ceil(_PAGINATION_MESSAGE_COUNT / _PAGINATION_BATCH_SIZE)
assert stats["batches"] >= expected_batches
with session_factory.create_session() as session:
remaining = session.query(Message).where(Message.id.in_(msg_ids)).count()
assert remaining == 0
def test_no_messages_in_range_returns_empty_stats(self, seed_messages):
"""A window entirely in the future must yield zero matches."""
far_future = _NOW + datetime.timedelta(days=365)
even_further = far_future + datetime.timedelta(days=1)
svc = MessagesCleanService.from_time_range(
policy=BillingDisabledPolicy(),
start_from=far_future,
end_before=even_further,
batch_size=_DEFAULT_BATCH_SIZE,
)
stats = svc.run()
assert stats["total_messages"] == 0
assert stats["total_deleted"] == 0
def test_relation_cascade_deletes(self, cascade_test_data):
"""Deleting a Message must cascade to its Feedback and Annotation rows."""
data = cascade_test_data
msg_id = data["msg_id"]
fb_id = data["fb_id"]
ann_id = data["ann_id"]
svc = MessagesCleanService.from_time_range(
policy=BillingDisabledPolicy(),
start_from=_OLD - datetime.timedelta(hours=1),
end_before=_OLD + datetime.timedelta(hours=1),
batch_size=_DEFAULT_BATCH_SIZE,
)
stats = svc.run()
assert stats["total_deleted"] == 1
with session_factory.create_session() as session:
assert session.query(Message).where(Message.id == msg_id).count() == 0
assert session.query(MessageFeedback).where(MessageFeedback.id == fb_id).count() == 0
assert session.query(MessageAnnotation).where(MessageAnnotation.id == ann_id).count() == 0
def test_factory_from_time_range_validation(self):
with pytest.raises(ValueError, match="start_from"):
MessagesCleanService.from_time_range(
policy=BillingDisabledPolicy(),
start_from=_NOW,
end_before=_OLD,
)
def test_factory_from_days_validation(self):
with pytest.raises(ValueError, match="days"):
MessagesCleanService.from_days(
policy=BillingDisabledPolicy(),
days=-1,
)
def test_factory_batch_size_validation(self):
with pytest.raises(ValueError, match="batch_size"):
MessagesCleanService.from_time_range(
policy=BillingDisabledPolicy(),
start_from=_OLD,
end_before=_NOW,
batch_size=0,
)

View File

@@ -0,0 +1,177 @@
import datetime
import io
import json
import uuid
import zipfile
from unittest.mock import MagicMock, patch
import pytest
from services.retention.workflow_run.archive_paid_plan_workflow_run import (
ArchiveSummary,
WorkflowRunArchiver,
)
from services.retention.workflow_run.constants import ARCHIVE_SCHEMA_VERSION
class TestWorkflowRunArchiverInit:
def test_start_from_without_end_before_raises(self):
with pytest.raises(ValueError, match="start_from and end_before must be provided together"):
WorkflowRunArchiver(start_from=datetime.datetime(2025, 1, 1))
def test_end_before_without_start_from_raises(self):
with pytest.raises(ValueError, match="start_from and end_before must be provided together"):
WorkflowRunArchiver(end_before=datetime.datetime(2025, 1, 1))
def test_start_equals_end_raises(self):
ts = datetime.datetime(2025, 1, 1)
with pytest.raises(ValueError, match="start_from must be earlier than end_before"):
WorkflowRunArchiver(start_from=ts, end_before=ts)
def test_start_after_end_raises(self):
with pytest.raises(ValueError, match="start_from must be earlier than end_before"):
WorkflowRunArchiver(
start_from=datetime.datetime(2025, 6, 1),
end_before=datetime.datetime(2025, 1, 1),
)
def test_workers_zero_raises(self):
with pytest.raises(ValueError, match="workers must be at least 1"):
WorkflowRunArchiver(workers=0)
def test_valid_init_defaults(self):
archiver = WorkflowRunArchiver(days=30, batch_size=50)
assert archiver.days == 30
assert archiver.batch_size == 50
assert archiver.dry_run is False
assert archiver.delete_after_archive is False
assert archiver.start_from is None
def test_valid_init_with_time_range(self):
start = datetime.datetime(2025, 1, 1)
end = datetime.datetime(2025, 6, 1)
archiver = WorkflowRunArchiver(start_from=start, end_before=end, workers=2)
assert archiver.start_from is not None
assert archiver.end_before is not None
assert archiver.workers == 2
class TestBuildArchiveBundle:
def test_bundle_contains_manifest_and_all_tables(self):
archiver = WorkflowRunArchiver(days=90)
manifest_data = json.dumps({"schema_version": ARCHIVE_SCHEMA_VERSION}).encode("utf-8")
table_payloads = dict.fromkeys(archiver.ARCHIVED_TABLES, b"")
bundle_bytes = archiver._build_archive_bundle(manifest_data, table_payloads)
with zipfile.ZipFile(io.BytesIO(bundle_bytes), "r") as zf:
names = set(zf.namelist())
assert "manifest.json" in names
for table in archiver.ARCHIVED_TABLES:
assert f"{table}.jsonl" in names, f"Missing {table}.jsonl in bundle"
def test_bundle_missing_table_payload_raises(self):
archiver = WorkflowRunArchiver(days=90)
manifest_data = b"{}"
incomplete_payloads = {archiver.ARCHIVED_TABLES[0]: b"data"}
with pytest.raises(ValueError, match="Missing archive payload"):
archiver._build_archive_bundle(manifest_data, incomplete_payloads)
class TestGenerateManifest:
def test_manifest_structure(self):
archiver = WorkflowRunArchiver(days=90)
from services.retention.workflow_run.archive_paid_plan_workflow_run import TableStats
run = MagicMock()
run.id = str(uuid.uuid4())
run.tenant_id = str(uuid.uuid4())
run.app_id = str(uuid.uuid4())
run.workflow_id = str(uuid.uuid4())
run.created_at = datetime.datetime(2025, 3, 15, 10, 0, 0)
stats = [
TableStats(table_name="workflow_runs", row_count=1, checksum="abc123", size_bytes=512),
TableStats(table_name="workflow_app_logs", row_count=2, checksum="def456", size_bytes=1024),
]
manifest = archiver._generate_manifest(run, stats)
assert manifest["schema_version"] == ARCHIVE_SCHEMA_VERSION
assert manifest["workflow_run_id"] == run.id
assert manifest["tenant_id"] == run.tenant_id
assert manifest["app_id"] == run.app_id
assert "tables" in manifest
assert manifest["tables"]["workflow_runs"]["row_count"] == 1
assert manifest["tables"]["workflow_runs"]["checksum"] == "abc123"
assert manifest["tables"]["workflow_app_logs"]["row_count"] == 2
class TestFilterPaidTenants:
def test_all_tenants_paid_when_billing_disabled(self):
archiver = WorkflowRunArchiver(days=90)
tenant_ids = {"t1", "t2", "t3"}
with patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg:
cfg.BILLING_ENABLED = False
result = archiver._filter_paid_tenants(tenant_ids)
assert result == tenant_ids
def test_empty_tenants_returns_empty(self):
archiver = WorkflowRunArchiver(days=90)
with patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg:
cfg.BILLING_ENABLED = True
result = archiver._filter_paid_tenants(set())
assert result == set()
def test_only_paid_plans_returned(self):
archiver = WorkflowRunArchiver(days=90)
mock_bulk = {
"t1": {"plan": "professional"},
"t2": {"plan": "sandbox"},
"t3": {"plan": "team"},
}
with (
patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg,
patch("services.retention.workflow_run.archive_paid_plan_workflow_run.BillingService") as billing,
):
cfg.BILLING_ENABLED = True
billing.get_plan_bulk_with_cache.return_value = mock_bulk
result = archiver._filter_paid_tenants({"t1", "t2", "t3"})
assert "t1" in result
assert "t3" in result
assert "t2" not in result
def test_billing_api_failure_returns_empty(self):
archiver = WorkflowRunArchiver(days=90)
with (
patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg,
patch("services.retention.workflow_run.archive_paid_plan_workflow_run.BillingService") as billing,
):
cfg.BILLING_ENABLED = True
billing.get_plan_bulk_with_cache.side_effect = RuntimeError("API down")
result = archiver._filter_paid_tenants({"t1"})
assert result == set()
class TestDryRunArchive:
@patch("services.retention.workflow_run.archive_paid_plan_workflow_run.get_archive_storage")
def test_dry_run_does_not_call_storage(self, mock_get_storage, flask_req_ctx):
archiver = WorkflowRunArchiver(days=90, dry_run=True)
with patch.object(archiver, "_get_runs_batch", return_value=[]):
summary = archiver.run()
mock_get_storage.assert_not_called()
assert isinstance(summary, ArchiveSummary)
assert summary.runs_failed == 0

View File

@@ -5,6 +5,7 @@ Covers:
- License status caching (get_cached_license_status)
"""
from datetime import datetime
from unittest.mock import patch
import pytest
@@ -15,9 +16,178 @@ from services.enterprise.enterprise_service import (
VALID_LICENSE_CACHE_TTL,
DefaultWorkspaceJoinResult,
EnterpriseService,
WebAppSettings,
WorkspacePermission,
try_join_default_workspace,
)
MODULE = "services.enterprise.enterprise_service"
class TestEnterpriseServiceInfo:
def test_get_info_delegates(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = {"version": "1.0"}
result = EnterpriseService.get_info()
req.send_request.assert_called_once_with("GET", "/info")
assert result == {"version": "1.0"}
def test_get_workspace_info_delegates(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = {"name": "ws"}
result = EnterpriseService.get_workspace_info("tenant-1")
req.send_request.assert_called_once_with("GET", "/workspace/tenant-1/info")
assert result == {"name": "ws"}
class TestSsoSettingsLastUpdateTime:
def test_app_sso_parses_valid_timestamp(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = "2025-01-15T10:30:00+00:00"
result = EnterpriseService.get_app_sso_settings_last_update_time()
assert isinstance(result, datetime)
assert result.year == 2025
def test_app_sso_raises_on_empty(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = ""
with pytest.raises(ValueError, match="No data found"):
EnterpriseService.get_app_sso_settings_last_update_time()
def test_app_sso_raises_on_invalid_format(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = "not-a-date"
with pytest.raises(ValueError, match="Invalid date format"):
EnterpriseService.get_app_sso_settings_last_update_time()
def test_workspace_sso_parses_valid_timestamp(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = "2025-06-01T00:00:00+00:00"
result = EnterpriseService.get_workspace_sso_settings_last_update_time()
assert isinstance(result, datetime)
def test_workspace_sso_raises_on_empty(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = None
with pytest.raises(ValueError, match="No data found"):
EnterpriseService.get_workspace_sso_settings_last_update_time()
class TestWorkspacePermissionService:
def test_raises_on_empty_workspace_id(self):
with pytest.raises(ValueError, match="workspace_id must be provided"):
EnterpriseService.WorkspacePermissionService.get_permission("")
def test_raises_on_missing_data(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = None
with pytest.raises(ValueError, match="No data found"):
EnterpriseService.WorkspacePermissionService.get_permission("ws-1")
def test_raises_on_missing_permission_key(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = {"other": "data"}
with pytest.raises(ValueError, match="No data found"):
EnterpriseService.WorkspacePermissionService.get_permission("ws-1")
def test_returns_parsed_permission(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = {
"permission": {
"workspaceId": "ws-1",
"allowMemberInvite": True,
"allowOwnerTransfer": False,
}
}
result = EnterpriseService.WorkspacePermissionService.get_permission("ws-1")
assert isinstance(result, WorkspacePermission)
assert result.workspace_id == "ws-1"
assert result.allow_member_invite is True
assert result.allow_owner_transfer is False
class TestWebAppAuth:
def test_is_user_allowed_returns_result_field(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = {"result": True}
assert EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp("u1", "a1") is True
def test_is_user_allowed_defaults_false(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = {}
assert EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp("u1", "a1") is False
def test_batch_is_user_allowed_returns_empty_for_no_apps(self):
assert EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps("u1", []) == {}
def test_batch_is_user_allowed_raises_on_empty_response(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = None
with pytest.raises(ValueError, match="No data found"):
EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps("u1", ["a1"])
def test_get_app_access_mode_raises_on_empty_app_id(self):
with pytest.raises(ValueError, match="app_id must be provided"):
EnterpriseService.WebAppAuth.get_app_access_mode_by_id("")
def test_get_app_access_mode_returns_settings(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = {"accessMode": "public"}
result = EnterpriseService.WebAppAuth.get_app_access_mode_by_id("a1")
assert isinstance(result, WebAppSettings)
assert result.access_mode == "public"
def test_batch_get_returns_empty_for_no_apps(self):
assert EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id([]) == {}
def test_batch_get_maps_access_modes(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = {"accessModes": {"a1": "public", "a2": "private"}}
result = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(["a1", "a2"])
assert result["a1"].access_mode == "public"
assert result["a2"].access_mode == "private"
def test_batch_get_raises_on_invalid_format(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = {"accessModes": "not-a-dict"}
with pytest.raises(ValueError, match="Invalid data format"):
EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(["a1"])
def test_update_access_mode_raises_on_empty_app_id(self):
with pytest.raises(ValueError, match="app_id must be provided"):
EnterpriseService.WebAppAuth.update_app_access_mode("", "public")
def test_update_access_mode_raises_on_invalid_mode(self):
with pytest.raises(ValueError, match="access_mode must be"):
EnterpriseService.WebAppAuth.update_app_access_mode("a1", "invalid")
def test_update_access_mode_delegates_and_returns(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = {"result": True}
result = EnterpriseService.WebAppAuth.update_app_access_mode("a1", "public")
assert result is True
req.send_request.assert_called_once_with(
"POST", "/webapp/access-mode", json={"appId": "a1", "accessMode": "public"}
)
def test_cleanup_webapp_raises_on_empty_app_id(self):
with pytest.raises(ValueError, match="app_id must be provided"):
EnterpriseService.WebAppAuth.cleanup_webapp("")
def test_cleanup_webapp_delegates(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
EnterpriseService.WebAppAuth.cleanup_webapp("a1")
req.send_request.assert_called_once_with("DELETE", "/webapp/clean", params={"appId": "a1"})
class TestJoinDefaultWorkspace:
def test_join_default_workspace_success(self):

View File

@@ -7,14 +7,20 @@ This module covers the pre-uninstall plugin hook behavior:
from unittest.mock import patch
import pytest
from httpx import HTTPStatusError
from configs import dify_config
from services.enterprise.plugin_manager_service import (
CheckCredentialPolicyComplianceRequest,
CredentialPolicyViolationError,
PluginCredentialType,
PluginManagerService,
PreUninstallPluginRequest,
)
MODULE = "services.enterprise.plugin_manager_service"
class TestTryPreUninstallPlugin:
def test_try_pre_uninstall_plugin_success(self):
@@ -88,3 +94,46 @@ class TestTryPreUninstallPlugin:
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
)
mock_logger.exception.assert_called_once()
class TestCheckCredentialPolicyCompliance:
def _request(self, cred_type=PluginCredentialType.MODEL):
return CheckCredentialPolicyComplianceRequest(
dify_credential_id="cred-1", provider="openai", credential_type=cred_type
)
def test_passes_when_result_true(self):
with patch(f"{MODULE}.EnterprisePluginManagerRequest") as req:
req.send_request.return_value = {"result": True}
PluginManagerService.check_credential_policy_compliance(self._request())
req.send_request.assert_called_once()
def test_raises_violation_when_result_false(self):
with patch(f"{MODULE}.EnterprisePluginManagerRequest") as req:
req.send_request.return_value = {"result": False}
with pytest.raises(CredentialPolicyViolationError, match="Credentials not available"):
PluginManagerService.check_credential_policy_compliance(self._request())
def test_raises_violation_on_invalid_response_format(self):
with patch(f"{MODULE}.EnterprisePluginManagerRequest") as req:
req.send_request.return_value = "not-a-dict"
with pytest.raises(CredentialPolicyViolationError, match="error occurred"):
PluginManagerService.check_credential_policy_compliance(self._request())
def test_raises_violation_on_api_exception(self):
with patch(f"{MODULE}.EnterprisePluginManagerRequest") as req:
req.send_request.side_effect = ConnectionError("network fail")
with pytest.raises(CredentialPolicyViolationError, match="error occurred"):
PluginManagerService.check_credential_policy_compliance(self._request())
def test_model_dump_serializes_credential_type_as_number(self):
body = self._request(PluginCredentialType.TOOL)
data = body.model_dump()
assert data["credential_type"] == 1
assert data["dify_credential_id"] == "cred-1"
def test_model_credential_type_values(self):
assert PluginCredentialType.MODEL.to_number() == 0
assert PluginCredentialType.TOOL.to_number() == 1

View File

@@ -0,0 +1,183 @@
from unittest.mock import MagicMock, patch
from models.account import TenantPluginAutoUpgradeStrategy
MODULE = "services.plugin.plugin_auto_upgrade_service"
def _patched_session():
"""Patch Session(db.engine) to return a mock session as context manager."""
session = MagicMock()
session_cls = MagicMock()
session_cls.return_value.__enter__ = MagicMock(return_value=session)
session_cls.return_value.__exit__ = MagicMock(return_value=False)
patcher = patch(f"{MODULE}.Session", session_cls)
db_patcher = patch(f"{MODULE}.db")
return patcher, db_patcher, session
class TestGetStrategy:
def test_returns_strategy_when_found(self):
p1, p2, session = _patched_session()
strategy = MagicMock()
session.query.return_value.where.return_value.first.return_value = strategy
with p1, p2:
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
result = PluginAutoUpgradeService.get_strategy("t1")
assert result is strategy
def test_returns_none_when_not_found(self):
p1, p2, session = _patched_session()
session.query.return_value.where.return_value.first.return_value = None
with p1, p2:
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
result = PluginAutoUpgradeService.get_strategy("t1")
assert result is None
class TestChangeStrategy:
def test_creates_new_strategy(self):
p1, p2, session = _patched_session()
session.query.return_value.where.return_value.first.return_value = None
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
strat_cls.return_value = MagicMock()
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
result = PluginAutoUpgradeService.change_strategy(
"t1",
TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
3,
TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
[],
[],
)
assert result is True
session.add.assert_called_once()
session.commit.assert_called_once()
def test_updates_existing_strategy(self):
p1, p2, session = _patched_session()
existing = MagicMock()
session.query.return_value.where.return_value.first.return_value = existing
with p1, p2:
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
result = PluginAutoUpgradeService.change_strategy(
"t1",
TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
5,
TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL,
["p1"],
["p2"],
)
assert result is True
assert existing.strategy_setting == TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST
assert existing.upgrade_time_of_day == 5
assert existing.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL
assert existing.exclude_plugins == ["p1"]
assert existing.include_plugins == ["p2"]
session.commit.assert_called_once()
class TestExcludePlugin:
def test_creates_default_strategy_when_none_exists(self):
p1, p2, session = _patched_session()
session.query.return_value.where.return_value.first.return_value = None
with (
p1,
p2,
patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls,
patch(f"{MODULE}.PluginAutoUpgradeService.change_strategy") as cs,
):
strat_cls.StrategySetting.FIX_ONLY = "fix_only"
strat_cls.UpgradeMode.EXCLUDE = "exclude"
cs.return_value = True
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
result = PluginAutoUpgradeService.exclude_plugin("t1", "plugin-1")
assert result is True
cs.assert_called_once()
def test_appends_to_exclude_list_in_exclude_mode(self):
p1, p2, session = _patched_session()
existing = MagicMock()
existing.upgrade_mode = "exclude"
existing.exclude_plugins = ["p-existing"]
session.query.return_value.where.return_value.first.return_value = existing
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
strat_cls.UpgradeMode.EXCLUDE = "exclude"
strat_cls.UpgradeMode.PARTIAL = "partial"
strat_cls.UpgradeMode.ALL = "all"
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
result = PluginAutoUpgradeService.exclude_plugin("t1", "p-new")
assert result is True
assert existing.exclude_plugins == ["p-existing", "p-new"]
session.commit.assert_called_once()
def test_removes_from_include_list_in_partial_mode(self):
p1, p2, session = _patched_session()
existing = MagicMock()
existing.upgrade_mode = "partial"
existing.include_plugins = ["p1", "p2"]
session.query.return_value.where.return_value.first.return_value = existing
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
strat_cls.UpgradeMode.EXCLUDE = "exclude"
strat_cls.UpgradeMode.PARTIAL = "partial"
strat_cls.UpgradeMode.ALL = "all"
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
result = PluginAutoUpgradeService.exclude_plugin("t1", "p1")
assert result is True
assert existing.include_plugins == ["p2"]
def test_switches_to_exclude_mode_from_all(self):
p1, p2, session = _patched_session()
existing = MagicMock()
existing.upgrade_mode = "all"
session.query.return_value.where.return_value.first.return_value = existing
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
strat_cls.UpgradeMode.EXCLUDE = "exclude"
strat_cls.UpgradeMode.PARTIAL = "partial"
strat_cls.UpgradeMode.ALL = "all"
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
result = PluginAutoUpgradeService.exclude_plugin("t1", "p1")
assert result is True
assert existing.upgrade_mode == "exclude"
assert existing.exclude_plugins == ["p1"]
def test_no_duplicate_in_exclude_list(self):
p1, p2, session = _patched_session()
existing = MagicMock()
existing.upgrade_mode = "exclude"
existing.exclude_plugins = ["p1"]
session.query.return_value.where.return_value.first.return_value = existing
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
strat_cls.UpgradeMode.EXCLUDE = "exclude"
strat_cls.UpgradeMode.PARTIAL = "partial"
strat_cls.UpgradeMode.ALL = "all"
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
PluginAutoUpgradeService.exclude_plugin("t1", "p1")
assert existing.exclude_plugins == ["p1"]

View File

@@ -0,0 +1,75 @@
from unittest.mock import MagicMock, patch
from models.account import TenantPluginPermission
MODULE = "services.plugin.plugin_permission_service"
def _patched_session():
"""Patch Session(db.engine) to return a mock session as context manager."""
session = MagicMock()
session_cls = MagicMock()
session_cls.return_value.__enter__ = MagicMock(return_value=session)
session_cls.return_value.__exit__ = MagicMock(return_value=False)
patcher = patch(f"{MODULE}.Session", session_cls)
db_patcher = patch(f"{MODULE}.db")
return patcher, db_patcher, session
class TestGetPermission:
def test_returns_permission_when_found(self):
p1, p2, session = _patched_session()
permission = MagicMock()
session.query.return_value.where.return_value.first.return_value = permission
with p1, p2:
from services.plugin.plugin_permission_service import PluginPermissionService
result = PluginPermissionService.get_permission("t1")
assert result is permission
def test_returns_none_when_not_found(self):
p1, p2, session = _patched_session()
session.query.return_value.where.return_value.first.return_value = None
with p1, p2:
from services.plugin.plugin_permission_service import PluginPermissionService
result = PluginPermissionService.get_permission("t1")
assert result is None
class TestChangePermission:
def test_creates_new_permission_when_not_exists(self):
p1, p2, session = _patched_session()
session.query.return_value.where.return_value.first.return_value = None
with p1, p2, patch(f"{MODULE}.TenantPluginPermission") as perm_cls:
perm_cls.return_value = MagicMock()
from services.plugin.plugin_permission_service import PluginPermissionService
result = PluginPermissionService.change_permission(
"t1", TenantPluginPermission.InstallPermission.EVERYONE, TenantPluginPermission.DebugPermission.EVERYONE
)
session.add.assert_called_once()
session.commit.assert_called_once()
def test_updates_existing_permission(self):
p1, p2, session = _patched_session()
existing = MagicMock()
session.query.return_value.where.return_value.first.return_value = existing
with p1, p2:
from services.plugin.plugin_permission_service import PluginPermissionService
result = PluginPermissionService.change_permission(
"t1", TenantPluginPermission.InstallPermission.ADMINS, TenantPluginPermission.DebugPermission.ADMINS
)
assert existing.install_permission == TenantPluginPermission.InstallPermission.ADMINS
assert existing.debug_permission == TenantPluginPermission.DebugPermission.ADMINS
session.commit.assert_called_once()
session.add.assert_not_called()

View File

@@ -0,0 +1,135 @@
import datetime
from unittest.mock import MagicMock, patch
from services.retention.conversation.messages_clean_policy import (
BillingDisabledPolicy,
BillingSandboxPolicy,
SimpleMessage,
create_message_clean_policy,
)
MODULE = "services.retention.conversation.messages_clean_policy"
def _msg(msg_id: str, app_id: str, days_ago: int = 0) -> SimpleMessage:
return SimpleMessage(
id=msg_id,
app_id=app_id,
created_at=datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=days_ago),
)
class TestBillingDisabledPolicy:
def test_returns_all_message_ids(self):
policy = BillingDisabledPolicy()
msgs = [_msg("m1", "app1"), _msg("m2", "app2"), _msg("m3", "app1")]
result = policy.filter_message_ids(msgs, {"app1": "t1", "app2": "t2"})
assert set(result) == {"m1", "m2", "m3"}
def test_empty_messages_returns_empty(self):
assert BillingDisabledPolicy().filter_message_ids([], {}) == []
class TestBillingSandboxPolicy:
def _policy(self, plans, *, graceful_days=21, whitelist=None, now=1_000_000_000):
return BillingSandboxPolicy(
plan_provider=lambda _ids: plans,
graceful_period_days=graceful_days,
tenant_whitelist=whitelist,
current_timestamp=now,
)
def test_empty_messages_returns_empty(self):
policy = self._policy({})
assert policy.filter_message_ids([], {"app1": "t1"}) == []
def test_empty_app_to_tenant_returns_empty(self):
policy = self._policy({})
assert policy.filter_message_ids([_msg("m1", "app1")], {}) == []
def test_empty_plans_returns_empty(self):
policy = self._policy({})
msgs = [_msg("m1", "app1")]
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == []
def test_non_sandbox_tenant_skipped(self):
plans = {"t1": {"plan": "professional", "expiration_date": 0}}
policy = self._policy(plans)
msgs = [_msg("m1", "app1")]
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == []
def test_sandbox_no_previous_subscription_deletes(self):
plans = {"t1": {"plan": "sandbox", "expiration_date": -1}}
policy = self._policy(plans)
msgs = [_msg("m1", "app1")]
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == ["m1"]
def test_sandbox_expired_beyond_grace_period_deletes(self):
now = 1_000_000_000
expired_long_ago = now - (22 * 24 * 60 * 60) # 22 days ago > 21 day grace
plans = {"t1": {"plan": "sandbox", "expiration_date": expired_long_ago}}
policy = self._policy(plans, now=now)
msgs = [_msg("m1", "app1")]
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == ["m1"]
def test_sandbox_within_grace_period_kept(self):
now = 1_000_000_000
expired_recently = now - (10 * 24 * 60 * 60) # 10 days ago < 21 day grace
plans = {"t1": {"plan": "sandbox", "expiration_date": expired_recently}}
policy = self._policy(plans, now=now)
msgs = [_msg("m1", "app1")]
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == []
def test_whitelisted_tenant_skipped(self):
plans = {"t1": {"plan": "sandbox", "expiration_date": -1}}
policy = self._policy(plans, whitelist=["t1"])
msgs = [_msg("m1", "app1")]
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == []
def test_message_without_tenant_mapping_skipped(self):
plans = {"t1": {"plan": "sandbox", "expiration_date": -1}}
policy = self._policy(plans)
msgs = [_msg("m1", "unmapped_app")]
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == []
def test_mixed_tenants_only_sandbox_deleted(self):
plans = {
"t_sandbox": {"plan": "sandbox", "expiration_date": -1},
"t_pro": {"plan": "professional", "expiration_date": 0},
}
policy = self._policy(plans)
msgs = [_msg("m1", "app_sandbox"), _msg("m2", "app_pro")]
app_map = {"app_sandbox": "t_sandbox", "app_pro": "t_pro"}
result = policy.filter_message_ids(msgs, app_map)
assert result == ["m1"]
class TestCreateMessageCleanPolicy:
def test_billing_disabled_returns_disabled_policy(self):
with patch(f"{MODULE}.dify_config") as cfg:
cfg.BILLING_ENABLED = False
policy = create_message_clean_policy()
assert isinstance(policy, BillingDisabledPolicy)
def test_billing_enabled_returns_sandbox_policy(self):
with (
patch(f"{MODULE}.dify_config") as cfg,
patch(f"{MODULE}.BillingService") as bs,
):
cfg.BILLING_ENABLED = True
bs.get_expired_subscription_cleanup_whitelist.return_value = ["wl1"]
bs.get_plan_bulk_with_cache = MagicMock()
policy = create_message_clean_policy(graceful_period_days=30)
assert isinstance(policy, BillingSandboxPolicy)

View File

@@ -0,0 +1,598 @@
from unittest.mock import MagicMock, Mock, patch
from core.tools.__base.tool import Tool
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolParameter, ToolProviderType
from services.tools.tools_transform_service import ToolTransformService
MODULE = "services.tools.tools_transform_service"
class TestToolTransformService:
"""Test cases for ToolTransformService.convert_tool_entity_to_api_entity method"""
def test_convert_tool_with_parameter_override(self):
"""Test that runtime parameters correctly override base parameters"""
# Create mock base parameters
base_param1 = Mock(spec=ToolParameter)
base_param1.name = "param1"
base_param1.form = ToolParameter.ToolParameterForm.FORM
base_param1.type = "string"
base_param1.label = "Base Param 1"
base_param2 = Mock(spec=ToolParameter)
base_param2.name = "param2"
base_param2.form = ToolParameter.ToolParameterForm.FORM
base_param2.type = "string"
base_param2.label = "Base Param 2"
# Create mock runtime parameters that override base parameters
runtime_param1 = Mock(spec=ToolParameter)
runtime_param1.name = "param1"
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
runtime_param1.type = "string"
runtime_param1.label = "Runtime Param 1" # Different label to verify override
# Create mock tool
mock_tool = Mock(spec=Tool)
mock_tool.entity = Mock()
mock_tool.entity.parameters = [base_param1, base_param2]
mock_tool.entity.identity = Mock()
mock_tool.entity.identity.author = "test_author"
mock_tool.entity.identity.name = "test_tool"
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
mock_tool.entity.description = Mock()
mock_tool.entity.description.human = I18nObject(en_US="Test description")
mock_tool.entity.output_schema = {}
mock_tool.get_runtime_parameters.return_value = [runtime_param1]
# Mock fork_tool_runtime to return the same tool
mock_tool.fork_tool_runtime.return_value = mock_tool
# Call the method
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
# Verify the result
assert isinstance(result, ToolApiEntity)
assert result.author == "test_author"
assert result.name == "test_tool"
assert result.parameters is not None
assert len(result.parameters) == 2
# Find the overridden parameter
overridden_param = next((p for p in result.parameters if p.name == "param1"), None)
assert overridden_param is not None
assert overridden_param.label == "Runtime Param 1" # Should be runtime version
# Find the non-overridden parameter
original_param = next((p for p in result.parameters if p.name == "param2"), None)
assert original_param is not None
assert original_param.label == "Base Param 2" # Should be base version
def test_convert_tool_with_additional_runtime_parameters(self):
"""Test that additional runtime parameters are added to the final list"""
# Create mock base parameters
base_param1 = Mock(spec=ToolParameter)
base_param1.name = "param1"
base_param1.form = ToolParameter.ToolParameterForm.FORM
base_param1.type = "string"
base_param1.label = "Base Param 1"
# Create mock runtime parameters - one that overrides and one that's new
runtime_param1 = Mock(spec=ToolParameter)
runtime_param1.name = "param1"
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
runtime_param1.type = "string"
runtime_param1.label = "Runtime Param 1"
runtime_param2 = Mock(spec=ToolParameter)
runtime_param2.name = "runtime_only"
runtime_param2.form = ToolParameter.ToolParameterForm.FORM
runtime_param2.type = "string"
runtime_param2.label = "Runtime Only Param"
# Create mock tool
mock_tool = Mock(spec=Tool)
mock_tool.entity = Mock()
mock_tool.entity.parameters = [base_param1]
mock_tool.entity.identity = Mock()
mock_tool.entity.identity.author = "test_author"
mock_tool.entity.identity.name = "test_tool"
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
mock_tool.entity.description = Mock()
mock_tool.entity.description.human = I18nObject(en_US="Test description")
mock_tool.entity.output_schema = {}
mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2]
# Mock fork_tool_runtime to return the same tool
mock_tool.fork_tool_runtime.return_value = mock_tool
# Call the method
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
# Verify the result
assert isinstance(result, ToolApiEntity)
assert result.parameters is not None
assert len(result.parameters) == 2
# Check that both parameters are present
param_names = [p.name for p in result.parameters]
assert "param1" in param_names
assert "runtime_only" in param_names
# Verify the overridden parameter has runtime version
overridden_param = next((p for p in result.parameters if p.name == "param1"), None)
assert overridden_param is not None
assert overridden_param.label == "Runtime Param 1"
# Verify the new runtime parameter is included
new_param = next((p for p in result.parameters if p.name == "runtime_only"), None)
assert new_param is not None
assert new_param.label == "Runtime Only Param"
def test_convert_tool_with_non_form_runtime_parameters(self):
"""Test that non-FORM runtime parameters are not added as new parameters"""
# Create mock base parameters
base_param1 = Mock(spec=ToolParameter)
base_param1.name = "param1"
base_param1.form = ToolParameter.ToolParameterForm.FORM
base_param1.type = "string"
base_param1.label = "Base Param 1"
# Create mock runtime parameters with different forms
runtime_param1 = Mock(spec=ToolParameter)
runtime_param1.name = "param1"
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
runtime_param1.type = "string"
runtime_param1.label = "Runtime Param 1"
runtime_param2 = Mock(spec=ToolParameter)
runtime_param2.name = "llm_param"
runtime_param2.form = ToolParameter.ToolParameterForm.LLM
runtime_param2.type = "string"
runtime_param2.label = "LLM Param"
# Create mock tool
mock_tool = Mock(spec=Tool)
mock_tool.entity = Mock()
mock_tool.entity.parameters = [base_param1]
mock_tool.entity.identity = Mock()
mock_tool.entity.identity.author = "test_author"
mock_tool.entity.identity.name = "test_tool"
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
mock_tool.entity.description = Mock()
mock_tool.entity.description.human = I18nObject(en_US="Test description")
mock_tool.entity.output_schema = {}
mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2]
# Mock fork_tool_runtime to return the same tool
mock_tool.fork_tool_runtime.return_value = mock_tool
# Call the method
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
# Verify the result
assert isinstance(result, ToolApiEntity)
assert result.parameters is not None
assert len(result.parameters) == 1 # Only the FORM parameter should be present
# Check that only the FORM parameter is present
param_names = [p.name for p in result.parameters]
assert "param1" in param_names
assert "llm_param" not in param_names
def test_convert_tool_with_empty_parameters(self):
"""Test conversion with empty base and runtime parameters"""
# Create mock tool with no parameters
mock_tool = Mock(spec=Tool)
mock_tool.entity = Mock()
mock_tool.entity.parameters = []
mock_tool.entity.identity = Mock()
mock_tool.entity.identity.author = "test_author"
mock_tool.entity.identity.name = "test_tool"
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
mock_tool.entity.description = Mock()
mock_tool.entity.description.human = I18nObject(en_US="Test description")
mock_tool.entity.output_schema = {}
mock_tool.get_runtime_parameters.return_value = []
# Mock fork_tool_runtime to return the same tool
mock_tool.fork_tool_runtime.return_value = mock_tool
# Call the method
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
# Verify the result
assert isinstance(result, ToolApiEntity)
assert result.parameters is not None
assert len(result.parameters) == 0
def test_convert_tool_with_none_parameters(self):
"""Test conversion when base parameters is None"""
# Create mock tool with None parameters
mock_tool = Mock(spec=Tool)
mock_tool.entity = Mock()
mock_tool.entity.parameters = None
mock_tool.entity.identity = Mock()
mock_tool.entity.identity.author = "test_author"
mock_tool.entity.identity.name = "test_tool"
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
mock_tool.entity.description = Mock()
mock_tool.entity.description.human = I18nObject(en_US="Test description")
mock_tool.entity.output_schema = {}
mock_tool.get_runtime_parameters.return_value = []
# Mock fork_tool_runtime to return the same tool
mock_tool.fork_tool_runtime.return_value = mock_tool
# Call the method
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
# Verify the result
assert isinstance(result, ToolApiEntity)
assert result.parameters is not None
assert len(result.parameters) == 0
def test_convert_tool_parameter_order_preserved(self):
"""Test that parameter order is preserved correctly"""
# Create mock base parameters in specific order
base_param1 = Mock(spec=ToolParameter)
base_param1.name = "param1"
base_param1.form = ToolParameter.ToolParameterForm.FORM
base_param1.type = "string"
base_param1.label = "Base Param 1"
base_param2 = Mock(spec=ToolParameter)
base_param2.name = "param2"
base_param2.form = ToolParameter.ToolParameterForm.FORM
base_param2.type = "string"
base_param2.label = "Base Param 2"
base_param3 = Mock(spec=ToolParameter)
base_param3.name = "param3"
base_param3.form = ToolParameter.ToolParameterForm.FORM
base_param3.type = "string"
base_param3.label = "Base Param 3"
# Create runtime parameter that overrides middle parameter
runtime_param2 = Mock(spec=ToolParameter)
runtime_param2.name = "param2"
runtime_param2.form = ToolParameter.ToolParameterForm.FORM
runtime_param2.type = "string"
runtime_param2.label = "Runtime Param 2"
# Create new runtime parameter
runtime_param4 = Mock(spec=ToolParameter)
runtime_param4.name = "param4"
runtime_param4.form = ToolParameter.ToolParameterForm.FORM
runtime_param4.type = "string"
runtime_param4.label = "Runtime Param 4"
# Create mock tool
mock_tool = Mock(spec=Tool)
mock_tool.entity = Mock()
mock_tool.entity.parameters = [base_param1, base_param2, base_param3]
mock_tool.entity.identity = Mock()
mock_tool.entity.identity.author = "test_author"
mock_tool.entity.identity.name = "test_tool"
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
mock_tool.entity.description = Mock()
mock_tool.entity.description.human = I18nObject(en_US="Test description")
mock_tool.entity.output_schema = {}
mock_tool.get_runtime_parameters.return_value = [runtime_param2, runtime_param4]
# Mock fork_tool_runtime to return the same tool
mock_tool.fork_tool_runtime.return_value = mock_tool
# Call the method
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
# Verify the result
assert isinstance(result, ToolApiEntity)
assert result.parameters is not None
assert len(result.parameters) == 4
# Check that order is maintained: base parameters first, then new runtime parameters
param_names = [p.name for p in result.parameters]
assert param_names == ["param1", "param2", "param3", "param4"]
# Verify that param2 was overridden with runtime version
param2 = result.parameters[1]
assert param2.name == "param2"
assert param2.label == "Runtime Param 2"
class TestWorkflowProviderToUserProvider:
"""Test cases for ToolTransformService.workflow_provider_to_user_provider method"""
def test_workflow_provider_to_user_provider_with_workflow_app_id(self):
"""Test that workflow_provider_to_user_provider correctly sets workflow_app_id."""
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
# Create mock workflow tool provider controller
workflow_app_id = "app_123"
provider_id = "provider_123"
mock_controller = Mock(spec=WorkflowToolProviderController)
mock_controller.provider_id = provider_id
mock_controller.entity = Mock()
mock_controller.entity.identity = Mock()
mock_controller.entity.identity.author = "test_author"
mock_controller.entity.identity.name = "test_workflow_tool"
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
mock_controller.entity.identity.icon_dark = None
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
# Call the method
result = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=mock_controller,
labels=["label1", "label2"],
workflow_app_id=workflow_app_id,
)
# Verify the result
assert isinstance(result, ToolProviderApiEntity)
assert result.id == provider_id
assert result.author == "test_author"
assert result.name == "test_workflow_tool"
assert result.type == ToolProviderType.WORKFLOW
assert result.workflow_app_id == workflow_app_id
assert result.labels == ["label1", "label2"]
assert result.is_team_authorization is True
assert result.plugin_id is None
assert result.plugin_unique_identifier is None
assert result.tools == []
def test_workflow_provider_to_user_provider_without_workflow_app_id(self):
"""Test that workflow_provider_to_user_provider works when workflow_app_id is not provided."""
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
# Create mock workflow tool provider controller
provider_id = "provider_123"
mock_controller = Mock(spec=WorkflowToolProviderController)
mock_controller.provider_id = provider_id
mock_controller.entity = Mock()
mock_controller.entity.identity = Mock()
mock_controller.entity.identity.author = "test_author"
mock_controller.entity.identity.name = "test_workflow_tool"
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
mock_controller.entity.identity.icon_dark = None
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
# Call the method without workflow_app_id
result = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=mock_controller,
labels=["label1"],
)
# Verify the result
assert isinstance(result, ToolProviderApiEntity)
assert result.id == provider_id
assert result.workflow_app_id is None
assert result.labels == ["label1"]
def test_workflow_provider_to_user_provider_workflow_app_id_none(self):
"""Test that workflow_provider_to_user_provider handles None workflow_app_id explicitly."""
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
# Create mock workflow tool provider controller
provider_id = "provider_123"
mock_controller = Mock(spec=WorkflowToolProviderController)
mock_controller.provider_id = provider_id
mock_controller.entity = Mock()
mock_controller.entity.identity = Mock()
mock_controller.entity.identity.author = "test_author"
mock_controller.entity.identity.name = "test_workflow_tool"
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
mock_controller.entity.identity.icon_dark = None
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
# Call the method with explicit None values
result = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=mock_controller,
labels=None,
workflow_app_id=None,
)
# Verify the result
assert isinstance(result, ToolProviderApiEntity)
assert result.id == provider_id
assert result.workflow_app_id is None
assert result.labels == []
def test_workflow_provider_to_user_provider_preserves_other_fields(self):
"""Test that workflow_provider_to_user_provider preserves all other entity fields."""
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
# Create mock workflow tool provider controller with various fields
workflow_app_id = "app_456"
provider_id = "provider_456"
mock_controller = Mock(spec=WorkflowToolProviderController)
mock_controller.provider_id = provider_id
mock_controller.entity = Mock()
mock_controller.entity.identity = Mock()
mock_controller.entity.identity.author = "another_author"
mock_controller.entity.identity.name = "another_workflow_tool"
mock_controller.entity.identity.description = I18nObject(
en_US="Another description", zh_Hans="Another description"
)
mock_controller.entity.identity.icon = {"type": "emoji", "content": "⚙️"}
mock_controller.entity.identity.icon_dark = {"type": "emoji", "content": "🔧"}
mock_controller.entity.identity.label = I18nObject(
en_US="Another Workflow Tool", zh_Hans="Another Workflow Tool"
)
# Call the method
result = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=mock_controller,
labels=["automation", "workflow"],
workflow_app_id=workflow_app_id,
)
# Verify all fields are preserved correctly
assert isinstance(result, ToolProviderApiEntity)
assert result.id == provider_id
assert result.author == "another_author"
assert result.name == "another_workflow_tool"
assert result.description.en_US == "Another description"
assert result.description.zh_Hans == "Another description"
assert result.icon == {"type": "emoji", "content": "⚙️"}
assert result.icon_dark == {"type": "emoji", "content": "🔧"}
assert result.label.en_US == "Another Workflow Tool"
assert result.label.zh_Hans == "Another Workflow Tool"
assert result.type == ToolProviderType.WORKFLOW
assert result.workflow_app_id == workflow_app_id
assert result.labels == ["automation", "workflow"]
assert result.masked_credentials == {}
assert result.is_team_authorization is True
assert result.allow_delete is True
assert result.plugin_id is None
assert result.plugin_unique_identifier is None
assert result.tools == []
class TestGetToolProviderIconUrl:
def test_builtin_provider_returns_console_url(self):
with patch(f"{MODULE}.dify_config") as cfg:
cfg.CONSOLE_API_URL = "https://app.dify.ai"
url = ToolTransformService.get_tool_provider_icon_url("builtin", "google", "icon.png")
assert "/builtin/google/icon" in url
assert url.startswith("https://app.dify.ai/console/api/workspaces/current/tool-provider")
def test_builtin_provider_with_no_console_url(self):
with patch(f"{MODULE}.dify_config") as cfg:
cfg.CONSOLE_API_URL = None
url = ToolTransformService.get_tool_provider_icon_url("builtin", "slack", "icon.png")
assert "/builtin/slack/icon" in url
def test_api_provider_parses_json_icon(self):
icon_json = '{"background": "#fff", "content": "A"}'
result = ToolTransformService.get_tool_provider_icon_url("api", "my-api", icon_json)
assert result == {"background": "#fff", "content": "A"}
def test_api_provider_returns_dict_icon_directly(self):
icon = {"background": "#000", "content": "B"}
result = ToolTransformService.get_tool_provider_icon_url("api", "my-api", icon)
assert result == icon
def test_api_provider_returns_fallback_on_invalid_json(self):
result = ToolTransformService.get_tool_provider_icon_url("api", "my-api", "not-json")
assert result == {"background": "#252525", "content": "\ud83d\ude01"}
def test_workflow_provider_behaves_like_api(self):
icon = {"background": "#123", "content": "W"}
assert ToolTransformService.get_tool_provider_icon_url("workflow", "wf", icon) == icon
def test_mcp_returns_icon_as_is(self):
assert ToolTransformService.get_tool_provider_icon_url("mcp", "srv", "icon-value") == "icon-value"
def test_unknown_type_returns_empty(self):
assert ToolTransformService.get_tool_provider_icon_url("unknown", "x", "i") == ""
class TestRepackProvider:
def test_repacks_dict_provider_icon(self):
provider = {"type": "builtin", "name": "google", "icon": "old"}
with patch.object(ToolTransformService, "get_tool_provider_icon_url", return_value="/new-url") as mock_fn:
ToolTransformService.repack_provider("t1", provider)
assert provider["icon"] == "/new-url"
mock_fn.assert_called_once_with(provider_type="builtin", provider_name="google", icon="old")
def test_repacks_tool_provider_api_entity_without_plugin(self):
entity = MagicMock(spec=ToolProviderApiEntity)
entity.plugin_id = None
entity.type = ToolProviderType.BUILT_IN
entity.name = "slack"
entity.icon = "icon.svg"
entity.icon_dark = "dark.svg"
with patch.object(ToolTransformService, "get_tool_provider_icon_url", return_value="/url"):
ToolTransformService.repack_provider("t1", entity)
assert entity.icon == "/url"
assert entity.icon_dark == "/url"
class TestConvertMcpSchemaToParameter:
def test_simple_object_schema(self):
schema = {
"type": "object",
"properties": {
"query": {"type": "string", "description": "Search query"},
"count": {"type": "integer", "description": "Result count"},
},
"required": ["query"],
}
params = ToolTransformService.convert_mcp_schema_to_parameter(schema)
assert len(params) == 2
query_param = next(p for p in params if p.name == "query")
count_param = next(p for p in params if p.name == "count")
assert query_param.required is True
assert count_param.required is False
assert count_param.type.value == "number"
def test_float_maps_to_number(self):
schema = {"type": "object", "properties": {"rate": {"type": "float"}}, "required": []}
assert ToolTransformService.convert_mcp_schema_to_parameter(schema)[0].type.value == "number"
def test_array_type_attaches_input_schema(self):
prop = {"type": "array", "description": "Items", "items": {"type": "string"}}
schema = {"type": "object", "properties": {"items": prop}, "required": []}
param = ToolTransformService.convert_mcp_schema_to_parameter(schema)[0]
assert param.input_schema is not None
def test_non_object_schema_returns_empty(self):
assert ToolTransformService.convert_mcp_schema_to_parameter({"type": "string"}) == []
def test_missing_properties_returns_empty(self):
assert ToolTransformService.convert_mcp_schema_to_parameter({"type": "object"}) == []
def test_list_type_uses_first_element(self):
schema = {"type": "object", "properties": {"f": {"type": ["string", "null"]}}, "required": []}
assert ToolTransformService.convert_mcp_schema_to_parameter(schema)[0].type.value == "string"
def test_missing_description_defaults_empty(self):
schema = {"type": "object", "properties": {"f": {"type": "string"}}, "required": []}
assert ToolTransformService.convert_mcp_schema_to_parameter(schema)[0].llm_description == ""
class TestApiProviderToController:
def test_api_key_header_auth(self):
db_provider = MagicMock()
db_provider.credentials = {"auth_type": "api_key_header"}
with patch(f"{MODULE}.ApiToolProviderController") as ctrl_cls:
ctrl_cls.from_db.return_value = MagicMock()
ToolTransformService.api_provider_to_controller(db_provider)
ctrl_cls.from_db.assert_called_once_with(db_provider=db_provider, auth_type=ApiProviderAuthType.API_KEY_HEADER)
def test_api_key_query_auth(self):
db_provider = MagicMock()
db_provider.credentials = {"auth_type": "api_key_query"}
with patch(f"{MODULE}.ApiToolProviderController") as ctrl_cls:
ctrl_cls.from_db.return_value = MagicMock()
ToolTransformService.api_provider_to_controller(db_provider)
ctrl_cls.from_db.assert_called_once_with(db_provider=db_provider, auth_type=ApiProviderAuthType.API_KEY_QUERY)
def test_legacy_api_key_maps_to_header(self):
db_provider = MagicMock()
db_provider.credentials = {"auth_type": "api_key"}
with patch(f"{MODULE}.ApiToolProviderController") as ctrl_cls:
ctrl_cls.from_db.return_value = MagicMock()
ToolTransformService.api_provider_to_controller(db_provider)
ctrl_cls.from_db.assert_called_once_with(db_provider=db_provider, auth_type=ApiProviderAuthType.API_KEY_HEADER)
def test_unknown_auth_defaults_to_none(self):
db_provider = MagicMock()
db_provider.credentials = {"auth_type": "something_else"}
with patch(f"{MODULE}.ApiToolProviderController") as ctrl_cls:
ctrl_cls.from_db.return_value = MagicMock()
ToolTransformService.api_provider_to_controller(db_provider)
ctrl_cls.from_db.assert_called_once_with(db_provider=db_provider, auth_type=ApiProviderAuthType.NONE)