Merge branch 'main' into 3-31-vite-task-cache

This commit is contained in:
Stephen Zhou
2026-03-31 14:53:38 +08:00
committed by GitHub
19 changed files with 2112 additions and 299 deletions

View File

@@ -0,0 +1,182 @@
import pytest
from sqlalchemy import delete
from core.db.session_factory import session_factory
from models import Tenant
from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
from services.plugin.plugin_permission_service import PluginPermissionService
@pytest.fixture
def tenant(flask_req_ctx):
with session_factory.create_session() as session:
t = Tenant(name="plugin_it_tenant")
session.add(t)
session.commit()
tenant_id = t.id
yield tenant_id
with session_factory.create_session() as session:
session.execute(delete(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id))
session.execute(
delete(TenantPluginAutoUpgradeStrategy).where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
)
session.execute(delete(Tenant).where(Tenant.id == tenant_id))
session.commit()
class TestPluginPermissionLifecycle:
def test_get_returns_none_for_new_tenant(self, tenant):
assert PluginPermissionService.get_permission(tenant) is None
def test_change_creates_row(self, tenant):
result = PluginPermissionService.change_permission(
tenant,
TenantPluginPermission.InstallPermission.ADMINS,
TenantPluginPermission.DebugPermission.EVERYONE,
)
assert result is True
perm = PluginPermissionService.get_permission(tenant)
assert perm is not None
assert perm.install_permission == TenantPluginPermission.InstallPermission.ADMINS
assert perm.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE
def test_change_updates_existing_row(self, tenant):
PluginPermissionService.change_permission(
tenant,
TenantPluginPermission.InstallPermission.ADMINS,
TenantPluginPermission.DebugPermission.NOBODY,
)
PluginPermissionService.change_permission(
tenant,
TenantPluginPermission.InstallPermission.EVERYONE,
TenantPluginPermission.DebugPermission.ADMINS,
)
perm = PluginPermissionService.get_permission(tenant)
assert perm is not None
assert perm.install_permission == TenantPluginPermission.InstallPermission.EVERYONE
assert perm.debug_permission == TenantPluginPermission.DebugPermission.ADMINS
with session_factory.create_session() as session:
count = session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant).count()
assert count == 1
class TestPluginAutoUpgradeLifecycle:
def test_get_returns_none_for_new_tenant(self, tenant):
assert PluginAutoUpgradeService.get_strategy(tenant) is None
def test_change_creates_row(self, tenant):
result = PluginAutoUpgradeService.change_strategy(
tenant,
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
upgrade_time_of_day=3,
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
exclude_plugins=[],
include_plugins=[],
)
assert result is True
strategy = PluginAutoUpgradeService.get_strategy(tenant)
assert strategy is not None
assert strategy.strategy_setting == TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST
assert strategy.upgrade_time_of_day == 3
def test_change_updates_existing_row(self, tenant):
PluginAutoUpgradeService.change_strategy(
tenant,
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
upgrade_time_of_day=0,
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
exclude_plugins=[],
include_plugins=[],
)
PluginAutoUpgradeService.change_strategy(
tenant,
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
upgrade_time_of_day=12,
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL,
exclude_plugins=[],
include_plugins=["plugin-a"],
)
strategy = PluginAutoUpgradeService.get_strategy(tenant)
assert strategy is not None
assert strategy.strategy_setting == TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST
assert strategy.upgrade_time_of_day == 12
assert strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL
assert strategy.include_plugins == ["plugin-a"]
def test_exclude_plugin_creates_strategy_when_none_exists(self, tenant):
PluginAutoUpgradeService.exclude_plugin(tenant, "my-plugin")
strategy = PluginAutoUpgradeService.get_strategy(tenant)
assert strategy is not None
assert strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE
assert "my-plugin" in strategy.exclude_plugins
def test_exclude_plugin_appends_in_exclude_mode(self, tenant):
PluginAutoUpgradeService.change_strategy(
tenant,
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
upgrade_time_of_day=0,
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
exclude_plugins=["existing"],
include_plugins=[],
)
PluginAutoUpgradeService.exclude_plugin(tenant, "new-plugin")
strategy = PluginAutoUpgradeService.get_strategy(tenant)
assert strategy is not None
assert "existing" in strategy.exclude_plugins
assert "new-plugin" in strategy.exclude_plugins
def test_exclude_plugin_dedup_in_exclude_mode(self, tenant):
PluginAutoUpgradeService.change_strategy(
tenant,
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
upgrade_time_of_day=0,
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
exclude_plugins=["same-plugin"],
include_plugins=[],
)
PluginAutoUpgradeService.exclude_plugin(tenant, "same-plugin")
strategy = PluginAutoUpgradeService.get_strategy(tenant)
assert strategy is not None
assert strategy.exclude_plugins.count("same-plugin") == 1
def test_exclude_from_partial_mode_removes_from_include(self, tenant):
PluginAutoUpgradeService.change_strategy(
tenant,
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
upgrade_time_of_day=0,
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL,
exclude_plugins=[],
include_plugins=["p1", "p2"],
)
PluginAutoUpgradeService.exclude_plugin(tenant, "p1")
strategy = PluginAutoUpgradeService.get_strategy(tenant)
assert strategy is not None
assert "p1" not in strategy.include_plugins
assert "p2" in strategy.include_plugins
def test_exclude_from_all_mode_switches_to_exclude(self, tenant):
PluginAutoUpgradeService.change_strategy(
tenant,
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
upgrade_time_of_day=0,
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
exclude_plugins=[],
include_plugins=[],
)
PluginAutoUpgradeService.exclude_plugin(tenant, "excluded-plugin")
strategy = PluginAutoUpgradeService.get_strategy(tenant)
assert strategy is not None
assert strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE
assert "excluded-plugin" in strategy.exclude_plugins

View File

@@ -0,0 +1,348 @@
import datetime
import math
import uuid
import pytest
from sqlalchemy import delete
from core.db.session_factory import session_factory
from models import Tenant
from models.enums import FeedbackFromSource, FeedbackRating
from models.model import (
App,
Conversation,
Message,
MessageAnnotation,
MessageFeedback,
)
from services.retention.conversation.messages_clean_policy import BillingDisabledPolicy
from services.retention.conversation.messages_clean_service import MessagesCleanService
_NOW = datetime.datetime(2026, 1, 15, 12, 0, 0, tzinfo=datetime.UTC)
_OLD = _NOW - datetime.timedelta(days=60)
_VERY_OLD = _NOW - datetime.timedelta(days=90)
_RECENT = _NOW - datetime.timedelta(days=5)
_WINDOW_START = _VERY_OLD - datetime.timedelta(hours=1)
_WINDOW_END = _RECENT + datetime.timedelta(hours=1)
_DEFAULT_BATCH_SIZE = 100
_PAGINATION_MESSAGE_COUNT = 25
_PAGINATION_BATCH_SIZE = 8
@pytest.fixture
def tenant_and_app(flask_req_ctx):
"""Creates a Tenant, App and Conversation for the test and cleans up after."""
with session_factory.create_session() as session:
tenant = Tenant(name="retention_it_tenant")
session.add(tenant)
session.flush()
app = App(
tenant_id=tenant.id,
name="Retention IT App",
mode="chat",
enable_site=True,
enable_api=True,
)
session.add(app)
session.flush()
conv = Conversation(
app_id=app.id,
mode="chat",
name="test_conv",
status="normal",
from_source="console",
_inputs={},
)
session.add(conv)
session.commit()
tenant_id = tenant.id
app_id = app.id
conv_id = conv.id
yield {"tenant_id": tenant_id, "app_id": app_id, "conversation_id": conv_id}
with session_factory.create_session() as session:
session.execute(delete(Conversation).where(Conversation.id == conv_id))
session.execute(delete(App).where(App.id == app_id))
session.execute(delete(Tenant).where(Tenant.id == tenant_id))
session.commit()
def _make_message(app_id: str, conversation_id: str, created_at: datetime.datetime) -> Message:
return Message(
app_id=app_id,
conversation_id=conversation_id,
query="test",
message=[{"text": "hello"}],
answer="world",
message_tokens=1,
message_unit_price=0,
answer_tokens=1,
answer_unit_price=0,
from_source="console",
currency="USD",
_inputs={},
created_at=created_at,
)
class TestMessagesCleanServiceIntegration:
@pytest.fixture
def seed_messages(self, tenant_and_app):
"""Seeds one message at each of _VERY_OLD, _OLD, and _RECENT.
Yields a semantic mapping keyed by age label.
"""
data = tenant_and_app
app_id = data["app_id"]
conv_id = data["conversation_id"]
# Ordered tuple of (label, timestamp) for deterministic seeding
timestamps = [
("very_old", _VERY_OLD),
("old", _OLD),
("recent", _RECENT),
]
msg_ids: dict[str, str] = {}
with session_factory.create_session() as session:
for label, ts in timestamps:
msg = _make_message(app_id, conv_id, ts)
session.add(msg)
session.flush()
msg_ids[label] = msg.id
session.commit()
yield {"msg_ids": msg_ids, **data}
with session_factory.create_session() as session:
session.execute(
delete(Message)
.where(Message.id.in_(list(msg_ids.values())))
.execution_options(synchronize_session=False)
)
session.commit()
@pytest.fixture
def paginated_seed_messages(self, tenant_and_app):
"""Seeds multiple messages separated by 1-second increments starting at _OLD."""
data = tenant_and_app
app_id = data["app_id"]
conv_id = data["conversation_id"]
msg_ids: list[str] = []
with session_factory.create_session() as session:
for i in range(_PAGINATION_MESSAGE_COUNT):
ts = _OLD + datetime.timedelta(seconds=i)
msg = _make_message(app_id, conv_id, ts)
session.add(msg)
session.flush()
msg_ids.append(msg.id)
session.commit()
yield {"msg_ids": msg_ids, **data}
with session_factory.create_session() as session:
session.execute(delete(Message).where(Message.id.in_(msg_ids)).execution_options(synchronize_session=False))
session.commit()
@pytest.fixture
def cascade_test_data(self, tenant_and_app):
"""Seeds one Message with an associated Feedback and Annotation."""
data = tenant_and_app
app_id = data["app_id"]
conv_id = data["conversation_id"]
with session_factory.create_session() as session:
msg = _make_message(app_id, conv_id, _OLD)
session.add(msg)
session.flush()
feedback = MessageFeedback(
app_id=app_id,
conversation_id=conv_id,
message_id=msg.id,
rating=FeedbackRating.LIKE,
from_source=FeedbackFromSource.USER,
)
annotation = MessageAnnotation(
app_id=app_id,
conversation_id=conv_id,
message_id=msg.id,
question="q",
content="a",
account_id=str(uuid.uuid4()),
)
session.add_all([feedback, annotation])
session.commit()
msg_id = msg.id
fb_id = feedback.id
ann_id = annotation.id
yield {"msg_id": msg_id, "fb_id": fb_id, "ann_id": ann_id, **data}
with session_factory.create_session() as session:
session.execute(delete(MessageAnnotation).where(MessageAnnotation.id == ann_id))
session.execute(delete(MessageFeedback).where(MessageFeedback.id == fb_id))
session.execute(delete(Message).where(Message.id == msg_id))
session.commit()
def test_dry_run_does_not_delete(self, seed_messages):
"""Dry-run must count eligible rows without deleting any of them."""
data = seed_messages
msg_ids = data["msg_ids"]
all_ids = list(msg_ids.values())
svc = MessagesCleanService.from_time_range(
policy=BillingDisabledPolicy(),
start_from=_WINDOW_START,
end_before=_WINDOW_END,
batch_size=_DEFAULT_BATCH_SIZE,
dry_run=True,
)
stats = svc.run()
assert stats["filtered_messages"] == len(all_ids)
assert stats["total_deleted"] == 0
with session_factory.create_session() as session:
remaining = session.query(Message).where(Message.id.in_(all_ids)).count()
assert remaining == len(all_ids)
def test_billing_disabled_deletes_all_in_range(self, seed_messages):
"""All 3 seeded messages fall within the window and must be deleted."""
data = seed_messages
msg_ids = data["msg_ids"]
all_ids = list(msg_ids.values())
svc = MessagesCleanService.from_time_range(
policy=BillingDisabledPolicy(),
start_from=_WINDOW_START,
end_before=_WINDOW_END,
batch_size=_DEFAULT_BATCH_SIZE,
dry_run=False,
)
stats = svc.run()
assert stats["total_deleted"] == len(all_ids)
with session_factory.create_session() as session:
remaining = session.query(Message).where(Message.id.in_(all_ids)).count()
assert remaining == 0
def test_start_from_filters_correctly(self, seed_messages):
"""Only the message at _OLD falls within the narrow ±1 h window."""
data = seed_messages
msg_ids = data["msg_ids"]
start = _OLD - datetime.timedelta(hours=1)
end = _OLD + datetime.timedelta(hours=1)
svc = MessagesCleanService.from_time_range(
policy=BillingDisabledPolicy(),
start_from=start,
end_before=end,
batch_size=_DEFAULT_BATCH_SIZE,
)
stats = svc.run()
assert stats["total_deleted"] == 1
with session_factory.create_session() as session:
all_ids = list(msg_ids.values())
remaining_ids = {r[0] for r in session.query(Message.id).where(Message.id.in_(all_ids)).all()}
assert msg_ids["old"] not in remaining_ids
assert msg_ids["very_old"] in remaining_ids
assert msg_ids["recent"] in remaining_ids
def test_cursor_pagination_across_batches(self, paginated_seed_messages):
"""Messages must be deleted across multiple batches."""
data = paginated_seed_messages
msg_ids = data["msg_ids"]
# _OLD is the earliest; the last one is _OLD + (_PAGINATION_MESSAGE_COUNT - 1) s.
pagination_window_start = _OLD - datetime.timedelta(seconds=1)
pagination_window_end = _OLD + datetime.timedelta(seconds=_PAGINATION_MESSAGE_COUNT)
svc = MessagesCleanService.from_time_range(
policy=BillingDisabledPolicy(),
start_from=pagination_window_start,
end_before=pagination_window_end,
batch_size=_PAGINATION_BATCH_SIZE,
)
stats = svc.run()
assert stats["total_deleted"] == _PAGINATION_MESSAGE_COUNT
expected_batches = math.ceil(_PAGINATION_MESSAGE_COUNT / _PAGINATION_BATCH_SIZE)
assert stats["batches"] >= expected_batches
with session_factory.create_session() as session:
remaining = session.query(Message).where(Message.id.in_(msg_ids)).count()
assert remaining == 0
def test_no_messages_in_range_returns_empty_stats(self, seed_messages):
"""A window entirely in the future must yield zero matches."""
far_future = _NOW + datetime.timedelta(days=365)
even_further = far_future + datetime.timedelta(days=1)
svc = MessagesCleanService.from_time_range(
policy=BillingDisabledPolicy(),
start_from=far_future,
end_before=even_further,
batch_size=_DEFAULT_BATCH_SIZE,
)
stats = svc.run()
assert stats["total_messages"] == 0
assert stats["total_deleted"] == 0
def test_relation_cascade_deletes(self, cascade_test_data):
"""Deleting a Message must cascade to its Feedback and Annotation rows."""
data = cascade_test_data
msg_id = data["msg_id"]
fb_id = data["fb_id"]
ann_id = data["ann_id"]
svc = MessagesCleanService.from_time_range(
policy=BillingDisabledPolicy(),
start_from=_OLD - datetime.timedelta(hours=1),
end_before=_OLD + datetime.timedelta(hours=1),
batch_size=_DEFAULT_BATCH_SIZE,
)
stats = svc.run()
assert stats["total_deleted"] == 1
with session_factory.create_session() as session:
assert session.query(Message).where(Message.id == msg_id).count() == 0
assert session.query(MessageFeedback).where(MessageFeedback.id == fb_id).count() == 0
assert session.query(MessageAnnotation).where(MessageAnnotation.id == ann_id).count() == 0
def test_factory_from_time_range_validation(self):
with pytest.raises(ValueError, match="start_from"):
MessagesCleanService.from_time_range(
policy=BillingDisabledPolicy(),
start_from=_NOW,
end_before=_OLD,
)
def test_factory_from_days_validation(self):
with pytest.raises(ValueError, match="days"):
MessagesCleanService.from_days(
policy=BillingDisabledPolicy(),
days=-1,
)
def test_factory_batch_size_validation(self):
with pytest.raises(ValueError, match="batch_size"):
MessagesCleanService.from_time_range(
policy=BillingDisabledPolicy(),
start_from=_OLD,
end_before=_NOW,
batch_size=0,
)

View File

@@ -0,0 +1,177 @@
import datetime
import io
import json
import uuid
import zipfile
from unittest.mock import MagicMock, patch
import pytest
from services.retention.workflow_run.archive_paid_plan_workflow_run import (
ArchiveSummary,
WorkflowRunArchiver,
)
from services.retention.workflow_run.constants import ARCHIVE_SCHEMA_VERSION
class TestWorkflowRunArchiverInit:
def test_start_from_without_end_before_raises(self):
with pytest.raises(ValueError, match="start_from and end_before must be provided together"):
WorkflowRunArchiver(start_from=datetime.datetime(2025, 1, 1))
def test_end_before_without_start_from_raises(self):
with pytest.raises(ValueError, match="start_from and end_before must be provided together"):
WorkflowRunArchiver(end_before=datetime.datetime(2025, 1, 1))
def test_start_equals_end_raises(self):
ts = datetime.datetime(2025, 1, 1)
with pytest.raises(ValueError, match="start_from must be earlier than end_before"):
WorkflowRunArchiver(start_from=ts, end_before=ts)
def test_start_after_end_raises(self):
with pytest.raises(ValueError, match="start_from must be earlier than end_before"):
WorkflowRunArchiver(
start_from=datetime.datetime(2025, 6, 1),
end_before=datetime.datetime(2025, 1, 1),
)
def test_workers_zero_raises(self):
with pytest.raises(ValueError, match="workers must be at least 1"):
WorkflowRunArchiver(workers=0)
def test_valid_init_defaults(self):
archiver = WorkflowRunArchiver(days=30, batch_size=50)
assert archiver.days == 30
assert archiver.batch_size == 50
assert archiver.dry_run is False
assert archiver.delete_after_archive is False
assert archiver.start_from is None
def test_valid_init_with_time_range(self):
start = datetime.datetime(2025, 1, 1)
end = datetime.datetime(2025, 6, 1)
archiver = WorkflowRunArchiver(start_from=start, end_before=end, workers=2)
assert archiver.start_from is not None
assert archiver.end_before is not None
assert archiver.workers == 2
class TestBuildArchiveBundle:
def test_bundle_contains_manifest_and_all_tables(self):
archiver = WorkflowRunArchiver(days=90)
manifest_data = json.dumps({"schema_version": ARCHIVE_SCHEMA_VERSION}).encode("utf-8")
table_payloads = dict.fromkeys(archiver.ARCHIVED_TABLES, b"")
bundle_bytes = archiver._build_archive_bundle(manifest_data, table_payloads)
with zipfile.ZipFile(io.BytesIO(bundle_bytes), "r") as zf:
names = set(zf.namelist())
assert "manifest.json" in names
for table in archiver.ARCHIVED_TABLES:
assert f"{table}.jsonl" in names, f"Missing {table}.jsonl in bundle"
def test_bundle_missing_table_payload_raises(self):
archiver = WorkflowRunArchiver(days=90)
manifest_data = b"{}"
incomplete_payloads = {archiver.ARCHIVED_TABLES[0]: b"data"}
with pytest.raises(ValueError, match="Missing archive payload"):
archiver._build_archive_bundle(manifest_data, incomplete_payloads)
class TestGenerateManifest:
def test_manifest_structure(self):
archiver = WorkflowRunArchiver(days=90)
from services.retention.workflow_run.archive_paid_plan_workflow_run import TableStats
run = MagicMock()
run.id = str(uuid.uuid4())
run.tenant_id = str(uuid.uuid4())
run.app_id = str(uuid.uuid4())
run.workflow_id = str(uuid.uuid4())
run.created_at = datetime.datetime(2025, 3, 15, 10, 0, 0)
stats = [
TableStats(table_name="workflow_runs", row_count=1, checksum="abc123", size_bytes=512),
TableStats(table_name="workflow_app_logs", row_count=2, checksum="def456", size_bytes=1024),
]
manifest = archiver._generate_manifest(run, stats)
assert manifest["schema_version"] == ARCHIVE_SCHEMA_VERSION
assert manifest["workflow_run_id"] == run.id
assert manifest["tenant_id"] == run.tenant_id
assert manifest["app_id"] == run.app_id
assert "tables" in manifest
assert manifest["tables"]["workflow_runs"]["row_count"] == 1
assert manifest["tables"]["workflow_runs"]["checksum"] == "abc123"
assert manifest["tables"]["workflow_app_logs"]["row_count"] == 2
class TestFilterPaidTenants:
def test_all_tenants_paid_when_billing_disabled(self):
archiver = WorkflowRunArchiver(days=90)
tenant_ids = {"t1", "t2", "t3"}
with patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg:
cfg.BILLING_ENABLED = False
result = archiver._filter_paid_tenants(tenant_ids)
assert result == tenant_ids
def test_empty_tenants_returns_empty(self):
archiver = WorkflowRunArchiver(days=90)
with patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg:
cfg.BILLING_ENABLED = True
result = archiver._filter_paid_tenants(set())
assert result == set()
def test_only_paid_plans_returned(self):
archiver = WorkflowRunArchiver(days=90)
mock_bulk = {
"t1": {"plan": "professional"},
"t2": {"plan": "sandbox"},
"t3": {"plan": "team"},
}
with (
patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg,
patch("services.retention.workflow_run.archive_paid_plan_workflow_run.BillingService") as billing,
):
cfg.BILLING_ENABLED = True
billing.get_plan_bulk_with_cache.return_value = mock_bulk
result = archiver._filter_paid_tenants({"t1", "t2", "t3"})
assert "t1" in result
assert "t3" in result
assert "t2" not in result
def test_billing_api_failure_returns_empty(self):
archiver = WorkflowRunArchiver(days=90)
with (
patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg,
patch("services.retention.workflow_run.archive_paid_plan_workflow_run.BillingService") as billing,
):
cfg.BILLING_ENABLED = True
billing.get_plan_bulk_with_cache.side_effect = RuntimeError("API down")
result = archiver._filter_paid_tenants({"t1"})
assert result == set()
class TestDryRunArchive:
@patch("services.retention.workflow_run.archive_paid_plan_workflow_run.get_archive_storage")
def test_dry_run_does_not_call_storage(self, mock_get_storage, flask_req_ctx):
archiver = WorkflowRunArchiver(days=90, dry_run=True)
with patch.object(archiver, "_get_runs_batch", return_value=[]):
summary = archiver.run()
mock_get_storage.assert_not_called()
assert isinstance(summary, ArchiveSummary)
assert summary.runs_failed == 0

View File

@@ -1,6 +1,12 @@
"""Testcontainers integration tests for rag_pipeline controller endpoints."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from controllers.console import console_ns
from controllers.console.datasets.rag_pipeline.rag_pipeline import (
@@ -9,6 +15,7 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline import (
PipelineTemplateListApi,
PublishCustomizedPipelineTemplateApi,
)
from models.dataset import PipelineCustomizedTemplate
def unwrap(func):
@@ -18,6 +25,10 @@ def unwrap(func):
class TestPipelineTemplateListApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_get_success(self, app):
api = PipelineTemplateListApi()
method = unwrap(api.get)
@@ -38,6 +49,10 @@ class TestPipelineTemplateListApi:
class TestPipelineTemplateDetailApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_get_success(self, app):
api = PipelineTemplateDetailApi()
method = unwrap(api.get)
@@ -99,6 +114,10 @@ class TestPipelineTemplateDetailApi:
class TestCustomizedPipelineTemplateApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_patch_success(self, app):
api = CustomizedPipelineTemplateApi()
method = unwrap(api.patch)
@@ -136,35 +155,29 @@ class TestCustomizedPipelineTemplateApi:
delete_mock.assert_called_once_with("tpl-1")
assert response == 200
def test_post_success(self, app):
def test_post_success(self, app, db_session_with_containers: Session):
api = CustomizedPipelineTemplateApi()
method = unwrap(api.post)
template = MagicMock()
template.yaml_content = "yaml-data"
tenant_id = str(uuid4())
template = PipelineCustomizedTemplate(
tenant_id=tenant_id,
name="Test Template",
description="Test",
chunk_structure="hierarchical",
icon={"icon": "📘"},
position=0,
yaml_content="yaml-data",
install_count=0,
language="en-US",
created_by=str(uuid4()),
)
db_session_with_containers.add(template)
db_session_with_containers.commit()
db_session_with_containers.expire_all()
fake_db = MagicMock()
fake_db.engine = MagicMock()
session = MagicMock()
session.query.return_value.where.return_value.first.return_value = template
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.Session",
return_value=session_ctx,
),
):
response, status = method(api, "tpl-1")
with app.test_request_context("/"):
response, status = method(api, template.id)
assert status == 200
assert response == {"data": "yaml-data"}
@@ -173,32 +186,16 @@ class TestCustomizedPipelineTemplateApi:
api = CustomizedPipelineTemplateApi()
method = unwrap(api.post)
fake_db = MagicMock()
fake_db.engine = MagicMock()
session = MagicMock()
session.query.return_value.where.return_value.first.return_value = None
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.Session",
return_value=session_ctx,
),
):
with app.test_request_context("/"):
with pytest.raises(ValueError):
method(api, "tpl-1")
method(api, str(uuid4()))
class TestPublishCustomizedPipelineTemplateApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_post_success(self, app):
api = PublishCustomizedPipelineTemplateApi()
method = unwrap(api.post)

View File

@@ -1,3 +1,7 @@
"""Testcontainers integration tests for rag_pipeline_datasets controller endpoints."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
@@ -19,6 +23,10 @@ def unwrap(func):
class TestCreateRagPipelineDatasetApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def _valid_payload(self):
return {"yaml_content": "name: test"}
@@ -33,13 +41,6 @@ class TestCreateRagPipelineDatasetApi:
mock_service = MagicMock()
mock_service.create_rag_pipeline_dataset.return_value = import_info
mock_session_ctx = MagicMock()
mock_session_ctx.__enter__.return_value = MagicMock()
mock_session_ctx.__exit__.return_value = None
fake_db = MagicMock()
fake_db.engine = MagicMock()
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
@@ -47,14 +48,6 @@ class TestCreateRagPipelineDatasetApi:
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session",
return_value=mock_session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService",
return_value=mock_service,
@@ -93,13 +86,6 @@ class TestCreateRagPipelineDatasetApi:
mock_service = MagicMock()
mock_service.create_rag_pipeline_dataset.side_effect = services.errors.dataset.DatasetNameDuplicateError()
mock_session_ctx = MagicMock()
mock_session_ctx.__enter__.return_value = MagicMock()
mock_session_ctx.__exit__.return_value = None
fake_db = MagicMock()
fake_db.engine = MagicMock()
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
@@ -107,14 +93,6 @@ class TestCreateRagPipelineDatasetApi:
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session",
return_value=mock_session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService",
return_value=mock_service,
@@ -143,6 +121,10 @@ class TestCreateRagPipelineDatasetApi:
class TestCreateEmptyRagPipelineDatasetApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_post_success(self, app):
api = CreateEmptyRagPipelineDatasetApi()
method = unwrap(api.post)

View File

@@ -1,5 +1,11 @@
"""Testcontainers integration tests for rag_pipeline_import controller endpoints."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from controllers.console import console_ns
from controllers.console.datasets.rag_pipeline.rag_pipeline_import import (
RagPipelineExportApi,
@@ -18,6 +24,10 @@ def unwrap(func):
class TestRagPipelineImportApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def _payload(self, mode="create"):
return {
"mode": mode,
@@ -30,7 +40,6 @@ class TestRagPipelineImportApi:
method = unwrap(api.post)
payload = self._payload()
user = MagicMock()
result = MagicMock()
result.status = "completed"
@@ -39,13 +48,6 @@ class TestRagPipelineImportApi:
service = MagicMock()
service.import_rag_pipeline.return_value = result
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
@@ -53,14 +55,6 @@ class TestRagPipelineImportApi:
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
@@ -76,7 +70,6 @@ class TestRagPipelineImportApi:
method = unwrap(api.post)
payload = self._payload()
user = MagicMock()
result = MagicMock()
result.status = ImportStatus.FAILED
@@ -85,13 +78,6 @@ class TestRagPipelineImportApi:
service = MagicMock()
service.import_rag_pipeline.return_value = result
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
@@ -99,14 +85,6 @@ class TestRagPipelineImportApi:
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
@@ -122,7 +100,6 @@ class TestRagPipelineImportApi:
method = unwrap(api.post)
payload = self._payload()
user = MagicMock()
result = MagicMock()
result.status = ImportStatus.PENDING
@@ -131,13 +108,6 @@ class TestRagPipelineImportApi:
service = MagicMock()
service.import_rag_pipeline.return_value = result
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
@@ -145,14 +115,6 @@ class TestRagPipelineImportApi:
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
@@ -165,6 +127,10 @@ class TestRagPipelineImportApi:
class TestRagPipelineImportConfirmApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_confirm_success(self, app):
api = RagPipelineImportConfirmApi()
method = unwrap(api.post)
@@ -177,27 +143,12 @@ class TestRagPipelineImportConfirmApi:
service = MagicMock()
service.confirm_import.return_value = result
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
@@ -220,27 +171,12 @@ class TestRagPipelineImportConfirmApi:
service = MagicMock()
service.confirm_import.return_value = result
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
@@ -253,6 +189,10 @@ class TestRagPipelineImportConfirmApi:
class TestRagPipelineImportCheckDependenciesApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_get_success(self, app):
api = RagPipelineImportCheckDependenciesApi()
method = unwrap(api.get)
@@ -264,23 +204,8 @@ class TestRagPipelineImportCheckDependenciesApi:
service = MagicMock()
service.check_dependencies.return_value = result
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
@@ -293,6 +218,10 @@ class TestRagPipelineImportCheckDependenciesApi:
class TestRagPipelineExportApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_get_with_include_secret(self, app):
api = RagPipelineExportApi()
method = unwrap(api.get)
@@ -301,23 +230,8 @@ class TestRagPipelineExportApi:
service = MagicMock()
service.export_rag_pipeline_dsl.return_value = {"yaml": "data"}
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/?include_secret=true"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,

View File

@@ -1,7 +1,13 @@
"""Testcontainers integration tests for rag_pipeline_workflow controller endpoints."""
from __future__ import annotations
from datetime import datetime
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, HTTPException, NotFound
import services
@@ -38,6 +44,10 @@ def unwrap(func):
class TestDraftWorkflowApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_get_draft_success(self, app):
api = DraftRagPipelineApi()
method = unwrap(api.get)
@@ -200,6 +210,10 @@ class TestDraftWorkflowApi:
class TestDraftRunNodes:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_iteration_node_success(self, app):
api = RagPipelineDraftRunIterationNodeApi()
method = unwrap(api.post)
@@ -275,6 +289,10 @@ class TestDraftRunNodes:
class TestPipelineRunApis:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_draft_run_success(self, app):
api = DraftRagPipelineRunApi()
method = unwrap(api.post)
@@ -337,6 +355,10 @@ class TestPipelineRunApis:
class TestDraftNodeRun:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_execution_not_found(self, app):
api = RagPipelineDraftNodeRunApi()
method = unwrap(api.post)
@@ -364,45 +386,43 @@ class TestDraftNodeRun:
class TestPublishedPipelineApis:
def test_publish_success(self, app):
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_publish_success(self, app, db_session_with_containers: Session):
from models.dataset import Pipeline
api = PublishedRagPipelineApi()
method = unwrap(api.post)
pipeline = MagicMock()
tenant_id = str(uuid4())
pipeline = Pipeline(
tenant_id=tenant_id,
name="test-pipeline",
description="test",
created_by=str(uuid4()),
)
db_session_with_containers.add(pipeline)
db_session_with_containers.commit()
db_session_with_containers.expire_all()
user = MagicMock(id="u1")
workflow = MagicMock(
id="w1",
id=str(uuid4()),
created_at=naive_utc_now(),
)
session = MagicMock()
session.merge.return_value = pipeline
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
service = MagicMock()
service.publish_workflow.return_value = workflow
fake_db = MagicMock()
fake_db.engine = MagicMock()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
@@ -415,6 +435,10 @@ class TestPublishedPipelineApis:
class TestMiscApis:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_task_stop(self, app):
api = RagPipelineTaskStopApi()
method = unwrap(api.post)
@@ -471,6 +495,10 @@ class TestMiscApis:
class TestPublishedRagPipelineRunApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_published_run_success(self, app):
api = PublishedRagPipelineRunApi()
method = unwrap(api.post)
@@ -536,6 +564,10 @@ class TestPublishedRagPipelineRunApi:
class TestDefaultBlockConfigApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_get_block_config_success(self, app):
api = DefaultRagPipelineBlockConfigApi()
method = unwrap(api.get)
@@ -567,6 +599,10 @@ class TestDefaultBlockConfigApi:
class TestPublishedAllRagPipelineApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_get_published_workflows_success(self, app):
api = PublishedAllRagPipelineApi()
method = unwrap(api.get)
@@ -577,28 +613,12 @@ class TestPublishedAllRagPipelineApi:
service = MagicMock()
service.get_all_published_workflow.return_value = ([{"id": "w1"}], False)
session = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
fake_db = MagicMock()
fake_db.engine = MagicMock()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
@@ -628,6 +648,10 @@ class TestPublishedAllRagPipelineApi:
class TestRagPipelineByIdApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_patch_success(self, app):
api = RagPipelineByIdApi()
method = unwrap(api.patch)
@@ -640,14 +664,6 @@ class TestRagPipelineByIdApi:
service = MagicMock()
service.update_workflow.return_value = workflow
session = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
fake_db = MagicMock()
fake_db.engine = MagicMock()
payload = {"marked_name": "test"}
with (
@@ -657,14 +673,6 @@ class TestRagPipelineByIdApi:
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
@@ -700,24 +708,8 @@ class TestRagPipelineByIdApi:
workflow_service = MagicMock()
session = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
fake_db = MagicMock()
fake_db.engine = MagicMock()
with (
app.test_request_context("/", method="DELETE"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.WorkflowService",
return_value=workflow_service,
@@ -725,12 +717,7 @@ class TestRagPipelineByIdApi:
):
result = method(api, pipeline, "old-workflow")
workflow_service.delete_workflow.assert_called_once_with(
session=session,
workflow_id="old-workflow",
tenant_id="t1",
)
session.commit.assert_called_once()
workflow_service.delete_workflow.assert_called_once()
assert result == (None, 204)
def test_delete_active_workflow_rejected(self, app):
@@ -745,6 +732,10 @@ class TestRagPipelineByIdApi:
class TestRagPipelineWorkflowLastRunApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_last_run_success(self, app):
api = RagPipelineWorkflowLastRunApi()
method = unwrap(api.get)
@@ -788,6 +779,10 @@ class TestRagPipelineWorkflowLastRunApi:
class TestRagPipelineDatasourceVariableApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_set_datasource_variables_success(self, app):
api = RagPipelineDatasourceVariableApi()
method = unwrap(api.post)

View File

@@ -1,7 +1,10 @@
"""Testcontainers integration tests for controllers.console.explore.conversation endpoints."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
import controllers.console.explore.conversation as conversation_module
@@ -48,24 +51,12 @@ def user():
return user
@pytest.fixture(autouse=True)
def mock_db_and_session():
with (
patch.object(
conversation_module,
"db",
MagicMock(session=MagicMock(), engine=MagicMock()),
),
patch(
"controllers.console.explore.conversation.Session",
MagicMock(),
),
):
yield
class TestConversationListApi:
def test_get_success(self, app: Flask, chat_app, user):
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_get_success(self, app, chat_app, user):
api = conversation_module.ConversationListApi()
method = unwrap(api.get)
@@ -90,7 +81,7 @@ class TestConversationListApi:
assert result["has_more"] is False
assert len(result["data"]) == 2
def test_last_conversation_not_exists(self, app: Flask, chat_app, user):
def test_last_conversation_not_exists(self, app, chat_app, user):
api = conversation_module.ConversationListApi()
method = unwrap(api.get)
@@ -106,7 +97,7 @@ class TestConversationListApi:
with pytest.raises(NotFound):
method(chat_app)
def test_wrong_app_mode(self, app: Flask, non_chat_app):
def test_wrong_app_mode(self, app, non_chat_app):
api = conversation_module.ConversationListApi()
method = unwrap(api.get)
@@ -116,7 +107,11 @@ class TestConversationListApi:
class TestConversationApi:
def test_delete_success(self, app: Flask, chat_app, user):
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_delete_success(self, app, chat_app, user):
api = conversation_module.ConversationApi()
method = unwrap(api.delete)
@@ -134,7 +129,7 @@ class TestConversationApi:
assert status == 204
assert body["result"] == "success"
def test_delete_not_found(self, app: Flask, chat_app, user):
def test_delete_not_found(self, app, chat_app, user):
api = conversation_module.ConversationApi()
method = unwrap(api.delete)
@@ -150,7 +145,7 @@ class TestConversationApi:
with pytest.raises(NotFound):
method(chat_app, "cid")
def test_delete_wrong_app_mode(self, app: Flask, non_chat_app):
def test_delete_wrong_app_mode(self, app, non_chat_app):
api = conversation_module.ConversationApi()
method = unwrap(api.delete)
@@ -160,7 +155,11 @@ class TestConversationApi:
class TestConversationRenameApi:
def test_rename_success(self, app: Flask, chat_app, user):
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_rename_success(self, app, chat_app, user):
api = conversation_module.ConversationRenameApi()
method = unwrap(api.post)
@@ -179,7 +178,7 @@ class TestConversationRenameApi:
assert result["id"] == "cid"
def test_rename_not_found(self, app: Flask, chat_app, user):
def test_rename_not_found(self, app, chat_app, user):
api = conversation_module.ConversationRenameApi()
method = unwrap(api.post)
@@ -197,7 +196,11 @@ class TestConversationRenameApi:
class TestConversationPinApi:
def test_pin_success(self, app: Flask, chat_app, user):
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_pin_success(self, app, chat_app, user):
api = conversation_module.ConversationPinApi()
method = unwrap(api.patch)
@@ -215,7 +218,11 @@ class TestConversationPinApi:
class TestConversationUnPinApi:
def test_unpin_success(self, app: Flask, chat_app, user):
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_unpin_success(self, app, chat_app, user):
api = conversation_module.ConversationUnPinApi()
method = unwrap(api.patch)

View File

@@ -5,6 +5,7 @@ Covers:
- License status caching (get_cached_license_status)
"""
from datetime import datetime
from unittest.mock import patch
import pytest
@@ -15,9 +16,178 @@ from services.enterprise.enterprise_service import (
VALID_LICENSE_CACHE_TTL,
DefaultWorkspaceJoinResult,
EnterpriseService,
WebAppSettings,
WorkspacePermission,
try_join_default_workspace,
)
MODULE = "services.enterprise.enterprise_service"
class TestEnterpriseServiceInfo:
def test_get_info_delegates(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = {"version": "1.0"}
result = EnterpriseService.get_info()
req.send_request.assert_called_once_with("GET", "/info")
assert result == {"version": "1.0"}
def test_get_workspace_info_delegates(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = {"name": "ws"}
result = EnterpriseService.get_workspace_info("tenant-1")
req.send_request.assert_called_once_with("GET", "/workspace/tenant-1/info")
assert result == {"name": "ws"}
class TestSsoSettingsLastUpdateTime:
def test_app_sso_parses_valid_timestamp(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = "2025-01-15T10:30:00+00:00"
result = EnterpriseService.get_app_sso_settings_last_update_time()
assert isinstance(result, datetime)
assert result.year == 2025
def test_app_sso_raises_on_empty(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = ""
with pytest.raises(ValueError, match="No data found"):
EnterpriseService.get_app_sso_settings_last_update_time()
def test_app_sso_raises_on_invalid_format(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = "not-a-date"
with pytest.raises(ValueError, match="Invalid date format"):
EnterpriseService.get_app_sso_settings_last_update_time()
def test_workspace_sso_parses_valid_timestamp(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = "2025-06-01T00:00:00+00:00"
result = EnterpriseService.get_workspace_sso_settings_last_update_time()
assert isinstance(result, datetime)
def test_workspace_sso_raises_on_empty(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = None
with pytest.raises(ValueError, match="No data found"):
EnterpriseService.get_workspace_sso_settings_last_update_time()
class TestWorkspacePermissionService:
def test_raises_on_empty_workspace_id(self):
with pytest.raises(ValueError, match="workspace_id must be provided"):
EnterpriseService.WorkspacePermissionService.get_permission("")
def test_raises_on_missing_data(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = None
with pytest.raises(ValueError, match="No data found"):
EnterpriseService.WorkspacePermissionService.get_permission("ws-1")
def test_raises_on_missing_permission_key(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = {"other": "data"}
with pytest.raises(ValueError, match="No data found"):
EnterpriseService.WorkspacePermissionService.get_permission("ws-1")
def test_returns_parsed_permission(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = {
"permission": {
"workspaceId": "ws-1",
"allowMemberInvite": True,
"allowOwnerTransfer": False,
}
}
result = EnterpriseService.WorkspacePermissionService.get_permission("ws-1")
assert isinstance(result, WorkspacePermission)
assert result.workspace_id == "ws-1"
assert result.allow_member_invite is True
assert result.allow_owner_transfer is False
class TestWebAppAuth:
def test_is_user_allowed_returns_result_field(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = {"result": True}
assert EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp("u1", "a1") is True
def test_is_user_allowed_defaults_false(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = {}
assert EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp("u1", "a1") is False
def test_batch_is_user_allowed_returns_empty_for_no_apps(self):
assert EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps("u1", []) == {}
def test_batch_is_user_allowed_raises_on_empty_response(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = None
with pytest.raises(ValueError, match="No data found"):
EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps("u1", ["a1"])
def test_get_app_access_mode_raises_on_empty_app_id(self):
with pytest.raises(ValueError, match="app_id must be provided"):
EnterpriseService.WebAppAuth.get_app_access_mode_by_id("")
def test_get_app_access_mode_returns_settings(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = {"accessMode": "public"}
result = EnterpriseService.WebAppAuth.get_app_access_mode_by_id("a1")
assert isinstance(result, WebAppSettings)
assert result.access_mode == "public"
def test_batch_get_returns_empty_for_no_apps(self):
assert EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id([]) == {}
def test_batch_get_maps_access_modes(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = {"accessModes": {"a1": "public", "a2": "private"}}
result = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(["a1", "a2"])
assert result["a1"].access_mode == "public"
assert result["a2"].access_mode == "private"
def test_batch_get_raises_on_invalid_format(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = {"accessModes": "not-a-dict"}
with pytest.raises(ValueError, match="Invalid data format"):
EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(["a1"])
def test_update_access_mode_raises_on_empty_app_id(self):
with pytest.raises(ValueError, match="app_id must be provided"):
EnterpriseService.WebAppAuth.update_app_access_mode("", "public")
def test_update_access_mode_raises_on_invalid_mode(self):
with pytest.raises(ValueError, match="access_mode must be"):
EnterpriseService.WebAppAuth.update_app_access_mode("a1", "invalid")
def test_update_access_mode_delegates_and_returns(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
req.send_request.return_value = {"result": True}
result = EnterpriseService.WebAppAuth.update_app_access_mode("a1", "public")
assert result is True
req.send_request.assert_called_once_with(
"POST", "/webapp/access-mode", json={"appId": "a1", "accessMode": "public"}
)
def test_cleanup_webapp_raises_on_empty_app_id(self):
with pytest.raises(ValueError, match="app_id must be provided"):
EnterpriseService.WebAppAuth.cleanup_webapp("")
def test_cleanup_webapp_delegates(self):
with patch(f"{MODULE}.EnterpriseRequest") as req:
EnterpriseService.WebAppAuth.cleanup_webapp("a1")
req.send_request.assert_called_once_with("DELETE", "/webapp/clean", params={"appId": "a1"})
class TestJoinDefaultWorkspace:
def test_join_default_workspace_success(self):

View File

@@ -7,14 +7,20 @@ This module covers the pre-uninstall plugin hook behavior:
from unittest.mock import patch
import pytest
from httpx import HTTPStatusError
from configs import dify_config
from services.enterprise.plugin_manager_service import (
CheckCredentialPolicyComplianceRequest,
CredentialPolicyViolationError,
PluginCredentialType,
PluginManagerService,
PreUninstallPluginRequest,
)
MODULE = "services.enterprise.plugin_manager_service"
class TestTryPreUninstallPlugin:
def test_try_pre_uninstall_plugin_success(self):
@@ -88,3 +94,46 @@ class TestTryPreUninstallPlugin:
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
)
mock_logger.exception.assert_called_once()
class TestCheckCredentialPolicyCompliance:
def _request(self, cred_type=PluginCredentialType.MODEL):
return CheckCredentialPolicyComplianceRequest(
dify_credential_id="cred-1", provider="openai", credential_type=cred_type
)
def test_passes_when_result_true(self):
with patch(f"{MODULE}.EnterprisePluginManagerRequest") as req:
req.send_request.return_value = {"result": True}
PluginManagerService.check_credential_policy_compliance(self._request())
req.send_request.assert_called_once()
def test_raises_violation_when_result_false(self):
with patch(f"{MODULE}.EnterprisePluginManagerRequest") as req:
req.send_request.return_value = {"result": False}
with pytest.raises(CredentialPolicyViolationError, match="Credentials not available"):
PluginManagerService.check_credential_policy_compliance(self._request())
def test_raises_violation_on_invalid_response_format(self):
with patch(f"{MODULE}.EnterprisePluginManagerRequest") as req:
req.send_request.return_value = "not-a-dict"
with pytest.raises(CredentialPolicyViolationError, match="error occurred"):
PluginManagerService.check_credential_policy_compliance(self._request())
def test_raises_violation_on_api_exception(self):
with patch(f"{MODULE}.EnterprisePluginManagerRequest") as req:
req.send_request.side_effect = ConnectionError("network fail")
with pytest.raises(CredentialPolicyViolationError, match="error occurred"):
PluginManagerService.check_credential_policy_compliance(self._request())
def test_model_dump_serializes_credential_type_as_number(self):
body = self._request(PluginCredentialType.TOOL)
data = body.model_dump()
assert data["credential_type"] == 1
assert data["dify_credential_id"] == "cred-1"
def test_model_credential_type_values(self):
assert PluginCredentialType.MODEL.to_number() == 0
assert PluginCredentialType.TOOL.to_number() == 1

View File

@@ -0,0 +1,183 @@
from unittest.mock import MagicMock, patch
from models.account import TenantPluginAutoUpgradeStrategy
MODULE = "services.plugin.plugin_auto_upgrade_service"
def _patched_session():
"""Patch Session(db.engine) to return a mock session as context manager."""
session = MagicMock()
session_cls = MagicMock()
session_cls.return_value.__enter__ = MagicMock(return_value=session)
session_cls.return_value.__exit__ = MagicMock(return_value=False)
patcher = patch(f"{MODULE}.Session", session_cls)
db_patcher = patch(f"{MODULE}.db")
return patcher, db_patcher, session
class TestGetStrategy:
def test_returns_strategy_when_found(self):
p1, p2, session = _patched_session()
strategy = MagicMock()
session.query.return_value.where.return_value.first.return_value = strategy
with p1, p2:
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
result = PluginAutoUpgradeService.get_strategy("t1")
assert result is strategy
def test_returns_none_when_not_found(self):
p1, p2, session = _patched_session()
session.query.return_value.where.return_value.first.return_value = None
with p1, p2:
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
result = PluginAutoUpgradeService.get_strategy("t1")
assert result is None
class TestChangeStrategy:
def test_creates_new_strategy(self):
p1, p2, session = _patched_session()
session.query.return_value.where.return_value.first.return_value = None
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
strat_cls.return_value = MagicMock()
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
result = PluginAutoUpgradeService.change_strategy(
"t1",
TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
3,
TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
[],
[],
)
assert result is True
session.add.assert_called_once()
session.commit.assert_called_once()
def test_updates_existing_strategy(self):
p1, p2, session = _patched_session()
existing = MagicMock()
session.query.return_value.where.return_value.first.return_value = existing
with p1, p2:
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
result = PluginAutoUpgradeService.change_strategy(
"t1",
TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
5,
TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL,
["p1"],
["p2"],
)
assert result is True
assert existing.strategy_setting == TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST
assert existing.upgrade_time_of_day == 5
assert existing.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL
assert existing.exclude_plugins == ["p1"]
assert existing.include_plugins == ["p2"]
session.commit.assert_called_once()
class TestExcludePlugin:
def test_creates_default_strategy_when_none_exists(self):
p1, p2, session = _patched_session()
session.query.return_value.where.return_value.first.return_value = None
with (
p1,
p2,
patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls,
patch(f"{MODULE}.PluginAutoUpgradeService.change_strategy") as cs,
):
strat_cls.StrategySetting.FIX_ONLY = "fix_only"
strat_cls.UpgradeMode.EXCLUDE = "exclude"
cs.return_value = True
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
result = PluginAutoUpgradeService.exclude_plugin("t1", "plugin-1")
assert result is True
cs.assert_called_once()
def test_appends_to_exclude_list_in_exclude_mode(self):
p1, p2, session = _patched_session()
existing = MagicMock()
existing.upgrade_mode = "exclude"
existing.exclude_plugins = ["p-existing"]
session.query.return_value.where.return_value.first.return_value = existing
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
strat_cls.UpgradeMode.EXCLUDE = "exclude"
strat_cls.UpgradeMode.PARTIAL = "partial"
strat_cls.UpgradeMode.ALL = "all"
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
result = PluginAutoUpgradeService.exclude_plugin("t1", "p-new")
assert result is True
assert existing.exclude_plugins == ["p-existing", "p-new"]
session.commit.assert_called_once()
def test_removes_from_include_list_in_partial_mode(self):
p1, p2, session = _patched_session()
existing = MagicMock()
existing.upgrade_mode = "partial"
existing.include_plugins = ["p1", "p2"]
session.query.return_value.where.return_value.first.return_value = existing
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
strat_cls.UpgradeMode.EXCLUDE = "exclude"
strat_cls.UpgradeMode.PARTIAL = "partial"
strat_cls.UpgradeMode.ALL = "all"
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
result = PluginAutoUpgradeService.exclude_plugin("t1", "p1")
assert result is True
assert existing.include_plugins == ["p2"]
def test_switches_to_exclude_mode_from_all(self):
p1, p2, session = _patched_session()
existing = MagicMock()
existing.upgrade_mode = "all"
session.query.return_value.where.return_value.first.return_value = existing
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
strat_cls.UpgradeMode.EXCLUDE = "exclude"
strat_cls.UpgradeMode.PARTIAL = "partial"
strat_cls.UpgradeMode.ALL = "all"
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
result = PluginAutoUpgradeService.exclude_plugin("t1", "p1")
assert result is True
assert existing.upgrade_mode == "exclude"
assert existing.exclude_plugins == ["p1"]
def test_no_duplicate_in_exclude_list(self):
p1, p2, session = _patched_session()
existing = MagicMock()
existing.upgrade_mode = "exclude"
existing.exclude_plugins = ["p1"]
session.query.return_value.where.return_value.first.return_value = existing
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
strat_cls.UpgradeMode.EXCLUDE = "exclude"
strat_cls.UpgradeMode.PARTIAL = "partial"
strat_cls.UpgradeMode.ALL = "all"
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
PluginAutoUpgradeService.exclude_plugin("t1", "p1")
assert existing.exclude_plugins == ["p1"]

View File

@@ -0,0 +1,75 @@
from unittest.mock import MagicMock, patch
from models.account import TenantPluginPermission
MODULE = "services.plugin.plugin_permission_service"
def _patched_session():
"""Patch Session(db.engine) to return a mock session as context manager."""
session = MagicMock()
session_cls = MagicMock()
session_cls.return_value.__enter__ = MagicMock(return_value=session)
session_cls.return_value.__exit__ = MagicMock(return_value=False)
patcher = patch(f"{MODULE}.Session", session_cls)
db_patcher = patch(f"{MODULE}.db")
return patcher, db_patcher, session
class TestGetPermission:
def test_returns_permission_when_found(self):
p1, p2, session = _patched_session()
permission = MagicMock()
session.query.return_value.where.return_value.first.return_value = permission
with p1, p2:
from services.plugin.plugin_permission_service import PluginPermissionService
result = PluginPermissionService.get_permission("t1")
assert result is permission
def test_returns_none_when_not_found(self):
p1, p2, session = _patched_session()
session.query.return_value.where.return_value.first.return_value = None
with p1, p2:
from services.plugin.plugin_permission_service import PluginPermissionService
result = PluginPermissionService.get_permission("t1")
assert result is None
class TestChangePermission:
def test_creates_new_permission_when_not_exists(self):
p1, p2, session = _patched_session()
session.query.return_value.where.return_value.first.return_value = None
with p1, p2, patch(f"{MODULE}.TenantPluginPermission") as perm_cls:
perm_cls.return_value = MagicMock()
from services.plugin.plugin_permission_service import PluginPermissionService
result = PluginPermissionService.change_permission(
"t1", TenantPluginPermission.InstallPermission.EVERYONE, TenantPluginPermission.DebugPermission.EVERYONE
)
session.add.assert_called_once()
session.commit.assert_called_once()
def test_updates_existing_permission(self):
p1, p2, session = _patched_session()
existing = MagicMock()
session.query.return_value.where.return_value.first.return_value = existing
with p1, p2:
from services.plugin.plugin_permission_service import PluginPermissionService
result = PluginPermissionService.change_permission(
"t1", TenantPluginPermission.InstallPermission.ADMINS, TenantPluginPermission.DebugPermission.ADMINS
)
assert existing.install_permission == TenantPluginPermission.InstallPermission.ADMINS
assert existing.debug_permission == TenantPluginPermission.DebugPermission.ADMINS
session.commit.assert_called_once()
session.add.assert_not_called()

View File

@@ -0,0 +1,135 @@
import datetime
from unittest.mock import MagicMock, patch
from services.retention.conversation.messages_clean_policy import (
BillingDisabledPolicy,
BillingSandboxPolicy,
SimpleMessage,
create_message_clean_policy,
)
MODULE = "services.retention.conversation.messages_clean_policy"
def _msg(msg_id: str, app_id: str, days_ago: int = 0) -> SimpleMessage:
return SimpleMessage(
id=msg_id,
app_id=app_id,
created_at=datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=days_ago),
)
class TestBillingDisabledPolicy:
def test_returns_all_message_ids(self):
policy = BillingDisabledPolicy()
msgs = [_msg("m1", "app1"), _msg("m2", "app2"), _msg("m3", "app1")]
result = policy.filter_message_ids(msgs, {"app1": "t1", "app2": "t2"})
assert set(result) == {"m1", "m2", "m3"}
def test_empty_messages_returns_empty(self):
assert BillingDisabledPolicy().filter_message_ids([], {}) == []
class TestBillingSandboxPolicy:
def _policy(self, plans, *, graceful_days=21, whitelist=None, now=1_000_000_000):
return BillingSandboxPolicy(
plan_provider=lambda _ids: plans,
graceful_period_days=graceful_days,
tenant_whitelist=whitelist,
current_timestamp=now,
)
def test_empty_messages_returns_empty(self):
policy = self._policy({})
assert policy.filter_message_ids([], {"app1": "t1"}) == []
def test_empty_app_to_tenant_returns_empty(self):
policy = self._policy({})
assert policy.filter_message_ids([_msg("m1", "app1")], {}) == []
def test_empty_plans_returns_empty(self):
policy = self._policy({})
msgs = [_msg("m1", "app1")]
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == []
def test_non_sandbox_tenant_skipped(self):
plans = {"t1": {"plan": "professional", "expiration_date": 0}}
policy = self._policy(plans)
msgs = [_msg("m1", "app1")]
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == []
def test_sandbox_no_previous_subscription_deletes(self):
plans = {"t1": {"plan": "sandbox", "expiration_date": -1}}
policy = self._policy(plans)
msgs = [_msg("m1", "app1")]
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == ["m1"]
def test_sandbox_expired_beyond_grace_period_deletes(self):
now = 1_000_000_000
expired_long_ago = now - (22 * 24 * 60 * 60) # 22 days ago > 21 day grace
plans = {"t1": {"plan": "sandbox", "expiration_date": expired_long_ago}}
policy = self._policy(plans, now=now)
msgs = [_msg("m1", "app1")]
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == ["m1"]
def test_sandbox_within_grace_period_kept(self):
now = 1_000_000_000
expired_recently = now - (10 * 24 * 60 * 60) # 10 days ago < 21 day grace
plans = {"t1": {"plan": "sandbox", "expiration_date": expired_recently}}
policy = self._policy(plans, now=now)
msgs = [_msg("m1", "app1")]
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == []
def test_whitelisted_tenant_skipped(self):
plans = {"t1": {"plan": "sandbox", "expiration_date": -1}}
policy = self._policy(plans, whitelist=["t1"])
msgs = [_msg("m1", "app1")]
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == []
def test_message_without_tenant_mapping_skipped(self):
plans = {"t1": {"plan": "sandbox", "expiration_date": -1}}
policy = self._policy(plans)
msgs = [_msg("m1", "unmapped_app")]
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == []
def test_mixed_tenants_only_sandbox_deleted(self):
plans = {
"t_sandbox": {"plan": "sandbox", "expiration_date": -1},
"t_pro": {"plan": "professional", "expiration_date": 0},
}
policy = self._policy(plans)
msgs = [_msg("m1", "app_sandbox"), _msg("m2", "app_pro")]
app_map = {"app_sandbox": "t_sandbox", "app_pro": "t_pro"}
result = policy.filter_message_ids(msgs, app_map)
assert result == ["m1"]
class TestCreateMessageCleanPolicy:
def test_billing_disabled_returns_disabled_policy(self):
with patch(f"{MODULE}.dify_config") as cfg:
cfg.BILLING_ENABLED = False
policy = create_message_clean_policy()
assert isinstance(policy, BillingDisabledPolicy)
def test_billing_enabled_returns_sandbox_policy(self):
with (
patch(f"{MODULE}.dify_config") as cfg,
patch(f"{MODULE}.BillingService") as bs,
):
cfg.BILLING_ENABLED = True
bs.get_expired_subscription_cleanup_whitelist.return_value = ["wl1"]
bs.get_plan_bulk_with_cache = MagicMock()
policy = create_message_clean_policy(graceful_period_days=30)
assert isinstance(policy, BillingSandboxPolicy)

View File

@@ -0,0 +1,598 @@
from unittest.mock import MagicMock, Mock, patch
from core.tools.__base.tool import Tool
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolParameter, ToolProviderType
from services.tools.tools_transform_service import ToolTransformService
MODULE = "services.tools.tools_transform_service"
class TestToolTransformService:
"""Test cases for ToolTransformService.convert_tool_entity_to_api_entity method"""
def test_convert_tool_with_parameter_override(self):
"""Test that runtime parameters correctly override base parameters"""
# Create mock base parameters
base_param1 = Mock(spec=ToolParameter)
base_param1.name = "param1"
base_param1.form = ToolParameter.ToolParameterForm.FORM
base_param1.type = "string"
base_param1.label = "Base Param 1"
base_param2 = Mock(spec=ToolParameter)
base_param2.name = "param2"
base_param2.form = ToolParameter.ToolParameterForm.FORM
base_param2.type = "string"
base_param2.label = "Base Param 2"
# Create mock runtime parameters that override base parameters
runtime_param1 = Mock(spec=ToolParameter)
runtime_param1.name = "param1"
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
runtime_param1.type = "string"
runtime_param1.label = "Runtime Param 1" # Different label to verify override
# Create mock tool
mock_tool = Mock(spec=Tool)
mock_tool.entity = Mock()
mock_tool.entity.parameters = [base_param1, base_param2]
mock_tool.entity.identity = Mock()
mock_tool.entity.identity.author = "test_author"
mock_tool.entity.identity.name = "test_tool"
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
mock_tool.entity.description = Mock()
mock_tool.entity.description.human = I18nObject(en_US="Test description")
mock_tool.entity.output_schema = {}
mock_tool.get_runtime_parameters.return_value = [runtime_param1]
# Mock fork_tool_runtime to return the same tool
mock_tool.fork_tool_runtime.return_value = mock_tool
# Call the method
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
# Verify the result
assert isinstance(result, ToolApiEntity)
assert result.author == "test_author"
assert result.name == "test_tool"
assert result.parameters is not None
assert len(result.parameters) == 2
# Find the overridden parameter
overridden_param = next((p for p in result.parameters if p.name == "param1"), None)
assert overridden_param is not None
assert overridden_param.label == "Runtime Param 1" # Should be runtime version
# Find the non-overridden parameter
original_param = next((p for p in result.parameters if p.name == "param2"), None)
assert original_param is not None
assert original_param.label == "Base Param 2" # Should be base version
def test_convert_tool_with_additional_runtime_parameters(self):
"""Test that additional runtime parameters are added to the final list"""
# Create mock base parameters
base_param1 = Mock(spec=ToolParameter)
base_param1.name = "param1"
base_param1.form = ToolParameter.ToolParameterForm.FORM
base_param1.type = "string"
base_param1.label = "Base Param 1"
# Create mock runtime parameters - one that overrides and one that's new
runtime_param1 = Mock(spec=ToolParameter)
runtime_param1.name = "param1"
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
runtime_param1.type = "string"
runtime_param1.label = "Runtime Param 1"
runtime_param2 = Mock(spec=ToolParameter)
runtime_param2.name = "runtime_only"
runtime_param2.form = ToolParameter.ToolParameterForm.FORM
runtime_param2.type = "string"
runtime_param2.label = "Runtime Only Param"
# Create mock tool
mock_tool = Mock(spec=Tool)
mock_tool.entity = Mock()
mock_tool.entity.parameters = [base_param1]
mock_tool.entity.identity = Mock()
mock_tool.entity.identity.author = "test_author"
mock_tool.entity.identity.name = "test_tool"
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
mock_tool.entity.description = Mock()
mock_tool.entity.description.human = I18nObject(en_US="Test description")
mock_tool.entity.output_schema = {}
mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2]
# Mock fork_tool_runtime to return the same tool
mock_tool.fork_tool_runtime.return_value = mock_tool
# Call the method
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
# Verify the result
assert isinstance(result, ToolApiEntity)
assert result.parameters is not None
assert len(result.parameters) == 2
# Check that both parameters are present
param_names = [p.name for p in result.parameters]
assert "param1" in param_names
assert "runtime_only" in param_names
# Verify the overridden parameter has runtime version
overridden_param = next((p for p in result.parameters if p.name == "param1"), None)
assert overridden_param is not None
assert overridden_param.label == "Runtime Param 1"
# Verify the new runtime parameter is included
new_param = next((p for p in result.parameters if p.name == "runtime_only"), None)
assert new_param is not None
assert new_param.label == "Runtime Only Param"
def test_convert_tool_with_non_form_runtime_parameters(self):
"""Test that non-FORM runtime parameters are not added as new parameters"""
# Create mock base parameters
base_param1 = Mock(spec=ToolParameter)
base_param1.name = "param1"
base_param1.form = ToolParameter.ToolParameterForm.FORM
base_param1.type = "string"
base_param1.label = "Base Param 1"
# Create mock runtime parameters with different forms
runtime_param1 = Mock(spec=ToolParameter)
runtime_param1.name = "param1"
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
runtime_param1.type = "string"
runtime_param1.label = "Runtime Param 1"
runtime_param2 = Mock(spec=ToolParameter)
runtime_param2.name = "llm_param"
runtime_param2.form = ToolParameter.ToolParameterForm.LLM
runtime_param2.type = "string"
runtime_param2.label = "LLM Param"
# Create mock tool
mock_tool = Mock(spec=Tool)
mock_tool.entity = Mock()
mock_tool.entity.parameters = [base_param1]
mock_tool.entity.identity = Mock()
mock_tool.entity.identity.author = "test_author"
mock_tool.entity.identity.name = "test_tool"
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
mock_tool.entity.description = Mock()
mock_tool.entity.description.human = I18nObject(en_US="Test description")
mock_tool.entity.output_schema = {}
mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2]
# Mock fork_tool_runtime to return the same tool
mock_tool.fork_tool_runtime.return_value = mock_tool
# Call the method
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
# Verify the result
assert isinstance(result, ToolApiEntity)
assert result.parameters is not None
assert len(result.parameters) == 1 # Only the FORM parameter should be present
# Check that only the FORM parameter is present
param_names = [p.name for p in result.parameters]
assert "param1" in param_names
assert "llm_param" not in param_names
def test_convert_tool_with_empty_parameters(self):
"""Test conversion with empty base and runtime parameters"""
# Create mock tool with no parameters
mock_tool = Mock(spec=Tool)
mock_tool.entity = Mock()
mock_tool.entity.parameters = []
mock_tool.entity.identity = Mock()
mock_tool.entity.identity.author = "test_author"
mock_tool.entity.identity.name = "test_tool"
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
mock_tool.entity.description = Mock()
mock_tool.entity.description.human = I18nObject(en_US="Test description")
mock_tool.entity.output_schema = {}
mock_tool.get_runtime_parameters.return_value = []
# Mock fork_tool_runtime to return the same tool
mock_tool.fork_tool_runtime.return_value = mock_tool
# Call the method
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
# Verify the result
assert isinstance(result, ToolApiEntity)
assert result.parameters is not None
assert len(result.parameters) == 0
def test_convert_tool_with_none_parameters(self):
"""Test conversion when base parameters is None"""
# Create mock tool with None parameters
mock_tool = Mock(spec=Tool)
mock_tool.entity = Mock()
mock_tool.entity.parameters = None
mock_tool.entity.identity = Mock()
mock_tool.entity.identity.author = "test_author"
mock_tool.entity.identity.name = "test_tool"
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
mock_tool.entity.description = Mock()
mock_tool.entity.description.human = I18nObject(en_US="Test description")
mock_tool.entity.output_schema = {}
mock_tool.get_runtime_parameters.return_value = []
# Mock fork_tool_runtime to return the same tool
mock_tool.fork_tool_runtime.return_value = mock_tool
# Call the method
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
# Verify the result
assert isinstance(result, ToolApiEntity)
assert result.parameters is not None
assert len(result.parameters) == 0
def test_convert_tool_parameter_order_preserved(self):
"""Test that parameter order is preserved correctly"""
# Create mock base parameters in specific order
base_param1 = Mock(spec=ToolParameter)
base_param1.name = "param1"
base_param1.form = ToolParameter.ToolParameterForm.FORM
base_param1.type = "string"
base_param1.label = "Base Param 1"
base_param2 = Mock(spec=ToolParameter)
base_param2.name = "param2"
base_param2.form = ToolParameter.ToolParameterForm.FORM
base_param2.type = "string"
base_param2.label = "Base Param 2"
base_param3 = Mock(spec=ToolParameter)
base_param3.name = "param3"
base_param3.form = ToolParameter.ToolParameterForm.FORM
base_param3.type = "string"
base_param3.label = "Base Param 3"
# Create runtime parameter that overrides middle parameter
runtime_param2 = Mock(spec=ToolParameter)
runtime_param2.name = "param2"
runtime_param2.form = ToolParameter.ToolParameterForm.FORM
runtime_param2.type = "string"
runtime_param2.label = "Runtime Param 2"
# Create new runtime parameter
runtime_param4 = Mock(spec=ToolParameter)
runtime_param4.name = "param4"
runtime_param4.form = ToolParameter.ToolParameterForm.FORM
runtime_param4.type = "string"
runtime_param4.label = "Runtime Param 4"
# Create mock tool
mock_tool = Mock(spec=Tool)
mock_tool.entity = Mock()
mock_tool.entity.parameters = [base_param1, base_param2, base_param3]
mock_tool.entity.identity = Mock()
mock_tool.entity.identity.author = "test_author"
mock_tool.entity.identity.name = "test_tool"
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
mock_tool.entity.description = Mock()
mock_tool.entity.description.human = I18nObject(en_US="Test description")
mock_tool.entity.output_schema = {}
mock_tool.get_runtime_parameters.return_value = [runtime_param2, runtime_param4]
# Mock fork_tool_runtime to return the same tool
mock_tool.fork_tool_runtime.return_value = mock_tool
# Call the method
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
# Verify the result
assert isinstance(result, ToolApiEntity)
assert result.parameters is not None
assert len(result.parameters) == 4
# Check that order is maintained: base parameters first, then new runtime parameters
param_names = [p.name for p in result.parameters]
assert param_names == ["param1", "param2", "param3", "param4"]
# Verify that param2 was overridden with runtime version
param2 = result.parameters[1]
assert param2.name == "param2"
assert param2.label == "Runtime Param 2"
class TestWorkflowProviderToUserProvider:
"""Test cases for ToolTransformService.workflow_provider_to_user_provider method"""
def test_workflow_provider_to_user_provider_with_workflow_app_id(self):
"""Test that workflow_provider_to_user_provider correctly sets workflow_app_id."""
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
# Create mock workflow tool provider controller
workflow_app_id = "app_123"
provider_id = "provider_123"
mock_controller = Mock(spec=WorkflowToolProviderController)
mock_controller.provider_id = provider_id
mock_controller.entity = Mock()
mock_controller.entity.identity = Mock()
mock_controller.entity.identity.author = "test_author"
mock_controller.entity.identity.name = "test_workflow_tool"
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
mock_controller.entity.identity.icon_dark = None
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
# Call the method
result = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=mock_controller,
labels=["label1", "label2"],
workflow_app_id=workflow_app_id,
)
# Verify the result
assert isinstance(result, ToolProviderApiEntity)
assert result.id == provider_id
assert result.author == "test_author"
assert result.name == "test_workflow_tool"
assert result.type == ToolProviderType.WORKFLOW
assert result.workflow_app_id == workflow_app_id
assert result.labels == ["label1", "label2"]
assert result.is_team_authorization is True
assert result.plugin_id is None
assert result.plugin_unique_identifier is None
assert result.tools == []
def test_workflow_provider_to_user_provider_without_workflow_app_id(self):
"""Test that workflow_provider_to_user_provider works when workflow_app_id is not provided."""
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
# Create mock workflow tool provider controller
provider_id = "provider_123"
mock_controller = Mock(spec=WorkflowToolProviderController)
mock_controller.provider_id = provider_id
mock_controller.entity = Mock()
mock_controller.entity.identity = Mock()
mock_controller.entity.identity.author = "test_author"
mock_controller.entity.identity.name = "test_workflow_tool"
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
mock_controller.entity.identity.icon_dark = None
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
# Call the method without workflow_app_id
result = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=mock_controller,
labels=["label1"],
)
# Verify the result
assert isinstance(result, ToolProviderApiEntity)
assert result.id == provider_id
assert result.workflow_app_id is None
assert result.labels == ["label1"]
def test_workflow_provider_to_user_provider_workflow_app_id_none(self):
"""Test that workflow_provider_to_user_provider handles None workflow_app_id explicitly."""
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
# Create mock workflow tool provider controller
provider_id = "provider_123"
mock_controller = Mock(spec=WorkflowToolProviderController)
mock_controller.provider_id = provider_id
mock_controller.entity = Mock()
mock_controller.entity.identity = Mock()
mock_controller.entity.identity.author = "test_author"
mock_controller.entity.identity.name = "test_workflow_tool"
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
mock_controller.entity.identity.icon_dark = None
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
# Call the method with explicit None values
result = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=mock_controller,
labels=None,
workflow_app_id=None,
)
# Verify the result
assert isinstance(result, ToolProviderApiEntity)
assert result.id == provider_id
assert result.workflow_app_id is None
assert result.labels == []
def test_workflow_provider_to_user_provider_preserves_other_fields(self):
"""Test that workflow_provider_to_user_provider preserves all other entity fields."""
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
# Create mock workflow tool provider controller with various fields
workflow_app_id = "app_456"
provider_id = "provider_456"
mock_controller = Mock(spec=WorkflowToolProviderController)
mock_controller.provider_id = provider_id
mock_controller.entity = Mock()
mock_controller.entity.identity = Mock()
mock_controller.entity.identity.author = "another_author"
mock_controller.entity.identity.name = "another_workflow_tool"
mock_controller.entity.identity.description = I18nObject(
en_US="Another description", zh_Hans="Another description"
)
mock_controller.entity.identity.icon = {"type": "emoji", "content": "⚙️"}
mock_controller.entity.identity.icon_dark = {"type": "emoji", "content": "🔧"}
mock_controller.entity.identity.label = I18nObject(
en_US="Another Workflow Tool", zh_Hans="Another Workflow Tool"
)
# Call the method
result = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=mock_controller,
labels=["automation", "workflow"],
workflow_app_id=workflow_app_id,
)
# Verify all fields are preserved correctly
assert isinstance(result, ToolProviderApiEntity)
assert result.id == provider_id
assert result.author == "another_author"
assert result.name == "another_workflow_tool"
assert result.description.en_US == "Another description"
assert result.description.zh_Hans == "Another description"
assert result.icon == {"type": "emoji", "content": "⚙️"}
assert result.icon_dark == {"type": "emoji", "content": "🔧"}
assert result.label.en_US == "Another Workflow Tool"
assert result.label.zh_Hans == "Another Workflow Tool"
assert result.type == ToolProviderType.WORKFLOW
assert result.workflow_app_id == workflow_app_id
assert result.labels == ["automation", "workflow"]
assert result.masked_credentials == {}
assert result.is_team_authorization is True
assert result.allow_delete is True
assert result.plugin_id is None
assert result.plugin_unique_identifier is None
assert result.tools == []
class TestGetToolProviderIconUrl:
def test_builtin_provider_returns_console_url(self):
with patch(f"{MODULE}.dify_config") as cfg:
cfg.CONSOLE_API_URL = "https://app.dify.ai"
url = ToolTransformService.get_tool_provider_icon_url("builtin", "google", "icon.png")
assert "/builtin/google/icon" in url
assert url.startswith("https://app.dify.ai/console/api/workspaces/current/tool-provider")
def test_builtin_provider_with_no_console_url(self):
with patch(f"{MODULE}.dify_config") as cfg:
cfg.CONSOLE_API_URL = None
url = ToolTransformService.get_tool_provider_icon_url("builtin", "slack", "icon.png")
assert "/builtin/slack/icon" in url
def test_api_provider_parses_json_icon(self):
icon_json = '{"background": "#fff", "content": "A"}'
result = ToolTransformService.get_tool_provider_icon_url("api", "my-api", icon_json)
assert result == {"background": "#fff", "content": "A"}
def test_api_provider_returns_dict_icon_directly(self):
icon = {"background": "#000", "content": "B"}
result = ToolTransformService.get_tool_provider_icon_url("api", "my-api", icon)
assert result == icon
def test_api_provider_returns_fallback_on_invalid_json(self):
result = ToolTransformService.get_tool_provider_icon_url("api", "my-api", "not-json")
assert result == {"background": "#252525", "content": "\ud83d\ude01"}
def test_workflow_provider_behaves_like_api(self):
icon = {"background": "#123", "content": "W"}
assert ToolTransformService.get_tool_provider_icon_url("workflow", "wf", icon) == icon
def test_mcp_returns_icon_as_is(self):
assert ToolTransformService.get_tool_provider_icon_url("mcp", "srv", "icon-value") == "icon-value"
def test_unknown_type_returns_empty(self):
assert ToolTransformService.get_tool_provider_icon_url("unknown", "x", "i") == ""
class TestRepackProvider:
def test_repacks_dict_provider_icon(self):
provider = {"type": "builtin", "name": "google", "icon": "old"}
with patch.object(ToolTransformService, "get_tool_provider_icon_url", return_value="/new-url") as mock_fn:
ToolTransformService.repack_provider("t1", provider)
assert provider["icon"] == "/new-url"
mock_fn.assert_called_once_with(provider_type="builtin", provider_name="google", icon="old")
def test_repacks_tool_provider_api_entity_without_plugin(self):
entity = MagicMock(spec=ToolProviderApiEntity)
entity.plugin_id = None
entity.type = ToolProviderType.BUILT_IN
entity.name = "slack"
entity.icon = "icon.svg"
entity.icon_dark = "dark.svg"
with patch.object(ToolTransformService, "get_tool_provider_icon_url", return_value="/url"):
ToolTransformService.repack_provider("t1", entity)
assert entity.icon == "/url"
assert entity.icon_dark == "/url"
class TestConvertMcpSchemaToParameter:
def test_simple_object_schema(self):
schema = {
"type": "object",
"properties": {
"query": {"type": "string", "description": "Search query"},
"count": {"type": "integer", "description": "Result count"},
},
"required": ["query"],
}
params = ToolTransformService.convert_mcp_schema_to_parameter(schema)
assert len(params) == 2
query_param = next(p for p in params if p.name == "query")
count_param = next(p for p in params if p.name == "count")
assert query_param.required is True
assert count_param.required is False
assert count_param.type.value == "number"
def test_float_maps_to_number(self):
schema = {"type": "object", "properties": {"rate": {"type": "float"}}, "required": []}
assert ToolTransformService.convert_mcp_schema_to_parameter(schema)[0].type.value == "number"
def test_array_type_attaches_input_schema(self):
prop = {"type": "array", "description": "Items", "items": {"type": "string"}}
schema = {"type": "object", "properties": {"items": prop}, "required": []}
param = ToolTransformService.convert_mcp_schema_to_parameter(schema)[0]
assert param.input_schema is not None
def test_non_object_schema_returns_empty(self):
assert ToolTransformService.convert_mcp_schema_to_parameter({"type": "string"}) == []
def test_missing_properties_returns_empty(self):
assert ToolTransformService.convert_mcp_schema_to_parameter({"type": "object"}) == []
def test_list_type_uses_first_element(self):
schema = {"type": "object", "properties": {"f": {"type": ["string", "null"]}}, "required": []}
assert ToolTransformService.convert_mcp_schema_to_parameter(schema)[0].type.value == "string"
def test_missing_description_defaults_empty(self):
schema = {"type": "object", "properties": {"f": {"type": "string"}}, "required": []}
assert ToolTransformService.convert_mcp_schema_to_parameter(schema)[0].llm_description == ""
class TestApiProviderToController:
def test_api_key_header_auth(self):
db_provider = MagicMock()
db_provider.credentials = {"auth_type": "api_key_header"}
with patch(f"{MODULE}.ApiToolProviderController") as ctrl_cls:
ctrl_cls.from_db.return_value = MagicMock()
ToolTransformService.api_provider_to_controller(db_provider)
ctrl_cls.from_db.assert_called_once_with(db_provider=db_provider, auth_type=ApiProviderAuthType.API_KEY_HEADER)
def test_api_key_query_auth(self):
db_provider = MagicMock()
db_provider.credentials = {"auth_type": "api_key_query"}
with patch(f"{MODULE}.ApiToolProviderController") as ctrl_cls:
ctrl_cls.from_db.return_value = MagicMock()
ToolTransformService.api_provider_to_controller(db_provider)
ctrl_cls.from_db.assert_called_once_with(db_provider=db_provider, auth_type=ApiProviderAuthType.API_KEY_QUERY)
def test_legacy_api_key_maps_to_header(self):
db_provider = MagicMock()
db_provider.credentials = {"auth_type": "api_key"}
with patch(f"{MODULE}.ApiToolProviderController") as ctrl_cls:
ctrl_cls.from_db.return_value = MagicMock()
ToolTransformService.api_provider_to_controller(db_provider)
ctrl_cls.from_db.assert_called_once_with(db_provider=db_provider, auth_type=ApiProviderAuthType.API_KEY_HEADER)
def test_unknown_auth_defaults_to_none(self):
db_provider = MagicMock()
db_provider.credentials = {"auth_type": "something_else"}
with patch(f"{MODULE}.ApiToolProviderController") as ctrl_cls:
ctrl_cls.from_db.return_value = MagicMock()
ToolTransformService.api_provider_to_controller(db_provider)
ctrl_cls.from_db.assert_called_once_with(db_provider=db_provider, auth_type=ApiProviderAuthType.NONE)

6
pnpm-lock.yaml generated
View File

@@ -249,9 +249,6 @@ catalogs:
autoprefixer:
specifier: 10.4.27
version: 10.4.27
axios:
specifier: ^1.14.0
version: 1.14.0
class-variance-authority:
specifier: 0.7.1
version: 0.7.1
@@ -573,6 +570,7 @@ overrides:
array.prototype.flatmap: npm:@nolyfill/array.prototype.flatmap@^1.0.44
array.prototype.tosorted: npm:@nolyfill/array.prototype.tosorted@^1.0.44
assert: npm:@nolyfill/assert@^1.0.26
axios: 1.14.0
brace-expansion@<2.0.2: 2.0.2
canvas: ^3.2.2
devalue@<5.3.2: 5.3.2
@@ -652,7 +650,7 @@ importers:
sdks/nodejs-client:
dependencies:
axios:
specifier: 'catalog:'
specifier: 1.14.0
version: 1.14.0
devDependencies:
'@eslint/js':

View File

@@ -1,3 +1,12 @@
trustPolicy: no-downgrade
minimumReleaseAge: 1440
blockExoticSubdeps: true
strictDepBuilds: true
allowBuilds:
'@parcel/watcher': false
canvas: false
esbuild: false
sharp: false
packages:
- web
- e2e
@@ -13,6 +22,7 @@ overrides:
array.prototype.flatmap: npm:@nolyfill/array.prototype.flatmap@^1.0.44
array.prototype.tosorted: npm:@nolyfill/array.prototype.tosorted@^1.0.44
assert: npm:@nolyfill/assert@^1.0.26
axios: 1.14.0
brace-expansion@<2.0.2: 2.0.2
canvas: ^3.2.2
devalue@<5.3.2: 5.3.2
@@ -59,13 +69,6 @@ overrides:
which-typed-array: npm:@nolyfill/which-typed-array@^1.0.44
yaml@>=2.0.0 <2.8.3: 2.8.3
yauzl@<3.2.1: 3.2.1
ignoredBuiltDependencies:
- canvas
- core-js-pure
onlyBuiltDependencies:
- "@parcel/watcher"
- esbuild
- sharp
catalog:
"@amplitude/analytics-browser": 2.38.0
"@amplitude/plugin-session-replay-browser": 1.27.5
@@ -149,7 +152,7 @@ catalog:
agentation: 3.0.2
ahooks: 3.9.7
autoprefixer: 10.4.27
axios: ^1.14.0
axios: 1.14.0
class-variance-authority: 0.7.1
clsx: 2.1.1
cmdk: 1.1.1