From 2de818530bf384dff43d473456a2e09ea60190f1 Mon Sep 17 00:00:00 2001 From: Dev Sharma <50591491+cryptus-neoxys@users.noreply.github.com> Date: Tue, 31 Mar 2026 08:46:42 +0530 Subject: [PATCH 1/7] test: add tests for api/services retention, enterprise, plugin (#32648) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: QuantumGhost --- .../services/plugin/__init__.py | 0 .../services/plugin/test_plugin_lifecycle.py | 182 ++++++ .../services/retention/__init__.py | 0 .../retention/test_messages_clean_service.py | 348 ++++++++++ .../retention/test_workflow_run_archiver.py | 177 ++++++ .../enterprise/test_enterprise_service.py | 170 +++++ .../enterprise/test_plugin_manager_service.py | 49 ++ .../test_plugin_auto_upgrade_service.py | 183 ++++++ .../plugin/test_plugin_permission_service.py | 75 +++ .../unit_tests/services/retention/__init__.py | 0 .../retention/test_messages_clean_policy.py | 135 ++++ .../tools/test_tools_transform_service.py | 598 ++++++++++++++++++ 12 files changed, 1917 insertions(+) create mode 100644 api/tests/integration_tests/services/plugin/__init__.py create mode 100644 api/tests/integration_tests/services/plugin/test_plugin_lifecycle.py create mode 100644 api/tests/integration_tests/services/retention/__init__.py create mode 100644 api/tests/integration_tests/services/retention/test_messages_clean_service.py create mode 100644 api/tests/integration_tests/services/retention/test_workflow_run_archiver.py create mode 100644 api/tests/unit_tests/services/plugin/test_plugin_auto_upgrade_service.py create mode 100644 api/tests/unit_tests/services/plugin/test_plugin_permission_service.py create mode 100644 api/tests/unit_tests/services/retention/__init__.py create mode 100644 api/tests/unit_tests/services/retention/test_messages_clean_policy.py create mode 100644 api/tests/unit_tests/services/tools/test_tools_transform_service.py diff --git a/api/tests/integration_tests/services/plugin/__init__.py b/api/tests/integration_tests/services/plugin/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/services/plugin/test_plugin_lifecycle.py b/api/tests/integration_tests/services/plugin/test_plugin_lifecycle.py new file mode 100644 index 0000000000..951a5ab4b4 --- /dev/null +++ b/api/tests/integration_tests/services/plugin/test_plugin_lifecycle.py @@ -0,0 +1,182 @@ +import pytest +from sqlalchemy import delete + +from core.db.session_factory import session_factory +from models import Tenant +from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission +from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService +from services.plugin.plugin_permission_service import PluginPermissionService + + +@pytest.fixture +def tenant(flask_req_ctx): + with session_factory.create_session() as session: + t = Tenant(name="plugin_it_tenant") + session.add(t) + session.commit() + tenant_id = t.id + + yield tenant_id + + with session_factory.create_session() as session: + session.execute(delete(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id)) + session.execute( + delete(TenantPluginAutoUpgradeStrategy).where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) + ) + session.execute(delete(Tenant).where(Tenant.id == tenant_id)) + session.commit() + + +class TestPluginPermissionLifecycle: + def test_get_returns_none_for_new_tenant(self, tenant): + assert PluginPermissionService.get_permission(tenant) is None + + def test_change_creates_row(self, tenant): + result = PluginPermissionService.change_permission( + tenant, + TenantPluginPermission.InstallPermission.ADMINS, + TenantPluginPermission.DebugPermission.EVERYONE, + ) + assert result is True + + perm = PluginPermissionService.get_permission(tenant) + assert perm is not None + assert perm.install_permission == TenantPluginPermission.InstallPermission.ADMINS + assert perm.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE + + def test_change_updates_existing_row(self, tenant): + PluginPermissionService.change_permission( + tenant, + TenantPluginPermission.InstallPermission.ADMINS, + TenantPluginPermission.DebugPermission.NOBODY, + ) + PluginPermissionService.change_permission( + tenant, + TenantPluginPermission.InstallPermission.EVERYONE, + TenantPluginPermission.DebugPermission.ADMINS, + ) + perm = PluginPermissionService.get_permission(tenant) + assert perm is not None + assert perm.install_permission == TenantPluginPermission.InstallPermission.EVERYONE + assert perm.debug_permission == TenantPluginPermission.DebugPermission.ADMINS + + with session_factory.create_session() as session: + count = session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant).count() + assert count == 1 + + +class TestPluginAutoUpgradeLifecycle: + def test_get_returns_none_for_new_tenant(self, tenant): + assert PluginAutoUpgradeService.get_strategy(tenant) is None + + def test_change_creates_row(self, tenant): + result = PluginAutoUpgradeService.change_strategy( + tenant, + strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST, + upgrade_time_of_day=3, + upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL, + exclude_plugins=[], + include_plugins=[], + ) + assert result is True + + strategy = PluginAutoUpgradeService.get_strategy(tenant) + assert strategy is not None + assert strategy.strategy_setting == TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST + assert strategy.upgrade_time_of_day == 3 + + def test_change_updates_existing_row(self, tenant): + PluginAutoUpgradeService.change_strategy( + tenant, + strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY, + upgrade_time_of_day=0, + upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL, + exclude_plugins=[], + include_plugins=[], + ) + PluginAutoUpgradeService.change_strategy( + tenant, + strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST, + upgrade_time_of_day=12, + upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL, + exclude_plugins=[], + include_plugins=["plugin-a"], + ) + + strategy = PluginAutoUpgradeService.get_strategy(tenant) + assert strategy is not None + assert strategy.strategy_setting == TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST + assert strategy.upgrade_time_of_day == 12 + assert strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL + assert strategy.include_plugins == ["plugin-a"] + + def test_exclude_plugin_creates_strategy_when_none_exists(self, tenant): + PluginAutoUpgradeService.exclude_plugin(tenant, "my-plugin") + + strategy = PluginAutoUpgradeService.get_strategy(tenant) + assert strategy is not None + assert strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE + assert "my-plugin" in strategy.exclude_plugins + + def test_exclude_plugin_appends_in_exclude_mode(self, tenant): + PluginAutoUpgradeService.change_strategy( + tenant, + strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY, + upgrade_time_of_day=0, + upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE, + exclude_plugins=["existing"], + include_plugins=[], + ) + PluginAutoUpgradeService.exclude_plugin(tenant, "new-plugin") + + strategy = PluginAutoUpgradeService.get_strategy(tenant) + assert strategy is not None + assert "existing" in strategy.exclude_plugins + assert "new-plugin" in strategy.exclude_plugins + + def test_exclude_plugin_dedup_in_exclude_mode(self, tenant): + PluginAutoUpgradeService.change_strategy( + tenant, + strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY, + upgrade_time_of_day=0, + upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE, + exclude_plugins=["same-plugin"], + include_plugins=[], + ) + PluginAutoUpgradeService.exclude_plugin(tenant, "same-plugin") + + strategy = PluginAutoUpgradeService.get_strategy(tenant) + assert strategy is not None + assert strategy.exclude_plugins.count("same-plugin") == 1 + + def test_exclude_from_partial_mode_removes_from_include(self, tenant): + PluginAutoUpgradeService.change_strategy( + tenant, + strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY, + upgrade_time_of_day=0, + upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL, + exclude_plugins=[], + include_plugins=["p1", "p2"], + ) + PluginAutoUpgradeService.exclude_plugin(tenant, "p1") + + strategy = PluginAutoUpgradeService.get_strategy(tenant) + assert strategy is not None + assert "p1" not in strategy.include_plugins + assert "p2" in strategy.include_plugins + + def test_exclude_from_all_mode_switches_to_exclude(self, tenant): + PluginAutoUpgradeService.change_strategy( + tenant, + strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST, + upgrade_time_of_day=0, + upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL, + exclude_plugins=[], + include_plugins=[], + ) + PluginAutoUpgradeService.exclude_plugin(tenant, "excluded-plugin") + + strategy = PluginAutoUpgradeService.get_strategy(tenant) + assert strategy is not None + assert strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE + assert "excluded-plugin" in strategy.exclude_plugins diff --git a/api/tests/integration_tests/services/retention/__init__.py b/api/tests/integration_tests/services/retention/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/services/retention/test_messages_clean_service.py b/api/tests/integration_tests/services/retention/test_messages_clean_service.py new file mode 100644 index 0000000000..348bb0af4a --- /dev/null +++ b/api/tests/integration_tests/services/retention/test_messages_clean_service.py @@ -0,0 +1,348 @@ +import datetime +import math +import uuid + +import pytest +from sqlalchemy import delete + +from core.db.session_factory import session_factory +from models import Tenant +from models.enums import FeedbackFromSource, FeedbackRating +from models.model import ( + App, + Conversation, + Message, + MessageAnnotation, + MessageFeedback, +) +from services.retention.conversation.messages_clean_policy import BillingDisabledPolicy +from services.retention.conversation.messages_clean_service import MessagesCleanService + +_NOW = datetime.datetime(2026, 1, 15, 12, 0, 0, tzinfo=datetime.UTC) +_OLD = _NOW - datetime.timedelta(days=60) +_VERY_OLD = _NOW - datetime.timedelta(days=90) +_RECENT = _NOW - datetime.timedelta(days=5) + +_WINDOW_START = _VERY_OLD - datetime.timedelta(hours=1) +_WINDOW_END = _RECENT + datetime.timedelta(hours=1) + +_DEFAULT_BATCH_SIZE = 100 +_PAGINATION_MESSAGE_COUNT = 25 +_PAGINATION_BATCH_SIZE = 8 + + +@pytest.fixture +def tenant_and_app(flask_req_ctx): + """Creates a Tenant, App and Conversation for the test and cleans up after.""" + with session_factory.create_session() as session: + tenant = Tenant(name="retention_it_tenant") + session.add(tenant) + session.flush() + + app = App( + tenant_id=tenant.id, + name="Retention IT App", + mode="chat", + enable_site=True, + enable_api=True, + ) + session.add(app) + session.flush() + + conv = Conversation( + app_id=app.id, + mode="chat", + name="test_conv", + status="normal", + from_source="console", + _inputs={}, + ) + session.add(conv) + session.commit() + + tenant_id = tenant.id + app_id = app.id + conv_id = conv.id + + yield {"tenant_id": tenant_id, "app_id": app_id, "conversation_id": conv_id} + + with session_factory.create_session() as session: + session.execute(delete(Conversation).where(Conversation.id == conv_id)) + session.execute(delete(App).where(App.id == app_id)) + session.execute(delete(Tenant).where(Tenant.id == tenant_id)) + session.commit() + + +def _make_message(app_id: str, conversation_id: str, created_at: datetime.datetime) -> Message: + return Message( + app_id=app_id, + conversation_id=conversation_id, + query="test", + message=[{"text": "hello"}], + answer="world", + message_tokens=1, + message_unit_price=0, + answer_tokens=1, + answer_unit_price=0, + from_source="console", + currency="USD", + _inputs={}, + created_at=created_at, + ) + + +class TestMessagesCleanServiceIntegration: + @pytest.fixture + def seed_messages(self, tenant_and_app): + """Seeds one message at each of _VERY_OLD, _OLD, and _RECENT. + Yields a semantic mapping keyed by age label. + """ + data = tenant_and_app + app_id = data["app_id"] + conv_id = data["conversation_id"] + # Ordered tuple of (label, timestamp) for deterministic seeding + timestamps = [ + ("very_old", _VERY_OLD), + ("old", _OLD), + ("recent", _RECENT), + ] + msg_ids: dict[str, str] = {} + + with session_factory.create_session() as session: + for label, ts in timestamps: + msg = _make_message(app_id, conv_id, ts) + session.add(msg) + session.flush() + msg_ids[label] = msg.id + session.commit() + + yield {"msg_ids": msg_ids, **data} + + with session_factory.create_session() as session: + session.execute( + delete(Message) + .where(Message.id.in_(list(msg_ids.values()))) + .execution_options(synchronize_session=False) + ) + session.commit() + + @pytest.fixture + def paginated_seed_messages(self, tenant_and_app): + """Seeds multiple messages separated by 1-second increments starting at _OLD.""" + data = tenant_and_app + app_id = data["app_id"] + conv_id = data["conversation_id"] + msg_ids: list[str] = [] + + with session_factory.create_session() as session: + for i in range(_PAGINATION_MESSAGE_COUNT): + ts = _OLD + datetime.timedelta(seconds=i) + msg = _make_message(app_id, conv_id, ts) + session.add(msg) + session.flush() + msg_ids.append(msg.id) + session.commit() + + yield {"msg_ids": msg_ids, **data} + + with session_factory.create_session() as session: + session.execute(delete(Message).where(Message.id.in_(msg_ids)).execution_options(synchronize_session=False)) + session.commit() + + @pytest.fixture + def cascade_test_data(self, tenant_and_app): + """Seeds one Message with an associated Feedback and Annotation.""" + data = tenant_and_app + app_id = data["app_id"] + conv_id = data["conversation_id"] + + with session_factory.create_session() as session: + msg = _make_message(app_id, conv_id, _OLD) + session.add(msg) + session.flush() + + feedback = MessageFeedback( + app_id=app_id, + conversation_id=conv_id, + message_id=msg.id, + rating=FeedbackRating.LIKE, + from_source=FeedbackFromSource.USER, + ) + annotation = MessageAnnotation( + app_id=app_id, + conversation_id=conv_id, + message_id=msg.id, + question="q", + content="a", + account_id=str(uuid.uuid4()), + ) + session.add_all([feedback, annotation]) + session.commit() + + msg_id = msg.id + fb_id = feedback.id + ann_id = annotation.id + + yield {"msg_id": msg_id, "fb_id": fb_id, "ann_id": ann_id, **data} + + with session_factory.create_session() as session: + session.execute(delete(MessageAnnotation).where(MessageAnnotation.id == ann_id)) + session.execute(delete(MessageFeedback).where(MessageFeedback.id == fb_id)) + session.execute(delete(Message).where(Message.id == msg_id)) + session.commit() + + def test_dry_run_does_not_delete(self, seed_messages): + """Dry-run must count eligible rows without deleting any of them.""" + data = seed_messages + msg_ids = data["msg_ids"] + all_ids = list(msg_ids.values()) + + svc = MessagesCleanService.from_time_range( + policy=BillingDisabledPolicy(), + start_from=_WINDOW_START, + end_before=_WINDOW_END, + batch_size=_DEFAULT_BATCH_SIZE, + dry_run=True, + ) + stats = svc.run() + + assert stats["filtered_messages"] == len(all_ids) + assert stats["total_deleted"] == 0 + + with session_factory.create_session() as session: + remaining = session.query(Message).where(Message.id.in_(all_ids)).count() + assert remaining == len(all_ids) + + def test_billing_disabled_deletes_all_in_range(self, seed_messages): + """All 3 seeded messages fall within the window and must be deleted.""" + data = seed_messages + msg_ids = data["msg_ids"] + all_ids = list(msg_ids.values()) + + svc = MessagesCleanService.from_time_range( + policy=BillingDisabledPolicy(), + start_from=_WINDOW_START, + end_before=_WINDOW_END, + batch_size=_DEFAULT_BATCH_SIZE, + dry_run=False, + ) + stats = svc.run() + + assert stats["total_deleted"] == len(all_ids) + + with session_factory.create_session() as session: + remaining = session.query(Message).where(Message.id.in_(all_ids)).count() + assert remaining == 0 + + def test_start_from_filters_correctly(self, seed_messages): + """Only the message at _OLD falls within the narrow ±1 h window.""" + data = seed_messages + msg_ids = data["msg_ids"] + + start = _OLD - datetime.timedelta(hours=1) + end = _OLD + datetime.timedelta(hours=1) + + svc = MessagesCleanService.from_time_range( + policy=BillingDisabledPolicy(), + start_from=start, + end_before=end, + batch_size=_DEFAULT_BATCH_SIZE, + ) + stats = svc.run() + + assert stats["total_deleted"] == 1 + + with session_factory.create_session() as session: + all_ids = list(msg_ids.values()) + remaining_ids = {r[0] for r in session.query(Message.id).where(Message.id.in_(all_ids)).all()} + + assert msg_ids["old"] not in remaining_ids + assert msg_ids["very_old"] in remaining_ids + assert msg_ids["recent"] in remaining_ids + + def test_cursor_pagination_across_batches(self, paginated_seed_messages): + """Messages must be deleted across multiple batches.""" + data = paginated_seed_messages + msg_ids = data["msg_ids"] + + # _OLD is the earliest; the last one is _OLD + (_PAGINATION_MESSAGE_COUNT - 1) s. + pagination_window_start = _OLD - datetime.timedelta(seconds=1) + pagination_window_end = _OLD + datetime.timedelta(seconds=_PAGINATION_MESSAGE_COUNT) + + svc = MessagesCleanService.from_time_range( + policy=BillingDisabledPolicy(), + start_from=pagination_window_start, + end_before=pagination_window_end, + batch_size=_PAGINATION_BATCH_SIZE, + ) + stats = svc.run() + + assert stats["total_deleted"] == _PAGINATION_MESSAGE_COUNT + expected_batches = math.ceil(_PAGINATION_MESSAGE_COUNT / _PAGINATION_BATCH_SIZE) + assert stats["batches"] >= expected_batches + + with session_factory.create_session() as session: + remaining = session.query(Message).where(Message.id.in_(msg_ids)).count() + assert remaining == 0 + + def test_no_messages_in_range_returns_empty_stats(self, seed_messages): + """A window entirely in the future must yield zero matches.""" + far_future = _NOW + datetime.timedelta(days=365) + even_further = far_future + datetime.timedelta(days=1) + + svc = MessagesCleanService.from_time_range( + policy=BillingDisabledPolicy(), + start_from=far_future, + end_before=even_further, + batch_size=_DEFAULT_BATCH_SIZE, + ) + stats = svc.run() + + assert stats["total_messages"] == 0 + assert stats["total_deleted"] == 0 + + def test_relation_cascade_deletes(self, cascade_test_data): + """Deleting a Message must cascade to its Feedback and Annotation rows.""" + data = cascade_test_data + msg_id = data["msg_id"] + fb_id = data["fb_id"] + ann_id = data["ann_id"] + + svc = MessagesCleanService.from_time_range( + policy=BillingDisabledPolicy(), + start_from=_OLD - datetime.timedelta(hours=1), + end_before=_OLD + datetime.timedelta(hours=1), + batch_size=_DEFAULT_BATCH_SIZE, + ) + stats = svc.run() + + assert stats["total_deleted"] == 1 + + with session_factory.create_session() as session: + assert session.query(Message).where(Message.id == msg_id).count() == 0 + assert session.query(MessageFeedback).where(MessageFeedback.id == fb_id).count() == 0 + assert session.query(MessageAnnotation).where(MessageAnnotation.id == ann_id).count() == 0 + + def test_factory_from_time_range_validation(self): + with pytest.raises(ValueError, match="start_from"): + MessagesCleanService.from_time_range( + policy=BillingDisabledPolicy(), + start_from=_NOW, + end_before=_OLD, + ) + + def test_factory_from_days_validation(self): + with pytest.raises(ValueError, match="days"): + MessagesCleanService.from_days( + policy=BillingDisabledPolicy(), + days=-1, + ) + + def test_factory_batch_size_validation(self): + with pytest.raises(ValueError, match="batch_size"): + MessagesCleanService.from_time_range( + policy=BillingDisabledPolicy(), + start_from=_OLD, + end_before=_NOW, + batch_size=0, + ) diff --git a/api/tests/integration_tests/services/retention/test_workflow_run_archiver.py b/api/tests/integration_tests/services/retention/test_workflow_run_archiver.py new file mode 100644 index 0000000000..5728eacdfb --- /dev/null +++ b/api/tests/integration_tests/services/retention/test_workflow_run_archiver.py @@ -0,0 +1,177 @@ +import datetime +import io +import json +import uuid +import zipfile +from unittest.mock import MagicMock, patch + +import pytest + +from services.retention.workflow_run.archive_paid_plan_workflow_run import ( + ArchiveSummary, + WorkflowRunArchiver, +) +from services.retention.workflow_run.constants import ARCHIVE_SCHEMA_VERSION + + +class TestWorkflowRunArchiverInit: + def test_start_from_without_end_before_raises(self): + with pytest.raises(ValueError, match="start_from and end_before must be provided together"): + WorkflowRunArchiver(start_from=datetime.datetime(2025, 1, 1)) + + def test_end_before_without_start_from_raises(self): + with pytest.raises(ValueError, match="start_from and end_before must be provided together"): + WorkflowRunArchiver(end_before=datetime.datetime(2025, 1, 1)) + + def test_start_equals_end_raises(self): + ts = datetime.datetime(2025, 1, 1) + with pytest.raises(ValueError, match="start_from must be earlier than end_before"): + WorkflowRunArchiver(start_from=ts, end_before=ts) + + def test_start_after_end_raises(self): + with pytest.raises(ValueError, match="start_from must be earlier than end_before"): + WorkflowRunArchiver( + start_from=datetime.datetime(2025, 6, 1), + end_before=datetime.datetime(2025, 1, 1), + ) + + def test_workers_zero_raises(self): + with pytest.raises(ValueError, match="workers must be at least 1"): + WorkflowRunArchiver(workers=0) + + def test_valid_init_defaults(self): + archiver = WorkflowRunArchiver(days=30, batch_size=50) + assert archiver.days == 30 + assert archiver.batch_size == 50 + assert archiver.dry_run is False + assert archiver.delete_after_archive is False + assert archiver.start_from is None + + def test_valid_init_with_time_range(self): + start = datetime.datetime(2025, 1, 1) + end = datetime.datetime(2025, 6, 1) + archiver = WorkflowRunArchiver(start_from=start, end_before=end, workers=2) + assert archiver.start_from is not None + assert archiver.end_before is not None + assert archiver.workers == 2 + + +class TestBuildArchiveBundle: + def test_bundle_contains_manifest_and_all_tables(self): + archiver = WorkflowRunArchiver(days=90) + + manifest_data = json.dumps({"schema_version": ARCHIVE_SCHEMA_VERSION}).encode("utf-8") + table_payloads = dict.fromkeys(archiver.ARCHIVED_TABLES, b"") + + bundle_bytes = archiver._build_archive_bundle(manifest_data, table_payloads) + + with zipfile.ZipFile(io.BytesIO(bundle_bytes), "r") as zf: + names = set(zf.namelist()) + assert "manifest.json" in names + for table in archiver.ARCHIVED_TABLES: + assert f"{table}.jsonl" in names, f"Missing {table}.jsonl in bundle" + + def test_bundle_missing_table_payload_raises(self): + archiver = WorkflowRunArchiver(days=90) + manifest_data = b"{}" + incomplete_payloads = {archiver.ARCHIVED_TABLES[0]: b"data"} + + with pytest.raises(ValueError, match="Missing archive payload"): + archiver._build_archive_bundle(manifest_data, incomplete_payloads) + + +class TestGenerateManifest: + def test_manifest_structure(self): + archiver = WorkflowRunArchiver(days=90) + from services.retention.workflow_run.archive_paid_plan_workflow_run import TableStats + + run = MagicMock() + run.id = str(uuid.uuid4()) + run.tenant_id = str(uuid.uuid4()) + run.app_id = str(uuid.uuid4()) + run.workflow_id = str(uuid.uuid4()) + run.created_at = datetime.datetime(2025, 3, 15, 10, 0, 0) + + stats = [ + TableStats(table_name="workflow_runs", row_count=1, checksum="abc123", size_bytes=512), + TableStats(table_name="workflow_app_logs", row_count=2, checksum="def456", size_bytes=1024), + ] + + manifest = archiver._generate_manifest(run, stats) + + assert manifest["schema_version"] == ARCHIVE_SCHEMA_VERSION + assert manifest["workflow_run_id"] == run.id + assert manifest["tenant_id"] == run.tenant_id + assert manifest["app_id"] == run.app_id + assert "tables" in manifest + assert manifest["tables"]["workflow_runs"]["row_count"] == 1 + assert manifest["tables"]["workflow_runs"]["checksum"] == "abc123" + assert manifest["tables"]["workflow_app_logs"]["row_count"] == 2 + + +class TestFilterPaidTenants: + def test_all_tenants_paid_when_billing_disabled(self): + archiver = WorkflowRunArchiver(days=90) + tenant_ids = {"t1", "t2", "t3"} + + with patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg: + cfg.BILLING_ENABLED = False + result = archiver._filter_paid_tenants(tenant_ids) + + assert result == tenant_ids + + def test_empty_tenants_returns_empty(self): + archiver = WorkflowRunArchiver(days=90) + + with patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg: + cfg.BILLING_ENABLED = True + result = archiver._filter_paid_tenants(set()) + + assert result == set() + + def test_only_paid_plans_returned(self): + archiver = WorkflowRunArchiver(days=90) + + mock_bulk = { + "t1": {"plan": "professional"}, + "t2": {"plan": "sandbox"}, + "t3": {"plan": "team"}, + } + + with ( + patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg, + patch("services.retention.workflow_run.archive_paid_plan_workflow_run.BillingService") as billing, + ): + cfg.BILLING_ENABLED = True + billing.get_plan_bulk_with_cache.return_value = mock_bulk + result = archiver._filter_paid_tenants({"t1", "t2", "t3"}) + + assert "t1" in result + assert "t3" in result + assert "t2" not in result + + def test_billing_api_failure_returns_empty(self): + archiver = WorkflowRunArchiver(days=90) + + with ( + patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg, + patch("services.retention.workflow_run.archive_paid_plan_workflow_run.BillingService") as billing, + ): + cfg.BILLING_ENABLED = True + billing.get_plan_bulk_with_cache.side_effect = RuntimeError("API down") + result = archiver._filter_paid_tenants({"t1"}) + + assert result == set() + + +class TestDryRunArchive: + @patch("services.retention.workflow_run.archive_paid_plan_workflow_run.get_archive_storage") + def test_dry_run_does_not_call_storage(self, mock_get_storage, flask_req_ctx): + archiver = WorkflowRunArchiver(days=90, dry_run=True) + + with patch.object(archiver, "_get_runs_batch", return_value=[]): + summary = archiver.run() + + mock_get_storage.assert_not_called() + assert isinstance(summary, ArchiveSummary) + assert summary.runs_failed == 0 diff --git a/api/tests/unit_tests/services/enterprise/test_enterprise_service.py b/api/tests/unit_tests/services/enterprise/test_enterprise_service.py index 59c07bfb37..6ad6a490b0 100644 --- a/api/tests/unit_tests/services/enterprise/test_enterprise_service.py +++ b/api/tests/unit_tests/services/enterprise/test_enterprise_service.py @@ -5,6 +5,7 @@ Covers: - License status caching (get_cached_license_status) """ +from datetime import datetime from unittest.mock import patch import pytest @@ -15,9 +16,178 @@ from services.enterprise.enterprise_service import ( VALID_LICENSE_CACHE_TTL, DefaultWorkspaceJoinResult, EnterpriseService, + WebAppSettings, + WorkspacePermission, try_join_default_workspace, ) +MODULE = "services.enterprise.enterprise_service" + + +class TestEnterpriseServiceInfo: + def test_get_info_delegates(self): + with patch(f"{MODULE}.EnterpriseRequest") as req: + req.send_request.return_value = {"version": "1.0"} + result = EnterpriseService.get_info() + + req.send_request.assert_called_once_with("GET", "/info") + assert result == {"version": "1.0"} + + def test_get_workspace_info_delegates(self): + with patch(f"{MODULE}.EnterpriseRequest") as req: + req.send_request.return_value = {"name": "ws"} + result = EnterpriseService.get_workspace_info("tenant-1") + + req.send_request.assert_called_once_with("GET", "/workspace/tenant-1/info") + assert result == {"name": "ws"} + + +class TestSsoSettingsLastUpdateTime: + def test_app_sso_parses_valid_timestamp(self): + with patch(f"{MODULE}.EnterpriseRequest") as req: + req.send_request.return_value = "2025-01-15T10:30:00+00:00" + result = EnterpriseService.get_app_sso_settings_last_update_time() + + assert isinstance(result, datetime) + assert result.year == 2025 + + def test_app_sso_raises_on_empty(self): + with patch(f"{MODULE}.EnterpriseRequest") as req: + req.send_request.return_value = "" + with pytest.raises(ValueError, match="No data found"): + EnterpriseService.get_app_sso_settings_last_update_time() + + def test_app_sso_raises_on_invalid_format(self): + with patch(f"{MODULE}.EnterpriseRequest") as req: + req.send_request.return_value = "not-a-date" + with pytest.raises(ValueError, match="Invalid date format"): + EnterpriseService.get_app_sso_settings_last_update_time() + + def test_workspace_sso_parses_valid_timestamp(self): + with patch(f"{MODULE}.EnterpriseRequest") as req: + req.send_request.return_value = "2025-06-01T00:00:00+00:00" + result = EnterpriseService.get_workspace_sso_settings_last_update_time() + + assert isinstance(result, datetime) + + def test_workspace_sso_raises_on_empty(self): + with patch(f"{MODULE}.EnterpriseRequest") as req: + req.send_request.return_value = None + with pytest.raises(ValueError, match="No data found"): + EnterpriseService.get_workspace_sso_settings_last_update_time() + + +class TestWorkspacePermissionService: + def test_raises_on_empty_workspace_id(self): + with pytest.raises(ValueError, match="workspace_id must be provided"): + EnterpriseService.WorkspacePermissionService.get_permission("") + + def test_raises_on_missing_data(self): + with patch(f"{MODULE}.EnterpriseRequest") as req: + req.send_request.return_value = None + with pytest.raises(ValueError, match="No data found"): + EnterpriseService.WorkspacePermissionService.get_permission("ws-1") + + def test_raises_on_missing_permission_key(self): + with patch(f"{MODULE}.EnterpriseRequest") as req: + req.send_request.return_value = {"other": "data"} + with pytest.raises(ValueError, match="No data found"): + EnterpriseService.WorkspacePermissionService.get_permission("ws-1") + + def test_returns_parsed_permission(self): + with patch(f"{MODULE}.EnterpriseRequest") as req: + req.send_request.return_value = { + "permission": { + "workspaceId": "ws-1", + "allowMemberInvite": True, + "allowOwnerTransfer": False, + } + } + result = EnterpriseService.WorkspacePermissionService.get_permission("ws-1") + + assert isinstance(result, WorkspacePermission) + assert result.workspace_id == "ws-1" + assert result.allow_member_invite is True + assert result.allow_owner_transfer is False + + +class TestWebAppAuth: + def test_is_user_allowed_returns_result_field(self): + with patch(f"{MODULE}.EnterpriseRequest") as req: + req.send_request.return_value = {"result": True} + assert EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp("u1", "a1") is True + + def test_is_user_allowed_defaults_false(self): + with patch(f"{MODULE}.EnterpriseRequest") as req: + req.send_request.return_value = {} + assert EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp("u1", "a1") is False + + def test_batch_is_user_allowed_returns_empty_for_no_apps(self): + assert EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps("u1", []) == {} + + def test_batch_is_user_allowed_raises_on_empty_response(self): + with patch(f"{MODULE}.EnterpriseRequest") as req: + req.send_request.return_value = None + with pytest.raises(ValueError, match="No data found"): + EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps("u1", ["a1"]) + + def test_get_app_access_mode_raises_on_empty_app_id(self): + with pytest.raises(ValueError, match="app_id must be provided"): + EnterpriseService.WebAppAuth.get_app_access_mode_by_id("") + + def test_get_app_access_mode_returns_settings(self): + with patch(f"{MODULE}.EnterpriseRequest") as req: + req.send_request.return_value = {"accessMode": "public"} + result = EnterpriseService.WebAppAuth.get_app_access_mode_by_id("a1") + + assert isinstance(result, WebAppSettings) + assert result.access_mode == "public" + + def test_batch_get_returns_empty_for_no_apps(self): + assert EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id([]) == {} + + def test_batch_get_maps_access_modes(self): + with patch(f"{MODULE}.EnterpriseRequest") as req: + req.send_request.return_value = {"accessModes": {"a1": "public", "a2": "private"}} + result = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(["a1", "a2"]) + + assert result["a1"].access_mode == "public" + assert result["a2"].access_mode == "private" + + def test_batch_get_raises_on_invalid_format(self): + with patch(f"{MODULE}.EnterpriseRequest") as req: + req.send_request.return_value = {"accessModes": "not-a-dict"} + with pytest.raises(ValueError, match="Invalid data format"): + EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(["a1"]) + + def test_update_access_mode_raises_on_empty_app_id(self): + with pytest.raises(ValueError, match="app_id must be provided"): + EnterpriseService.WebAppAuth.update_app_access_mode("", "public") + + def test_update_access_mode_raises_on_invalid_mode(self): + with pytest.raises(ValueError, match="access_mode must be"): + EnterpriseService.WebAppAuth.update_app_access_mode("a1", "invalid") + + def test_update_access_mode_delegates_and_returns(self): + with patch(f"{MODULE}.EnterpriseRequest") as req: + req.send_request.return_value = {"result": True} + result = EnterpriseService.WebAppAuth.update_app_access_mode("a1", "public") + + assert result is True + req.send_request.assert_called_once_with( + "POST", "/webapp/access-mode", json={"appId": "a1", "accessMode": "public"} + ) + + def test_cleanup_webapp_raises_on_empty_app_id(self): + with pytest.raises(ValueError, match="app_id must be provided"): + EnterpriseService.WebAppAuth.cleanup_webapp("") + + def test_cleanup_webapp_delegates(self): + with patch(f"{MODULE}.EnterpriseRequest") as req: + EnterpriseService.WebAppAuth.cleanup_webapp("a1") + + req.send_request.assert_called_once_with("DELETE", "/webapp/clean", params={"appId": "a1"}) + class TestJoinDefaultWorkspace: def test_join_default_workspace_success(self): diff --git a/api/tests/unit_tests/services/enterprise/test_plugin_manager_service.py b/api/tests/unit_tests/services/enterprise/test_plugin_manager_service.py index 6ee328ae2c..759d907934 100644 --- a/api/tests/unit_tests/services/enterprise/test_plugin_manager_service.py +++ b/api/tests/unit_tests/services/enterprise/test_plugin_manager_service.py @@ -7,14 +7,20 @@ This module covers the pre-uninstall plugin hook behavior: from unittest.mock import patch +import pytest from httpx import HTTPStatusError from configs import dify_config from services.enterprise.plugin_manager_service import ( + CheckCredentialPolicyComplianceRequest, + CredentialPolicyViolationError, + PluginCredentialType, PluginManagerService, PreUninstallPluginRequest, ) +MODULE = "services.enterprise.plugin_manager_service" + class TestTryPreUninstallPlugin: def test_try_pre_uninstall_plugin_success(self): @@ -88,3 +94,46 @@ class TestTryPreUninstallPlugin: timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT, ) mock_logger.exception.assert_called_once() + + +class TestCheckCredentialPolicyCompliance: + def _request(self, cred_type=PluginCredentialType.MODEL): + return CheckCredentialPolicyComplianceRequest( + dify_credential_id="cred-1", provider="openai", credential_type=cred_type + ) + + def test_passes_when_result_true(self): + with patch(f"{MODULE}.EnterprisePluginManagerRequest") as req: + req.send_request.return_value = {"result": True} + PluginManagerService.check_credential_policy_compliance(self._request()) + + req.send_request.assert_called_once() + + def test_raises_violation_when_result_false(self): + with patch(f"{MODULE}.EnterprisePluginManagerRequest") as req: + req.send_request.return_value = {"result": False} + with pytest.raises(CredentialPolicyViolationError, match="Credentials not available"): + PluginManagerService.check_credential_policy_compliance(self._request()) + + def test_raises_violation_on_invalid_response_format(self): + with patch(f"{MODULE}.EnterprisePluginManagerRequest") as req: + req.send_request.return_value = "not-a-dict" + with pytest.raises(CredentialPolicyViolationError, match="error occurred"): + PluginManagerService.check_credential_policy_compliance(self._request()) + + def test_raises_violation_on_api_exception(self): + with patch(f"{MODULE}.EnterprisePluginManagerRequest") as req: + req.send_request.side_effect = ConnectionError("network fail") + with pytest.raises(CredentialPolicyViolationError, match="error occurred"): + PluginManagerService.check_credential_policy_compliance(self._request()) + + def test_model_dump_serializes_credential_type_as_number(self): + body = self._request(PluginCredentialType.TOOL) + data = body.model_dump() + + assert data["credential_type"] == 1 + assert data["dify_credential_id"] == "cred-1" + + def test_model_credential_type_values(self): + assert PluginCredentialType.MODEL.to_number() == 0 + assert PluginCredentialType.TOOL.to_number() == 1 diff --git a/api/tests/unit_tests/services/plugin/test_plugin_auto_upgrade_service.py b/api/tests/unit_tests/services/plugin/test_plugin_auto_upgrade_service.py new file mode 100644 index 0000000000..edb50d09a6 --- /dev/null +++ b/api/tests/unit_tests/services/plugin/test_plugin_auto_upgrade_service.py @@ -0,0 +1,183 @@ +from unittest.mock import MagicMock, patch + +from models.account import TenantPluginAutoUpgradeStrategy + +MODULE = "services.plugin.plugin_auto_upgrade_service" + + +def _patched_session(): + """Patch Session(db.engine) to return a mock session as context manager.""" + session = MagicMock() + session_cls = MagicMock() + session_cls.return_value.__enter__ = MagicMock(return_value=session) + session_cls.return_value.__exit__ = MagicMock(return_value=False) + patcher = patch(f"{MODULE}.Session", session_cls) + db_patcher = patch(f"{MODULE}.db") + return patcher, db_patcher, session + + +class TestGetStrategy: + def test_returns_strategy_when_found(self): + p1, p2, session = _patched_session() + strategy = MagicMock() + session.query.return_value.where.return_value.first.return_value = strategy + + with p1, p2: + from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService + + result = PluginAutoUpgradeService.get_strategy("t1") + + assert result is strategy + + def test_returns_none_when_not_found(self): + p1, p2, session = _patched_session() + session.query.return_value.where.return_value.first.return_value = None + + with p1, p2: + from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService + + result = PluginAutoUpgradeService.get_strategy("t1") + + assert result is None + + +class TestChangeStrategy: + def test_creates_new_strategy(self): + p1, p2, session = _patched_session() + session.query.return_value.where.return_value.first.return_value = None + + with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: + strat_cls.return_value = MagicMock() + from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService + + result = PluginAutoUpgradeService.change_strategy( + "t1", + TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY, + 3, + TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL, + [], + [], + ) + + assert result is True + session.add.assert_called_once() + session.commit.assert_called_once() + + def test_updates_existing_strategy(self): + p1, p2, session = _patched_session() + existing = MagicMock() + session.query.return_value.where.return_value.first.return_value = existing + + with p1, p2: + from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService + + result = PluginAutoUpgradeService.change_strategy( + "t1", + TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST, + 5, + TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL, + ["p1"], + ["p2"], + ) + + assert result is True + assert existing.strategy_setting == TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST + assert existing.upgrade_time_of_day == 5 + assert existing.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL + assert existing.exclude_plugins == ["p1"] + assert existing.include_plugins == ["p2"] + session.commit.assert_called_once() + + +class TestExcludePlugin: + def test_creates_default_strategy_when_none_exists(self): + p1, p2, session = _patched_session() + session.query.return_value.where.return_value.first.return_value = None + + with ( + p1, + p2, + patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls, + patch(f"{MODULE}.PluginAutoUpgradeService.change_strategy") as cs, + ): + strat_cls.StrategySetting.FIX_ONLY = "fix_only" + strat_cls.UpgradeMode.EXCLUDE = "exclude" + cs.return_value = True + from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService + + result = PluginAutoUpgradeService.exclude_plugin("t1", "plugin-1") + + assert result is True + cs.assert_called_once() + + def test_appends_to_exclude_list_in_exclude_mode(self): + p1, p2, session = _patched_session() + existing = MagicMock() + existing.upgrade_mode = "exclude" + existing.exclude_plugins = ["p-existing"] + session.query.return_value.where.return_value.first.return_value = existing + + with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: + strat_cls.UpgradeMode.EXCLUDE = "exclude" + strat_cls.UpgradeMode.PARTIAL = "partial" + strat_cls.UpgradeMode.ALL = "all" + from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService + + result = PluginAutoUpgradeService.exclude_plugin("t1", "p-new") + + assert result is True + assert existing.exclude_plugins == ["p-existing", "p-new"] + session.commit.assert_called_once() + + def test_removes_from_include_list_in_partial_mode(self): + p1, p2, session = _patched_session() + existing = MagicMock() + existing.upgrade_mode = "partial" + existing.include_plugins = ["p1", "p2"] + session.query.return_value.where.return_value.first.return_value = existing + + with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: + strat_cls.UpgradeMode.EXCLUDE = "exclude" + strat_cls.UpgradeMode.PARTIAL = "partial" + strat_cls.UpgradeMode.ALL = "all" + from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService + + result = PluginAutoUpgradeService.exclude_plugin("t1", "p1") + + assert result is True + assert existing.include_plugins == ["p2"] + + def test_switches_to_exclude_mode_from_all(self): + p1, p2, session = _patched_session() + existing = MagicMock() + existing.upgrade_mode = "all" + session.query.return_value.where.return_value.first.return_value = existing + + with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: + strat_cls.UpgradeMode.EXCLUDE = "exclude" + strat_cls.UpgradeMode.PARTIAL = "partial" + strat_cls.UpgradeMode.ALL = "all" + from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService + + result = PluginAutoUpgradeService.exclude_plugin("t1", "p1") + + assert result is True + assert existing.upgrade_mode == "exclude" + assert existing.exclude_plugins == ["p1"] + + def test_no_duplicate_in_exclude_list(self): + p1, p2, session = _patched_session() + existing = MagicMock() + existing.upgrade_mode = "exclude" + existing.exclude_plugins = ["p1"] + session.query.return_value.where.return_value.first.return_value = existing + + with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: + strat_cls.UpgradeMode.EXCLUDE = "exclude" + strat_cls.UpgradeMode.PARTIAL = "partial" + strat_cls.UpgradeMode.ALL = "all" + from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService + + PluginAutoUpgradeService.exclude_plugin("t1", "p1") + + assert existing.exclude_plugins == ["p1"] diff --git a/api/tests/unit_tests/services/plugin/test_plugin_permission_service.py b/api/tests/unit_tests/services/plugin/test_plugin_permission_service.py new file mode 100644 index 0000000000..69091110db --- /dev/null +++ b/api/tests/unit_tests/services/plugin/test_plugin_permission_service.py @@ -0,0 +1,75 @@ +from unittest.mock import MagicMock, patch + +from models.account import TenantPluginPermission + +MODULE = "services.plugin.plugin_permission_service" + + +def _patched_session(): + """Patch Session(db.engine) to return a mock session as context manager.""" + session = MagicMock() + session_cls = MagicMock() + session_cls.return_value.__enter__ = MagicMock(return_value=session) + session_cls.return_value.__exit__ = MagicMock(return_value=False) + patcher = patch(f"{MODULE}.Session", session_cls) + db_patcher = patch(f"{MODULE}.db") + return patcher, db_patcher, session + + +class TestGetPermission: + def test_returns_permission_when_found(self): + p1, p2, session = _patched_session() + permission = MagicMock() + session.query.return_value.where.return_value.first.return_value = permission + + with p1, p2: + from services.plugin.plugin_permission_service import PluginPermissionService + + result = PluginPermissionService.get_permission("t1") + + assert result is permission + + def test_returns_none_when_not_found(self): + p1, p2, session = _patched_session() + session.query.return_value.where.return_value.first.return_value = None + + with p1, p2: + from services.plugin.plugin_permission_service import PluginPermissionService + + result = PluginPermissionService.get_permission("t1") + + assert result is None + + +class TestChangePermission: + def test_creates_new_permission_when_not_exists(self): + p1, p2, session = _patched_session() + session.query.return_value.where.return_value.first.return_value = None + + with p1, p2, patch(f"{MODULE}.TenantPluginPermission") as perm_cls: + perm_cls.return_value = MagicMock() + from services.plugin.plugin_permission_service import PluginPermissionService + + result = PluginPermissionService.change_permission( + "t1", TenantPluginPermission.InstallPermission.EVERYONE, TenantPluginPermission.DebugPermission.EVERYONE + ) + + session.add.assert_called_once() + session.commit.assert_called_once() + + def test_updates_existing_permission(self): + p1, p2, session = _patched_session() + existing = MagicMock() + session.query.return_value.where.return_value.first.return_value = existing + + with p1, p2: + from services.plugin.plugin_permission_service import PluginPermissionService + + result = PluginPermissionService.change_permission( + "t1", TenantPluginPermission.InstallPermission.ADMINS, TenantPluginPermission.DebugPermission.ADMINS + ) + + assert existing.install_permission == TenantPluginPermission.InstallPermission.ADMINS + assert existing.debug_permission == TenantPluginPermission.DebugPermission.ADMINS + session.commit.assert_called_once() + session.add.assert_not_called() diff --git a/api/tests/unit_tests/services/retention/__init__.py b/api/tests/unit_tests/services/retention/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/services/retention/test_messages_clean_policy.py b/api/tests/unit_tests/services/retention/test_messages_clean_policy.py new file mode 100644 index 0000000000..79c079c683 --- /dev/null +++ b/api/tests/unit_tests/services/retention/test_messages_clean_policy.py @@ -0,0 +1,135 @@ +import datetime +from unittest.mock import MagicMock, patch + +from services.retention.conversation.messages_clean_policy import ( + BillingDisabledPolicy, + BillingSandboxPolicy, + SimpleMessage, + create_message_clean_policy, +) + +MODULE = "services.retention.conversation.messages_clean_policy" + + +def _msg(msg_id: str, app_id: str, days_ago: int = 0) -> SimpleMessage: + return SimpleMessage( + id=msg_id, + app_id=app_id, + created_at=datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=days_ago), + ) + + +class TestBillingDisabledPolicy: + def test_returns_all_message_ids(self): + policy = BillingDisabledPolicy() + msgs = [_msg("m1", "app1"), _msg("m2", "app2"), _msg("m3", "app1")] + + result = policy.filter_message_ids(msgs, {"app1": "t1", "app2": "t2"}) + + assert set(result) == {"m1", "m2", "m3"} + + def test_empty_messages_returns_empty(self): + assert BillingDisabledPolicy().filter_message_ids([], {}) == [] + + +class TestBillingSandboxPolicy: + def _policy(self, plans, *, graceful_days=21, whitelist=None, now=1_000_000_000): + return BillingSandboxPolicy( + plan_provider=lambda _ids: plans, + graceful_period_days=graceful_days, + tenant_whitelist=whitelist, + current_timestamp=now, + ) + + def test_empty_messages_returns_empty(self): + policy = self._policy({}) + assert policy.filter_message_ids([], {"app1": "t1"}) == [] + + def test_empty_app_to_tenant_returns_empty(self): + policy = self._policy({}) + assert policy.filter_message_ids([_msg("m1", "app1")], {}) == [] + + def test_empty_plans_returns_empty(self): + policy = self._policy({}) + msgs = [_msg("m1", "app1")] + assert policy.filter_message_ids(msgs, {"app1": "t1"}) == [] + + def test_non_sandbox_tenant_skipped(self): + plans = {"t1": {"plan": "professional", "expiration_date": 0}} + policy = self._policy(plans) + msgs = [_msg("m1", "app1")] + + assert policy.filter_message_ids(msgs, {"app1": "t1"}) == [] + + def test_sandbox_no_previous_subscription_deletes(self): + plans = {"t1": {"plan": "sandbox", "expiration_date": -1}} + policy = self._policy(plans) + msgs = [_msg("m1", "app1")] + + assert policy.filter_message_ids(msgs, {"app1": "t1"}) == ["m1"] + + def test_sandbox_expired_beyond_grace_period_deletes(self): + now = 1_000_000_000 + expired_long_ago = now - (22 * 24 * 60 * 60) # 22 days ago > 21 day grace + plans = {"t1": {"plan": "sandbox", "expiration_date": expired_long_ago}} + policy = self._policy(plans, now=now) + msgs = [_msg("m1", "app1")] + + assert policy.filter_message_ids(msgs, {"app1": "t1"}) == ["m1"] + + def test_sandbox_within_grace_period_kept(self): + now = 1_000_000_000 + expired_recently = now - (10 * 24 * 60 * 60) # 10 days ago < 21 day grace + plans = {"t1": {"plan": "sandbox", "expiration_date": expired_recently}} + policy = self._policy(plans, now=now) + msgs = [_msg("m1", "app1")] + + assert policy.filter_message_ids(msgs, {"app1": "t1"}) == [] + + def test_whitelisted_tenant_skipped(self): + plans = {"t1": {"plan": "sandbox", "expiration_date": -1}} + policy = self._policy(plans, whitelist=["t1"]) + msgs = [_msg("m1", "app1")] + + assert policy.filter_message_ids(msgs, {"app1": "t1"}) == [] + + def test_message_without_tenant_mapping_skipped(self): + plans = {"t1": {"plan": "sandbox", "expiration_date": -1}} + policy = self._policy(plans) + msgs = [_msg("m1", "unmapped_app")] + + assert policy.filter_message_ids(msgs, {"app1": "t1"}) == [] + + def test_mixed_tenants_only_sandbox_deleted(self): + plans = { + "t_sandbox": {"plan": "sandbox", "expiration_date": -1}, + "t_pro": {"plan": "professional", "expiration_date": 0}, + } + policy = self._policy(plans) + msgs = [_msg("m1", "app_sandbox"), _msg("m2", "app_pro")] + app_map = {"app_sandbox": "t_sandbox", "app_pro": "t_pro"} + + result = policy.filter_message_ids(msgs, app_map) + + assert result == ["m1"] + + +class TestCreateMessageCleanPolicy: + def test_billing_disabled_returns_disabled_policy(self): + with patch(f"{MODULE}.dify_config") as cfg: + cfg.BILLING_ENABLED = False + policy = create_message_clean_policy() + + assert isinstance(policy, BillingDisabledPolicy) + + def test_billing_enabled_returns_sandbox_policy(self): + with ( + patch(f"{MODULE}.dify_config") as cfg, + patch(f"{MODULE}.BillingService") as bs, + ): + cfg.BILLING_ENABLED = True + bs.get_expired_subscription_cleanup_whitelist.return_value = ["wl1"] + bs.get_plan_bulk_with_cache = MagicMock() + policy = create_message_clean_policy(graceful_period_days=30) + + assert isinstance(policy, BillingSandboxPolicy) diff --git a/api/tests/unit_tests/services/tools/test_tools_transform_service.py b/api/tests/unit_tests/services/tools/test_tools_transform_service.py new file mode 100644 index 0000000000..32c1a00d30 --- /dev/null +++ b/api/tests/unit_tests/services/tools/test_tools_transform_service.py @@ -0,0 +1,598 @@ +from unittest.mock import MagicMock, Mock, patch + +from core.tools.__base.tool import Tool +from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ApiProviderAuthType, ToolParameter, ToolProviderType +from services.tools.tools_transform_service import ToolTransformService + +MODULE = "services.tools.tools_transform_service" + + +class TestToolTransformService: + """Test cases for ToolTransformService.convert_tool_entity_to_api_entity method""" + + def test_convert_tool_with_parameter_override(self): + """Test that runtime parameters correctly override base parameters""" + # Create mock base parameters + base_param1 = Mock(spec=ToolParameter) + base_param1.name = "param1" + base_param1.form = ToolParameter.ToolParameterForm.FORM + base_param1.type = "string" + base_param1.label = "Base Param 1" + + base_param2 = Mock(spec=ToolParameter) + base_param2.name = "param2" + base_param2.form = ToolParameter.ToolParameterForm.FORM + base_param2.type = "string" + base_param2.label = "Base Param 2" + + # Create mock runtime parameters that override base parameters + runtime_param1 = Mock(spec=ToolParameter) + runtime_param1.name = "param1" + runtime_param1.form = ToolParameter.ToolParameterForm.FORM + runtime_param1.type = "string" + runtime_param1.label = "Runtime Param 1" # Different label to verify override + + # Create mock tool + mock_tool = Mock(spec=Tool) + mock_tool.entity = Mock() + mock_tool.entity.parameters = [base_param1, base_param2] + mock_tool.entity.identity = Mock() + mock_tool.entity.identity.author = "test_author" + mock_tool.entity.identity.name = "test_tool" + mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") + mock_tool.entity.description = Mock() + mock_tool.entity.description.human = I18nObject(en_US="Test description") + mock_tool.entity.output_schema = {} + mock_tool.get_runtime_parameters.return_value = [runtime_param1] + + # Mock fork_tool_runtime to return the same tool + mock_tool.fork_tool_runtime.return_value = mock_tool + + # Call the method + result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) + + # Verify the result + assert isinstance(result, ToolApiEntity) + assert result.author == "test_author" + assert result.name == "test_tool" + assert result.parameters is not None + assert len(result.parameters) == 2 + + # Find the overridden parameter + overridden_param = next((p for p in result.parameters if p.name == "param1"), None) + assert overridden_param is not None + assert overridden_param.label == "Runtime Param 1" # Should be runtime version + + # Find the non-overridden parameter + original_param = next((p for p in result.parameters if p.name == "param2"), None) + assert original_param is not None + assert original_param.label == "Base Param 2" # Should be base version + + def test_convert_tool_with_additional_runtime_parameters(self): + """Test that additional runtime parameters are added to the final list""" + # Create mock base parameters + base_param1 = Mock(spec=ToolParameter) + base_param1.name = "param1" + base_param1.form = ToolParameter.ToolParameterForm.FORM + base_param1.type = "string" + base_param1.label = "Base Param 1" + + # Create mock runtime parameters - one that overrides and one that's new + runtime_param1 = Mock(spec=ToolParameter) + runtime_param1.name = "param1" + runtime_param1.form = ToolParameter.ToolParameterForm.FORM + runtime_param1.type = "string" + runtime_param1.label = "Runtime Param 1" + + runtime_param2 = Mock(spec=ToolParameter) + runtime_param2.name = "runtime_only" + runtime_param2.form = ToolParameter.ToolParameterForm.FORM + runtime_param2.type = "string" + runtime_param2.label = "Runtime Only Param" + + # Create mock tool + mock_tool = Mock(spec=Tool) + mock_tool.entity = Mock() + mock_tool.entity.parameters = [base_param1] + mock_tool.entity.identity = Mock() + mock_tool.entity.identity.author = "test_author" + mock_tool.entity.identity.name = "test_tool" + mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") + mock_tool.entity.description = Mock() + mock_tool.entity.description.human = I18nObject(en_US="Test description") + mock_tool.entity.output_schema = {} + mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2] + + # Mock fork_tool_runtime to return the same tool + mock_tool.fork_tool_runtime.return_value = mock_tool + + # Call the method + result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) + + # Verify the result + assert isinstance(result, ToolApiEntity) + assert result.parameters is not None + assert len(result.parameters) == 2 + + # Check that both parameters are present + param_names = [p.name for p in result.parameters] + assert "param1" in param_names + assert "runtime_only" in param_names + + # Verify the overridden parameter has runtime version + overridden_param = next((p for p in result.parameters if p.name == "param1"), None) + assert overridden_param is not None + assert overridden_param.label == "Runtime Param 1" + + # Verify the new runtime parameter is included + new_param = next((p for p in result.parameters if p.name == "runtime_only"), None) + assert new_param is not None + assert new_param.label == "Runtime Only Param" + + def test_convert_tool_with_non_form_runtime_parameters(self): + """Test that non-FORM runtime parameters are not added as new parameters""" + # Create mock base parameters + base_param1 = Mock(spec=ToolParameter) + base_param1.name = "param1" + base_param1.form = ToolParameter.ToolParameterForm.FORM + base_param1.type = "string" + base_param1.label = "Base Param 1" + + # Create mock runtime parameters with different forms + runtime_param1 = Mock(spec=ToolParameter) + runtime_param1.name = "param1" + runtime_param1.form = ToolParameter.ToolParameterForm.FORM + runtime_param1.type = "string" + runtime_param1.label = "Runtime Param 1" + + runtime_param2 = Mock(spec=ToolParameter) + runtime_param2.name = "llm_param" + runtime_param2.form = ToolParameter.ToolParameterForm.LLM + runtime_param2.type = "string" + runtime_param2.label = "LLM Param" + + # Create mock tool + mock_tool = Mock(spec=Tool) + mock_tool.entity = Mock() + mock_tool.entity.parameters = [base_param1] + mock_tool.entity.identity = Mock() + mock_tool.entity.identity.author = "test_author" + mock_tool.entity.identity.name = "test_tool" + mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") + mock_tool.entity.description = Mock() + mock_tool.entity.description.human = I18nObject(en_US="Test description") + mock_tool.entity.output_schema = {} + mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2] + + # Mock fork_tool_runtime to return the same tool + mock_tool.fork_tool_runtime.return_value = mock_tool + + # Call the method + result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) + + # Verify the result + assert isinstance(result, ToolApiEntity) + assert result.parameters is not None + assert len(result.parameters) == 1 # Only the FORM parameter should be present + + # Check that only the FORM parameter is present + param_names = [p.name for p in result.parameters] + assert "param1" in param_names + assert "llm_param" not in param_names + + def test_convert_tool_with_empty_parameters(self): + """Test conversion with empty base and runtime parameters""" + # Create mock tool with no parameters + mock_tool = Mock(spec=Tool) + mock_tool.entity = Mock() + mock_tool.entity.parameters = [] + mock_tool.entity.identity = Mock() + mock_tool.entity.identity.author = "test_author" + mock_tool.entity.identity.name = "test_tool" + mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") + mock_tool.entity.description = Mock() + mock_tool.entity.description.human = I18nObject(en_US="Test description") + mock_tool.entity.output_schema = {} + mock_tool.get_runtime_parameters.return_value = [] + + # Mock fork_tool_runtime to return the same tool + mock_tool.fork_tool_runtime.return_value = mock_tool + + # Call the method + result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) + + # Verify the result + assert isinstance(result, ToolApiEntity) + assert result.parameters is not None + assert len(result.parameters) == 0 + + def test_convert_tool_with_none_parameters(self): + """Test conversion when base parameters is None""" + # Create mock tool with None parameters + mock_tool = Mock(spec=Tool) + mock_tool.entity = Mock() + mock_tool.entity.parameters = None + mock_tool.entity.identity = Mock() + mock_tool.entity.identity.author = "test_author" + mock_tool.entity.identity.name = "test_tool" + mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") + mock_tool.entity.description = Mock() + mock_tool.entity.description.human = I18nObject(en_US="Test description") + mock_tool.entity.output_schema = {} + mock_tool.get_runtime_parameters.return_value = [] + + # Mock fork_tool_runtime to return the same tool + mock_tool.fork_tool_runtime.return_value = mock_tool + + # Call the method + result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) + + # Verify the result + assert isinstance(result, ToolApiEntity) + assert result.parameters is not None + assert len(result.parameters) == 0 + + def test_convert_tool_parameter_order_preserved(self): + """Test that parameter order is preserved correctly""" + # Create mock base parameters in specific order + base_param1 = Mock(spec=ToolParameter) + base_param1.name = "param1" + base_param1.form = ToolParameter.ToolParameterForm.FORM + base_param1.type = "string" + base_param1.label = "Base Param 1" + + base_param2 = Mock(spec=ToolParameter) + base_param2.name = "param2" + base_param2.form = ToolParameter.ToolParameterForm.FORM + base_param2.type = "string" + base_param2.label = "Base Param 2" + + base_param3 = Mock(spec=ToolParameter) + base_param3.name = "param3" + base_param3.form = ToolParameter.ToolParameterForm.FORM + base_param3.type = "string" + base_param3.label = "Base Param 3" + + # Create runtime parameter that overrides middle parameter + runtime_param2 = Mock(spec=ToolParameter) + runtime_param2.name = "param2" + runtime_param2.form = ToolParameter.ToolParameterForm.FORM + runtime_param2.type = "string" + runtime_param2.label = "Runtime Param 2" + + # Create new runtime parameter + runtime_param4 = Mock(spec=ToolParameter) + runtime_param4.name = "param4" + runtime_param4.form = ToolParameter.ToolParameterForm.FORM + runtime_param4.type = "string" + runtime_param4.label = "Runtime Param 4" + + # Create mock tool + mock_tool = Mock(spec=Tool) + mock_tool.entity = Mock() + mock_tool.entity.parameters = [base_param1, base_param2, base_param3] + mock_tool.entity.identity = Mock() + mock_tool.entity.identity.author = "test_author" + mock_tool.entity.identity.name = "test_tool" + mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") + mock_tool.entity.description = Mock() + mock_tool.entity.description.human = I18nObject(en_US="Test description") + mock_tool.entity.output_schema = {} + mock_tool.get_runtime_parameters.return_value = [runtime_param2, runtime_param4] + + # Mock fork_tool_runtime to return the same tool + mock_tool.fork_tool_runtime.return_value = mock_tool + + # Call the method + result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) + + # Verify the result + assert isinstance(result, ToolApiEntity) + assert result.parameters is not None + assert len(result.parameters) == 4 + + # Check that order is maintained: base parameters first, then new runtime parameters + param_names = [p.name for p in result.parameters] + assert param_names == ["param1", "param2", "param3", "param4"] + + # Verify that param2 was overridden with runtime version + param2 = result.parameters[1] + assert param2.name == "param2" + assert param2.label == "Runtime Param 2" + + +class TestWorkflowProviderToUserProvider: + """Test cases for ToolTransformService.workflow_provider_to_user_provider method""" + + def test_workflow_provider_to_user_provider_with_workflow_app_id(self): + """Test that workflow_provider_to_user_provider correctly sets workflow_app_id.""" + from core.tools.workflow_as_tool.provider import WorkflowToolProviderController + + # Create mock workflow tool provider controller + workflow_app_id = "app_123" + provider_id = "provider_123" + mock_controller = Mock(spec=WorkflowToolProviderController) + mock_controller.provider_id = provider_id + mock_controller.entity = Mock() + mock_controller.entity.identity = Mock() + mock_controller.entity.identity.author = "test_author" + mock_controller.entity.identity.name = "test_workflow_tool" + mock_controller.entity.identity.description = I18nObject(en_US="Test description") + mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"} + mock_controller.entity.identity.icon_dark = None + mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool") + + # Call the method + result = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=mock_controller, + labels=["label1", "label2"], + workflow_app_id=workflow_app_id, + ) + + # Verify the result + assert isinstance(result, ToolProviderApiEntity) + assert result.id == provider_id + assert result.author == "test_author" + assert result.name == "test_workflow_tool" + assert result.type == ToolProviderType.WORKFLOW + assert result.workflow_app_id == workflow_app_id + assert result.labels == ["label1", "label2"] + assert result.is_team_authorization is True + assert result.plugin_id is None + assert result.plugin_unique_identifier is None + assert result.tools == [] + + def test_workflow_provider_to_user_provider_without_workflow_app_id(self): + """Test that workflow_provider_to_user_provider works when workflow_app_id is not provided.""" + from core.tools.workflow_as_tool.provider import WorkflowToolProviderController + + # Create mock workflow tool provider controller + provider_id = "provider_123" + mock_controller = Mock(spec=WorkflowToolProviderController) + mock_controller.provider_id = provider_id + mock_controller.entity = Mock() + mock_controller.entity.identity = Mock() + mock_controller.entity.identity.author = "test_author" + mock_controller.entity.identity.name = "test_workflow_tool" + mock_controller.entity.identity.description = I18nObject(en_US="Test description") + mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"} + mock_controller.entity.identity.icon_dark = None + mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool") + + # Call the method without workflow_app_id + result = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=mock_controller, + labels=["label1"], + ) + + # Verify the result + assert isinstance(result, ToolProviderApiEntity) + assert result.id == provider_id + assert result.workflow_app_id is None + assert result.labels == ["label1"] + + def test_workflow_provider_to_user_provider_workflow_app_id_none(self): + """Test that workflow_provider_to_user_provider handles None workflow_app_id explicitly.""" + from core.tools.workflow_as_tool.provider import WorkflowToolProviderController + + # Create mock workflow tool provider controller + provider_id = "provider_123" + mock_controller = Mock(spec=WorkflowToolProviderController) + mock_controller.provider_id = provider_id + mock_controller.entity = Mock() + mock_controller.entity.identity = Mock() + mock_controller.entity.identity.author = "test_author" + mock_controller.entity.identity.name = "test_workflow_tool" + mock_controller.entity.identity.description = I18nObject(en_US="Test description") + mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"} + mock_controller.entity.identity.icon_dark = None + mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool") + + # Call the method with explicit None values + result = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=mock_controller, + labels=None, + workflow_app_id=None, + ) + + # Verify the result + assert isinstance(result, ToolProviderApiEntity) + assert result.id == provider_id + assert result.workflow_app_id is None + assert result.labels == [] + + def test_workflow_provider_to_user_provider_preserves_other_fields(self): + """Test that workflow_provider_to_user_provider preserves all other entity fields.""" + from core.tools.workflow_as_tool.provider import WorkflowToolProviderController + + # Create mock workflow tool provider controller with various fields + workflow_app_id = "app_456" + provider_id = "provider_456" + mock_controller = Mock(spec=WorkflowToolProviderController) + mock_controller.provider_id = provider_id + mock_controller.entity = Mock() + mock_controller.entity.identity = Mock() + mock_controller.entity.identity.author = "another_author" + mock_controller.entity.identity.name = "another_workflow_tool" + mock_controller.entity.identity.description = I18nObject( + en_US="Another description", zh_Hans="Another description" + ) + mock_controller.entity.identity.icon = {"type": "emoji", "content": "⚙️"} + mock_controller.entity.identity.icon_dark = {"type": "emoji", "content": "🔧"} + mock_controller.entity.identity.label = I18nObject( + en_US="Another Workflow Tool", zh_Hans="Another Workflow Tool" + ) + + # Call the method + result = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=mock_controller, + labels=["automation", "workflow"], + workflow_app_id=workflow_app_id, + ) + + # Verify all fields are preserved correctly + assert isinstance(result, ToolProviderApiEntity) + assert result.id == provider_id + assert result.author == "another_author" + assert result.name == "another_workflow_tool" + assert result.description.en_US == "Another description" + assert result.description.zh_Hans == "Another description" + assert result.icon == {"type": "emoji", "content": "⚙️"} + assert result.icon_dark == {"type": "emoji", "content": "🔧"} + assert result.label.en_US == "Another Workflow Tool" + assert result.label.zh_Hans == "Another Workflow Tool" + assert result.type == ToolProviderType.WORKFLOW + assert result.workflow_app_id == workflow_app_id + assert result.labels == ["automation", "workflow"] + assert result.masked_credentials == {} + assert result.is_team_authorization is True + assert result.allow_delete is True + assert result.plugin_id is None + assert result.plugin_unique_identifier is None + assert result.tools == [] + + +class TestGetToolProviderIconUrl: + def test_builtin_provider_returns_console_url(self): + with patch(f"{MODULE}.dify_config") as cfg: + cfg.CONSOLE_API_URL = "https://app.dify.ai" + url = ToolTransformService.get_tool_provider_icon_url("builtin", "google", "icon.png") + + assert "/builtin/google/icon" in url + assert url.startswith("https://app.dify.ai/console/api/workspaces/current/tool-provider") + + def test_builtin_provider_with_no_console_url(self): + with patch(f"{MODULE}.dify_config") as cfg: + cfg.CONSOLE_API_URL = None + url = ToolTransformService.get_tool_provider_icon_url("builtin", "slack", "icon.png") + + assert "/builtin/slack/icon" in url + + def test_api_provider_parses_json_icon(self): + icon_json = '{"background": "#fff", "content": "A"}' + result = ToolTransformService.get_tool_provider_icon_url("api", "my-api", icon_json) + assert result == {"background": "#fff", "content": "A"} + + def test_api_provider_returns_dict_icon_directly(self): + icon = {"background": "#000", "content": "B"} + result = ToolTransformService.get_tool_provider_icon_url("api", "my-api", icon) + assert result == icon + + def test_api_provider_returns_fallback_on_invalid_json(self): + result = ToolTransformService.get_tool_provider_icon_url("api", "my-api", "not-json") + assert result == {"background": "#252525", "content": "\ud83d\ude01"} + + def test_workflow_provider_behaves_like_api(self): + icon = {"background": "#123", "content": "W"} + assert ToolTransformService.get_tool_provider_icon_url("workflow", "wf", icon) == icon + + def test_mcp_returns_icon_as_is(self): + assert ToolTransformService.get_tool_provider_icon_url("mcp", "srv", "icon-value") == "icon-value" + + def test_unknown_type_returns_empty(self): + assert ToolTransformService.get_tool_provider_icon_url("unknown", "x", "i") == "" + + +class TestRepackProvider: + def test_repacks_dict_provider_icon(self): + provider = {"type": "builtin", "name": "google", "icon": "old"} + with patch.object(ToolTransformService, "get_tool_provider_icon_url", return_value="/new-url") as mock_fn: + ToolTransformService.repack_provider("t1", provider) + + assert provider["icon"] == "/new-url" + mock_fn.assert_called_once_with(provider_type="builtin", provider_name="google", icon="old") + + def test_repacks_tool_provider_api_entity_without_plugin(self): + entity = MagicMock(spec=ToolProviderApiEntity) + entity.plugin_id = None + entity.type = ToolProviderType.BUILT_IN + entity.name = "slack" + entity.icon = "icon.svg" + entity.icon_dark = "dark.svg" + + with patch.object(ToolTransformService, "get_tool_provider_icon_url", return_value="/url"): + ToolTransformService.repack_provider("t1", entity) + + assert entity.icon == "/url" + assert entity.icon_dark == "/url" + + +class TestConvertMcpSchemaToParameter: + def test_simple_object_schema(self): + schema = { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"}, + "count": {"type": "integer", "description": "Result count"}, + }, + "required": ["query"], + } + + params = ToolTransformService.convert_mcp_schema_to_parameter(schema) + + assert len(params) == 2 + query_param = next(p for p in params if p.name == "query") + count_param = next(p for p in params if p.name == "count") + assert query_param.required is True + assert count_param.required is False + assert count_param.type.value == "number" + + def test_float_maps_to_number(self): + schema = {"type": "object", "properties": {"rate": {"type": "float"}}, "required": []} + assert ToolTransformService.convert_mcp_schema_to_parameter(schema)[0].type.value == "number" + + def test_array_type_attaches_input_schema(self): + prop = {"type": "array", "description": "Items", "items": {"type": "string"}} + schema = {"type": "object", "properties": {"items": prop}, "required": []} + param = ToolTransformService.convert_mcp_schema_to_parameter(schema)[0] + assert param.input_schema is not None + + def test_non_object_schema_returns_empty(self): + assert ToolTransformService.convert_mcp_schema_to_parameter({"type": "string"}) == [] + + def test_missing_properties_returns_empty(self): + assert ToolTransformService.convert_mcp_schema_to_parameter({"type": "object"}) == [] + + def test_list_type_uses_first_element(self): + schema = {"type": "object", "properties": {"f": {"type": ["string", "null"]}}, "required": []} + assert ToolTransformService.convert_mcp_schema_to_parameter(schema)[0].type.value == "string" + + def test_missing_description_defaults_empty(self): + schema = {"type": "object", "properties": {"f": {"type": "string"}}, "required": []} + assert ToolTransformService.convert_mcp_schema_to_parameter(schema)[0].llm_description == "" + + +class TestApiProviderToController: + def test_api_key_header_auth(self): + db_provider = MagicMock() + db_provider.credentials = {"auth_type": "api_key_header"} + with patch(f"{MODULE}.ApiToolProviderController") as ctrl_cls: + ctrl_cls.from_db.return_value = MagicMock() + ToolTransformService.api_provider_to_controller(db_provider) + ctrl_cls.from_db.assert_called_once_with(db_provider=db_provider, auth_type=ApiProviderAuthType.API_KEY_HEADER) + + def test_api_key_query_auth(self): + db_provider = MagicMock() + db_provider.credentials = {"auth_type": "api_key_query"} + with patch(f"{MODULE}.ApiToolProviderController") as ctrl_cls: + ctrl_cls.from_db.return_value = MagicMock() + ToolTransformService.api_provider_to_controller(db_provider) + ctrl_cls.from_db.assert_called_once_with(db_provider=db_provider, auth_type=ApiProviderAuthType.API_KEY_QUERY) + + def test_legacy_api_key_maps_to_header(self): + db_provider = MagicMock() + db_provider.credentials = {"auth_type": "api_key"} + with patch(f"{MODULE}.ApiToolProviderController") as ctrl_cls: + ctrl_cls.from_db.return_value = MagicMock() + ToolTransformService.api_provider_to_controller(db_provider) + ctrl_cls.from_db.assert_called_once_with(db_provider=db_provider, auth_type=ApiProviderAuthType.API_KEY_HEADER) + + def test_unknown_auth_defaults_to_none(self): + db_provider = MagicMock() + db_provider.credentials = {"auth_type": "something_else"} + with patch(f"{MODULE}.ApiToolProviderController") as ctrl_cls: + ctrl_cls.from_db.return_value = MagicMock() + ToolTransformService.api_provider_to_controller(db_provider) + ctrl_cls.from_db.assert_called_once_with(db_provider=db_provider, auth_type=ApiProviderAuthType.NONE) From adc6c6c13b182216f4c53a077ffe996187610b1f Mon Sep 17 00:00:00 2001 From: Stephen Zhou Date: Tue, 31 Mar 2026 11:46:02 +0800 Subject: [PATCH 2/7] chore: try to avoid supply chain security (#34317) --- pnpm-lock.yaml | 6 ++---- pnpm-workspace.yaml | 19 +++++++++++-------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 01a96c5585..b6c234d8ad 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -249,9 +249,6 @@ catalogs: autoprefixer: specifier: 10.4.27 version: 10.4.27 - axios: - specifier: ^1.14.0 - version: 1.14.0 class-variance-authority: specifier: 0.7.1 version: 0.7.1 @@ -573,6 +570,7 @@ overrides: array.prototype.flatmap: npm:@nolyfill/array.prototype.flatmap@^1.0.44 array.prototype.tosorted: npm:@nolyfill/array.prototype.tosorted@^1.0.44 assert: npm:@nolyfill/assert@^1.0.26 + axios: 1.14.0 brace-expansion@<2.0.2: 2.0.2 canvas: ^3.2.2 devalue@<5.3.2: 5.3.2 @@ -652,7 +650,7 @@ importers: sdks/nodejs-client: dependencies: axios: - specifier: 'catalog:' + specifier: 1.14.0 version: 1.14.0 devDependencies: '@eslint/js': diff --git a/pnpm-workspace.yaml b/pnpm-workspace.yaml index dece6f3f4f..ae53a57832 100644 --- a/pnpm-workspace.yaml +++ b/pnpm-workspace.yaml @@ -1,3 +1,12 @@ +trustPolicy: no-downgrade +minimumReleaseAge: 1440 +blockExoticSubdeps: true +strictDepBuilds: true +allowBuilds: + '@parcel/watcher': false + canvas: false + esbuild: false + sharp: false packages: - web - e2e @@ -13,6 +22,7 @@ overrides: array.prototype.flatmap: npm:@nolyfill/array.prototype.flatmap@^1.0.44 array.prototype.tosorted: npm:@nolyfill/array.prototype.tosorted@^1.0.44 assert: npm:@nolyfill/assert@^1.0.26 + axios: 1.14.0 brace-expansion@<2.0.2: 2.0.2 canvas: ^3.2.2 devalue@<5.3.2: 5.3.2 @@ -59,13 +69,6 @@ overrides: which-typed-array: npm:@nolyfill/which-typed-array@^1.0.44 yaml@>=2.0.0 <2.8.3: 2.8.3 yauzl@<3.2.1: 3.2.1 -ignoredBuiltDependencies: - - canvas - - core-js-pure -onlyBuiltDependencies: - - "@parcel/watcher" - - esbuild - - sharp catalog: "@amplitude/analytics-browser": 2.38.0 "@amplitude/plugin-session-replay-browser": 1.27.5 @@ -149,7 +152,7 @@ catalog: agentation: 3.0.2 ahooks: 3.9.7 autoprefixer: 10.4.27 - axios: ^1.14.0 + axios: 1.14.0 class-variance-authority: 0.7.1 clsx: 2.1.1 cmdk: 1.1.1 From 88863609e979d5eb080f6e61be6b1a406773cdcd Mon Sep 17 00:00:00 2001 From: YBoy Date: Tue, 31 Mar 2026 07:56:53 +0300 Subject: [PATCH 3/7] test: migrate rag pipeline controller tests to testcontainers (#34303) --- .../rag_pipeline/test_rag_pipeline.py | 91 +++++++++---------- 1 file changed, 44 insertions(+), 47 deletions(-) rename api/tests/{unit_tests => test_containers_integration_tests}/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py (77%) diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py similarity index 77% rename from api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py rename to api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py index ebbb34e069..d5ae95dfb7 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py @@ -1,6 +1,12 @@ +"""Testcontainers integration tests for rag_pipeline controller endpoints.""" + +from __future__ import annotations + from unittest.mock import MagicMock, patch +from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from controllers.console import console_ns from controllers.console.datasets.rag_pipeline.rag_pipeline import ( @@ -9,6 +15,7 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline import ( PipelineTemplateListApi, PublishCustomizedPipelineTemplateApi, ) +from models.dataset import PipelineCustomizedTemplate def unwrap(func): @@ -18,6 +25,10 @@ def unwrap(func): class TestPipelineTemplateListApi: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_get_success(self, app): api = PipelineTemplateListApi() method = unwrap(api.get) @@ -38,6 +49,10 @@ class TestPipelineTemplateListApi: class TestPipelineTemplateDetailApi: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_get_success(self, app): api = PipelineTemplateDetailApi() method = unwrap(api.get) @@ -99,6 +114,10 @@ class TestPipelineTemplateDetailApi: class TestCustomizedPipelineTemplateApi: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_patch_success(self, app): api = CustomizedPipelineTemplateApi() method = unwrap(api.patch) @@ -136,35 +155,29 @@ class TestCustomizedPipelineTemplateApi: delete_mock.assert_called_once_with("tpl-1") assert response == 200 - def test_post_success(self, app): + def test_post_success(self, app, db_session_with_containers: Session): api = CustomizedPipelineTemplateApi() method = unwrap(api.post) - template = MagicMock() - template.yaml_content = "yaml-data" + tenant_id = str(uuid4()) + template = PipelineCustomizedTemplate( + tenant_id=tenant_id, + name="Test Template", + description="Test", + chunk_structure="hierarchical", + icon={"icon": "📘"}, + position=0, + yaml_content="yaml-data", + install_count=0, + language="en-US", + created_by=str(uuid4()), + ) + db_session_with_containers.add(template) + db_session_with_containers.commit() + db_session_with_containers.expire_all() - fake_db = MagicMock() - fake_db.engine = MagicMock() - - session = MagicMock() - session.query.return_value.where.return_value.first.return_value = template - - session_ctx = MagicMock() - session_ctx.__enter__.return_value = session - session_ctx.__exit__.return_value = None - - with ( - app.test_request_context("/"), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline.db", - fake_db, - ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline.Session", - return_value=session_ctx, - ), - ): - response, status = method(api, "tpl-1") + with app.test_request_context("/"): + response, status = method(api, template.id) assert status == 200 assert response == {"data": "yaml-data"} @@ -173,32 +186,16 @@ class TestCustomizedPipelineTemplateApi: api = CustomizedPipelineTemplateApi() method = unwrap(api.post) - fake_db = MagicMock() - fake_db.engine = MagicMock() - - session = MagicMock() - session.query.return_value.where.return_value.first.return_value = None - - session_ctx = MagicMock() - session_ctx.__enter__.return_value = session - session_ctx.__exit__.return_value = None - - with ( - app.test_request_context("/"), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline.db", - fake_db, - ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline.Session", - return_value=session_ctx, - ), - ): + with app.test_request_context("/"): with pytest.raises(ValueError): - method(api, "tpl-1") + method(api, str(uuid4())) class TestPublishCustomizedPipelineTemplateApi: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_post_success(self, app): api = PublishCustomizedPipelineTemplateApi() method = unwrap(api.post) From 9b7b432e0822c4fffd09d3166132dddd4ae969e7 Mon Sep 17 00:00:00 2001 From: YBoy Date: Tue, 31 Mar 2026 07:57:53 +0300 Subject: [PATCH 4/7] test: migrate rag pipeline import controller tests to testcontainers (#34305) --- .../rag_pipeline/test_rag_pipeline_import.py | 130 +++--------------- 1 file changed, 22 insertions(+), 108 deletions(-) rename api/tests/{unit_tests => test_containers_integration_tests}/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py (66%) diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py similarity index 66% rename from api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py rename to api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py index a72ad45110..cb67892878 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py @@ -1,5 +1,11 @@ +"""Testcontainers integration tests for rag_pipeline_import controller endpoints.""" + +from __future__ import annotations + from unittest.mock import MagicMock, patch +import pytest + from controllers.console import console_ns from controllers.console.datasets.rag_pipeline.rag_pipeline_import import ( RagPipelineExportApi, @@ -18,6 +24,10 @@ def unwrap(func): class TestRagPipelineImportApi: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def _payload(self, mode="create"): return { "mode": mode, @@ -30,7 +40,6 @@ class TestRagPipelineImportApi: method = unwrap(api.post) payload = self._payload() - user = MagicMock() result = MagicMock() result.status = "completed" @@ -39,13 +48,6 @@ class TestRagPipelineImportApi: service = MagicMock() service.import_rag_pipeline.return_value = result - fake_db = MagicMock() - fake_db.engine = MagicMock() - - session_ctx = MagicMock() - session_ctx.__enter__.return_value = MagicMock() - session_ctx.__exit__.return_value = None - with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), @@ -53,14 +55,6 @@ class TestRagPipelineImportApi: "controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant", return_value=(user, "tenant"), ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", - fake_db, - ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", - return_value=session_ctx, - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", return_value=service, @@ -76,7 +70,6 @@ class TestRagPipelineImportApi: method = unwrap(api.post) payload = self._payload() - user = MagicMock() result = MagicMock() result.status = ImportStatus.FAILED @@ -85,13 +78,6 @@ class TestRagPipelineImportApi: service = MagicMock() service.import_rag_pipeline.return_value = result - fake_db = MagicMock() - fake_db.engine = MagicMock() - - session_ctx = MagicMock() - session_ctx.__enter__.return_value = MagicMock() - session_ctx.__exit__.return_value = None - with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), @@ -99,14 +85,6 @@ class TestRagPipelineImportApi: "controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant", return_value=(user, "tenant"), ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", - fake_db, - ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", - return_value=session_ctx, - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", return_value=service, @@ -122,7 +100,6 @@ class TestRagPipelineImportApi: method = unwrap(api.post) payload = self._payload() - user = MagicMock() result = MagicMock() result.status = ImportStatus.PENDING @@ -131,13 +108,6 @@ class TestRagPipelineImportApi: service = MagicMock() service.import_rag_pipeline.return_value = result - fake_db = MagicMock() - fake_db.engine = MagicMock() - - session_ctx = MagicMock() - session_ctx.__enter__.return_value = MagicMock() - session_ctx.__exit__.return_value = None - with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), @@ -145,14 +115,6 @@ class TestRagPipelineImportApi: "controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant", return_value=(user, "tenant"), ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", - fake_db, - ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", - return_value=session_ctx, - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", return_value=service, @@ -165,6 +127,10 @@ class TestRagPipelineImportApi: class TestRagPipelineImportConfirmApi: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_confirm_success(self, app): api = RagPipelineImportConfirmApi() method = unwrap(api.post) @@ -177,27 +143,12 @@ class TestRagPipelineImportConfirmApi: service = MagicMock() service.confirm_import.return_value = result - fake_db = MagicMock() - fake_db.engine = MagicMock() - - session_ctx = MagicMock() - session_ctx.__enter__.return_value = MagicMock() - session_ctx.__exit__.return_value = None - with ( app.test_request_context("/"), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant", return_value=(user, "tenant"), ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", - fake_db, - ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", - return_value=session_ctx, - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", return_value=service, @@ -220,27 +171,12 @@ class TestRagPipelineImportConfirmApi: service = MagicMock() service.confirm_import.return_value = result - fake_db = MagicMock() - fake_db.engine = MagicMock() - - session_ctx = MagicMock() - session_ctx.__enter__.return_value = MagicMock() - session_ctx.__exit__.return_value = None - with ( app.test_request_context("/"), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant", return_value=(user, "tenant"), ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", - fake_db, - ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", - return_value=session_ctx, - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", return_value=service, @@ -253,6 +189,10 @@ class TestRagPipelineImportConfirmApi: class TestRagPipelineImportCheckDependenciesApi: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_get_success(self, app): api = RagPipelineImportCheckDependenciesApi() method = unwrap(api.get) @@ -264,23 +204,8 @@ class TestRagPipelineImportCheckDependenciesApi: service = MagicMock() service.check_dependencies.return_value = result - fake_db = MagicMock() - fake_db.engine = MagicMock() - - session_ctx = MagicMock() - session_ctx.__enter__.return_value = MagicMock() - session_ctx.__exit__.return_value = None - with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", - fake_db, - ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", - return_value=session_ctx, - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", return_value=service, @@ -293,6 +218,10 @@ class TestRagPipelineImportCheckDependenciesApi: class TestRagPipelineExportApi: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_get_with_include_secret(self, app): api = RagPipelineExportApi() method = unwrap(api.get) @@ -301,23 +230,8 @@ class TestRagPipelineExportApi: service = MagicMock() service.export_rag_pipeline_dsl.return_value = {"yaml": "data"} - fake_db = MagicMock() - fake_db.engine = MagicMock() - - session_ctx = MagicMock() - session_ctx.__enter__.return_value = MagicMock() - session_ctx.__exit__.return_value = None - with ( app.test_request_context("/?include_secret=true"), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", - fake_db, - ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", - return_value=session_ctx, - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", return_value=service, From cc68f0e640f0932e4e037ab307e6dc7c4ae8b53f Mon Sep 17 00:00:00 2001 From: YBoy Date: Tue, 31 Mar 2026 07:58:14 +0300 Subject: [PATCH 5/7] test: migrate rag pipeline workflow controller tests to testcontainers (#34306) --- .../test_rag_pipeline_workflow.py | 145 +++++++++--------- 1 file changed, 70 insertions(+), 75 deletions(-) rename api/tests/{unit_tests => test_containers_integration_tests}/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py (91%) diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py similarity index 91% rename from api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py rename to api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py index a3c0592d76..c1f3122c2b 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py @@ -1,7 +1,13 @@ +"""Testcontainers integration tests for rag_pipeline_workflow controller endpoints.""" + +from __future__ import annotations + from datetime import datetime from unittest.mock import MagicMock, patch +from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden, HTTPException, NotFound import services @@ -38,6 +44,10 @@ def unwrap(func): class TestDraftWorkflowApi: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_get_draft_success(self, app): api = DraftRagPipelineApi() method = unwrap(api.get) @@ -200,6 +210,10 @@ class TestDraftWorkflowApi: class TestDraftRunNodes: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_iteration_node_success(self, app): api = RagPipelineDraftRunIterationNodeApi() method = unwrap(api.post) @@ -275,6 +289,10 @@ class TestDraftRunNodes: class TestPipelineRunApis: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_draft_run_success(self, app): api = DraftRagPipelineRunApi() method = unwrap(api.post) @@ -337,6 +355,10 @@ class TestPipelineRunApis: class TestDraftNodeRun: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_execution_not_found(self, app): api = RagPipelineDraftNodeRunApi() method = unwrap(api.post) @@ -364,45 +386,43 @@ class TestDraftNodeRun: class TestPublishedPipelineApis: - def test_publish_success(self, app): + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + + def test_publish_success(self, app, db_session_with_containers: Session): + from models.dataset import Pipeline + api = PublishedRagPipelineApi() method = unwrap(api.post) - pipeline = MagicMock() + tenant_id = str(uuid4()) + pipeline = Pipeline( + tenant_id=tenant_id, + name="test-pipeline", + description="test", + created_by=str(uuid4()), + ) + db_session_with_containers.add(pipeline) + db_session_with_containers.commit() + db_session_with_containers.expire_all() + user = MagicMock(id="u1") workflow = MagicMock( - id="w1", + id=str(uuid4()), created_at=naive_utc_now(), ) - session = MagicMock() - session.merge.return_value = pipeline - - session_ctx = MagicMock() - session_ctx.__enter__.return_value = session - session_ctx.__exit__.return_value = None - service = MagicMock() service.publish_workflow.return_value = workflow - fake_db = MagicMock() - fake_db.engine = MagicMock() - with ( app.test_request_context("/"), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", return_value=(user, "t"), ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db", - fake_db, - ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session", - return_value=session_ctx, - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", return_value=service, @@ -415,6 +435,10 @@ class TestPublishedPipelineApis: class TestMiscApis: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_task_stop(self, app): api = RagPipelineTaskStopApi() method = unwrap(api.post) @@ -471,6 +495,10 @@ class TestMiscApis: class TestPublishedRagPipelineRunApi: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_published_run_success(self, app): api = PublishedRagPipelineRunApi() method = unwrap(api.post) @@ -536,6 +564,10 @@ class TestPublishedRagPipelineRunApi: class TestDefaultBlockConfigApi: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_get_block_config_success(self, app): api = DefaultRagPipelineBlockConfigApi() method = unwrap(api.get) @@ -567,6 +599,10 @@ class TestDefaultBlockConfigApi: class TestPublishedAllRagPipelineApi: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_get_published_workflows_success(self, app): api = PublishedAllRagPipelineApi() method = unwrap(api.get) @@ -577,28 +613,12 @@ class TestPublishedAllRagPipelineApi: service = MagicMock() service.get_all_published_workflow.return_value = ([{"id": "w1"}], False) - session = MagicMock() - session_ctx = MagicMock() - session_ctx.__enter__.return_value = session - session_ctx.__exit__.return_value = None - - fake_db = MagicMock() - fake_db.engine = MagicMock() - with ( app.test_request_context("/"), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", return_value=(user, "t"), ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db", - fake_db, - ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session", - return_value=session_ctx, - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", return_value=service, @@ -628,6 +648,10 @@ class TestPublishedAllRagPipelineApi: class TestRagPipelineByIdApi: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_patch_success(self, app): api = RagPipelineByIdApi() method = unwrap(api.patch) @@ -640,14 +664,6 @@ class TestRagPipelineByIdApi: service = MagicMock() service.update_workflow.return_value = workflow - session = MagicMock() - session_ctx = MagicMock() - session_ctx.__enter__.return_value = session - session_ctx.__exit__.return_value = None - - fake_db = MagicMock() - fake_db.engine = MagicMock() - payload = {"marked_name": "test"} with ( @@ -657,14 +673,6 @@ class TestRagPipelineByIdApi: "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", return_value=(user, "t"), ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db", - fake_db, - ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session", - return_value=session_ctx, - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", return_value=service, @@ -700,24 +708,8 @@ class TestRagPipelineByIdApi: workflow_service = MagicMock() - session = MagicMock() - session_ctx = MagicMock() - session_ctx.__enter__.return_value = session - session_ctx.__exit__.return_value = None - - fake_db = MagicMock() - fake_db.engine = MagicMock() - with ( app.test_request_context("/", method="DELETE"), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db", - fake_db, - ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session", - return_value=session_ctx, - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.WorkflowService", return_value=workflow_service, @@ -725,12 +717,7 @@ class TestRagPipelineByIdApi: ): result = method(api, pipeline, "old-workflow") - workflow_service.delete_workflow.assert_called_once_with( - session=session, - workflow_id="old-workflow", - tenant_id="t1", - ) - session.commit.assert_called_once() + workflow_service.delete_workflow.assert_called_once() assert result == (None, 204) def test_delete_active_workflow_rejected(self, app): @@ -745,6 +732,10 @@ class TestRagPipelineByIdApi: class TestRagPipelineWorkflowLastRunApi: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_last_run_success(self, app): api = RagPipelineWorkflowLastRunApi() method = unwrap(api.get) @@ -788,6 +779,10 @@ class TestRagPipelineWorkflowLastRunApi: class TestRagPipelineDatasourceVariableApi: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_set_datasource_variables_success(self, app): api = RagPipelineDatasourceVariableApi() method = unwrap(api.post) From 303f5484085b16a088d2a8d734cfa5d20749ddec Mon Sep 17 00:00:00 2001 From: YBoy Date: Tue, 31 Mar 2026 07:59:13 +0300 Subject: [PATCH 6/7] test: migrate rag pipeline datasets controller tests to testcontainers (#34304) --- .../test_rag_pipeline_datasets.py | 42 ++++++------------- 1 file changed, 12 insertions(+), 30 deletions(-) rename api/tests/{unit_tests => test_containers_integration_tests}/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py (83%) diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py similarity index 83% rename from api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py rename to api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py index fd38fcbb5e..64e3de2ca3 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py @@ -1,3 +1,7 @@ +"""Testcontainers integration tests for rag_pipeline_datasets controller endpoints.""" + +from __future__ import annotations + from unittest.mock import MagicMock, patch import pytest @@ -19,6 +23,10 @@ def unwrap(func): class TestCreateRagPipelineDatasetApi: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def _valid_payload(self): return {"yaml_content": "name: test"} @@ -33,13 +41,6 @@ class TestCreateRagPipelineDatasetApi: mock_service = MagicMock() mock_service.create_rag_pipeline_dataset.return_value = import_info - mock_session_ctx = MagicMock() - mock_session_ctx.__enter__.return_value = MagicMock() - mock_session_ctx.__exit__.return_value = None - - fake_db = MagicMock() - fake_db.engine = MagicMock() - with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), @@ -47,14 +48,6 @@ class TestCreateRagPipelineDatasetApi: "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant", return_value=(user, "tenant-1"), ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db", - fake_db, - ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session", - return_value=mock_session_ctx, - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService", return_value=mock_service, @@ -93,13 +86,6 @@ class TestCreateRagPipelineDatasetApi: mock_service = MagicMock() mock_service.create_rag_pipeline_dataset.side_effect = services.errors.dataset.DatasetNameDuplicateError() - mock_session_ctx = MagicMock() - mock_session_ctx.__enter__.return_value = MagicMock() - mock_session_ctx.__exit__.return_value = None - - fake_db = MagicMock() - fake_db.engine = MagicMock() - with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), @@ -107,14 +93,6 @@ class TestCreateRagPipelineDatasetApi: "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant", return_value=(user, "tenant-1"), ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db", - fake_db, - ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session", - return_value=mock_session_ctx, - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService", return_value=mock_service, @@ -143,6 +121,10 @@ class TestCreateRagPipelineDatasetApi: class TestCreateEmptyRagPipelineDatasetApi: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_post_success(self, app): api = CreateEmptyRagPipelineDatasetApi() method = unwrap(api.post) From 1063e021f237922541621a9741595788a389345d Mon Sep 17 00:00:00 2001 From: YBoy Date: Tue, 31 Mar 2026 08:00:22 +0300 Subject: [PATCH 7/7] test: migrate explore conversation controller tests to testcontainers (#34312) --- .../console/explore/test_conversation.py | 61 +++++++++++-------- 1 file changed, 34 insertions(+), 27 deletions(-) rename api/tests/{unit_tests => test_containers_integration_tests}/controllers/console/explore/test_conversation.py (82%) diff --git a/api/tests/unit_tests/controllers/console/explore/test_conversation.py b/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py similarity index 82% rename from api/tests/unit_tests/controllers/console/explore/test_conversation.py rename to api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py index 65cc209725..83492048ef 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_conversation.py +++ b/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py @@ -1,7 +1,10 @@ +"""Testcontainers integration tests for controllers.console.explore.conversation endpoints.""" + +from __future__ import annotations + from unittest.mock import MagicMock, patch import pytest -from flask import Flask from werkzeug.exceptions import NotFound import controllers.console.explore.conversation as conversation_module @@ -48,24 +51,12 @@ def user(): return user -@pytest.fixture(autouse=True) -def mock_db_and_session(): - with ( - patch.object( - conversation_module, - "db", - MagicMock(session=MagicMock(), engine=MagicMock()), - ), - patch( - "controllers.console.explore.conversation.Session", - MagicMock(), - ), - ): - yield - - class TestConversationListApi: - def test_get_success(self, app: Flask, chat_app, user): + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + + def test_get_success(self, app, chat_app, user): api = conversation_module.ConversationListApi() method = unwrap(api.get) @@ -90,7 +81,7 @@ class TestConversationListApi: assert result["has_more"] is False assert len(result["data"]) == 2 - def test_last_conversation_not_exists(self, app: Flask, chat_app, user): + def test_last_conversation_not_exists(self, app, chat_app, user): api = conversation_module.ConversationListApi() method = unwrap(api.get) @@ -106,7 +97,7 @@ class TestConversationListApi: with pytest.raises(NotFound): method(chat_app) - def test_wrong_app_mode(self, app: Flask, non_chat_app): + def test_wrong_app_mode(self, app, non_chat_app): api = conversation_module.ConversationListApi() method = unwrap(api.get) @@ -116,7 +107,11 @@ class TestConversationListApi: class TestConversationApi: - def test_delete_success(self, app: Flask, chat_app, user): + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + + def test_delete_success(self, app, chat_app, user): api = conversation_module.ConversationApi() method = unwrap(api.delete) @@ -134,7 +129,7 @@ class TestConversationApi: assert status == 204 assert body["result"] == "success" - def test_delete_not_found(self, app: Flask, chat_app, user): + def test_delete_not_found(self, app, chat_app, user): api = conversation_module.ConversationApi() method = unwrap(api.delete) @@ -150,7 +145,7 @@ class TestConversationApi: with pytest.raises(NotFound): method(chat_app, "cid") - def test_delete_wrong_app_mode(self, app: Flask, non_chat_app): + def test_delete_wrong_app_mode(self, app, non_chat_app): api = conversation_module.ConversationApi() method = unwrap(api.delete) @@ -160,7 +155,11 @@ class TestConversationApi: class TestConversationRenameApi: - def test_rename_success(self, app: Flask, chat_app, user): + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + + def test_rename_success(self, app, chat_app, user): api = conversation_module.ConversationRenameApi() method = unwrap(api.post) @@ -179,7 +178,7 @@ class TestConversationRenameApi: assert result["id"] == "cid" - def test_rename_not_found(self, app: Flask, chat_app, user): + def test_rename_not_found(self, app, chat_app, user): api = conversation_module.ConversationRenameApi() method = unwrap(api.post) @@ -197,7 +196,11 @@ class TestConversationRenameApi: class TestConversationPinApi: - def test_pin_success(self, app: Flask, chat_app, user): + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + + def test_pin_success(self, app, chat_app, user): api = conversation_module.ConversationPinApi() method = unwrap(api.patch) @@ -215,7 +218,11 @@ class TestConversationPinApi: class TestConversationUnPinApi: - def test_unpin_success(self, app: Flask, chat_app, user): + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + + def test_unpin_success(self, app, chat_app, user): api = conversation_module.ConversationUnPinApi() method = unwrap(api.patch)