mirror of
https://github.com/langgenius/dify.git
synced 2026-03-31 18:00:55 -04:00
Merge branch 'main' into 3-31-vite-task-cache
This commit is contained in:
@@ -0,0 +1,182 @@
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from models import Tenant
|
||||
from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant(flask_req_ctx):
|
||||
with session_factory.create_session() as session:
|
||||
t = Tenant(name="plugin_it_tenant")
|
||||
session.add(t)
|
||||
session.commit()
|
||||
tenant_id = t.id
|
||||
|
||||
yield tenant_id
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
session.execute(delete(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id))
|
||||
session.execute(
|
||||
delete(TenantPluginAutoUpgradeStrategy).where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
)
|
||||
session.execute(delete(Tenant).where(Tenant.id == tenant_id))
|
||||
session.commit()
|
||||
|
||||
|
||||
class TestPluginPermissionLifecycle:
|
||||
def test_get_returns_none_for_new_tenant(self, tenant):
|
||||
assert PluginPermissionService.get_permission(tenant) is None
|
||||
|
||||
def test_change_creates_row(self, tenant):
|
||||
result = PluginPermissionService.change_permission(
|
||||
tenant,
|
||||
TenantPluginPermission.InstallPermission.ADMINS,
|
||||
TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
perm = PluginPermissionService.get_permission(tenant)
|
||||
assert perm is not None
|
||||
assert perm.install_permission == TenantPluginPermission.InstallPermission.ADMINS
|
||||
assert perm.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE
|
||||
|
||||
def test_change_updates_existing_row(self, tenant):
|
||||
PluginPermissionService.change_permission(
|
||||
tenant,
|
||||
TenantPluginPermission.InstallPermission.ADMINS,
|
||||
TenantPluginPermission.DebugPermission.NOBODY,
|
||||
)
|
||||
PluginPermissionService.change_permission(
|
||||
tenant,
|
||||
TenantPluginPermission.InstallPermission.EVERYONE,
|
||||
TenantPluginPermission.DebugPermission.ADMINS,
|
||||
)
|
||||
perm = PluginPermissionService.get_permission(tenant)
|
||||
assert perm is not None
|
||||
assert perm.install_permission == TenantPluginPermission.InstallPermission.EVERYONE
|
||||
assert perm.debug_permission == TenantPluginPermission.DebugPermission.ADMINS
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
count = session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant).count()
|
||||
assert count == 1
|
||||
|
||||
|
||||
class TestPluginAutoUpgradeLifecycle:
|
||||
def test_get_returns_none_for_new_tenant(self, tenant):
|
||||
assert PluginAutoUpgradeService.get_strategy(tenant) is None
|
||||
|
||||
def test_change_creates_row(self, tenant):
|
||||
result = PluginAutoUpgradeService.change_strategy(
|
||||
tenant,
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
|
||||
upgrade_time_of_day=3,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
|
||||
exclude_plugins=[],
|
||||
include_plugins=[],
|
||||
)
|
||||
assert result is True
|
||||
|
||||
strategy = PluginAutoUpgradeService.get_strategy(tenant)
|
||||
assert strategy is not None
|
||||
assert strategy.strategy_setting == TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST
|
||||
assert strategy.upgrade_time_of_day == 3
|
||||
|
||||
def test_change_updates_existing_row(self, tenant):
|
||||
PluginAutoUpgradeService.change_strategy(
|
||||
tenant,
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
|
||||
upgrade_time_of_day=0,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
|
||||
exclude_plugins=[],
|
||||
include_plugins=[],
|
||||
)
|
||||
PluginAutoUpgradeService.change_strategy(
|
||||
tenant,
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
|
||||
upgrade_time_of_day=12,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL,
|
||||
exclude_plugins=[],
|
||||
include_plugins=["plugin-a"],
|
||||
)
|
||||
|
||||
strategy = PluginAutoUpgradeService.get_strategy(tenant)
|
||||
assert strategy is not None
|
||||
assert strategy.strategy_setting == TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST
|
||||
assert strategy.upgrade_time_of_day == 12
|
||||
assert strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL
|
||||
assert strategy.include_plugins == ["plugin-a"]
|
||||
|
||||
def test_exclude_plugin_creates_strategy_when_none_exists(self, tenant):
|
||||
PluginAutoUpgradeService.exclude_plugin(tenant, "my-plugin")
|
||||
|
||||
strategy = PluginAutoUpgradeService.get_strategy(tenant)
|
||||
assert strategy is not None
|
||||
assert strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE
|
||||
assert "my-plugin" in strategy.exclude_plugins
|
||||
|
||||
def test_exclude_plugin_appends_in_exclude_mode(self, tenant):
|
||||
PluginAutoUpgradeService.change_strategy(
|
||||
tenant,
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
|
||||
upgrade_time_of_day=0,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
|
||||
exclude_plugins=["existing"],
|
||||
include_plugins=[],
|
||||
)
|
||||
PluginAutoUpgradeService.exclude_plugin(tenant, "new-plugin")
|
||||
|
||||
strategy = PluginAutoUpgradeService.get_strategy(tenant)
|
||||
assert strategy is not None
|
||||
assert "existing" in strategy.exclude_plugins
|
||||
assert "new-plugin" in strategy.exclude_plugins
|
||||
|
||||
def test_exclude_plugin_dedup_in_exclude_mode(self, tenant):
|
||||
PluginAutoUpgradeService.change_strategy(
|
||||
tenant,
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
|
||||
upgrade_time_of_day=0,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
|
||||
exclude_plugins=["same-plugin"],
|
||||
include_plugins=[],
|
||||
)
|
||||
PluginAutoUpgradeService.exclude_plugin(tenant, "same-plugin")
|
||||
|
||||
strategy = PluginAutoUpgradeService.get_strategy(tenant)
|
||||
assert strategy is not None
|
||||
assert strategy.exclude_plugins.count("same-plugin") == 1
|
||||
|
||||
def test_exclude_from_partial_mode_removes_from_include(self, tenant):
|
||||
PluginAutoUpgradeService.change_strategy(
|
||||
tenant,
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
|
||||
upgrade_time_of_day=0,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL,
|
||||
exclude_plugins=[],
|
||||
include_plugins=["p1", "p2"],
|
||||
)
|
||||
PluginAutoUpgradeService.exclude_plugin(tenant, "p1")
|
||||
|
||||
strategy = PluginAutoUpgradeService.get_strategy(tenant)
|
||||
assert strategy is not None
|
||||
assert "p1" not in strategy.include_plugins
|
||||
assert "p2" in strategy.include_plugins
|
||||
|
||||
def test_exclude_from_all_mode_switches_to_exclude(self, tenant):
|
||||
PluginAutoUpgradeService.change_strategy(
|
||||
tenant,
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
|
||||
upgrade_time_of_day=0,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
|
||||
exclude_plugins=[],
|
||||
include_plugins=[],
|
||||
)
|
||||
PluginAutoUpgradeService.exclude_plugin(tenant, "excluded-plugin")
|
||||
|
||||
strategy = PluginAutoUpgradeService.get_strategy(tenant)
|
||||
assert strategy is not None
|
||||
assert strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE
|
||||
assert "excluded-plugin" in strategy.exclude_plugins
|
||||
@@ -0,0 +1,348 @@
|
||||
import datetime
|
||||
import math
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from models import Tenant
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.model import (
|
||||
App,
|
||||
Conversation,
|
||||
Message,
|
||||
MessageAnnotation,
|
||||
MessageFeedback,
|
||||
)
|
||||
from services.retention.conversation.messages_clean_policy import BillingDisabledPolicy
|
||||
from services.retention.conversation.messages_clean_service import MessagesCleanService
|
||||
|
||||
_NOW = datetime.datetime(2026, 1, 15, 12, 0, 0, tzinfo=datetime.UTC)
|
||||
_OLD = _NOW - datetime.timedelta(days=60)
|
||||
_VERY_OLD = _NOW - datetime.timedelta(days=90)
|
||||
_RECENT = _NOW - datetime.timedelta(days=5)
|
||||
|
||||
_WINDOW_START = _VERY_OLD - datetime.timedelta(hours=1)
|
||||
_WINDOW_END = _RECENT + datetime.timedelta(hours=1)
|
||||
|
||||
_DEFAULT_BATCH_SIZE = 100
|
||||
_PAGINATION_MESSAGE_COUNT = 25
|
||||
_PAGINATION_BATCH_SIZE = 8
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_and_app(flask_req_ctx):
|
||||
"""Creates a Tenant, App and Conversation for the test and cleans up after."""
|
||||
with session_factory.create_session() as session:
|
||||
tenant = Tenant(name="retention_it_tenant")
|
||||
session.add(tenant)
|
||||
session.flush()
|
||||
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name="Retention IT App",
|
||||
mode="chat",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
)
|
||||
session.add(app)
|
||||
session.flush()
|
||||
|
||||
conv = Conversation(
|
||||
app_id=app.id,
|
||||
mode="chat",
|
||||
name="test_conv",
|
||||
status="normal",
|
||||
from_source="console",
|
||||
_inputs={},
|
||||
)
|
||||
session.add(conv)
|
||||
session.commit()
|
||||
|
||||
tenant_id = tenant.id
|
||||
app_id = app.id
|
||||
conv_id = conv.id
|
||||
|
||||
yield {"tenant_id": tenant_id, "app_id": app_id, "conversation_id": conv_id}
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
session.execute(delete(Conversation).where(Conversation.id == conv_id))
|
||||
session.execute(delete(App).where(App.id == app_id))
|
||||
session.execute(delete(Tenant).where(Tenant.id == tenant_id))
|
||||
session.commit()
|
||||
|
||||
|
||||
def _make_message(app_id: str, conversation_id: str, created_at: datetime.datetime) -> Message:
|
||||
return Message(
|
||||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
query="test",
|
||||
message=[{"text": "hello"}],
|
||||
answer="world",
|
||||
message_tokens=1,
|
||||
message_unit_price=0,
|
||||
answer_tokens=1,
|
||||
answer_unit_price=0,
|
||||
from_source="console",
|
||||
currency="USD",
|
||||
_inputs={},
|
||||
created_at=created_at,
|
||||
)
|
||||
|
||||
|
||||
class TestMessagesCleanServiceIntegration:
|
||||
@pytest.fixture
|
||||
def seed_messages(self, tenant_and_app):
|
||||
"""Seeds one message at each of _VERY_OLD, _OLD, and _RECENT.
|
||||
Yields a semantic mapping keyed by age label.
|
||||
"""
|
||||
data = tenant_and_app
|
||||
app_id = data["app_id"]
|
||||
conv_id = data["conversation_id"]
|
||||
# Ordered tuple of (label, timestamp) for deterministic seeding
|
||||
timestamps = [
|
||||
("very_old", _VERY_OLD),
|
||||
("old", _OLD),
|
||||
("recent", _RECENT),
|
||||
]
|
||||
msg_ids: dict[str, str] = {}
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
for label, ts in timestamps:
|
||||
msg = _make_message(app_id, conv_id, ts)
|
||||
session.add(msg)
|
||||
session.flush()
|
||||
msg_ids[label] = msg.id
|
||||
session.commit()
|
||||
|
||||
yield {"msg_ids": msg_ids, **data}
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
session.execute(
|
||||
delete(Message)
|
||||
.where(Message.id.in_(list(msg_ids.values())))
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
@pytest.fixture
|
||||
def paginated_seed_messages(self, tenant_and_app):
|
||||
"""Seeds multiple messages separated by 1-second increments starting at _OLD."""
|
||||
data = tenant_and_app
|
||||
app_id = data["app_id"]
|
||||
conv_id = data["conversation_id"]
|
||||
msg_ids: list[str] = []
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
for i in range(_PAGINATION_MESSAGE_COUNT):
|
||||
ts = _OLD + datetime.timedelta(seconds=i)
|
||||
msg = _make_message(app_id, conv_id, ts)
|
||||
session.add(msg)
|
||||
session.flush()
|
||||
msg_ids.append(msg.id)
|
||||
session.commit()
|
||||
|
||||
yield {"msg_ids": msg_ids, **data}
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
session.execute(delete(Message).where(Message.id.in_(msg_ids)).execution_options(synchronize_session=False))
|
||||
session.commit()
|
||||
|
||||
@pytest.fixture
|
||||
def cascade_test_data(self, tenant_and_app):
|
||||
"""Seeds one Message with an associated Feedback and Annotation."""
|
||||
data = tenant_and_app
|
||||
app_id = data["app_id"]
|
||||
conv_id = data["conversation_id"]
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
msg = _make_message(app_id, conv_id, _OLD)
|
||||
session.add(msg)
|
||||
session.flush()
|
||||
|
||||
feedback = MessageFeedback(
|
||||
app_id=app_id,
|
||||
conversation_id=conv_id,
|
||||
message_id=msg.id,
|
||||
rating=FeedbackRating.LIKE,
|
||||
from_source=FeedbackFromSource.USER,
|
||||
)
|
||||
annotation = MessageAnnotation(
|
||||
app_id=app_id,
|
||||
conversation_id=conv_id,
|
||||
message_id=msg.id,
|
||||
question="q",
|
||||
content="a",
|
||||
account_id=str(uuid.uuid4()),
|
||||
)
|
||||
session.add_all([feedback, annotation])
|
||||
session.commit()
|
||||
|
||||
msg_id = msg.id
|
||||
fb_id = feedback.id
|
||||
ann_id = annotation.id
|
||||
|
||||
yield {"msg_id": msg_id, "fb_id": fb_id, "ann_id": ann_id, **data}
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
session.execute(delete(MessageAnnotation).where(MessageAnnotation.id == ann_id))
|
||||
session.execute(delete(MessageFeedback).where(MessageFeedback.id == fb_id))
|
||||
session.execute(delete(Message).where(Message.id == msg_id))
|
||||
session.commit()
|
||||
|
||||
def test_dry_run_does_not_delete(self, seed_messages):
|
||||
"""Dry-run must count eligible rows without deleting any of them."""
|
||||
data = seed_messages
|
||||
msg_ids = data["msg_ids"]
|
||||
all_ids = list(msg_ids.values())
|
||||
|
||||
svc = MessagesCleanService.from_time_range(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=_WINDOW_START,
|
||||
end_before=_WINDOW_END,
|
||||
batch_size=_DEFAULT_BATCH_SIZE,
|
||||
dry_run=True,
|
||||
)
|
||||
stats = svc.run()
|
||||
|
||||
assert stats["filtered_messages"] == len(all_ids)
|
||||
assert stats["total_deleted"] == 0
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
remaining = session.query(Message).where(Message.id.in_(all_ids)).count()
|
||||
assert remaining == len(all_ids)
|
||||
|
||||
def test_billing_disabled_deletes_all_in_range(self, seed_messages):
|
||||
"""All 3 seeded messages fall within the window and must be deleted."""
|
||||
data = seed_messages
|
||||
msg_ids = data["msg_ids"]
|
||||
all_ids = list(msg_ids.values())
|
||||
|
||||
svc = MessagesCleanService.from_time_range(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=_WINDOW_START,
|
||||
end_before=_WINDOW_END,
|
||||
batch_size=_DEFAULT_BATCH_SIZE,
|
||||
dry_run=False,
|
||||
)
|
||||
stats = svc.run()
|
||||
|
||||
assert stats["total_deleted"] == len(all_ids)
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
remaining = session.query(Message).where(Message.id.in_(all_ids)).count()
|
||||
assert remaining == 0
|
||||
|
||||
def test_start_from_filters_correctly(self, seed_messages):
|
||||
"""Only the message at _OLD falls within the narrow ±1 h window."""
|
||||
data = seed_messages
|
||||
msg_ids = data["msg_ids"]
|
||||
|
||||
start = _OLD - datetime.timedelta(hours=1)
|
||||
end = _OLD + datetime.timedelta(hours=1)
|
||||
|
||||
svc = MessagesCleanService.from_time_range(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=start,
|
||||
end_before=end,
|
||||
batch_size=_DEFAULT_BATCH_SIZE,
|
||||
)
|
||||
stats = svc.run()
|
||||
|
||||
assert stats["total_deleted"] == 1
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
all_ids = list(msg_ids.values())
|
||||
remaining_ids = {r[0] for r in session.query(Message.id).where(Message.id.in_(all_ids)).all()}
|
||||
|
||||
assert msg_ids["old"] not in remaining_ids
|
||||
assert msg_ids["very_old"] in remaining_ids
|
||||
assert msg_ids["recent"] in remaining_ids
|
||||
|
||||
def test_cursor_pagination_across_batches(self, paginated_seed_messages):
|
||||
"""Messages must be deleted across multiple batches."""
|
||||
data = paginated_seed_messages
|
||||
msg_ids = data["msg_ids"]
|
||||
|
||||
# _OLD is the earliest; the last one is _OLD + (_PAGINATION_MESSAGE_COUNT - 1) s.
|
||||
pagination_window_start = _OLD - datetime.timedelta(seconds=1)
|
||||
pagination_window_end = _OLD + datetime.timedelta(seconds=_PAGINATION_MESSAGE_COUNT)
|
||||
|
||||
svc = MessagesCleanService.from_time_range(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=pagination_window_start,
|
||||
end_before=pagination_window_end,
|
||||
batch_size=_PAGINATION_BATCH_SIZE,
|
||||
)
|
||||
stats = svc.run()
|
||||
|
||||
assert stats["total_deleted"] == _PAGINATION_MESSAGE_COUNT
|
||||
expected_batches = math.ceil(_PAGINATION_MESSAGE_COUNT / _PAGINATION_BATCH_SIZE)
|
||||
assert stats["batches"] >= expected_batches
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
remaining = session.query(Message).where(Message.id.in_(msg_ids)).count()
|
||||
assert remaining == 0
|
||||
|
||||
def test_no_messages_in_range_returns_empty_stats(self, seed_messages):
|
||||
"""A window entirely in the future must yield zero matches."""
|
||||
far_future = _NOW + datetime.timedelta(days=365)
|
||||
even_further = far_future + datetime.timedelta(days=1)
|
||||
|
||||
svc = MessagesCleanService.from_time_range(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=far_future,
|
||||
end_before=even_further,
|
||||
batch_size=_DEFAULT_BATCH_SIZE,
|
||||
)
|
||||
stats = svc.run()
|
||||
|
||||
assert stats["total_messages"] == 0
|
||||
assert stats["total_deleted"] == 0
|
||||
|
||||
def test_relation_cascade_deletes(self, cascade_test_data):
|
||||
"""Deleting a Message must cascade to its Feedback and Annotation rows."""
|
||||
data = cascade_test_data
|
||||
msg_id = data["msg_id"]
|
||||
fb_id = data["fb_id"]
|
||||
ann_id = data["ann_id"]
|
||||
|
||||
svc = MessagesCleanService.from_time_range(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=_OLD - datetime.timedelta(hours=1),
|
||||
end_before=_OLD + datetime.timedelta(hours=1),
|
||||
batch_size=_DEFAULT_BATCH_SIZE,
|
||||
)
|
||||
stats = svc.run()
|
||||
|
||||
assert stats["total_deleted"] == 1
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
assert session.query(Message).where(Message.id == msg_id).count() == 0
|
||||
assert session.query(MessageFeedback).where(MessageFeedback.id == fb_id).count() == 0
|
||||
assert session.query(MessageAnnotation).where(MessageAnnotation.id == ann_id).count() == 0
|
||||
|
||||
def test_factory_from_time_range_validation(self):
|
||||
with pytest.raises(ValueError, match="start_from"):
|
||||
MessagesCleanService.from_time_range(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=_NOW,
|
||||
end_before=_OLD,
|
||||
)
|
||||
|
||||
def test_factory_from_days_validation(self):
|
||||
with pytest.raises(ValueError, match="days"):
|
||||
MessagesCleanService.from_days(
|
||||
policy=BillingDisabledPolicy(),
|
||||
days=-1,
|
||||
)
|
||||
|
||||
def test_factory_batch_size_validation(self):
|
||||
with pytest.raises(ValueError, match="batch_size"):
|
||||
MessagesCleanService.from_time_range(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=_OLD,
|
||||
end_before=_NOW,
|
||||
batch_size=0,
|
||||
)
|
||||
@@ -0,0 +1,177 @@
|
||||
import datetime
|
||||
import io
|
||||
import json
|
||||
import uuid
|
||||
import zipfile
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from services.retention.workflow_run.archive_paid_plan_workflow_run import (
|
||||
ArchiveSummary,
|
||||
WorkflowRunArchiver,
|
||||
)
|
||||
from services.retention.workflow_run.constants import ARCHIVE_SCHEMA_VERSION
|
||||
|
||||
|
||||
class TestWorkflowRunArchiverInit:
|
||||
def test_start_from_without_end_before_raises(self):
|
||||
with pytest.raises(ValueError, match="start_from and end_before must be provided together"):
|
||||
WorkflowRunArchiver(start_from=datetime.datetime(2025, 1, 1))
|
||||
|
||||
def test_end_before_without_start_from_raises(self):
|
||||
with pytest.raises(ValueError, match="start_from and end_before must be provided together"):
|
||||
WorkflowRunArchiver(end_before=datetime.datetime(2025, 1, 1))
|
||||
|
||||
def test_start_equals_end_raises(self):
|
||||
ts = datetime.datetime(2025, 1, 1)
|
||||
with pytest.raises(ValueError, match="start_from must be earlier than end_before"):
|
||||
WorkflowRunArchiver(start_from=ts, end_before=ts)
|
||||
|
||||
def test_start_after_end_raises(self):
|
||||
with pytest.raises(ValueError, match="start_from must be earlier than end_before"):
|
||||
WorkflowRunArchiver(
|
||||
start_from=datetime.datetime(2025, 6, 1),
|
||||
end_before=datetime.datetime(2025, 1, 1),
|
||||
)
|
||||
|
||||
def test_workers_zero_raises(self):
|
||||
with pytest.raises(ValueError, match="workers must be at least 1"):
|
||||
WorkflowRunArchiver(workers=0)
|
||||
|
||||
def test_valid_init_defaults(self):
|
||||
archiver = WorkflowRunArchiver(days=30, batch_size=50)
|
||||
assert archiver.days == 30
|
||||
assert archiver.batch_size == 50
|
||||
assert archiver.dry_run is False
|
||||
assert archiver.delete_after_archive is False
|
||||
assert archiver.start_from is None
|
||||
|
||||
def test_valid_init_with_time_range(self):
|
||||
start = datetime.datetime(2025, 1, 1)
|
||||
end = datetime.datetime(2025, 6, 1)
|
||||
archiver = WorkflowRunArchiver(start_from=start, end_before=end, workers=2)
|
||||
assert archiver.start_from is not None
|
||||
assert archiver.end_before is not None
|
||||
assert archiver.workers == 2
|
||||
|
||||
|
||||
class TestBuildArchiveBundle:
|
||||
def test_bundle_contains_manifest_and_all_tables(self):
|
||||
archiver = WorkflowRunArchiver(days=90)
|
||||
|
||||
manifest_data = json.dumps({"schema_version": ARCHIVE_SCHEMA_VERSION}).encode("utf-8")
|
||||
table_payloads = dict.fromkeys(archiver.ARCHIVED_TABLES, b"")
|
||||
|
||||
bundle_bytes = archiver._build_archive_bundle(manifest_data, table_payloads)
|
||||
|
||||
with zipfile.ZipFile(io.BytesIO(bundle_bytes), "r") as zf:
|
||||
names = set(zf.namelist())
|
||||
assert "manifest.json" in names
|
||||
for table in archiver.ARCHIVED_TABLES:
|
||||
assert f"{table}.jsonl" in names, f"Missing {table}.jsonl in bundle"
|
||||
|
||||
def test_bundle_missing_table_payload_raises(self):
|
||||
archiver = WorkflowRunArchiver(days=90)
|
||||
manifest_data = b"{}"
|
||||
incomplete_payloads = {archiver.ARCHIVED_TABLES[0]: b"data"}
|
||||
|
||||
with pytest.raises(ValueError, match="Missing archive payload"):
|
||||
archiver._build_archive_bundle(manifest_data, incomplete_payloads)
|
||||
|
||||
|
||||
class TestGenerateManifest:
|
||||
def test_manifest_structure(self):
|
||||
archiver = WorkflowRunArchiver(days=90)
|
||||
from services.retention.workflow_run.archive_paid_plan_workflow_run import TableStats
|
||||
|
||||
run = MagicMock()
|
||||
run.id = str(uuid.uuid4())
|
||||
run.tenant_id = str(uuid.uuid4())
|
||||
run.app_id = str(uuid.uuid4())
|
||||
run.workflow_id = str(uuid.uuid4())
|
||||
run.created_at = datetime.datetime(2025, 3, 15, 10, 0, 0)
|
||||
|
||||
stats = [
|
||||
TableStats(table_name="workflow_runs", row_count=1, checksum="abc123", size_bytes=512),
|
||||
TableStats(table_name="workflow_app_logs", row_count=2, checksum="def456", size_bytes=1024),
|
||||
]
|
||||
|
||||
manifest = archiver._generate_manifest(run, stats)
|
||||
|
||||
assert manifest["schema_version"] == ARCHIVE_SCHEMA_VERSION
|
||||
assert manifest["workflow_run_id"] == run.id
|
||||
assert manifest["tenant_id"] == run.tenant_id
|
||||
assert manifest["app_id"] == run.app_id
|
||||
assert "tables" in manifest
|
||||
assert manifest["tables"]["workflow_runs"]["row_count"] == 1
|
||||
assert manifest["tables"]["workflow_runs"]["checksum"] == "abc123"
|
||||
assert manifest["tables"]["workflow_app_logs"]["row_count"] == 2
|
||||
|
||||
|
||||
class TestFilterPaidTenants:
|
||||
def test_all_tenants_paid_when_billing_disabled(self):
|
||||
archiver = WorkflowRunArchiver(days=90)
|
||||
tenant_ids = {"t1", "t2", "t3"}
|
||||
|
||||
with patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg:
|
||||
cfg.BILLING_ENABLED = False
|
||||
result = archiver._filter_paid_tenants(tenant_ids)
|
||||
|
||||
assert result == tenant_ids
|
||||
|
||||
def test_empty_tenants_returns_empty(self):
|
||||
archiver = WorkflowRunArchiver(days=90)
|
||||
|
||||
with patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg:
|
||||
cfg.BILLING_ENABLED = True
|
||||
result = archiver._filter_paid_tenants(set())
|
||||
|
||||
assert result == set()
|
||||
|
||||
def test_only_paid_plans_returned(self):
|
||||
archiver = WorkflowRunArchiver(days=90)
|
||||
|
||||
mock_bulk = {
|
||||
"t1": {"plan": "professional"},
|
||||
"t2": {"plan": "sandbox"},
|
||||
"t3": {"plan": "team"},
|
||||
}
|
||||
|
||||
with (
|
||||
patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg,
|
||||
patch("services.retention.workflow_run.archive_paid_plan_workflow_run.BillingService") as billing,
|
||||
):
|
||||
cfg.BILLING_ENABLED = True
|
||||
billing.get_plan_bulk_with_cache.return_value = mock_bulk
|
||||
result = archiver._filter_paid_tenants({"t1", "t2", "t3"})
|
||||
|
||||
assert "t1" in result
|
||||
assert "t3" in result
|
||||
assert "t2" not in result
|
||||
|
||||
def test_billing_api_failure_returns_empty(self):
|
||||
archiver = WorkflowRunArchiver(days=90)
|
||||
|
||||
with (
|
||||
patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg,
|
||||
patch("services.retention.workflow_run.archive_paid_plan_workflow_run.BillingService") as billing,
|
||||
):
|
||||
cfg.BILLING_ENABLED = True
|
||||
billing.get_plan_bulk_with_cache.side_effect = RuntimeError("API down")
|
||||
result = archiver._filter_paid_tenants({"t1"})
|
||||
|
||||
assert result == set()
|
||||
|
||||
|
||||
class TestDryRunArchive:
|
||||
@patch("services.retention.workflow_run.archive_paid_plan_workflow_run.get_archive_storage")
|
||||
def test_dry_run_does_not_call_storage(self, mock_get_storage, flask_req_ctx):
|
||||
archiver = WorkflowRunArchiver(days=90, dry_run=True)
|
||||
|
||||
with patch.object(archiver, "_get_runs_batch", return_value=[]):
|
||||
summary = archiver.run()
|
||||
|
||||
mock_get_storage.assert_not_called()
|
||||
assert isinstance(summary, ArchiveSummary)
|
||||
assert summary.runs_failed == 0
|
||||
@@ -1,6 +1,12 @@
|
||||
"""Testcontainers integration tests for rag_pipeline controller endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline import (
|
||||
@@ -9,6 +15,7 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline import (
|
||||
PipelineTemplateListApi,
|
||||
PublishCustomizedPipelineTemplateApi,
|
||||
)
|
||||
from models.dataset import PipelineCustomizedTemplate
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
@@ -18,6 +25,10 @@ def unwrap(func):
|
||||
|
||||
|
||||
class TestPipelineTemplateListApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app):
|
||||
api = PipelineTemplateListApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -38,6 +49,10 @@ class TestPipelineTemplateListApi:
|
||||
|
||||
|
||||
class TestPipelineTemplateDetailApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app):
|
||||
api = PipelineTemplateDetailApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -99,6 +114,10 @@ class TestPipelineTemplateDetailApi:
|
||||
|
||||
|
||||
class TestCustomizedPipelineTemplateApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_patch_success(self, app):
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.patch)
|
||||
@@ -136,35 +155,29 @@ class TestCustomizedPipelineTemplateApi:
|
||||
delete_mock.assert_called_once_with("tpl-1")
|
||||
assert response == 200
|
||||
|
||||
def test_post_success(self, app):
|
||||
def test_post_success(self, app, db_session_with_containers: Session):
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
template = MagicMock()
|
||||
template.yaml_content = "yaml-data"
|
||||
tenant_id = str(uuid4())
|
||||
template = PipelineCustomizedTemplate(
|
||||
tenant_id=tenant_id,
|
||||
name="Test Template",
|
||||
description="Test",
|
||||
chunk_structure="hierarchical",
|
||||
icon={"icon": "📘"},
|
||||
position=0,
|
||||
yaml_content="yaml-data",
|
||||
install_count=0,
|
||||
language="en-US",
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
db_session_with_containers.add(template)
|
||||
db_session_with_containers.commit()
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = template
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "tpl-1")
|
||||
with app.test_request_context("/"):
|
||||
response, status = method(api, template.id)
|
||||
|
||||
assert status == 200
|
||||
assert response == {"data": "yaml-data"}
|
||||
@@ -173,32 +186,16 @@ class TestCustomizedPipelineTemplateApi:
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
):
|
||||
with app.test_request_context("/"):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "tpl-1")
|
||||
method(api, str(uuid4()))
|
||||
|
||||
|
||||
class TestPublishCustomizedPipelineTemplateApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_post_success(self, app):
|
||||
api = PublishCustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -1,3 +1,7 @@
|
||||
"""Testcontainers integration tests for rag_pipeline_datasets controller endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -19,6 +23,10 @@ def unwrap(func):
|
||||
|
||||
|
||||
class TestCreateRagPipelineDatasetApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def _valid_payload(self):
|
||||
return {"yaml_content": "name: test"}
|
||||
|
||||
@@ -33,13 +41,6 @@ class TestCreateRagPipelineDatasetApi:
|
||||
mock_service = MagicMock()
|
||||
mock_service.create_rag_pipeline_dataset.return_value = import_info
|
||||
|
||||
mock_session_ctx = MagicMock()
|
||||
mock_session_ctx.__enter__.return_value = MagicMock()
|
||||
mock_session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
@@ -47,14 +48,6 @@ class TestCreateRagPipelineDatasetApi:
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session",
|
||||
return_value=mock_session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService",
|
||||
return_value=mock_service,
|
||||
@@ -93,13 +86,6 @@ class TestCreateRagPipelineDatasetApi:
|
||||
mock_service = MagicMock()
|
||||
mock_service.create_rag_pipeline_dataset.side_effect = services.errors.dataset.DatasetNameDuplicateError()
|
||||
|
||||
mock_session_ctx = MagicMock()
|
||||
mock_session_ctx.__enter__.return_value = MagicMock()
|
||||
mock_session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
@@ -107,14 +93,6 @@ class TestCreateRagPipelineDatasetApi:
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session",
|
||||
return_value=mock_session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService",
|
||||
return_value=mock_service,
|
||||
@@ -143,6 +121,10 @@ class TestCreateRagPipelineDatasetApi:
|
||||
|
||||
|
||||
class TestCreateEmptyRagPipelineDatasetApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_post_success(self, app):
|
||||
api = CreateEmptyRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -1,5 +1,11 @@
|
||||
"""Testcontainers integration tests for rag_pipeline_import controller endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline_import import (
|
||||
RagPipelineExportApi,
|
||||
@@ -18,6 +24,10 @@ def unwrap(func):
|
||||
|
||||
|
||||
class TestRagPipelineImportApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def _payload(self, mode="create"):
|
||||
return {
|
||||
"mode": mode,
|
||||
@@ -30,7 +40,6 @@ class TestRagPipelineImportApi:
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._payload()
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = "completed"
|
||||
@@ -39,13 +48,6 @@ class TestRagPipelineImportApi:
|
||||
service = MagicMock()
|
||||
service.import_rag_pipeline.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
@@ -53,14 +55,6 @@ class TestRagPipelineImportApi:
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
@@ -76,7 +70,6 @@ class TestRagPipelineImportApi:
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._payload()
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = ImportStatus.FAILED
|
||||
@@ -85,13 +78,6 @@ class TestRagPipelineImportApi:
|
||||
service = MagicMock()
|
||||
service.import_rag_pipeline.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
@@ -99,14 +85,6 @@ class TestRagPipelineImportApi:
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
@@ -122,7 +100,6 @@ class TestRagPipelineImportApi:
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._payload()
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = ImportStatus.PENDING
|
||||
@@ -131,13 +108,6 @@ class TestRagPipelineImportApi:
|
||||
service = MagicMock()
|
||||
service.import_rag_pipeline.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
@@ -145,14 +115,6 @@ class TestRagPipelineImportApi:
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
@@ -165,6 +127,10 @@ class TestRagPipelineImportApi:
|
||||
|
||||
|
||||
class TestRagPipelineImportConfirmApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_confirm_success(self, app):
|
||||
api = RagPipelineImportConfirmApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -177,27 +143,12 @@ class TestRagPipelineImportConfirmApi:
|
||||
service = MagicMock()
|
||||
service.confirm_import.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
@@ -220,27 +171,12 @@ class TestRagPipelineImportConfirmApi:
|
||||
service = MagicMock()
|
||||
service.confirm_import.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
@@ -253,6 +189,10 @@ class TestRagPipelineImportConfirmApi:
|
||||
|
||||
|
||||
class TestRagPipelineImportCheckDependenciesApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app):
|
||||
api = RagPipelineImportCheckDependenciesApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -264,23 +204,8 @@ class TestRagPipelineImportCheckDependenciesApi:
|
||||
service = MagicMock()
|
||||
service.check_dependencies.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
@@ -293,6 +218,10 @@ class TestRagPipelineImportCheckDependenciesApi:
|
||||
|
||||
|
||||
class TestRagPipelineExportApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_with_include_secret(self, app):
|
||||
api = RagPipelineExportApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -301,23 +230,8 @@ class TestRagPipelineExportApi:
|
||||
service = MagicMock()
|
||||
service.export_rag_pipeline_dsl.return_value = {"yaml": "data"}
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/?include_secret=true"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
@@ -1,7 +1,13 @@
|
||||
"""Testcontainers integration tests for rag_pipeline_workflow controller endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, HTTPException, NotFound
|
||||
|
||||
import services
|
||||
@@ -38,6 +44,10 @@ def unwrap(func):
|
||||
|
||||
|
||||
class TestDraftWorkflowApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_draft_success(self, app):
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -200,6 +210,10 @@ class TestDraftWorkflowApi:
|
||||
|
||||
|
||||
class TestDraftRunNodes:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_iteration_node_success(self, app):
|
||||
api = RagPipelineDraftRunIterationNodeApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -275,6 +289,10 @@ class TestDraftRunNodes:
|
||||
|
||||
|
||||
class TestPipelineRunApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_draft_run_success(self, app):
|
||||
api = DraftRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -337,6 +355,10 @@ class TestPipelineRunApis:
|
||||
|
||||
|
||||
class TestDraftNodeRun:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_execution_not_found(self, app):
|
||||
api = RagPipelineDraftNodeRunApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -364,45 +386,43 @@ class TestDraftNodeRun:
|
||||
|
||||
|
||||
class TestPublishedPipelineApis:
|
||||
def test_publish_success(self, app):
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_publish_success(self, app, db_session_with_containers: Session):
|
||||
from models.dataset import Pipeline
|
||||
|
||||
api = PublishedRagPipelineApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
tenant_id = str(uuid4())
|
||||
pipeline = Pipeline(
|
||||
tenant_id=tenant_id,
|
||||
name="test-pipeline",
|
||||
description="test",
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
db_session_with_containers.add(pipeline)
|
||||
db_session_with_containers.commit()
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
workflow = MagicMock(
|
||||
id="w1",
|
||||
id=str(uuid4()),
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
session = MagicMock()
|
||||
session.merge.return_value = pipeline
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
service = MagicMock()
|
||||
service.publish_workflow.return_value = workflow
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
@@ -415,6 +435,10 @@ class TestPublishedPipelineApis:
|
||||
|
||||
|
||||
class TestMiscApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_task_stop(self, app):
|
||||
api = RagPipelineTaskStopApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -471,6 +495,10 @@ class TestMiscApis:
|
||||
|
||||
|
||||
class TestPublishedRagPipelineRunApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_published_run_success(self, app):
|
||||
api = PublishedRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -536,6 +564,10 @@ class TestPublishedRagPipelineRunApi:
|
||||
|
||||
|
||||
class TestDefaultBlockConfigApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_block_config_success(self, app):
|
||||
api = DefaultRagPipelineBlockConfigApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -567,6 +599,10 @@ class TestDefaultBlockConfigApi:
|
||||
|
||||
|
||||
class TestPublishedAllRagPipelineApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_published_workflows_success(self, app):
|
||||
api = PublishedAllRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -577,28 +613,12 @@ class TestPublishedAllRagPipelineApi:
|
||||
service = MagicMock()
|
||||
service.get_all_published_workflow.return_value = ([{"id": "w1"}], False)
|
||||
|
||||
session = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
@@ -628,6 +648,10 @@ class TestPublishedAllRagPipelineApi:
|
||||
|
||||
|
||||
class TestRagPipelineByIdApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_patch_success(self, app):
|
||||
api = RagPipelineByIdApi()
|
||||
method = unwrap(api.patch)
|
||||
@@ -640,14 +664,6 @@ class TestRagPipelineByIdApi:
|
||||
service = MagicMock()
|
||||
service.update_workflow.return_value = workflow
|
||||
|
||||
session = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
payload = {"marked_name": "test"}
|
||||
|
||||
with (
|
||||
@@ -657,14 +673,6 @@ class TestRagPipelineByIdApi:
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
@@ -700,24 +708,8 @@ class TestRagPipelineByIdApi:
|
||||
|
||||
workflow_service = MagicMock()
|
||||
|
||||
session = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", method="DELETE"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.WorkflowService",
|
||||
return_value=workflow_service,
|
||||
@@ -725,12 +717,7 @@ class TestRagPipelineByIdApi:
|
||||
):
|
||||
result = method(api, pipeline, "old-workflow")
|
||||
|
||||
workflow_service.delete_workflow.assert_called_once_with(
|
||||
session=session,
|
||||
workflow_id="old-workflow",
|
||||
tenant_id="t1",
|
||||
)
|
||||
session.commit.assert_called_once()
|
||||
workflow_service.delete_workflow.assert_called_once()
|
||||
assert result == (None, 204)
|
||||
|
||||
def test_delete_active_workflow_rejected(self, app):
|
||||
@@ -745,6 +732,10 @@ class TestRagPipelineByIdApi:
|
||||
|
||||
|
||||
class TestRagPipelineWorkflowLastRunApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_last_run_success(self, app):
|
||||
api = RagPipelineWorkflowLastRunApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -788,6 +779,10 @@ class TestRagPipelineWorkflowLastRunApi:
|
||||
|
||||
|
||||
class TestRagPipelineDatasourceVariableApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_set_datasource_variables_success(self, app):
|
||||
api = RagPipelineDatasourceVariableApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -1,7 +1,10 @@
|
||||
"""Testcontainers integration tests for controllers.console.explore.conversation endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import controllers.console.explore.conversation as conversation_module
|
||||
@@ -48,24 +51,12 @@ def user():
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_db_and_session():
|
||||
with (
|
||||
patch.object(
|
||||
conversation_module,
|
||||
"db",
|
||||
MagicMock(session=MagicMock(), engine=MagicMock()),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.explore.conversation.Session",
|
||||
MagicMock(),
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
class TestConversationListApi:
|
||||
def test_get_success(self, app: Flask, chat_app, user):
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app, chat_app, user):
|
||||
api = conversation_module.ConversationListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@@ -90,7 +81,7 @@ class TestConversationListApi:
|
||||
assert result["has_more"] is False
|
||||
assert len(result["data"]) == 2
|
||||
|
||||
def test_last_conversation_not_exists(self, app: Flask, chat_app, user):
|
||||
def test_last_conversation_not_exists(self, app, chat_app, user):
|
||||
api = conversation_module.ConversationListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@@ -106,7 +97,7 @@ class TestConversationListApi:
|
||||
with pytest.raises(NotFound):
|
||||
method(chat_app)
|
||||
|
||||
def test_wrong_app_mode(self, app: Flask, non_chat_app):
|
||||
def test_wrong_app_mode(self, app, non_chat_app):
|
||||
api = conversation_module.ConversationListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@@ -116,7 +107,11 @@ class TestConversationListApi:
|
||||
|
||||
|
||||
class TestConversationApi:
|
||||
def test_delete_success(self, app: Flask, chat_app, user):
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_delete_success(self, app, chat_app, user):
|
||||
api = conversation_module.ConversationApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
@@ -134,7 +129,7 @@ class TestConversationApi:
|
||||
assert status == 204
|
||||
assert body["result"] == "success"
|
||||
|
||||
def test_delete_not_found(self, app: Flask, chat_app, user):
|
||||
def test_delete_not_found(self, app, chat_app, user):
|
||||
api = conversation_module.ConversationApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
@@ -150,7 +145,7 @@ class TestConversationApi:
|
||||
with pytest.raises(NotFound):
|
||||
method(chat_app, "cid")
|
||||
|
||||
def test_delete_wrong_app_mode(self, app: Flask, non_chat_app):
|
||||
def test_delete_wrong_app_mode(self, app, non_chat_app):
|
||||
api = conversation_module.ConversationApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
@@ -160,7 +155,11 @@ class TestConversationApi:
|
||||
|
||||
|
||||
class TestConversationRenameApi:
|
||||
def test_rename_success(self, app: Flask, chat_app, user):
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_rename_success(self, app, chat_app, user):
|
||||
api = conversation_module.ConversationRenameApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@@ -179,7 +178,7 @@ class TestConversationRenameApi:
|
||||
|
||||
assert result["id"] == "cid"
|
||||
|
||||
def test_rename_not_found(self, app: Flask, chat_app, user):
|
||||
def test_rename_not_found(self, app, chat_app, user):
|
||||
api = conversation_module.ConversationRenameApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@@ -197,7 +196,11 @@ class TestConversationRenameApi:
|
||||
|
||||
|
||||
class TestConversationPinApi:
|
||||
def test_pin_success(self, app: Flask, chat_app, user):
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_pin_success(self, app, chat_app, user):
|
||||
api = conversation_module.ConversationPinApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
@@ -215,7 +218,11 @@ class TestConversationPinApi:
|
||||
|
||||
|
||||
class TestConversationUnPinApi:
|
||||
def test_unpin_success(self, app: Flask, chat_app, user):
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_unpin_success(self, app, chat_app, user):
|
||||
api = conversation_module.ConversationUnPinApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
@@ -5,6 +5,7 @@ Covers:
|
||||
- License status caching (get_cached_license_status)
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -15,9 +16,178 @@ from services.enterprise.enterprise_service import (
|
||||
VALID_LICENSE_CACHE_TTL,
|
||||
DefaultWorkspaceJoinResult,
|
||||
EnterpriseService,
|
||||
WebAppSettings,
|
||||
WorkspacePermission,
|
||||
try_join_default_workspace,
|
||||
)
|
||||
|
||||
MODULE = "services.enterprise.enterprise_service"
|
||||
|
||||
|
||||
class TestEnterpriseServiceInfo:
|
||||
def test_get_info_delegates(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {"version": "1.0"}
|
||||
result = EnterpriseService.get_info()
|
||||
|
||||
req.send_request.assert_called_once_with("GET", "/info")
|
||||
assert result == {"version": "1.0"}
|
||||
|
||||
def test_get_workspace_info_delegates(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {"name": "ws"}
|
||||
result = EnterpriseService.get_workspace_info("tenant-1")
|
||||
|
||||
req.send_request.assert_called_once_with("GET", "/workspace/tenant-1/info")
|
||||
assert result == {"name": "ws"}
|
||||
|
||||
|
||||
class TestSsoSettingsLastUpdateTime:
|
||||
def test_app_sso_parses_valid_timestamp(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = "2025-01-15T10:30:00+00:00"
|
||||
result = EnterpriseService.get_app_sso_settings_last_update_time()
|
||||
|
||||
assert isinstance(result, datetime)
|
||||
assert result.year == 2025
|
||||
|
||||
def test_app_sso_raises_on_empty(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = ""
|
||||
with pytest.raises(ValueError, match="No data found"):
|
||||
EnterpriseService.get_app_sso_settings_last_update_time()
|
||||
|
||||
def test_app_sso_raises_on_invalid_format(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = "not-a-date"
|
||||
with pytest.raises(ValueError, match="Invalid date format"):
|
||||
EnterpriseService.get_app_sso_settings_last_update_time()
|
||||
|
||||
def test_workspace_sso_parses_valid_timestamp(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = "2025-06-01T00:00:00+00:00"
|
||||
result = EnterpriseService.get_workspace_sso_settings_last_update_time()
|
||||
|
||||
assert isinstance(result, datetime)
|
||||
|
||||
def test_workspace_sso_raises_on_empty(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = None
|
||||
with pytest.raises(ValueError, match="No data found"):
|
||||
EnterpriseService.get_workspace_sso_settings_last_update_time()
|
||||
|
||||
|
||||
class TestWorkspacePermissionService:
|
||||
def test_raises_on_empty_workspace_id(self):
|
||||
with pytest.raises(ValueError, match="workspace_id must be provided"):
|
||||
EnterpriseService.WorkspacePermissionService.get_permission("")
|
||||
|
||||
def test_raises_on_missing_data(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = None
|
||||
with pytest.raises(ValueError, match="No data found"):
|
||||
EnterpriseService.WorkspacePermissionService.get_permission("ws-1")
|
||||
|
||||
def test_raises_on_missing_permission_key(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {"other": "data"}
|
||||
with pytest.raises(ValueError, match="No data found"):
|
||||
EnterpriseService.WorkspacePermissionService.get_permission("ws-1")
|
||||
|
||||
def test_returns_parsed_permission(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {
|
||||
"permission": {
|
||||
"workspaceId": "ws-1",
|
||||
"allowMemberInvite": True,
|
||||
"allowOwnerTransfer": False,
|
||||
}
|
||||
}
|
||||
result = EnterpriseService.WorkspacePermissionService.get_permission("ws-1")
|
||||
|
||||
assert isinstance(result, WorkspacePermission)
|
||||
assert result.workspace_id == "ws-1"
|
||||
assert result.allow_member_invite is True
|
||||
assert result.allow_owner_transfer is False
|
||||
|
||||
|
||||
class TestWebAppAuth:
|
||||
def test_is_user_allowed_returns_result_field(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {"result": True}
|
||||
assert EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp("u1", "a1") is True
|
||||
|
||||
def test_is_user_allowed_defaults_false(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {}
|
||||
assert EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp("u1", "a1") is False
|
||||
|
||||
def test_batch_is_user_allowed_returns_empty_for_no_apps(self):
|
||||
assert EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps("u1", []) == {}
|
||||
|
||||
def test_batch_is_user_allowed_raises_on_empty_response(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = None
|
||||
with pytest.raises(ValueError, match="No data found"):
|
||||
EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps("u1", ["a1"])
|
||||
|
||||
def test_get_app_access_mode_raises_on_empty_app_id(self):
|
||||
with pytest.raises(ValueError, match="app_id must be provided"):
|
||||
EnterpriseService.WebAppAuth.get_app_access_mode_by_id("")
|
||||
|
||||
def test_get_app_access_mode_returns_settings(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {"accessMode": "public"}
|
||||
result = EnterpriseService.WebAppAuth.get_app_access_mode_by_id("a1")
|
||||
|
||||
assert isinstance(result, WebAppSettings)
|
||||
assert result.access_mode == "public"
|
||||
|
||||
def test_batch_get_returns_empty_for_no_apps(self):
|
||||
assert EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id([]) == {}
|
||||
|
||||
def test_batch_get_maps_access_modes(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {"accessModes": {"a1": "public", "a2": "private"}}
|
||||
result = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(["a1", "a2"])
|
||||
|
||||
assert result["a1"].access_mode == "public"
|
||||
assert result["a2"].access_mode == "private"
|
||||
|
||||
def test_batch_get_raises_on_invalid_format(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {"accessModes": "not-a-dict"}
|
||||
with pytest.raises(ValueError, match="Invalid data format"):
|
||||
EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(["a1"])
|
||||
|
||||
def test_update_access_mode_raises_on_empty_app_id(self):
|
||||
with pytest.raises(ValueError, match="app_id must be provided"):
|
||||
EnterpriseService.WebAppAuth.update_app_access_mode("", "public")
|
||||
|
||||
def test_update_access_mode_raises_on_invalid_mode(self):
|
||||
with pytest.raises(ValueError, match="access_mode must be"):
|
||||
EnterpriseService.WebAppAuth.update_app_access_mode("a1", "invalid")
|
||||
|
||||
def test_update_access_mode_delegates_and_returns(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {"result": True}
|
||||
result = EnterpriseService.WebAppAuth.update_app_access_mode("a1", "public")
|
||||
|
||||
assert result is True
|
||||
req.send_request.assert_called_once_with(
|
||||
"POST", "/webapp/access-mode", json={"appId": "a1", "accessMode": "public"}
|
||||
)
|
||||
|
||||
def test_cleanup_webapp_raises_on_empty_app_id(self):
|
||||
with pytest.raises(ValueError, match="app_id must be provided"):
|
||||
EnterpriseService.WebAppAuth.cleanup_webapp("")
|
||||
|
||||
def test_cleanup_webapp_delegates(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
EnterpriseService.WebAppAuth.cleanup_webapp("a1")
|
||||
|
||||
req.send_request.assert_called_once_with("DELETE", "/webapp/clean", params={"appId": "a1"})
|
||||
|
||||
|
||||
class TestJoinDefaultWorkspace:
|
||||
def test_join_default_workspace_success(self):
|
||||
|
||||
@@ -7,14 +7,20 @@ This module covers the pre-uninstall plugin hook behavior:
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from httpx import HTTPStatusError
|
||||
|
||||
from configs import dify_config
|
||||
from services.enterprise.plugin_manager_service import (
|
||||
CheckCredentialPolicyComplianceRequest,
|
||||
CredentialPolicyViolationError,
|
||||
PluginCredentialType,
|
||||
PluginManagerService,
|
||||
PreUninstallPluginRequest,
|
||||
)
|
||||
|
||||
MODULE = "services.enterprise.plugin_manager_service"
|
||||
|
||||
|
||||
class TestTryPreUninstallPlugin:
|
||||
def test_try_pre_uninstall_plugin_success(self):
|
||||
@@ -88,3 +94,46 @@ class TestTryPreUninstallPlugin:
|
||||
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
|
||||
)
|
||||
mock_logger.exception.assert_called_once()
|
||||
|
||||
|
||||
class TestCheckCredentialPolicyCompliance:
|
||||
def _request(self, cred_type=PluginCredentialType.MODEL):
|
||||
return CheckCredentialPolicyComplianceRequest(
|
||||
dify_credential_id="cred-1", provider="openai", credential_type=cred_type
|
||||
)
|
||||
|
||||
def test_passes_when_result_true(self):
|
||||
with patch(f"{MODULE}.EnterprisePluginManagerRequest") as req:
|
||||
req.send_request.return_value = {"result": True}
|
||||
PluginManagerService.check_credential_policy_compliance(self._request())
|
||||
|
||||
req.send_request.assert_called_once()
|
||||
|
||||
def test_raises_violation_when_result_false(self):
|
||||
with patch(f"{MODULE}.EnterprisePluginManagerRequest") as req:
|
||||
req.send_request.return_value = {"result": False}
|
||||
with pytest.raises(CredentialPolicyViolationError, match="Credentials not available"):
|
||||
PluginManagerService.check_credential_policy_compliance(self._request())
|
||||
|
||||
def test_raises_violation_on_invalid_response_format(self):
|
||||
with patch(f"{MODULE}.EnterprisePluginManagerRequest") as req:
|
||||
req.send_request.return_value = "not-a-dict"
|
||||
with pytest.raises(CredentialPolicyViolationError, match="error occurred"):
|
||||
PluginManagerService.check_credential_policy_compliance(self._request())
|
||||
|
||||
def test_raises_violation_on_api_exception(self):
|
||||
with patch(f"{MODULE}.EnterprisePluginManagerRequest") as req:
|
||||
req.send_request.side_effect = ConnectionError("network fail")
|
||||
with pytest.raises(CredentialPolicyViolationError, match="error occurred"):
|
||||
PluginManagerService.check_credential_policy_compliance(self._request())
|
||||
|
||||
def test_model_dump_serializes_credential_type_as_number(self):
|
||||
body = self._request(PluginCredentialType.TOOL)
|
||||
data = body.model_dump()
|
||||
|
||||
assert data["credential_type"] == 1
|
||||
assert data["dify_credential_id"] == "cred-1"
|
||||
|
||||
def test_model_credential_type_values(self):
|
||||
assert PluginCredentialType.MODEL.to_number() == 0
|
||||
assert PluginCredentialType.TOOL.to_number() == 1
|
||||
|
||||
@@ -0,0 +1,183 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from models.account import TenantPluginAutoUpgradeStrategy
|
||||
|
||||
MODULE = "services.plugin.plugin_auto_upgrade_service"
|
||||
|
||||
|
||||
def _patched_session():
|
||||
"""Patch Session(db.engine) to return a mock session as context manager."""
|
||||
session = MagicMock()
|
||||
session_cls = MagicMock()
|
||||
session_cls.return_value.__enter__ = MagicMock(return_value=session)
|
||||
session_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
patcher = patch(f"{MODULE}.Session", session_cls)
|
||||
db_patcher = patch(f"{MODULE}.db")
|
||||
return patcher, db_patcher, session
|
||||
|
||||
|
||||
class TestGetStrategy:
|
||||
def test_returns_strategy_when_found(self):
|
||||
p1, p2, session = _patched_session()
|
||||
strategy = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = strategy
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
result = PluginAutoUpgradeService.get_strategy("t1")
|
||||
|
||||
assert result is strategy
|
||||
|
||||
def test_returns_none_when_not_found(self):
|
||||
p1, p2, session = _patched_session()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
result = PluginAutoUpgradeService.get_strategy("t1")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestChangeStrategy:
|
||||
def test_creates_new_strategy(self):
|
||||
p1, p2, session = _patched_session()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.return_value = MagicMock()
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
result = PluginAutoUpgradeService.change_strategy(
|
||||
"t1",
|
||||
TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
|
||||
3,
|
||||
TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
|
||||
[],
|
||||
[],
|
||||
)
|
||||
|
||||
assert result is True
|
||||
session.add.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_updates_existing_strategy(self):
|
||||
p1, p2, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
result = PluginAutoUpgradeService.change_strategy(
|
||||
"t1",
|
||||
TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
|
||||
5,
|
||||
TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL,
|
||||
["p1"],
|
||||
["p2"],
|
||||
)
|
||||
|
||||
assert result is True
|
||||
assert existing.strategy_setting == TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST
|
||||
assert existing.upgrade_time_of_day == 5
|
||||
assert existing.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL
|
||||
assert existing.exclude_plugins == ["p1"]
|
||||
assert existing.include_plugins == ["p2"]
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
class TestExcludePlugin:
|
||||
def test_creates_default_strategy_when_none_exists(self):
|
||||
p1, p2, session = _patched_session()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with (
|
||||
p1,
|
||||
p2,
|
||||
patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls,
|
||||
patch(f"{MODULE}.PluginAutoUpgradeService.change_strategy") as cs,
|
||||
):
|
||||
strat_cls.StrategySetting.FIX_ONLY = "fix_only"
|
||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||
cs.return_value = True
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
result = PluginAutoUpgradeService.exclude_plugin("t1", "plugin-1")
|
||||
|
||||
assert result is True
|
||||
cs.assert_called_once()
|
||||
|
||||
def test_appends_to_exclude_list_in_exclude_mode(self):
|
||||
p1, p2, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
existing.upgrade_mode = "exclude"
|
||||
existing.exclude_plugins = ["p-existing"]
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||
strat_cls.UpgradeMode.PARTIAL = "partial"
|
||||
strat_cls.UpgradeMode.ALL = "all"
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
result = PluginAutoUpgradeService.exclude_plugin("t1", "p-new")
|
||||
|
||||
assert result is True
|
||||
assert existing.exclude_plugins == ["p-existing", "p-new"]
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_removes_from_include_list_in_partial_mode(self):
|
||||
p1, p2, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
existing.upgrade_mode = "partial"
|
||||
existing.include_plugins = ["p1", "p2"]
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||
strat_cls.UpgradeMode.PARTIAL = "partial"
|
||||
strat_cls.UpgradeMode.ALL = "all"
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
result = PluginAutoUpgradeService.exclude_plugin("t1", "p1")
|
||||
|
||||
assert result is True
|
||||
assert existing.include_plugins == ["p2"]
|
||||
|
||||
def test_switches_to_exclude_mode_from_all(self):
|
||||
p1, p2, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
existing.upgrade_mode = "all"
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||
strat_cls.UpgradeMode.PARTIAL = "partial"
|
||||
strat_cls.UpgradeMode.ALL = "all"
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
result = PluginAutoUpgradeService.exclude_plugin("t1", "p1")
|
||||
|
||||
assert result is True
|
||||
assert existing.upgrade_mode == "exclude"
|
||||
assert existing.exclude_plugins == ["p1"]
|
||||
|
||||
def test_no_duplicate_in_exclude_list(self):
|
||||
p1, p2, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
existing.upgrade_mode = "exclude"
|
||||
existing.exclude_plugins = ["p1"]
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||
strat_cls.UpgradeMode.PARTIAL = "partial"
|
||||
strat_cls.UpgradeMode.ALL = "all"
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
PluginAutoUpgradeService.exclude_plugin("t1", "p1")
|
||||
|
||||
assert existing.exclude_plugins == ["p1"]
|
||||
@@ -0,0 +1,75 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from models.account import TenantPluginPermission
|
||||
|
||||
MODULE = "services.plugin.plugin_permission_service"
|
||||
|
||||
|
||||
def _patched_session():
|
||||
"""Patch Session(db.engine) to return a mock session as context manager."""
|
||||
session = MagicMock()
|
||||
session_cls = MagicMock()
|
||||
session_cls.return_value.__enter__ = MagicMock(return_value=session)
|
||||
session_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
patcher = patch(f"{MODULE}.Session", session_cls)
|
||||
db_patcher = patch(f"{MODULE}.db")
|
||||
return patcher, db_patcher, session
|
||||
|
||||
|
||||
class TestGetPermission:
|
||||
def test_returns_permission_when_found(self):
|
||||
p1, p2, session = _patched_session()
|
||||
permission = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = permission
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
|
||||
result = PluginPermissionService.get_permission("t1")
|
||||
|
||||
assert result is permission
|
||||
|
||||
def test_returns_none_when_not_found(self):
|
||||
p1, p2, session = _patched_session()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
|
||||
result = PluginPermissionService.get_permission("t1")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestChangePermission:
|
||||
def test_creates_new_permission_when_not_exists(self):
|
||||
p1, p2, session = _patched_session()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginPermission") as perm_cls:
|
||||
perm_cls.return_value = MagicMock()
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
|
||||
result = PluginPermissionService.change_permission(
|
||||
"t1", TenantPluginPermission.InstallPermission.EVERYONE, TenantPluginPermission.DebugPermission.EVERYONE
|
||||
)
|
||||
|
||||
session.add.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_updates_existing_permission(self):
|
||||
p1, p2, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
|
||||
result = PluginPermissionService.change_permission(
|
||||
"t1", TenantPluginPermission.InstallPermission.ADMINS, TenantPluginPermission.DebugPermission.ADMINS
|
||||
)
|
||||
|
||||
assert existing.install_permission == TenantPluginPermission.InstallPermission.ADMINS
|
||||
assert existing.debug_permission == TenantPluginPermission.DebugPermission.ADMINS
|
||||
session.commit.assert_called_once()
|
||||
session.add.assert_not_called()
|
||||
0
api/tests/unit_tests/services/retention/__init__.py
Normal file
0
api/tests/unit_tests/services/retention/__init__.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from services.retention.conversation.messages_clean_policy import (
|
||||
BillingDisabledPolicy,
|
||||
BillingSandboxPolicy,
|
||||
SimpleMessage,
|
||||
create_message_clean_policy,
|
||||
)
|
||||
|
||||
MODULE = "services.retention.conversation.messages_clean_policy"
|
||||
|
||||
|
||||
def _msg(msg_id: str, app_id: str, days_ago: int = 0) -> SimpleMessage:
|
||||
return SimpleMessage(
|
||||
id=msg_id,
|
||||
app_id=app_id,
|
||||
created_at=datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=days_ago),
|
||||
)
|
||||
|
||||
|
||||
class TestBillingDisabledPolicy:
|
||||
def test_returns_all_message_ids(self):
|
||||
policy = BillingDisabledPolicy()
|
||||
msgs = [_msg("m1", "app1"), _msg("m2", "app2"), _msg("m3", "app1")]
|
||||
|
||||
result = policy.filter_message_ids(msgs, {"app1": "t1", "app2": "t2"})
|
||||
|
||||
assert set(result) == {"m1", "m2", "m3"}
|
||||
|
||||
def test_empty_messages_returns_empty(self):
|
||||
assert BillingDisabledPolicy().filter_message_ids([], {}) == []
|
||||
|
||||
|
||||
class TestBillingSandboxPolicy:
|
||||
def _policy(self, plans, *, graceful_days=21, whitelist=None, now=1_000_000_000):
|
||||
return BillingSandboxPolicy(
|
||||
plan_provider=lambda _ids: plans,
|
||||
graceful_period_days=graceful_days,
|
||||
tenant_whitelist=whitelist,
|
||||
current_timestamp=now,
|
||||
)
|
||||
|
||||
def test_empty_messages_returns_empty(self):
|
||||
policy = self._policy({})
|
||||
assert policy.filter_message_ids([], {"app1": "t1"}) == []
|
||||
|
||||
def test_empty_app_to_tenant_returns_empty(self):
|
||||
policy = self._policy({})
|
||||
assert policy.filter_message_ids([_msg("m1", "app1")], {}) == []
|
||||
|
||||
def test_empty_plans_returns_empty(self):
|
||||
policy = self._policy({})
|
||||
msgs = [_msg("m1", "app1")]
|
||||
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == []
|
||||
|
||||
def test_non_sandbox_tenant_skipped(self):
|
||||
plans = {"t1": {"plan": "professional", "expiration_date": 0}}
|
||||
policy = self._policy(plans)
|
||||
msgs = [_msg("m1", "app1")]
|
||||
|
||||
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == []
|
||||
|
||||
def test_sandbox_no_previous_subscription_deletes(self):
|
||||
plans = {"t1": {"plan": "sandbox", "expiration_date": -1}}
|
||||
policy = self._policy(plans)
|
||||
msgs = [_msg("m1", "app1")]
|
||||
|
||||
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == ["m1"]
|
||||
|
||||
def test_sandbox_expired_beyond_grace_period_deletes(self):
|
||||
now = 1_000_000_000
|
||||
expired_long_ago = now - (22 * 24 * 60 * 60) # 22 days ago > 21 day grace
|
||||
plans = {"t1": {"plan": "sandbox", "expiration_date": expired_long_ago}}
|
||||
policy = self._policy(plans, now=now)
|
||||
msgs = [_msg("m1", "app1")]
|
||||
|
||||
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == ["m1"]
|
||||
|
||||
def test_sandbox_within_grace_period_kept(self):
|
||||
now = 1_000_000_000
|
||||
expired_recently = now - (10 * 24 * 60 * 60) # 10 days ago < 21 day grace
|
||||
plans = {"t1": {"plan": "sandbox", "expiration_date": expired_recently}}
|
||||
policy = self._policy(plans, now=now)
|
||||
msgs = [_msg("m1", "app1")]
|
||||
|
||||
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == []
|
||||
|
||||
def test_whitelisted_tenant_skipped(self):
|
||||
plans = {"t1": {"plan": "sandbox", "expiration_date": -1}}
|
||||
policy = self._policy(plans, whitelist=["t1"])
|
||||
msgs = [_msg("m1", "app1")]
|
||||
|
||||
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == []
|
||||
|
||||
def test_message_without_tenant_mapping_skipped(self):
|
||||
plans = {"t1": {"plan": "sandbox", "expiration_date": -1}}
|
||||
policy = self._policy(plans)
|
||||
msgs = [_msg("m1", "unmapped_app")]
|
||||
|
||||
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == []
|
||||
|
||||
def test_mixed_tenants_only_sandbox_deleted(self):
|
||||
plans = {
|
||||
"t_sandbox": {"plan": "sandbox", "expiration_date": -1},
|
||||
"t_pro": {"plan": "professional", "expiration_date": 0},
|
||||
}
|
||||
policy = self._policy(plans)
|
||||
msgs = [_msg("m1", "app_sandbox"), _msg("m2", "app_pro")]
|
||||
app_map = {"app_sandbox": "t_sandbox", "app_pro": "t_pro"}
|
||||
|
||||
result = policy.filter_message_ids(msgs, app_map)
|
||||
|
||||
assert result == ["m1"]
|
||||
|
||||
|
||||
class TestCreateMessageCleanPolicy:
|
||||
def test_billing_disabled_returns_disabled_policy(self):
|
||||
with patch(f"{MODULE}.dify_config") as cfg:
|
||||
cfg.BILLING_ENABLED = False
|
||||
policy = create_message_clean_policy()
|
||||
|
||||
assert isinstance(policy, BillingDisabledPolicy)
|
||||
|
||||
def test_billing_enabled_returns_sandbox_policy(self):
|
||||
with (
|
||||
patch(f"{MODULE}.dify_config") as cfg,
|
||||
patch(f"{MODULE}.BillingService") as bs,
|
||||
):
|
||||
cfg.BILLING_ENABLED = True
|
||||
bs.get_expired_subscription_cleanup_whitelist.return_value = ["wl1"]
|
||||
bs.get_plan_bulk_with_cache = MagicMock()
|
||||
policy = create_message_clean_policy(graceful_period_days=30)
|
||||
|
||||
assert isinstance(policy, BillingSandboxPolicy)
|
||||
@@ -0,0 +1,598 @@
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolParameter, ToolProviderType
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
MODULE = "services.tools.tools_transform_service"
|
||||
|
||||
|
||||
class TestToolTransformService:
|
||||
"""Test cases for ToolTransformService.convert_tool_entity_to_api_entity method"""
|
||||
|
||||
def test_convert_tool_with_parameter_override(self):
|
||||
"""Test that runtime parameters correctly override base parameters"""
|
||||
# Create mock base parameters
|
||||
base_param1 = Mock(spec=ToolParameter)
|
||||
base_param1.name = "param1"
|
||||
base_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param1.type = "string"
|
||||
base_param1.label = "Base Param 1"
|
||||
|
||||
base_param2 = Mock(spec=ToolParameter)
|
||||
base_param2.name = "param2"
|
||||
base_param2.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param2.type = "string"
|
||||
base_param2.label = "Base Param 2"
|
||||
|
||||
# Create mock runtime parameters that override base parameters
|
||||
runtime_param1 = Mock(spec=ToolParameter)
|
||||
runtime_param1.name = "param1"
|
||||
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param1.type = "string"
|
||||
runtime_param1.label = "Runtime Param 1" # Different label to verify override
|
||||
|
||||
# Create mock tool
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = [base_param1, base_param2]
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = [runtime_param1]
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.author == "test_author"
|
||||
assert result.name == "test_tool"
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 2
|
||||
|
||||
# Find the overridden parameter
|
||||
overridden_param = next((p for p in result.parameters if p.name == "param1"), None)
|
||||
assert overridden_param is not None
|
||||
assert overridden_param.label == "Runtime Param 1" # Should be runtime version
|
||||
|
||||
# Find the non-overridden parameter
|
||||
original_param = next((p for p in result.parameters if p.name == "param2"), None)
|
||||
assert original_param is not None
|
||||
assert original_param.label == "Base Param 2" # Should be base version
|
||||
|
||||
def test_convert_tool_with_additional_runtime_parameters(self):
|
||||
"""Test that additional runtime parameters are added to the final list"""
|
||||
# Create mock base parameters
|
||||
base_param1 = Mock(spec=ToolParameter)
|
||||
base_param1.name = "param1"
|
||||
base_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param1.type = "string"
|
||||
base_param1.label = "Base Param 1"
|
||||
|
||||
# Create mock runtime parameters - one that overrides and one that's new
|
||||
runtime_param1 = Mock(spec=ToolParameter)
|
||||
runtime_param1.name = "param1"
|
||||
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param1.type = "string"
|
||||
runtime_param1.label = "Runtime Param 1"
|
||||
|
||||
runtime_param2 = Mock(spec=ToolParameter)
|
||||
runtime_param2.name = "runtime_only"
|
||||
runtime_param2.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param2.type = "string"
|
||||
runtime_param2.label = "Runtime Only Param"
|
||||
|
||||
# Create mock tool
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = [base_param1]
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2]
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 2
|
||||
|
||||
# Check that both parameters are present
|
||||
param_names = [p.name for p in result.parameters]
|
||||
assert "param1" in param_names
|
||||
assert "runtime_only" in param_names
|
||||
|
||||
# Verify the overridden parameter has runtime version
|
||||
overridden_param = next((p for p in result.parameters if p.name == "param1"), None)
|
||||
assert overridden_param is not None
|
||||
assert overridden_param.label == "Runtime Param 1"
|
||||
|
||||
# Verify the new runtime parameter is included
|
||||
new_param = next((p for p in result.parameters if p.name == "runtime_only"), None)
|
||||
assert new_param is not None
|
||||
assert new_param.label == "Runtime Only Param"
|
||||
|
||||
def test_convert_tool_with_non_form_runtime_parameters(self):
|
||||
"""Test that non-FORM runtime parameters are not added as new parameters"""
|
||||
# Create mock base parameters
|
||||
base_param1 = Mock(spec=ToolParameter)
|
||||
base_param1.name = "param1"
|
||||
base_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param1.type = "string"
|
||||
base_param1.label = "Base Param 1"
|
||||
|
||||
# Create mock runtime parameters with different forms
|
||||
runtime_param1 = Mock(spec=ToolParameter)
|
||||
runtime_param1.name = "param1"
|
||||
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param1.type = "string"
|
||||
runtime_param1.label = "Runtime Param 1"
|
||||
|
||||
runtime_param2 = Mock(spec=ToolParameter)
|
||||
runtime_param2.name = "llm_param"
|
||||
runtime_param2.form = ToolParameter.ToolParameterForm.LLM
|
||||
runtime_param2.type = "string"
|
||||
runtime_param2.label = "LLM Param"
|
||||
|
||||
# Create mock tool
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = [base_param1]
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2]
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 1 # Only the FORM parameter should be present
|
||||
|
||||
# Check that only the FORM parameter is present
|
||||
param_names = [p.name for p in result.parameters]
|
||||
assert "param1" in param_names
|
||||
assert "llm_param" not in param_names
|
||||
|
||||
def test_convert_tool_with_empty_parameters(self):
|
||||
"""Test conversion with empty base and runtime parameters"""
|
||||
# Create mock tool with no parameters
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = []
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = []
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 0
|
||||
|
||||
def test_convert_tool_with_none_parameters(self):
|
||||
"""Test conversion when base parameters is None"""
|
||||
# Create mock tool with None parameters
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = None
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = []
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 0
|
||||
|
||||
def test_convert_tool_parameter_order_preserved(self):
|
||||
"""Test that parameter order is preserved correctly"""
|
||||
# Create mock base parameters in specific order
|
||||
base_param1 = Mock(spec=ToolParameter)
|
||||
base_param1.name = "param1"
|
||||
base_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param1.type = "string"
|
||||
base_param1.label = "Base Param 1"
|
||||
|
||||
base_param2 = Mock(spec=ToolParameter)
|
||||
base_param2.name = "param2"
|
||||
base_param2.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param2.type = "string"
|
||||
base_param2.label = "Base Param 2"
|
||||
|
||||
base_param3 = Mock(spec=ToolParameter)
|
||||
base_param3.name = "param3"
|
||||
base_param3.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param3.type = "string"
|
||||
base_param3.label = "Base Param 3"
|
||||
|
||||
# Create runtime parameter that overrides middle parameter
|
||||
runtime_param2 = Mock(spec=ToolParameter)
|
||||
runtime_param2.name = "param2"
|
||||
runtime_param2.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param2.type = "string"
|
||||
runtime_param2.label = "Runtime Param 2"
|
||||
|
||||
# Create new runtime parameter
|
||||
runtime_param4 = Mock(spec=ToolParameter)
|
||||
runtime_param4.name = "param4"
|
||||
runtime_param4.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param4.type = "string"
|
||||
runtime_param4.label = "Runtime Param 4"
|
||||
|
||||
# Create mock tool
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = [base_param1, base_param2, base_param3]
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = [runtime_param2, runtime_param4]
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 4
|
||||
|
||||
# Check that order is maintained: base parameters first, then new runtime parameters
|
||||
param_names = [p.name for p in result.parameters]
|
||||
assert param_names == ["param1", "param2", "param3", "param4"]
|
||||
|
||||
# Verify that param2 was overridden with runtime version
|
||||
param2 = result.parameters[1]
|
||||
assert param2.name == "param2"
|
||||
assert param2.label == "Runtime Param 2"
|
||||
|
||||
|
||||
class TestWorkflowProviderToUserProvider:
|
||||
"""Test cases for ToolTransformService.workflow_provider_to_user_provider method"""
|
||||
|
||||
def test_workflow_provider_to_user_provider_with_workflow_app_id(self):
|
||||
"""Test that workflow_provider_to_user_provider correctly sets workflow_app_id."""
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
# Create mock workflow tool provider controller
|
||||
workflow_app_id = "app_123"
|
||||
provider_id = "provider_123"
|
||||
mock_controller = Mock(spec=WorkflowToolProviderController)
|
||||
mock_controller.provider_id = provider_id
|
||||
mock_controller.entity = Mock()
|
||||
mock_controller.entity.identity = Mock()
|
||||
mock_controller.entity.identity.author = "test_author"
|
||||
mock_controller.entity.identity.name = "test_workflow_tool"
|
||||
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
|
||||
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
|
||||
mock_controller.entity.identity.icon_dark = None
|
||||
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=mock_controller,
|
||||
labels=["label1", "label2"],
|
||||
workflow_app_id=workflow_app_id,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == provider_id
|
||||
assert result.author == "test_author"
|
||||
assert result.name == "test_workflow_tool"
|
||||
assert result.type == ToolProviderType.WORKFLOW
|
||||
assert result.workflow_app_id == workflow_app_id
|
||||
assert result.labels == ["label1", "label2"]
|
||||
assert result.is_team_authorization is True
|
||||
assert result.plugin_id is None
|
||||
assert result.plugin_unique_identifier is None
|
||||
assert result.tools == []
|
||||
|
||||
def test_workflow_provider_to_user_provider_without_workflow_app_id(self):
|
||||
"""Test that workflow_provider_to_user_provider works when workflow_app_id is not provided."""
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
# Create mock workflow tool provider controller
|
||||
provider_id = "provider_123"
|
||||
mock_controller = Mock(spec=WorkflowToolProviderController)
|
||||
mock_controller.provider_id = provider_id
|
||||
mock_controller.entity = Mock()
|
||||
mock_controller.entity.identity = Mock()
|
||||
mock_controller.entity.identity.author = "test_author"
|
||||
mock_controller.entity.identity.name = "test_workflow_tool"
|
||||
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
|
||||
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
|
||||
mock_controller.entity.identity.icon_dark = None
|
||||
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
|
||||
|
||||
# Call the method without workflow_app_id
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=mock_controller,
|
||||
labels=["label1"],
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == provider_id
|
||||
assert result.workflow_app_id is None
|
||||
assert result.labels == ["label1"]
|
||||
|
||||
def test_workflow_provider_to_user_provider_workflow_app_id_none(self):
|
||||
"""Test that workflow_provider_to_user_provider handles None workflow_app_id explicitly."""
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
# Create mock workflow tool provider controller
|
||||
provider_id = "provider_123"
|
||||
mock_controller = Mock(spec=WorkflowToolProviderController)
|
||||
mock_controller.provider_id = provider_id
|
||||
mock_controller.entity = Mock()
|
||||
mock_controller.entity.identity = Mock()
|
||||
mock_controller.entity.identity.author = "test_author"
|
||||
mock_controller.entity.identity.name = "test_workflow_tool"
|
||||
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
|
||||
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
|
||||
mock_controller.entity.identity.icon_dark = None
|
||||
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
|
||||
|
||||
# Call the method with explicit None values
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=mock_controller,
|
||||
labels=None,
|
||||
workflow_app_id=None,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == provider_id
|
||||
assert result.workflow_app_id is None
|
||||
assert result.labels == []
|
||||
|
||||
def test_workflow_provider_to_user_provider_preserves_other_fields(self):
|
||||
"""Test that workflow_provider_to_user_provider preserves all other entity fields."""
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
# Create mock workflow tool provider controller with various fields
|
||||
workflow_app_id = "app_456"
|
||||
provider_id = "provider_456"
|
||||
mock_controller = Mock(spec=WorkflowToolProviderController)
|
||||
mock_controller.provider_id = provider_id
|
||||
mock_controller.entity = Mock()
|
||||
mock_controller.entity.identity = Mock()
|
||||
mock_controller.entity.identity.author = "another_author"
|
||||
mock_controller.entity.identity.name = "another_workflow_tool"
|
||||
mock_controller.entity.identity.description = I18nObject(
|
||||
en_US="Another description", zh_Hans="Another description"
|
||||
)
|
||||
mock_controller.entity.identity.icon = {"type": "emoji", "content": "⚙️"}
|
||||
mock_controller.entity.identity.icon_dark = {"type": "emoji", "content": "🔧"}
|
||||
mock_controller.entity.identity.label = I18nObject(
|
||||
en_US="Another Workflow Tool", zh_Hans="Another Workflow Tool"
|
||||
)
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=mock_controller,
|
||||
labels=["automation", "workflow"],
|
||||
workflow_app_id=workflow_app_id,
|
||||
)
|
||||
|
||||
# Verify all fields are preserved correctly
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == provider_id
|
||||
assert result.author == "another_author"
|
||||
assert result.name == "another_workflow_tool"
|
||||
assert result.description.en_US == "Another description"
|
||||
assert result.description.zh_Hans == "Another description"
|
||||
assert result.icon == {"type": "emoji", "content": "⚙️"}
|
||||
assert result.icon_dark == {"type": "emoji", "content": "🔧"}
|
||||
assert result.label.en_US == "Another Workflow Tool"
|
||||
assert result.label.zh_Hans == "Another Workflow Tool"
|
||||
assert result.type == ToolProviderType.WORKFLOW
|
||||
assert result.workflow_app_id == workflow_app_id
|
||||
assert result.labels == ["automation", "workflow"]
|
||||
assert result.masked_credentials == {}
|
||||
assert result.is_team_authorization is True
|
||||
assert result.allow_delete is True
|
||||
assert result.plugin_id is None
|
||||
assert result.plugin_unique_identifier is None
|
||||
assert result.tools == []
|
||||
|
||||
|
||||
class TestGetToolProviderIconUrl:
|
||||
def test_builtin_provider_returns_console_url(self):
|
||||
with patch(f"{MODULE}.dify_config") as cfg:
|
||||
cfg.CONSOLE_API_URL = "https://app.dify.ai"
|
||||
url = ToolTransformService.get_tool_provider_icon_url("builtin", "google", "icon.png")
|
||||
|
||||
assert "/builtin/google/icon" in url
|
||||
assert url.startswith("https://app.dify.ai/console/api/workspaces/current/tool-provider")
|
||||
|
||||
def test_builtin_provider_with_no_console_url(self):
|
||||
with patch(f"{MODULE}.dify_config") as cfg:
|
||||
cfg.CONSOLE_API_URL = None
|
||||
url = ToolTransformService.get_tool_provider_icon_url("builtin", "slack", "icon.png")
|
||||
|
||||
assert "/builtin/slack/icon" in url
|
||||
|
||||
def test_api_provider_parses_json_icon(self):
|
||||
icon_json = '{"background": "#fff", "content": "A"}'
|
||||
result = ToolTransformService.get_tool_provider_icon_url("api", "my-api", icon_json)
|
||||
assert result == {"background": "#fff", "content": "A"}
|
||||
|
||||
def test_api_provider_returns_dict_icon_directly(self):
|
||||
icon = {"background": "#000", "content": "B"}
|
||||
result = ToolTransformService.get_tool_provider_icon_url("api", "my-api", icon)
|
||||
assert result == icon
|
||||
|
||||
def test_api_provider_returns_fallback_on_invalid_json(self):
|
||||
result = ToolTransformService.get_tool_provider_icon_url("api", "my-api", "not-json")
|
||||
assert result == {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
def test_workflow_provider_behaves_like_api(self):
|
||||
icon = {"background": "#123", "content": "W"}
|
||||
assert ToolTransformService.get_tool_provider_icon_url("workflow", "wf", icon) == icon
|
||||
|
||||
def test_mcp_returns_icon_as_is(self):
|
||||
assert ToolTransformService.get_tool_provider_icon_url("mcp", "srv", "icon-value") == "icon-value"
|
||||
|
||||
def test_unknown_type_returns_empty(self):
|
||||
assert ToolTransformService.get_tool_provider_icon_url("unknown", "x", "i") == ""
|
||||
|
||||
|
||||
class TestRepackProvider:
|
||||
def test_repacks_dict_provider_icon(self):
|
||||
provider = {"type": "builtin", "name": "google", "icon": "old"}
|
||||
with patch.object(ToolTransformService, "get_tool_provider_icon_url", return_value="/new-url") as mock_fn:
|
||||
ToolTransformService.repack_provider("t1", provider)
|
||||
|
||||
assert provider["icon"] == "/new-url"
|
||||
mock_fn.assert_called_once_with(provider_type="builtin", provider_name="google", icon="old")
|
||||
|
||||
def test_repacks_tool_provider_api_entity_without_plugin(self):
|
||||
entity = MagicMock(spec=ToolProviderApiEntity)
|
||||
entity.plugin_id = None
|
||||
entity.type = ToolProviderType.BUILT_IN
|
||||
entity.name = "slack"
|
||||
entity.icon = "icon.svg"
|
||||
entity.icon_dark = "dark.svg"
|
||||
|
||||
with patch.object(ToolTransformService, "get_tool_provider_icon_url", return_value="/url"):
|
||||
ToolTransformService.repack_provider("t1", entity)
|
||||
|
||||
assert entity.icon == "/url"
|
||||
assert entity.icon_dark == "/url"
|
||||
|
||||
|
||||
class TestConvertMcpSchemaToParameter:
|
||||
def test_simple_object_schema(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"},
|
||||
"count": {"type": "integer", "description": "Result count"},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
params = ToolTransformService.convert_mcp_schema_to_parameter(schema)
|
||||
|
||||
assert len(params) == 2
|
||||
query_param = next(p for p in params if p.name == "query")
|
||||
count_param = next(p for p in params if p.name == "count")
|
||||
assert query_param.required is True
|
||||
assert count_param.required is False
|
||||
assert count_param.type.value == "number"
|
||||
|
||||
def test_float_maps_to_number(self):
|
||||
schema = {"type": "object", "properties": {"rate": {"type": "float"}}, "required": []}
|
||||
assert ToolTransformService.convert_mcp_schema_to_parameter(schema)[0].type.value == "number"
|
||||
|
||||
def test_array_type_attaches_input_schema(self):
|
||||
prop = {"type": "array", "description": "Items", "items": {"type": "string"}}
|
||||
schema = {"type": "object", "properties": {"items": prop}, "required": []}
|
||||
param = ToolTransformService.convert_mcp_schema_to_parameter(schema)[0]
|
||||
assert param.input_schema is not None
|
||||
|
||||
def test_non_object_schema_returns_empty(self):
|
||||
assert ToolTransformService.convert_mcp_schema_to_parameter({"type": "string"}) == []
|
||||
|
||||
def test_missing_properties_returns_empty(self):
|
||||
assert ToolTransformService.convert_mcp_schema_to_parameter({"type": "object"}) == []
|
||||
|
||||
def test_list_type_uses_first_element(self):
|
||||
schema = {"type": "object", "properties": {"f": {"type": ["string", "null"]}}, "required": []}
|
||||
assert ToolTransformService.convert_mcp_schema_to_parameter(schema)[0].type.value == "string"
|
||||
|
||||
def test_missing_description_defaults_empty(self):
|
||||
schema = {"type": "object", "properties": {"f": {"type": "string"}}, "required": []}
|
||||
assert ToolTransformService.convert_mcp_schema_to_parameter(schema)[0].llm_description == ""
|
||||
|
||||
|
||||
class TestApiProviderToController:
|
||||
def test_api_key_header_auth(self):
|
||||
db_provider = MagicMock()
|
||||
db_provider.credentials = {"auth_type": "api_key_header"}
|
||||
with patch(f"{MODULE}.ApiToolProviderController") as ctrl_cls:
|
||||
ctrl_cls.from_db.return_value = MagicMock()
|
||||
ToolTransformService.api_provider_to_controller(db_provider)
|
||||
ctrl_cls.from_db.assert_called_once_with(db_provider=db_provider, auth_type=ApiProviderAuthType.API_KEY_HEADER)
|
||||
|
||||
def test_api_key_query_auth(self):
|
||||
db_provider = MagicMock()
|
||||
db_provider.credentials = {"auth_type": "api_key_query"}
|
||||
with patch(f"{MODULE}.ApiToolProviderController") as ctrl_cls:
|
||||
ctrl_cls.from_db.return_value = MagicMock()
|
||||
ToolTransformService.api_provider_to_controller(db_provider)
|
||||
ctrl_cls.from_db.assert_called_once_with(db_provider=db_provider, auth_type=ApiProviderAuthType.API_KEY_QUERY)
|
||||
|
||||
def test_legacy_api_key_maps_to_header(self):
|
||||
db_provider = MagicMock()
|
||||
db_provider.credentials = {"auth_type": "api_key"}
|
||||
with patch(f"{MODULE}.ApiToolProviderController") as ctrl_cls:
|
||||
ctrl_cls.from_db.return_value = MagicMock()
|
||||
ToolTransformService.api_provider_to_controller(db_provider)
|
||||
ctrl_cls.from_db.assert_called_once_with(db_provider=db_provider, auth_type=ApiProviderAuthType.API_KEY_HEADER)
|
||||
|
||||
def test_unknown_auth_defaults_to_none(self):
|
||||
db_provider = MagicMock()
|
||||
db_provider.credentials = {"auth_type": "something_else"}
|
||||
with patch(f"{MODULE}.ApiToolProviderController") as ctrl_cls:
|
||||
ctrl_cls.from_db.return_value = MagicMock()
|
||||
ToolTransformService.api_provider_to_controller(db_provider)
|
||||
ctrl_cls.from_db.assert_called_once_with(db_provider=db_provider, auth_type=ApiProviderAuthType.NONE)
|
||||
6
pnpm-lock.yaml
generated
6
pnpm-lock.yaml
generated
@@ -249,9 +249,6 @@ catalogs:
|
||||
autoprefixer:
|
||||
specifier: 10.4.27
|
||||
version: 10.4.27
|
||||
axios:
|
||||
specifier: ^1.14.0
|
||||
version: 1.14.0
|
||||
class-variance-authority:
|
||||
specifier: 0.7.1
|
||||
version: 0.7.1
|
||||
@@ -573,6 +570,7 @@ overrides:
|
||||
array.prototype.flatmap: npm:@nolyfill/array.prototype.flatmap@^1.0.44
|
||||
array.prototype.tosorted: npm:@nolyfill/array.prototype.tosorted@^1.0.44
|
||||
assert: npm:@nolyfill/assert@^1.0.26
|
||||
axios: 1.14.0
|
||||
brace-expansion@<2.0.2: 2.0.2
|
||||
canvas: ^3.2.2
|
||||
devalue@<5.3.2: 5.3.2
|
||||
@@ -652,7 +650,7 @@ importers:
|
||||
sdks/nodejs-client:
|
||||
dependencies:
|
||||
axios:
|
||||
specifier: 'catalog:'
|
||||
specifier: 1.14.0
|
||||
version: 1.14.0
|
||||
devDependencies:
|
||||
'@eslint/js':
|
||||
|
||||
@@ -1,3 +1,12 @@
|
||||
trustPolicy: no-downgrade
|
||||
minimumReleaseAge: 1440
|
||||
blockExoticSubdeps: true
|
||||
strictDepBuilds: true
|
||||
allowBuilds:
|
||||
'@parcel/watcher': false
|
||||
canvas: false
|
||||
esbuild: false
|
||||
sharp: false
|
||||
packages:
|
||||
- web
|
||||
- e2e
|
||||
@@ -13,6 +22,7 @@ overrides:
|
||||
array.prototype.flatmap: npm:@nolyfill/array.prototype.flatmap@^1.0.44
|
||||
array.prototype.tosorted: npm:@nolyfill/array.prototype.tosorted@^1.0.44
|
||||
assert: npm:@nolyfill/assert@^1.0.26
|
||||
axios: 1.14.0
|
||||
brace-expansion@<2.0.2: 2.0.2
|
||||
canvas: ^3.2.2
|
||||
devalue@<5.3.2: 5.3.2
|
||||
@@ -59,13 +69,6 @@ overrides:
|
||||
which-typed-array: npm:@nolyfill/which-typed-array@^1.0.44
|
||||
yaml@>=2.0.0 <2.8.3: 2.8.3
|
||||
yauzl@<3.2.1: 3.2.1
|
||||
ignoredBuiltDependencies:
|
||||
- canvas
|
||||
- core-js-pure
|
||||
onlyBuiltDependencies:
|
||||
- "@parcel/watcher"
|
||||
- esbuild
|
||||
- sharp
|
||||
catalog:
|
||||
"@amplitude/analytics-browser": 2.38.0
|
||||
"@amplitude/plugin-session-replay-browser": 1.27.5
|
||||
@@ -149,7 +152,7 @@ catalog:
|
||||
agentation: 3.0.2
|
||||
ahooks: 3.9.7
|
||||
autoprefixer: 10.4.27
|
||||
axios: ^1.14.0
|
||||
axios: 1.14.0
|
||||
class-variance-authority: 0.7.1
|
||||
clsx: 2.1.1
|
||||
cmdk: 1.1.1
|
||||
|
||||
Reference in New Issue
Block a user