Merge remote-tracking branch 'origin/main' into optional-plugin-invoke

# Conflicts:
#	api/tests/test_containers_integration_tests/conftest.py
This commit is contained in:
-LAN-
2026-03-23 10:24:22 +08:00
50 changed files with 12870 additions and 198 deletions

View File

@@ -120,7 +120,7 @@ jobs:
- name: Run Claude Code for Translation Sync
if: steps.detect_changes.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@df37d2f0760a4b5683a6e617c9325bc1a36443f6 # v1.0.75
uses: anthropics/claude-code-action@6062f3709600659be5e47fcddf2cf76993c235c2 # v1.0.76
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -284,27 +284,29 @@ class TidbOnQdrantVector(BaseVector):
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
for node_id in ids:
try:
filter = models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchValue(value=node_id),
),
],
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
if not ids:
return
try:
filter = models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchAny(any=ids),
),
],
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
def text_exists(self, id: str) -> bool:
all_collection_name = []

View File

@@ -335,7 +335,11 @@ class BillingService:
# Redis returns bytes, decode to string and parse JSON
json_str = cached_value.decode("utf-8") if isinstance(cached_value, bytes) else cached_value
plan_dict = json.loads(json_str)
# NOTE (hj24): New billing versions may return timestamp as str, and validate_python
# in non-strict mode will coerce it to the expected int type.
# To preserve compatibility, always keep non-strict mode here and avoid strict mode.
subscription_plan = subscription_adapter.validate_python(plan_dict)
# NOTE END
tenant_plans[tenant_id] = subscription_plan
except Exception:
logger.exception(

View File

@@ -704,7 +704,7 @@ class MCPToolManageService:
raise ValueError(f"MCP tool {server_url} already exists")
if "unique_mcp_provider_server_identifier" in error_msg:
raise ValueError(f"MCP tool {server_identifier} already exists")
raise
raise error
def _is_valid_url(self, url: str) -> bool:
"""Validate URL format."""

View File

@@ -168,6 +168,7 @@ class DifyTestContainers:
# Start Dify Sandbox container for code execution environment
# Dify Sandbox provides a secure environment for executing user code
# Use pinned version 0.2.12 to match production docker-compose configuration
logger.info("Initializing Dify Sandbox container...")
sandbox_image = os.getenv(SANDBOX_TEST_IMAGE_ENV, DEFAULT_SANDBOX_TEST_IMAGE)
self.dify_sandbox = DockerContainer(image=sandbox_image).with_network(self.network)

View File

@@ -0,0 +1,320 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask, request
from werkzeug.exceptions import InternalServerError, NotFound
from werkzeug.local import LocalProxy
from controllers.console.app.error import (
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.app.message import (
ChatMessageListApi,
ChatMessagesQuery,
FeedbackExportQuery,
MessageAnnotationCountApi,
MessageApi,
MessageFeedbackApi,
MessageFeedbackExportApi,
MessageFeedbackPayload,
MessageSuggestedQuestionApi,
)
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from models import App, AppMode
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
@pytest.fixture
def app():
flask_app = Flask(__name__)
flask_app.config["TESTING"] = True
flask_app.config["RESTX_MASK_HEADER"] = "X-Fields"
return flask_app
@pytest.fixture
def mock_account():
from models.account import Account, AccountStatus
account = MagicMock(spec=Account)
account.id = "user_123"
account.timezone = "UTC"
account.status = AccountStatus.ACTIVE
account.is_admin_or_owner = True
account.current_tenant.current_role = "owner"
account.has_edit_permission = True
return account
@pytest.fixture
def mock_app_model():
app_model = MagicMock(spec=App)
app_model.id = "app_123"
app_model.mode = AppMode.CHAT
app_model.tenant_id = "tenant_123"
return app_model
@pytest.fixture(autouse=True)
def mock_csrf():
with patch("libs.login.check_csrf_token") as mock:
yield mock
import contextlib
@contextlib.contextmanager
def setup_test_context(
test_app, endpoint_class, route_path, method, mock_account, mock_app_model, payload=None, qs=None
):
with (
patch("extensions.ext_database.db") as mock_db,
patch("controllers.console.app.wraps.db", mock_db),
patch("controllers.console.wraps.db", mock_db),
patch("controllers.console.app.message.db", mock_db),
patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
patch("controllers.console.app.message.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
):
# Set up a generic query mock that usually returns mock_app_model when getting app
app_query_mock = MagicMock()
app_query_mock.filter.return_value.first.return_value = mock_app_model
app_query_mock.filter.return_value.filter.return_value.first.return_value = mock_app_model
app_query_mock.where.return_value.first.return_value = mock_app_model
app_query_mock.where.return_value.where.return_value.first.return_value = mock_app_model
data_query_mock = MagicMock()
def query_side_effect(*args, **kwargs):
if args and hasattr(args[0], "__name__") and args[0].__name__ == "App":
return app_query_mock
return data_query_mock
mock_db.session.query.side_effect = query_side_effect
mock_db.data_query = data_query_mock
# Let the caller override the stat db query logic
proxy_mock = LocalProxy(lambda: mock_account)
query_string = "&".join([f"{k}={v}" for k, v in (qs or {}).items()])
full_path = f"{route_path}?{query_string}" if qs else route_path
with (
patch("libs.login.current_user", proxy_mock),
patch("flask_login.current_user", proxy_mock),
patch("controllers.console.app.message.attach_message_extra_contents", return_value=None),
):
with test_app.test_request_context(full_path, method=method, json=payload):
request.view_args = {"app_id": "app_123"}
if "suggested-questions" in route_path:
# simplistic extraction for message_id
parts = route_path.split("chat-messages/")
if len(parts) > 1:
request.view_args["message_id"] = parts[1].split("/")[0]
elif "messages/" in route_path and "chat-messages" not in route_path:
parts = route_path.split("messages/")
if len(parts) > 1:
request.view_args["message_id"] = parts[1].split("/")[0]
api_instance = endpoint_class()
# Check if it has a dispatch_request or method
if hasattr(api_instance, method.lower()):
yield api_instance, mock_db, request.view_args
class TestMessageValidators:
def test_chat_messages_query_validators(self):
# Test empty_to_none
assert ChatMessagesQuery.empty_to_none("") is None
assert ChatMessagesQuery.empty_to_none("val") == "val"
# Test validate_uuid
assert ChatMessagesQuery.validate_uuid(None) is None
assert (
ChatMessagesQuery.validate_uuid("123e4567-e89b-12d3-a456-426614174000")
== "123e4567-e89b-12d3-a456-426614174000"
)
def test_message_feedback_validators(self):
assert (
MessageFeedbackPayload.validate_message_id("123e4567-e89b-12d3-a456-426614174000")
== "123e4567-e89b-12d3-a456-426614174000"
)
def test_feedback_export_validators(self):
assert FeedbackExportQuery.parse_bool(None) is None
assert FeedbackExportQuery.parse_bool(True) is True
assert FeedbackExportQuery.parse_bool("1") is True
assert FeedbackExportQuery.parse_bool("0") is False
assert FeedbackExportQuery.parse_bool("off") is False
with pytest.raises(ValueError):
FeedbackExportQuery.parse_bool("invalid")
class TestMessageEndpoints:
def test_chat_message_list_not_found(self, app, mock_account, mock_app_model):
with setup_test_context(
app,
ChatMessageListApi,
"/apps/app_123/chat-messages",
"GET",
mock_account,
mock_app_model,
qs={"conversation_id": "123e4567-e89b-12d3-a456-426614174000"},
) as (api, mock_db, v_args):
mock_db.data_query.where.return_value.first.return_value = None
with pytest.raises(NotFound):
api.get(**v_args)
def test_chat_message_list_success(self, app, mock_account, mock_app_model):
with setup_test_context(
app,
ChatMessageListApi,
"/apps/app_123/chat-messages",
"GET",
mock_account,
mock_app_model,
qs={"conversation_id": "123e4567-e89b-12d3-a456-426614174000", "limit": 1},
) as (api, mock_db, v_args):
mock_conv = MagicMock()
mock_conv.id = "123e4567-e89b-12d3-a456-426614174000"
mock_msg = MagicMock()
mock_msg.id = "msg_123"
mock_msg.feedbacks = []
mock_msg.annotation = None
mock_msg.annotation_hit_history = None
mock_msg.agent_thoughts = []
mock_msg.message_files = []
mock_msg.extra_contents = []
mock_msg.message = {}
mock_msg.message_metadata_dict = {}
# mock returns
q_mock = mock_db.data_query
q_mock.where.return_value.first.side_effect = [mock_conv]
q_mock.where.return_value.order_by.return_value.limit.return_value.all.return_value = [mock_msg]
mock_db.session.scalar.return_value = False
resp = api.get(**v_args)
assert resp["limit"] == 1
assert resp["has_more"] is False
assert len(resp["data"]) == 1
def test_message_feedback_not_found(self, app, mock_account, mock_app_model):
with setup_test_context(
app,
MessageFeedbackApi,
"/apps/app_123/feedbacks",
"POST",
mock_account,
mock_app_model,
payload={"message_id": "123e4567-e89b-12d3-a456-426614174000"},
) as (api, mock_db, v_args):
mock_db.data_query.where.return_value.first.return_value = None
with pytest.raises(NotFound):
api.post(**v_args)
def test_message_feedback_success(self, app, mock_account, mock_app_model):
payload = {"message_id": "123e4567-e89b-12d3-a456-426614174000", "rating": "like"}
with setup_test_context(
app, MessageFeedbackApi, "/apps/app_123/feedbacks", "POST", mock_account, mock_app_model, payload=payload
) as (api, mock_db, v_args):
mock_msg = MagicMock()
mock_msg.admin_feedback = None
mock_db.data_query.where.return_value.first.return_value = mock_msg
resp = api.post(**v_args)
assert resp == {"result": "success"}
def test_message_annotation_count(self, app, mock_account, mock_app_model):
with setup_test_context(
app, MessageAnnotationCountApi, "/apps/app_123/annotations/count", "GET", mock_account, mock_app_model
) as (api, mock_db, v_args):
mock_db.data_query.where.return_value.count.return_value = 5
resp = api.get(**v_args)
assert resp == {"count": 5}
@patch("controllers.console.app.message.MessageService")
def test_message_suggested_questions_success(self, mock_msg_srv, app, mock_account, mock_app_model):
mock_msg_srv.get_suggested_questions_after_answer.return_value = ["q1", "q2"]
with setup_test_context(
app,
MessageSuggestedQuestionApi,
"/apps/app_123/chat-messages/msg_123/suggested-questions",
"GET",
mock_account,
mock_app_model,
) as (api, mock_db, v_args):
resp = api.get(**v_args)
assert resp == {"data": ["q1", "q2"]}
@pytest.mark.parametrize(
("exc", "expected_exc"),
[
(MessageNotExistsError, NotFound),
(ConversationNotExistsError, NotFound),
(ProviderTokenNotInitError, ProviderNotInitializeError),
(QuotaExceededError, ProviderQuotaExceededError),
(ModelCurrentlyNotSupportError, ProviderModelCurrentlyNotSupportError),
(SuggestedQuestionsAfterAnswerDisabledError, AppSuggestedQuestionsAfterAnswerDisabledError),
(Exception, InternalServerError),
],
)
@patch("controllers.console.app.message.MessageService")
def test_message_suggested_questions_errors(
self, mock_msg_srv, exc, expected_exc, app, mock_account, mock_app_model
):
mock_msg_srv.get_suggested_questions_after_answer.side_effect = exc()
with setup_test_context(
app,
MessageSuggestedQuestionApi,
"/apps/app_123/chat-messages/msg_123/suggested-questions",
"GET",
mock_account,
mock_app_model,
) as (api, mock_db, v_args):
with pytest.raises(expected_exc):
api.get(**v_args)
@patch("services.feedback_service.FeedbackService.export_feedbacks")
def test_message_feedback_export_success(self, mock_export, app, mock_account, mock_app_model):
mock_export.return_value = {"exported": True}
with setup_test_context(
app, MessageFeedbackExportApi, "/apps/app_123/feedbacks/export", "GET", mock_account, mock_app_model
) as (api, mock_db, v_args):
resp = api.get(**v_args)
assert resp == {"exported": True}
def test_message_api_get_success(self, app, mock_account, mock_app_model):
with setup_test_context(
app, MessageApi, "/apps/app_123/messages/msg_123", "GET", mock_account, mock_app_model
) as (api, mock_db, v_args):
mock_msg = MagicMock()
mock_msg.id = "msg_123"
mock_msg.feedbacks = []
mock_msg.annotation = None
mock_msg.annotation_hit_history = None
mock_msg.agent_thoughts = []
mock_msg.message_files = []
mock_msg.extra_contents = []
mock_msg.message = {}
mock_msg.message_metadata_dict = {}
mock_db.data_query.where.return_value.first.return_value = mock_msg
resp = api.get(**v_args)
assert resp["id"] == "msg_123"

View File

@@ -0,0 +1,275 @@
from decimal import Decimal
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask, request
from werkzeug.local import LocalProxy
from controllers.console.app.statistic import (
AverageResponseTimeStatistic,
AverageSessionInteractionStatistic,
DailyConversationStatistic,
DailyMessageStatistic,
DailyTerminalsStatistic,
DailyTokenCostStatistic,
TokensPerSecondStatistic,
UserSatisfactionRateStatistic,
)
from models import App, AppMode
@pytest.fixture
def app():
flask_app = Flask(__name__)
flask_app.config["TESTING"] = True
return flask_app
@pytest.fixture
def mock_account():
from models.account import Account, AccountStatus
account = MagicMock(spec=Account)
account.id = "user_123"
account.timezone = "UTC"
account.status = AccountStatus.ACTIVE
account.is_admin_or_owner = True
account.current_tenant.current_role = "owner"
account.has_edit_permission = True
return account
@pytest.fixture
def mock_app_model():
app_model = MagicMock(spec=App)
app_model.id = "app_123"
app_model.mode = AppMode.CHAT
app_model.tenant_id = "tenant_123"
return app_model
@pytest.fixture(autouse=True)
def mock_csrf():
with patch("libs.login.check_csrf_token") as mock:
yield mock
def setup_test_context(
test_app, endpoint_class, route_path, mock_account, mock_app_model, mock_rs, mock_parse_ret=(None, None)
):
with (
patch("controllers.console.app.statistic.db") as mock_db_stat,
patch("controllers.console.app.wraps.db") as mock_db_wraps,
patch("controllers.console.wraps.db", mock_db_wraps),
patch(
"controllers.console.app.statistic.current_account_with_tenant", return_value=(mock_account, "tenant_123")
),
patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
):
mock_conn = MagicMock()
mock_conn.execute.return_value = mock_rs
mock_begin = MagicMock()
mock_begin.__enter__.return_value = mock_conn
mock_db_stat.engine.begin.return_value = mock_begin
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = mock_app_model
mock_query.filter.return_value.filter.return_value.first.return_value = mock_app_model
mock_query.where.return_value.first.return_value = mock_app_model
mock_query.where.return_value.where.return_value.first.return_value = mock_app_model
mock_db_wraps.session.query.return_value = mock_query
proxy_mock = LocalProxy(lambda: mock_account)
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
with test_app.test_request_context(route_path, method="GET"):
request.view_args = {"app_id": "app_123"}
api_instance = endpoint_class()
response = api_instance.get(app_id="app_123")
return response
class TestStatisticEndpoints:
def test_daily_message_statistic(self, app, mock_account, mock_app_model):
mock_row = MagicMock()
mock_row.date = "2023-01-01"
mock_row.message_count = 10
mock_row.interactions = Decimal(0)
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
response = setup_test_context(
app,
DailyMessageStatistic,
"/apps/app_123/statistics/daily-messages?start=2023-01-01 00:00&end=2023-01-02 00:00",
mock_account,
mock_app_model,
[mock_row],
)
assert response.status_code == 200
assert response.json["data"][0]["message_count"] == 10
def test_daily_conversation_statistic(self, app, mock_account, mock_app_model):
mock_row = MagicMock()
mock_row.date = "2023-01-01"
mock_row.conversation_count = 5
mock_row.interactions = Decimal(0)
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
response = setup_test_context(
app,
DailyConversationStatistic,
"/apps/app_123/statistics/daily-conversations",
mock_account,
mock_app_model,
[mock_row],
)
assert response.status_code == 200
assert response.json["data"][0]["conversation_count"] == 5
def test_daily_terminals_statistic(self, app, mock_account, mock_app_model):
mock_row = MagicMock()
mock_row.date = "2023-01-01"
mock_row.terminal_count = 2
mock_row.interactions = Decimal(0)
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
response = setup_test_context(
app,
DailyTerminalsStatistic,
"/apps/app_123/statistics/daily-end-users",
mock_account,
mock_app_model,
[mock_row],
)
assert response.status_code == 200
assert response.json["data"][0]["terminal_count"] == 2
def test_daily_token_cost_statistic(self, app, mock_account, mock_app_model):
mock_row = MagicMock()
mock_row.date = "2023-01-01"
mock_row.token_count = 100
mock_row.total_price = Decimal("0.02")
mock_row.interactions = Decimal(0)
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
response = setup_test_context(
app,
DailyTokenCostStatistic,
"/apps/app_123/statistics/token-costs",
mock_account,
mock_app_model,
[mock_row],
)
assert response.status_code == 200
assert response.json["data"][0]["token_count"] == 100
assert response.json["data"][0]["total_price"] == "0.02"
def test_average_session_interaction_statistic(self, app, mock_account, mock_app_model):
mock_row = MagicMock()
mock_row.date = "2023-01-01"
mock_row.interactions = Decimal("3.523")
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
response = setup_test_context(
app,
AverageSessionInteractionStatistic,
"/apps/app_123/statistics/average-session-interactions",
mock_account,
mock_app_model,
[mock_row],
)
assert response.status_code == 200
assert response.json["data"][0]["interactions"] == 3.52
def test_user_satisfaction_rate_statistic(self, app, mock_account, mock_app_model):
mock_row = MagicMock()
mock_row.date = "2023-01-01"
mock_row.message_count = 100
mock_row.feedback_count = 10
mock_row.interactions = Decimal(0)
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
response = setup_test_context(
app,
UserSatisfactionRateStatistic,
"/apps/app_123/statistics/user-satisfaction-rate",
mock_account,
mock_app_model,
[mock_row],
)
assert response.status_code == 200
assert response.json["data"][0]["rate"] == 100.0
def test_average_response_time_statistic(self, app, mock_account, mock_app_model):
mock_app_model.mode = AppMode.COMPLETION
mock_row = MagicMock()
mock_row.date = "2023-01-01"
mock_row.latency = 1.234
mock_row.interactions = Decimal(0)
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
response = setup_test_context(
app,
AverageResponseTimeStatistic,
"/apps/app_123/statistics/average-response-time",
mock_account,
mock_app_model,
[mock_row],
)
assert response.status_code == 200
assert response.json["data"][0]["latency"] == 1234.0
def test_tokens_per_second_statistic(self, app, mock_account, mock_app_model):
mock_row = MagicMock()
mock_row.date = "2023-01-01"
mock_row.tokens_per_second = 15.5
mock_row.interactions = Decimal(0)
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
response = setup_test_context(
app,
TokensPerSecondStatistic,
"/apps/app_123/statistics/tokens-per-second",
mock_account,
mock_app_model,
[mock_row],
)
assert response.status_code == 200
assert response.json["data"][0]["tps"] == 15.5
@patch("controllers.console.app.statistic.parse_time_range")
def test_invalid_time_range(self, mock_parse, app, mock_account, mock_app_model):
mock_parse.side_effect = ValueError("Invalid time")
from werkzeug.exceptions import BadRequest
with pytest.raises(BadRequest):
setup_test_context(
app,
DailyMessageStatistic,
"/apps/app_123/statistics/daily-messages?start=invalid&end=invalid",
mock_account,
mock_app_model,
[],
)
@patch("controllers.console.app.statistic.parse_time_range")
def test_time_range_params_passed(self, mock_parse, app, mock_account, mock_app_model):
import datetime
start = datetime.datetime.now()
end = datetime.datetime.now()
mock_parse.return_value = (start, end)
response = setup_test_context(
app,
DailyMessageStatistic,
"/apps/app_123/statistics/daily-messages?start=something&end=something",
mock_account,
mock_app_model,
[],
)
assert response.status_code == 200
mock_parse.assert_called_once()

View File

@@ -0,0 +1,313 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask, request
from werkzeug.local import LocalProxy
from controllers.console.app.error import DraftWorkflowNotExist
from controllers.console.app.workflow_draft_variable import (
ConversationVariableCollectionApi,
EnvironmentVariableCollectionApi,
NodeVariableCollectionApi,
SystemVariableCollectionApi,
VariableApi,
VariableResetApi,
WorkflowVariableCollectionApi,
)
from controllers.web.error import InvalidArgumentError, NotFoundError
from models import App, AppMode
from models.enums import DraftVariableType
@pytest.fixture
def app():
flask_app = Flask(__name__)
flask_app.config["TESTING"] = True
flask_app.config["RESTX_MASK_HEADER"] = "X-Fields"
return flask_app
@pytest.fixture
def mock_account():
from models.account import Account, AccountStatus
account = MagicMock(spec=Account)
account.id = "user_123"
account.timezone = "UTC"
account.status = AccountStatus.ACTIVE
account.is_admin_or_owner = True
account.current_tenant.current_role = "owner"
account.has_edit_permission = True
return account
@pytest.fixture
def mock_app_model():
app_model = MagicMock(spec=App)
app_model.id = "app_123"
app_model.mode = AppMode.WORKFLOW
app_model.tenant_id = "tenant_123"
return app_model
@pytest.fixture(autouse=True)
def mock_csrf():
with patch("libs.login.check_csrf_token") as mock:
yield mock
def setup_test_context(test_app, endpoint_class, route_path, method, mock_account, mock_app_model, payload=None):
with (
patch("controllers.console.app.wraps.db") as mock_db_wraps,
patch("controllers.console.wraps.db", mock_db_wraps),
patch("controllers.console.app.workflow_draft_variable.db"),
patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
):
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = mock_app_model
mock_query.filter.return_value.filter.return_value.first.return_value = mock_app_model
mock_query.where.return_value.first.return_value = mock_app_model
mock_query.where.return_value.where.return_value.first.return_value = mock_app_model
mock_db_wraps.session.query.return_value = mock_query
proxy_mock = LocalProxy(lambda: mock_account)
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
with test_app.test_request_context(route_path, method=method, json=payload):
request.view_args = {"app_id": "app_123"}
# extract node_id or variable_id from path manually since view_args overrides
if "nodes/" in route_path:
request.view_args["node_id"] = route_path.split("nodes/")[1].split("/")[0]
if "variables/" in route_path:
# simplistic extraction
parts = route_path.split("variables/")
if len(parts) > 1 and parts[1] and parts[1] != "reset":
request.view_args["variable_id"] = parts[1].split("/")[0]
api_instance = endpoint_class()
# we just call dispatch_request to avoid manual argument passing
if hasattr(api_instance, method.lower()):
func = getattr(api_instance, method.lower())
return func(**request.view_args)
class TestWorkflowDraftVariableEndpoints:
@staticmethod
def _mock_workflow_variable(variable_type: DraftVariableType = DraftVariableType.NODE) -> MagicMock:
class DummyValueType:
def exposed_type(self):
return DraftVariableType.NODE
mock_var = MagicMock()
mock_var.app_id = "app_123"
mock_var.id = "var_123"
mock_var.name = "test_var"
mock_var.description = ""
mock_var.get_variable_type.return_value = variable_type
mock_var.get_selector.return_value = []
mock_var.value_type = DummyValueType()
mock_var.edited = False
mock_var.visible = True
mock_var.file_id = None
mock_var.variable_file = None
mock_var.is_truncated.return_value = False
mock_var.get_value.return_value.model_copy.return_value.value = "test_value"
return mock_var
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_workflow_variable_collection_get_success(
self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model
):
mock_wf_srv.return_value.is_workflow_exist.return_value = True
from services.workflow_draft_variable_service import WorkflowDraftVariableList
mock_draft_srv.return_value.list_variables_without_values.return_value = WorkflowDraftVariableList(
variables=[], total=0
)
resp = setup_test_context(
app,
WorkflowVariableCollectionApi,
"/apps/app_123/workflows/draft/variables?page=1&limit=20",
"GET",
mock_account,
mock_app_model,
)
assert resp == {"items": [], "total": 0}
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
def test_workflow_variable_collection_get_not_exist(self, mock_wf_srv, app, mock_account, mock_app_model):
mock_wf_srv.return_value.is_workflow_exist.return_value = False
with pytest.raises(DraftWorkflowNotExist):
setup_test_context(
app,
WorkflowVariableCollectionApi,
"/apps/app_123/workflows/draft/variables",
"GET",
mock_account,
mock_app_model,
)
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_workflow_variable_collection_delete(self, mock_draft_srv, app, mock_account, mock_app_model):
resp = setup_test_context(
app,
WorkflowVariableCollectionApi,
"/apps/app_123/workflows/draft/variables",
"DELETE",
mock_account,
mock_app_model,
)
assert resp.status_code == 204
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_node_variable_collection_get_success(self, mock_draft_srv, app, mock_account, mock_app_model):
from services.workflow_draft_variable_service import WorkflowDraftVariableList
mock_draft_srv.return_value.list_node_variables.return_value = WorkflowDraftVariableList(variables=[])
resp = setup_test_context(
app,
NodeVariableCollectionApi,
"/apps/app_123/workflows/draft/nodes/node_123/variables",
"GET",
mock_account,
mock_app_model,
)
assert resp == {"items": []}
def test_node_variable_collection_get_invalid_node_id(self, app, mock_account, mock_app_model):
with pytest.raises(InvalidArgumentError):
setup_test_context(
app,
NodeVariableCollectionApi,
"/apps/app_123/workflows/draft/nodes/sys/variables",
"GET",
mock_account,
mock_app_model,
)
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_node_variable_collection_delete(self, mock_draft_srv, app, mock_account, mock_app_model):
resp = setup_test_context(
app,
NodeVariableCollectionApi,
"/apps/app_123/workflows/draft/nodes/node_123/variables",
"DELETE",
mock_account,
mock_app_model,
)
assert resp.status_code == 204
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_variable_api_get_success(self, mock_draft_srv, app, mock_account, mock_app_model):
mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable()
resp = setup_test_context(
app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "GET", mock_account, mock_app_model
)
assert resp["id"] == "var_123"
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_variable_api_get_not_found(self, mock_draft_srv, app, mock_account, mock_app_model):
mock_draft_srv.return_value.get_variable.return_value = None
with pytest.raises(NotFoundError):
setup_test_context(
app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "GET", mock_account, mock_app_model
)
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_variable_api_patch_success(self, mock_draft_srv, app, mock_account, mock_app_model):
mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable()
resp = setup_test_context(
app,
VariableApi,
"/apps/app_123/workflows/draft/variables/var_123",
"PATCH",
mock_account,
mock_app_model,
payload={"name": "new_name"},
)
assert resp["id"] == "var_123"
mock_draft_srv.return_value.update_variable.assert_called_once()
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_variable_api_delete_success(self, mock_draft_srv, app, mock_account, mock_app_model):
mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable()
resp = setup_test_context(
app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "DELETE", mock_account, mock_app_model
)
assert resp.status_code == 204
mock_draft_srv.return_value.delete_variable.assert_called_once()
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_variable_reset_api_put_success(self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model):
mock_wf_srv.return_value.get_draft_workflow.return_value = MagicMock()
mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable()
mock_draft_srv.return_value.reset_variable.return_value = None # means no content
resp = setup_test_context(
app,
VariableResetApi,
"/apps/app_123/workflows/draft/variables/var_123/reset",
"PUT",
mock_account,
mock_app_model,
)
assert resp.status_code == 204
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_conversation_variable_collection_get(self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model):
mock_wf_srv.return_value.get_draft_workflow.return_value = MagicMock()
from services.workflow_draft_variable_service import WorkflowDraftVariableList
mock_draft_srv.return_value.list_conversation_variables.return_value = WorkflowDraftVariableList(variables=[])
resp = setup_test_context(
app,
ConversationVariableCollectionApi,
"/apps/app_123/workflows/draft/conversation-variables",
"GET",
mock_account,
mock_app_model,
)
assert resp == {"items": []}
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_system_variable_collection_get(self, mock_draft_srv, app, mock_account, mock_app_model):
from services.workflow_draft_variable_service import WorkflowDraftVariableList
mock_draft_srv.return_value.list_system_variables.return_value = WorkflowDraftVariableList(variables=[])
resp = setup_test_context(
app,
SystemVariableCollectionApi,
"/apps/app_123/workflows/draft/system-variables",
"GET",
mock_account,
mock_app_model,
)
assert resp == {"items": []}
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
def test_environment_variable_collection_get(self, mock_wf_srv, app, mock_account, mock_app_model):
mock_wf = MagicMock()
mock_wf.environment_variables = []
mock_wf_srv.return_value.get_draft_workflow.return_value = mock_wf
resp = setup_test_context(
app,
EnvironmentVariableCollectionApi,
"/apps/app_123/workflows/draft/environment-variables",
"GET",
mock_account,
mock_app_model,
)
assert resp == {"items": []}

View File

@@ -0,0 +1,209 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.data_source_bearer_auth import (
ApiKeyAuthDataSource,
ApiKeyAuthDataSourceBinding,
ApiKeyAuthDataSourceBindingDelete,
)
from controllers.console.auth.error import ApiKeyAuthFailedError
class TestApiKeyAuthDataSource:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
app.config["WTF_CSRF_ENABLED"] = False
return app
@patch("libs.login.check_csrf_token")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list")
def test_get_api_key_auth_data_source(self, mock_get_list, mock_db, mock_csrf, app):
from models.account import Account, AccountStatus
mock_account = MagicMock(spec=Account)
mock_account.id = "user_123"
mock_account.status = AccountStatus.ACTIVE
mock_account.is_admin_or_owner = True
mock_account.current_tenant.current_role = "owner"
mock_binding = MagicMock()
mock_binding.id = "bind_123"
mock_binding.category = "api_key"
mock_binding.provider = "custom_provider"
mock_binding.disabled = False
mock_binding.created_at.timestamp.return_value = 1620000000
mock_binding.updated_at.timestamp.return_value = 1620000001
mock_get_list.return_value = [mock_binding]
with (
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
patch(
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
return_value=(mock_account, "tenant_123"),
),
):
with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"):
proxy_mock = MagicMock()
proxy_mock._get_current_object.return_value = mock_account
with patch("libs.login.current_user", proxy_mock):
api_instance = ApiKeyAuthDataSource()
response = api_instance.get()
assert "sources" in response
assert len(response["sources"]) == 1
assert response["sources"][0]["provider"] == "custom_provider"
@patch("libs.login.check_csrf_token")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list")
def test_get_api_key_auth_data_source_empty(self, mock_get_list, mock_db, mock_csrf, app):
from models.account import Account, AccountStatus
mock_account = MagicMock(spec=Account)
mock_account.id = "user_123"
mock_account.status = AccountStatus.ACTIVE
mock_account.is_admin_or_owner = True
mock_account.current_tenant.current_role = "owner"
mock_get_list.return_value = None
with (
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
patch(
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
return_value=(mock_account, "tenant_123"),
),
):
with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"):
proxy_mock = MagicMock()
proxy_mock._get_current_object.return_value = mock_account
with patch("libs.login.current_user", proxy_mock):
api_instance = ApiKeyAuthDataSource()
response = api_instance.get()
assert "sources" in response
assert len(response["sources"]) == 0
class TestApiKeyAuthDataSourceBinding:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
app.config["WTF_CSRF_ENABLED"] = False
return app
@patch("libs.login.check_csrf_token")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth")
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args")
def test_create_binding_successful(self, mock_validate, mock_create, mock_db, mock_csrf, app):
from models.account import Account, AccountStatus
mock_account = MagicMock(spec=Account)
mock_account.id = "user_123"
mock_account.status = AccountStatus.ACTIVE
mock_account.is_admin_or_owner = True
mock_account.current_tenant.current_role = "owner"
with (
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
patch(
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
return_value=(mock_account, "tenant_123"),
),
):
with app.test_request_context(
"/console/api/api-key-auth/data-source/binding",
method="POST",
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
):
proxy_mock = MagicMock()
proxy_mock._get_current_object.return_value = mock_account
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
api_instance = ApiKeyAuthDataSourceBinding()
response = api_instance.post()
assert response[0]["result"] == "success"
assert response[1] == 200
mock_validate.assert_called_once()
mock_create.assert_called_once()
@patch("libs.login.check_csrf_token")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth")
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args")
def test_create_binding_failure(self, mock_validate, mock_create, mock_db, mock_csrf, app):
from models.account import Account, AccountStatus
mock_account = MagicMock(spec=Account)
mock_account.id = "user_123"
mock_account.status = AccountStatus.ACTIVE
mock_account.is_admin_or_owner = True
mock_account.current_tenant.current_role = "owner"
mock_create.side_effect = ValueError("Invalid structure")
with (
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
patch(
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
return_value=(mock_account, "tenant_123"),
),
):
with app.test_request_context(
"/console/api/api-key-auth/data-source/binding",
method="POST",
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
):
proxy_mock = MagicMock()
proxy_mock._get_current_object.return_value = mock_account
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
api_instance = ApiKeyAuthDataSourceBinding()
with pytest.raises(ApiKeyAuthFailedError, match="Invalid structure"):
api_instance.post()
class TestApiKeyAuthDataSourceBindingDelete:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
app.config["WTF_CSRF_ENABLED"] = False
return app
@patch("libs.login.check_csrf_token")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth")
def test_delete_binding_successful(self, mock_delete, mock_db, mock_csrf, app):
from models.account import Account, AccountStatus
mock_account = MagicMock(spec=Account)
mock_account.id = "user_123"
mock_account.status = AccountStatus.ACTIVE
mock_account.is_admin_or_owner = True
mock_account.current_tenant.current_role = "owner"
with (
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
patch(
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
return_value=(mock_account, "tenant_123"),
),
):
with app.test_request_context("/console/api/api-key-auth/data-source/binding_123", method="DELETE"):
proxy_mock = MagicMock()
proxy_mock._get_current_object.return_value = mock_account
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
api_instance = ApiKeyAuthDataSourceBindingDelete()
response = api_instance.delete("binding_123")
assert response[0]["result"] == "success"
assert response[1] == 204
mock_delete.assert_called_once_with("tenant_123", "binding_123")

View File

@@ -0,0 +1,192 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.local import LocalProxy
from controllers.console.auth.data_source_oauth import (
OAuthDataSource,
OAuthDataSourceBinding,
OAuthDataSourceCallback,
OAuthDataSourceSync,
)
class TestOAuthDataSource:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
@patch("flask_login.current_user")
@patch("libs.login.current_user")
@patch("libs.login.check_csrf_token")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.data_source_oauth.dify_config.NOTION_INTEGRATION_TYPE", None)
def test_get_oauth_url_successful(
self, mock_db, mock_csrf, mock_libs_user, mock_flask_user, mock_get_providers, app
):
mock_oauth_provider = MagicMock()
mock_oauth_provider.get_authorization_url.return_value = "http://oauth.provider/auth"
mock_get_providers.return_value = {"notion": mock_oauth_provider}
from models.account import Account, AccountStatus
mock_account = MagicMock(spec=Account)
mock_account.id = "user_123"
mock_account.status = AccountStatus.ACTIVE
mock_account.is_admin_or_owner = True
mock_account.current_tenant.current_role = "owner"
mock_libs_user.return_value = mock_account
mock_flask_user.return_value = mock_account
# also patch current_account_with_tenant
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())):
with app.test_request_context("/console/api/oauth/data-source/notion", method="GET"):
proxy_mock = LocalProxy(lambda: mock_account)
with patch("libs.login.current_user", proxy_mock):
api_instance = OAuthDataSource()
response = api_instance.get("notion")
assert response[0]["data"] == "http://oauth.provider/auth"
assert response[1] == 200
mock_oauth_provider.get_authorization_url.assert_called_once()
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
@patch("flask_login.current_user")
@patch("libs.login.check_csrf_token")
@patch("controllers.console.wraps.db")
def test_get_oauth_url_invalid_provider(self, mock_db, mock_csrf, mock_flask_user, mock_get_providers, app):
mock_get_providers.return_value = {"notion": MagicMock()}
from models.account import Account, AccountStatus
mock_account = MagicMock(spec=Account)
mock_account.id = "user_123"
mock_account.status = AccountStatus.ACTIVE
mock_account.is_admin_or_owner = True
mock_account.current_tenant.current_role = "owner"
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())):
with app.test_request_context("/console/api/oauth/data-source/unknown_provider", method="GET"):
proxy_mock = LocalProxy(lambda: mock_account)
with patch("libs.login.current_user", proxy_mock):
api_instance = OAuthDataSource()
response = api_instance.get("unknown_provider")
assert response[0]["error"] == "Invalid provider"
assert response[1] == 400
class TestOAuthDataSourceCallback:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
def test_oauth_callback_successful(self, mock_get_providers, app):
provider_mock = MagicMock()
mock_get_providers.return_value = {"notion": provider_mock}
with app.test_request_context("/console/api/oauth/data-source/notion/callback?code=mock_code", method="GET"):
api_instance = OAuthDataSourceCallback()
response = api_instance.get("notion")
assert response.status_code == 302
assert "code=mock_code" in response.location
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
def test_oauth_callback_missing_code(self, mock_get_providers, app):
provider_mock = MagicMock()
mock_get_providers.return_value = {"notion": provider_mock}
with app.test_request_context("/console/api/oauth/data-source/notion/callback", method="GET"):
api_instance = OAuthDataSourceCallback()
response = api_instance.get("notion")
assert response.status_code == 302
assert "error=Access denied" in response.location
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
def test_oauth_callback_invalid_provider(self, mock_get_providers, app):
mock_get_providers.return_value = {"notion": MagicMock()}
with app.test_request_context("/console/api/oauth/data-source/invalid/callback?code=mock_code", method="GET"):
api_instance = OAuthDataSourceCallback()
response = api_instance.get("invalid")
assert response[0]["error"] == "Invalid provider"
assert response[1] == 400
class TestOAuthDataSourceBinding:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
def test_get_binding_successful(self, mock_get_providers, app):
mock_provider = MagicMock()
mock_provider.get_access_token.return_value = None
mock_get_providers.return_value = {"notion": mock_provider}
with app.test_request_context("/console/api/oauth/data-source/notion/binding?code=auth_code_123", method="GET"):
api_instance = OAuthDataSourceBinding()
response = api_instance.get("notion")
assert response[0]["result"] == "success"
assert response[1] == 200
mock_provider.get_access_token.assert_called_once_with("auth_code_123")
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
def test_get_binding_missing_code(self, mock_get_providers, app):
mock_get_providers.return_value = {"notion": MagicMock()}
with app.test_request_context("/console/api/oauth/data-source/notion/binding?code=", method="GET"):
api_instance = OAuthDataSourceBinding()
response = api_instance.get("notion")
assert response[0]["error"] == "Invalid code"
assert response[1] == 400
class TestOAuthDataSourceSync:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
@patch("libs.login.check_csrf_token")
@patch("controllers.console.wraps.db")
def test_sync_successful(self, mock_db, mock_csrf, mock_get_providers, app):
mock_provider = MagicMock()
mock_provider.sync_data_source.return_value = None
mock_get_providers.return_value = {"notion": mock_provider}
from models.account import Account, AccountStatus
mock_account = MagicMock(spec=Account)
mock_account.id = "user_123"
mock_account.status = AccountStatus.ACTIVE
mock_account.is_admin_or_owner = True
mock_account.current_tenant.current_role = "owner"
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())):
with app.test_request_context("/console/api/oauth/data-source/notion/binding_123/sync", method="GET"):
proxy_mock = LocalProxy(lambda: mock_account)
with patch("libs.login.current_user", proxy_mock):
api_instance = OAuthDataSourceSync()
# The route pattern uses <uuid:binding_id>, so we just pass a string for unit testing
response = api_instance.get("notion", "binding_123")
assert response[0]["result"] == "success"
assert response[1] == 200
mock_provider.sync_data_source.assert_called_once_with("binding_123")

View File

@@ -0,0 +1,417 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import BadRequest, NotFound
from controllers.console.auth.oauth_server import (
OAuthServerAppApi,
OAuthServerUserAccountApi,
OAuthServerUserAuthorizeApi,
OAuthServerUserTokenApi,
)
class TestOAuthServerAppApi:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def mock_oauth_provider_app(self):
from models.model import OAuthProviderApp
oauth_app = MagicMock(spec=OAuthProviderApp)
oauth_app.client_id = "test_client_id"
oauth_app.redirect_uris = ["http://localhost/callback"]
oauth_app.app_icon = "icon_url"
oauth_app.app_label = "Test App"
oauth_app.scope = "read,write"
return oauth_app
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_successful_post(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context(
"/oauth/provider",
method="POST",
json={"client_id": "test_client_id", "redirect_uri": "http://localhost/callback"},
):
api_instance = OAuthServerAppApi()
response = api_instance.post()
assert response["app_icon"] == "icon_url"
assert response["app_label"] == "Test App"
assert response["scope"] == "read,write"
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_invalid_redirect_uri(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context(
"/oauth/provider",
method="POST",
json={"client_id": "test_client_id", "redirect_uri": "http://invalid/callback"},
):
api_instance = OAuthServerAppApi()
with pytest.raises(BadRequest, match="redirect_uri is invalid"):
api_instance.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_invalid_client_id(self, mock_get_app, mock_db, app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = None
with app.test_request_context(
"/oauth/provider",
method="POST",
json={"client_id": "test_invalid_client_id", "redirect_uri": "http://localhost/callback"},
):
api_instance = OAuthServerAppApi()
with pytest.raises(NotFound, match="client_id is invalid"):
api_instance.post()
class TestOAuthServerUserAuthorizeApi:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def mock_oauth_provider_app(self):
oauth_app = MagicMock()
oauth_app.client_id = "test_client_id"
return oauth_app
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
@patch("controllers.console.auth.oauth_server.current_account_with_tenant")
@patch("controllers.console.wraps.current_account_with_tenant")
@patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_authorization_code")
@patch("libs.login.check_csrf_token")
def test_successful_authorize(
self, mock_csrf, mock_sign, mock_wrap_current, mock_current, mock_get_app, mock_db, app, mock_oauth_provider_app
):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
mock_account = MagicMock()
mock_account.id = "user_123"
from models.account import AccountStatus
mock_account.status = AccountStatus.ACTIVE
mock_current.return_value = (mock_account, MagicMock())
mock_wrap_current.return_value = (mock_account, MagicMock())
mock_sign.return_value = "auth_code_123"
with app.test_request_context("/oauth/provider/authorize", method="POST", json={"client_id": "test_client_id"}):
with patch("libs.login.current_user", mock_account):
api_instance = OAuthServerUserAuthorizeApi()
response = api_instance.post()
assert response["code"] == "auth_code_123"
mock_sign.assert_called_once_with("test_client_id", "user_123")
class TestOAuthServerUserTokenApi:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def mock_oauth_provider_app(self):
from models.model import OAuthProviderApp
oauth_app = MagicMock(spec=OAuthProviderApp)
oauth_app.client_id = "test_client_id"
oauth_app.client_secret = "test_secret"
oauth_app.redirect_uris = ["http://localhost/callback"]
return oauth_app
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
@patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token")
def test_authorization_code_grant(self, mock_sign, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
mock_sign.return_value = ("access_123", "refresh_123")
with app.test_request_context(
"/oauth/provider/token",
method="POST",
json={
"client_id": "test_client_id",
"grant_type": "authorization_code",
"code": "auth_code",
"client_secret": "test_secret",
"redirect_uri": "http://localhost/callback",
},
):
api_instance = OAuthServerUserTokenApi()
response = api_instance.post()
assert response["access_token"] == "access_123"
assert response["refresh_token"] == "refresh_123"
assert response["token_type"] == "Bearer"
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_authorization_code_grant_missing_code(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context(
"/oauth/provider/token",
method="POST",
json={
"client_id": "test_client_id",
"grant_type": "authorization_code",
"client_secret": "test_secret",
"redirect_uri": "http://localhost/callback",
},
):
api_instance = OAuthServerUserTokenApi()
with pytest.raises(BadRequest, match="code is required"):
api_instance.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_authorization_code_grant_invalid_secret(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context(
"/oauth/provider/token",
method="POST",
json={
"client_id": "test_client_id",
"grant_type": "authorization_code",
"code": "auth_code",
"client_secret": "invalid_secret",
"redirect_uri": "http://localhost/callback",
},
):
api_instance = OAuthServerUserTokenApi()
with pytest.raises(BadRequest, match="client_secret is invalid"):
api_instance.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_authorization_code_grant_invalid_redirect_uri(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context(
"/oauth/provider/token",
method="POST",
json={
"client_id": "test_client_id",
"grant_type": "authorization_code",
"code": "auth_code",
"client_secret": "test_secret",
"redirect_uri": "http://invalid/callback",
},
):
api_instance = OAuthServerUserTokenApi()
with pytest.raises(BadRequest, match="redirect_uri is invalid"):
api_instance.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
@patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token")
def test_refresh_token_grant(self, mock_sign, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
mock_sign.return_value = ("new_access", "new_refresh")
with app.test_request_context(
"/oauth/provider/token",
method="POST",
json={"client_id": "test_client_id", "grant_type": "refresh_token", "refresh_token": "refresh_123"},
):
api_instance = OAuthServerUserTokenApi()
response = api_instance.post()
assert response["access_token"] == "new_access"
assert response["refresh_token"] == "new_refresh"
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_refresh_token_grant_missing_token(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context(
"/oauth/provider/token",
method="POST",
json={
"client_id": "test_client_id",
"grant_type": "refresh_token",
},
):
api_instance = OAuthServerUserTokenApi()
with pytest.raises(BadRequest, match="refresh_token is required"):
api_instance.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_invalid_grant_type(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context(
"/oauth/provider/token",
method="POST",
json={
"client_id": "test_client_id",
"grant_type": "invalid_grant",
},
):
api_instance = OAuthServerUserTokenApi()
with pytest.raises(BadRequest, match="invalid grant_type"):
api_instance.post()
class TestOAuthServerUserAccountApi:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def mock_oauth_provider_app(self):
from models.model import OAuthProviderApp
oauth_app = MagicMock(spec=OAuthProviderApp)
oauth_app.client_id = "test_client_id"
return oauth_app
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
@patch("controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token")
def test_successful_account_retrieval(self, mock_validate, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
mock_account = MagicMock()
mock_account.name = "Test User"
mock_account.email = "test@example.com"
mock_account.avatar = "avatar_url"
mock_account.interface_language = "en-US"
mock_account.timezone = "UTC"
mock_validate.return_value = mock_account
with app.test_request_context(
"/oauth/provider/account",
method="POST",
json={"client_id": "test_client_id"},
headers={"Authorization": "Bearer valid_access_token"},
):
api_instance = OAuthServerUserAccountApi()
response = api_instance.post()
assert response["name"] == "Test User"
assert response["email"] == "test@example.com"
assert response["avatar"] == "avatar_url"
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_missing_authorization_header(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context("/oauth/provider/account", method="POST", json={"client_id": "test_client_id"}):
api_instance = OAuthServerUserAccountApi()
response = api_instance.post()
assert response.status_code == 401
assert response.json["error"] == "Authorization header is required"
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_invalid_authorization_header_format(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context(
"/oauth/provider/account",
method="POST",
json={"client_id": "test_client_id"},
headers={"Authorization": "InvalidFormat"},
):
api_instance = OAuthServerUserAccountApi()
response = api_instance.post()
assert response.status_code == 401
assert response.json["error"] == "Invalid Authorization header format"
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_invalid_token_type(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context(
"/oauth/provider/account",
method="POST",
json={"client_id": "test_client_id"},
headers={"Authorization": "Basic something"},
):
api_instance = OAuthServerUserAccountApi()
response = api_instance.post()
assert response.status_code == 401
assert response.json["error"] == "token_type is invalid"
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_missing_access_token(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context(
"/oauth/provider/account",
method="POST",
json={"client_id": "test_client_id"},
headers={"Authorization": "Bearer "},
):
api_instance = OAuthServerUserAccountApi()
response = api_instance.post()
assert response.status_code == 401
assert response.json["error"] == "Invalid Authorization header format"
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
@patch("controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token")
def test_invalid_access_token(self, mock_validate, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
mock_validate.return_value = None
with app.test_request_context(
"/oauth/provider/account",
method="POST",
json={"client_id": "test_client_id"},
headers={"Authorization": "Bearer invalid_token"},
):
api_instance = OAuthServerUserAccountApi()
response = api_instance.post()
assert response.status_code == 401
assert response.json["error"] == "access_token or client_id is invalid"

View File

@@ -0,0 +1,181 @@
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from core.extension.api_based_extension_requestor import APIBasedExtensionPoint
from core.moderation.api.api import ApiModeration, ModerationInputParams, ModerationOutputParams
from core.moderation.base import ModerationAction, ModerationInputsResult, ModerationOutputsResult
from models.api_based_extension import APIBasedExtension
class TestApiModeration:
@pytest.fixture
def api_config(self):
return {
"inputs_config": {
"enabled": True,
},
"outputs_config": {
"enabled": True,
},
"api_based_extension_id": "test-extension-id",
}
@pytest.fixture
def api_moderation(self, api_config):
return ApiModeration(app_id="test-app-id", tenant_id="test-tenant-id", config=api_config)
def test_moderation_input_params(self):
params = ModerationInputParams(app_id="app-1", inputs={"key": "val"}, query="test query")
assert params.app_id == "app-1"
assert params.inputs == {"key": "val"}
assert params.query == "test query"
# Test defaults
params_default = ModerationInputParams()
assert params_default.app_id == ""
assert params_default.inputs == {}
assert params_default.query == ""
def test_moderation_output_params(self):
params = ModerationOutputParams(app_id="app-1", text="test text")
assert params.app_id == "app-1"
assert params.text == "test text"
with pytest.raises(ValidationError):
ModerationOutputParams()
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
def test_validate_config_success(self, mock_get_extension, api_config):
mock_get_extension.return_value = MagicMock(spec=APIBasedExtension)
ApiModeration.validate_config("test-tenant-id", api_config)
mock_get_extension.assert_called_once_with("test-tenant-id", "test-extension-id")
def test_validate_config_missing_extension_id(self):
config = {
"inputs_config": {"enabled": True},
"outputs_config": {"enabled": True},
}
with pytest.raises(ValueError, match="api_based_extension_id is required"):
ApiModeration.validate_config("test-tenant-id", config)
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
def test_validate_config_extension_not_found(self, mock_get_extension, api_config):
mock_get_extension.return_value = None
with pytest.raises(ValueError, match="API-based Extension not found"):
ApiModeration.validate_config("test-tenant-id", api_config)
@patch("core.moderation.api.api.ApiModeration._get_config_by_requestor")
def test_moderation_for_inputs_enabled(self, mock_get_config, api_moderation):
mock_get_config.return_value = {"flagged": True, "action": "direct_output", "preset_response": "Blocked by API"}
result = api_moderation.moderation_for_inputs(inputs={"q": "a"}, query="hello")
assert isinstance(result, ModerationInputsResult)
assert result.flagged is True
assert result.action == ModerationAction.DIRECT_OUTPUT
assert result.preset_response == "Blocked by API"
mock_get_config.assert_called_once_with(
APIBasedExtensionPoint.APP_MODERATION_INPUT,
{"app_id": "test-app-id", "inputs": {"q": "a"}, "query": "hello"},
)
def test_moderation_for_inputs_disabled(self):
config = {
"inputs_config": {"enabled": False},
"outputs_config": {"enabled": True},
"api_based_extension_id": "ext-id",
}
moderation = ApiModeration("app-id", "tenant-id", config)
result = moderation.moderation_for_inputs(inputs={}, query="")
assert result.flagged is False
assert result.action == ModerationAction.DIRECT_OUTPUT
assert result.preset_response == ""
def test_moderation_for_inputs_no_config(self):
moderation = ApiModeration("app-id", "tenant-id", None)
with pytest.raises(ValueError, match="The config is not set"):
moderation.moderation_for_inputs({}, "")
@patch("core.moderation.api.api.ApiModeration._get_config_by_requestor")
def test_moderation_for_outputs_enabled(self, mock_get_config, api_moderation):
mock_get_config.return_value = {"flagged": False, "action": "direct_output", "preset_response": ""}
result = api_moderation.moderation_for_outputs(text="hello world")
assert isinstance(result, ModerationOutputsResult)
assert result.flagged is False
mock_get_config.assert_called_once_with(
APIBasedExtensionPoint.APP_MODERATION_OUTPUT, {"app_id": "test-app-id", "text": "hello world"}
)
def test_moderation_for_outputs_disabled(self):
config = {
"inputs_config": {"enabled": True},
"outputs_config": {"enabled": False},
"api_based_extension_id": "ext-id",
}
moderation = ApiModeration("app-id", "tenant-id", config)
result = moderation.moderation_for_outputs(text="test")
assert result.flagged is False
assert result.action == ModerationAction.DIRECT_OUTPUT
def test_moderation_for_outputs_no_config(self):
moderation = ApiModeration("app-id", "tenant-id", None)
with pytest.raises(ValueError, match="The config is not set"):
moderation.moderation_for_outputs("test")
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
@patch("core.moderation.api.api.decrypt_token")
@patch("core.moderation.api.api.APIBasedExtensionRequestor")
def test_get_config_by_requestor_success(self, mock_requestor_cls, mock_decrypt, mock_get_ext, api_moderation):
mock_ext = MagicMock(spec=APIBasedExtension)
mock_ext.api_endpoint = "http://api.test"
mock_ext.api_key = "encrypted-key"
mock_get_ext.return_value = mock_ext
mock_decrypt.return_value = "decrypted-key"
mock_requestor = MagicMock()
mock_requestor.request.return_value = {"flagged": True}
mock_requestor_cls.return_value = mock_requestor
params = {"some": "params"}
result = api_moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params)
assert result == {"flagged": True}
mock_get_ext.assert_called_once_with("test-tenant-id", "test-extension-id")
mock_decrypt.assert_called_once_with("test-tenant-id", "encrypted-key")
mock_requestor_cls.assert_called_once_with("http://api.test", "decrypted-key")
mock_requestor.request.assert_called_once_with(APIBasedExtensionPoint.APP_MODERATION_INPUT, params)
def test_get_config_by_requestor_no_config(self):
moderation = ApiModeration("app-id", "tenant-id", None)
with pytest.raises(ValueError, match="The config is not set"):
moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, {})
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
def test_get_config_by_requestor_extension_not_found(self, mock_get_ext, api_moderation):
mock_get_ext.return_value = None
with pytest.raises(ValueError, match="API-based Extension not found"):
api_moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, {})
@patch("core.moderation.api.api.db.session.scalar")
def test_get_api_based_extension(self, mock_scalar):
mock_ext = MagicMock(spec=APIBasedExtension)
mock_scalar.return_value = mock_ext
result = ApiModeration._get_api_based_extension("tenant-1", "ext-1")
assert result == mock_ext
mock_scalar.assert_called_once()
# Verify the call has the correct filters
args, kwargs = mock_scalar.call_args
stmt = args[0]
# We can't easily inspect the statement without complex sqlalchemy tricks,
# but calling it is usually enough for unit tests if we mock the result.

View File

@@ -0,0 +1,207 @@
from unittest.mock import MagicMock, patch
import pytest
from core.app.app_config.entities import AppConfig, SensitiveWordAvoidanceEntity
from core.moderation.base import ModerationAction, ModerationError, ModerationInputsResult
from core.moderation.input_moderation import InputModeration
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager
class TestInputModeration:
@pytest.fixture
def app_config(self):
config = MagicMock(spec=AppConfig)
config.sensitive_word_avoidance = None
return config
@pytest.fixture
def input_moderation(self):
return InputModeration()
def test_check_no_sensitive_word_avoidance(self, app_config, input_moderation):
app_id = "test_app_id"
tenant_id = "test_tenant_id"
inputs = {"input_key": "input_value"}
query = "test query"
message_id = "test_message_id"
flagged, final_inputs, final_query = input_moderation.check(
app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id
)
assert flagged is False
assert final_inputs == inputs
assert final_query == query
@patch("core.moderation.input_moderation.ModerationFactory")
def test_check_not_flagged(self, mock_factory_cls, app_config, input_moderation):
app_id = "test_app_id"
tenant_id = "test_tenant_id"
inputs = {"input_key": "input_value"}
query = "test query"
message_id = "test_message_id"
# Setup config
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
sensitive_word_config.type = "keywords"
sensitive_word_config.config = {"keywords": ["bad"]}
app_config.sensitive_word_avoidance = sensitive_word_config
# Setup factory mock
mock_factory = mock_factory_cls.return_value
mock_result = ModerationInputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
mock_factory.moderation_for_inputs.return_value = mock_result
flagged, final_inputs, final_query = input_moderation.check(
app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id
)
assert flagged is False
assert final_inputs == inputs
assert final_query == query
mock_factory_cls.assert_called_once_with(
name="keywords", app_id=app_id, tenant_id=tenant_id, config={"keywords": ["bad"]}
)
mock_factory.moderation_for_inputs.assert_called_once_with(dict(inputs), query)
@patch("core.moderation.input_moderation.ModerationFactory")
@patch("core.moderation.input_moderation.TraceTask")
def test_check_with_trace_manager(self, mock_trace_task, mock_factory_cls, app_config, input_moderation):
app_id = "test_app_id"
tenant_id = "test_tenant_id"
inputs = {"input_key": "input_value"}
query = "test query"
message_id = "test_message_id"
trace_manager = MagicMock(spec=TraceQueueManager)
# Setup config
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
sensitive_word_config.type = "keywords"
sensitive_word_config.config = {}
app_config.sensitive_word_avoidance = sensitive_word_config
# Setup factory mock
mock_factory = mock_factory_cls.return_value
mock_result = ModerationInputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
mock_factory.moderation_for_inputs.return_value = mock_result
input_moderation.check(
app_id=app_id,
tenant_id=tenant_id,
app_config=app_config,
inputs=inputs,
query=query,
message_id=message_id,
trace_manager=trace_manager,
)
trace_manager.add_trace_task.assert_called_once_with(mock_trace_task.return_value)
mock_trace_task.assert_called_once()
call_kwargs = mock_trace_task.call_args.kwargs
call_args = mock_trace_task.call_args.args
assert call_args[0] == TraceTaskName.MODERATION_TRACE
assert call_kwargs["message_id"] == message_id
assert call_kwargs["moderation_result"] == mock_result
assert call_kwargs["inputs"] == inputs
assert "timer" in call_kwargs
@patch("core.moderation.input_moderation.ModerationFactory")
def test_check_flagged_direct_output(self, mock_factory_cls, app_config, input_moderation):
app_id = "test_app_id"
tenant_id = "test_tenant_id"
inputs = {"input_key": "input_value"}
query = "test query"
message_id = "test_message_id"
# Setup config
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
sensitive_word_config.type = "keywords"
sensitive_word_config.config = {}
app_config.sensitive_word_avoidance = sensitive_word_config
# Setup factory mock
mock_factory = mock_factory_cls.return_value
mock_result = ModerationInputsResult(
flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="Blocked content"
)
mock_factory.moderation_for_inputs.return_value = mock_result
with pytest.raises(ModerationError) as excinfo:
input_moderation.check(
app_id=app_id,
tenant_id=tenant_id,
app_config=app_config,
inputs=inputs,
query=query,
message_id=message_id,
)
assert str(excinfo.value) == "Blocked content"
@patch("core.moderation.input_moderation.ModerationFactory")
def test_check_flagged_overridden(self, mock_factory_cls, app_config, input_moderation):
app_id = "test_app_id"
tenant_id = "test_tenant_id"
inputs = {"input_key": "input_value"}
query = "test query"
message_id = "test_message_id"
# Setup config
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
sensitive_word_config.type = "keywords"
sensitive_word_config.config = {}
app_config.sensitive_word_avoidance = sensitive_word_config
# Setup factory mock
mock_factory = mock_factory_cls.return_value
mock_result = ModerationInputsResult(
flagged=True,
action=ModerationAction.OVERRIDDEN,
inputs={"input_key": "overridden_value"},
query="overridden query",
)
mock_factory.moderation_for_inputs.return_value = mock_result
flagged, final_inputs, final_query = input_moderation.check(
app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id
)
assert flagged is True
assert final_inputs == {"input_key": "overridden_value"}
assert final_query == "overridden query"
@patch("core.moderation.input_moderation.ModerationFactory")
def test_check_flagged_other_action(self, mock_factory_cls, app_config, input_moderation):
app_id = "test_app_id"
tenant_id = "test_tenant_id"
inputs = {"input_key": "input_value"}
query = "test query"
message_id = "test_message_id"
# Setup config
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
sensitive_word_config.type = "keywords"
sensitive_word_config.config = {}
app_config.sensitive_word_avoidance = sensitive_word_config
# Setup factory mock
mock_factory = mock_factory_cls.return_value
mock_result = MagicMock()
mock_result.flagged = True
mock_result.action = "NONE" # Some other action
mock_factory.moderation_for_inputs.return_value = mock_result
flagged, final_inputs, final_query = input_moderation.check(
app_id=app_id,
tenant_id=tenant_id,
app_config=app_config,
inputs=inputs,
query=query,
message_id=message_id,
)
assert flagged is True
assert final_inputs == inputs
assert final_query == query

View File

@@ -0,0 +1,234 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import QueueMessageReplaceEvent
from core.moderation.base import ModerationAction, ModerationOutputsResult
from core.moderation.output_moderation import ModerationRule, OutputModeration
class TestOutputModeration:
@pytest.fixture
def mock_queue_manager(self):
return MagicMock(spec=AppQueueManager)
@pytest.fixture
def moderation_rule(self):
return ModerationRule(type="keywords", config={"keywords": "badword"})
@pytest.fixture
def output_moderation(self, mock_queue_manager, moderation_rule):
return OutputModeration(
tenant_id="test_tenant", app_id="test_app", rule=moderation_rule, queue_manager=mock_queue_manager
)
def test_should_direct_output(self, output_moderation):
assert output_moderation.should_direct_output() is False
output_moderation.final_output = "blocked"
assert output_moderation.should_direct_output() is True
def test_get_final_output(self, output_moderation):
assert output_moderation.get_final_output() == ""
output_moderation.final_output = "blocked"
assert output_moderation.get_final_output() == "blocked"
def test_append_new_token(self, output_moderation):
with patch.object(OutputModeration, "start_thread") as mock_start:
output_moderation.append_new_token("hello")
assert output_moderation.buffer == "hello"
mock_start.assert_called_once()
output_moderation.thread = MagicMock()
output_moderation.append_new_token(" world")
assert output_moderation.buffer == "hello world"
assert mock_start.call_count == 1
def test_moderation_completion_no_flag(self, output_moderation):
with patch.object(OutputModeration, "moderation") as mock_moderation:
mock_moderation.return_value = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
output, flagged = output_moderation.moderation_completion("safe content")
assert output == "safe content"
assert flagged is False
assert output_moderation.is_final_chunk is True
def test_moderation_completion_flagged_direct_output(self, output_moderation, mock_queue_manager):
with patch.object(OutputModeration, "moderation") as mock_moderation:
mock_moderation.return_value = ModerationOutputsResult(
flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="preset"
)
output, flagged = output_moderation.moderation_completion("badword content", public_event=True)
assert output == "preset"
assert flagged is True
mock_queue_manager.publish.assert_called_once()
args, _ = mock_queue_manager.publish.call_args
assert isinstance(args[0], QueueMessageReplaceEvent)
assert args[0].text == "preset"
assert args[1] == PublishFrom.TASK_PIPELINE
def test_moderation_completion_flagged_overridden(self, output_moderation, mock_queue_manager):
with patch.object(OutputModeration, "moderation") as mock_moderation:
mock_moderation.return_value = ModerationOutputsResult(
flagged=True, action=ModerationAction.OVERRIDDEN, text="masked content"
)
output, flagged = output_moderation.moderation_completion("badword content", public_event=True)
assert output == "masked content"
assert flagged is True
mock_queue_manager.publish.assert_called_once()
args, _ = mock_queue_manager.publish.call_args
assert args[0].text == "masked content"
def test_start_thread(self, output_moderation):
mock_app = MagicMock(spec=Flask)
with patch("core.moderation.output_moderation.current_app") as mock_current_app:
mock_current_app._get_current_object.return_value = mock_app
with patch("threading.Thread") as mock_thread_class:
mock_thread_instance = MagicMock()
mock_thread_class.return_value = mock_thread_instance
thread = output_moderation.start_thread()
assert thread == mock_thread_instance
mock_thread_class.assert_called_once()
mock_thread_instance.start.assert_called_once()
def test_stop_thread(self, output_moderation):
mock_thread = MagicMock()
mock_thread.is_alive.return_value = True
output_moderation.thread = mock_thread
output_moderation.stop_thread()
assert output_moderation.thread_running is False
output_moderation.thread_running = True
mock_thread.is_alive.return_value = False
output_moderation.stop_thread()
assert output_moderation.thread_running is True
@patch("core.moderation.output_moderation.ModerationFactory")
def test_moderation_success(self, mock_factory_class, output_moderation):
mock_factory = mock_factory_class.return_value
mock_result = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
mock_factory.moderation_for_outputs.return_value = mock_result
result = output_moderation.moderation("tenant", "app", "buffer")
assert result == mock_result
mock_factory_class.assert_called_once_with(
name="keywords", app_id="app", tenant_id="tenant", config={"keywords": "badword"}
)
@patch("core.moderation.output_moderation.ModerationFactory")
def test_moderation_exception(self, mock_factory_class, output_moderation):
mock_factory_class.side_effect = Exception("error")
result = output_moderation.moderation("tenant", "app", "buffer")
assert result is None
def test_worker_loop_and_exit(self, output_moderation, mock_queue_manager):
mock_app = MagicMock(spec=Flask)
# Test exit on thread_running=False
output_moderation.thread_running = False
output_moderation.worker(mock_app, 10)
# Should exit immediately
def test_worker_no_flag(self, output_moderation):
mock_app = MagicMock(spec=Flask)
with patch.object(OutputModeration, "moderation") as mock_moderation:
mock_moderation.return_value = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
output_moderation.buffer = "safe"
output_moderation.is_final_chunk = True
# To avoid infinite loop, we'll set thread_running to False after one iteration
def side_effect(*args, **kwargs):
output_moderation.thread_running = False
return mock_moderation.return_value
mock_moderation.side_effect = side_effect
output_moderation.worker(mock_app, 10)
assert mock_moderation.called
def test_worker_flagged_direct_output(self, output_moderation, mock_queue_manager):
mock_app = MagicMock(spec=Flask)
with patch.object(OutputModeration, "moderation") as mock_moderation:
mock_moderation.return_value = ModerationOutputsResult(
flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="preset"
)
output_moderation.buffer = "badword"
output_moderation.is_final_chunk = True
output_moderation.worker(mock_app, 10)
assert output_moderation.final_output == "preset"
mock_queue_manager.publish.assert_called_once()
# It breaks on DIRECT_OUTPUT
def test_worker_flagged_overridden(self, output_moderation, mock_queue_manager):
mock_app = MagicMock(spec=Flask)
with patch.object(OutputModeration, "moderation") as mock_moderation:
# Use side_effect to change thread_running on second call
def side_effect(*args, **kwargs):
if mock_moderation.call_count > 1:
output_moderation.thread_running = False
return None
return ModerationOutputsResult(flagged=True, action=ModerationAction.OVERRIDDEN, text="masked")
mock_moderation.side_effect = side_effect
output_moderation.buffer = "badword"
output_moderation.is_final_chunk = True
output_moderation.worker(mock_app, 10)
mock_queue_manager.publish.assert_called_once()
args, _ = mock_queue_manager.publish.call_args
assert args[0].text == "masked"
def test_worker_chunk_too_small(self, output_moderation):
mock_app = MagicMock(spec=Flask)
with patch("time.sleep") as mock_sleep:
# chunk_length < buffer_size and not is_final_chunk
output_moderation.buffer = "123" # length 3
output_moderation.is_final_chunk = False
def sleep_side_effect(seconds):
output_moderation.thread_running = False
mock_sleep.side_effect = sleep_side_effect
output_moderation.worker(mock_app, 10) # buffer_size 10
mock_sleep.assert_called_once_with(1)
def test_worker_empty_not_flagged(self, output_moderation, mock_queue_manager):
mock_app = MagicMock(spec=Flask)
with patch.object(OutputModeration, "moderation") as mock_moderation:
# Return None (exception or no rule)
mock_moderation.return_value = None
def side_effect(*args, **kwargs):
output_moderation.thread_running = False
mock_moderation.side_effect = side_effect
output_moderation.buffer = "something"
output_moderation.is_final_chunk = True
output_moderation.worker(mock_app, 10)
mock_queue_manager.publish.assert_not_called()

View File

@@ -0,0 +1,160 @@
from unittest.mock import patch
import httpx
import pytest
from qdrant_client.http import models as rest
from qdrant_client.http.exceptions import UnexpectedResponse
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import (
TidbOnQdrantConfig,
TidbOnQdrantVector,
)
class TestTidbOnQdrantVectorDeleteByIds:
"""Unit tests for TidbOnQdrantVector.delete_by_ids method."""
@pytest.fixture
def vector_instance(self):
"""Create a TidbOnQdrantVector instance for testing."""
config = TidbOnQdrantConfig(
endpoint="http://localhost:6333",
api_key="test_api_key",
)
with patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector.qdrant_client.QdrantClient"):
vector = TidbOnQdrantVector(
collection_name="test_collection",
group_id="test_group",
config=config,
)
return vector
def test_delete_by_ids_with_multiple_ids(self, vector_instance):
"""Test batch deletion with multiple document IDs."""
ids = ["doc1", "doc2", "doc3"]
vector_instance.delete_by_ids(ids)
# Verify that delete was called once with MatchAny filter
vector_instance._client.delete.assert_called_once()
call_args = vector_instance._client.delete.call_args
# Check collection name
assert call_args[1]["collection_name"] == "test_collection"
# Verify filter uses MatchAny with all IDs
filter_selector = call_args[1]["points_selector"]
filter_obj = filter_selector.filter
assert len(filter_obj.must) == 1
field_condition = filter_obj.must[0]
assert field_condition.key == "metadata.doc_id"
assert isinstance(field_condition.match, rest.MatchAny)
assert set(field_condition.match.any) == {"doc1", "doc2", "doc3"}
def test_delete_by_ids_with_single_id(self, vector_instance):
"""Test deletion with a single document ID."""
ids = ["doc1"]
vector_instance.delete_by_ids(ids)
# Verify that delete was called once
vector_instance._client.delete.assert_called_once()
call_args = vector_instance._client.delete.call_args
# Verify filter uses MatchAny with single ID
filter_selector = call_args[1]["points_selector"]
filter_obj = filter_selector.filter
field_condition = filter_obj.must[0]
assert isinstance(field_condition.match, rest.MatchAny)
assert field_condition.match.any == ["doc1"]
def test_delete_by_ids_with_empty_list(self, vector_instance):
"""Test deletion with empty ID list returns early without API call."""
vector_instance.delete_by_ids([])
# Verify that delete was NOT called
vector_instance._client.delete.assert_not_called()
def test_delete_by_ids_with_404_error(self, vector_instance):
"""Test that 404 errors (collection not found) are handled gracefully."""
ids = ["doc1", "doc2"]
# Mock a 404 error
error = UnexpectedResponse(
status_code=404,
reason_phrase="Not Found",
content=b"Collection not found",
headers=httpx.Headers(),
)
vector_instance._client.delete.side_effect = error
# Should not raise an exception
vector_instance.delete_by_ids(ids)
# Verify delete was called
vector_instance._client.delete.assert_called_once()
def test_delete_by_ids_with_unexpected_error(self, vector_instance):
"""Test that non-404 errors are re-raised."""
ids = ["doc1", "doc2"]
# Mock a 500 error
error = UnexpectedResponse(
status_code=500,
reason_phrase="Internal Server Error",
content=b"Server error",
headers=httpx.Headers(),
)
vector_instance._client.delete.side_effect = error
# Should re-raise the exception
with pytest.raises(UnexpectedResponse) as exc_info:
vector_instance.delete_by_ids(ids)
assert exc_info.value.status_code == 500
def test_delete_by_ids_with_large_batch(self, vector_instance):
"""Test deletion with a large batch of IDs."""
# Create 1000 IDs
ids = [f"doc_{i}" for i in range(1000)]
vector_instance.delete_by_ids(ids)
# Verify single delete call with all IDs
vector_instance._client.delete.assert_called_once()
call_args = vector_instance._client.delete.call_args
filter_selector = call_args[1]["points_selector"]
filter_obj = filter_selector.filter
field_condition = filter_obj.must[0]
# Verify all 1000 IDs are in the batch
assert len(field_condition.match.any) == 1000
assert "doc_0" in field_condition.match.any
assert "doc_999" in field_condition.match.any
def test_delete_by_ids_filter_structure(self, vector_instance):
"""Test that the filter structure is correctly constructed."""
ids = ["doc1", "doc2"]
vector_instance.delete_by_ids(ids)
call_args = vector_instance._client.delete.call_args
filter_selector = call_args[1]["points_selector"]
filter_obj = filter_selector.filter
# Verify Filter structure
assert isinstance(filter_obj, rest.Filter)
assert filter_obj.must is not None
assert len(filter_obj.must) == 1
# Verify FieldCondition structure
field_condition = filter_obj.must[0]
assert isinstance(field_condition, rest.FieldCondition)
assert field_condition.key == "metadata.doc_id"
# Verify MatchAny structure
assert isinstance(field_condition.match, rest.MatchAny)
assert field_condition.match.any == ids

View File

@@ -0,0 +1,677 @@
from __future__ import annotations
import dataclasses
import json
from collections.abc import Sequence
from datetime import datetime, timedelta
from types import SimpleNamespace
from typing import Any
from unittest.mock import MagicMock
import pytest
from core.repositories.human_input_repository import (
HumanInputFormRecord,
HumanInputFormRepositoryImpl,
HumanInputFormSubmissionRepository,
_HumanInputFormEntityImpl,
_HumanInputFormRecipientEntityImpl,
_InvalidTimeoutStatusError,
_WorkspaceMemberInfo,
)
from dify_graph.nodes.human_input.entities import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
ExternalRecipient,
HumanInputNodeData,
MemberRecipient,
UserAction,
WebAppDeliveryMethod,
)
from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
from dify_graph.repositories.human_input_form_repository import FormCreateParams, FormNotFoundError
from libs.datetime_utils import naive_utc_now
from models.human_input import HumanInputFormRecipient, RecipientType
@pytest.fixture(autouse=True)
def _stub_select(monkeypatch: pytest.MonkeyPatch) -> None:
class _FakeSelect:
def join(self, *_args: Any, **_kwargs: Any) -> _FakeSelect:
return self
def where(self, *_args: Any, **_kwargs: Any) -> _FakeSelect:
return self
def options(self, *_args: Any, **_kwargs: Any) -> _FakeSelect:
return self
monkeypatch.setattr("core.repositories.human_input_repository.select", lambda *_args, **_kwargs: _FakeSelect())
monkeypatch.setattr("core.repositories.human_input_repository.selectinload", lambda *_args, **_kwargs: "_loader")
def _make_form_definition_json(*, include_expiration_time: bool) -> str:
payload: dict[str, Any] = {
"form_content": "hi",
"inputs": [],
"user_actions": [{"id": "submit", "title": "Submit"}],
"rendered_content": "<p>hi</p>",
}
if include_expiration_time:
payload["expiration_time"] = naive_utc_now()
return json.dumps(payload, default=str)
@dataclasses.dataclass
class _DummyForm:
id: str
workflow_run_id: str | None
node_id: str
tenant_id: str
app_id: str
form_definition: str
rendered_content: str
expiration_time: datetime
form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME
created_at: datetime = dataclasses.field(default_factory=naive_utc_now)
selected_action_id: str | None = None
submitted_data: str | None = None
submitted_at: datetime | None = None
submission_user_id: str | None = None
submission_end_user_id: str | None = None
completed_by_recipient_id: str | None = None
status: HumanInputFormStatus = HumanInputFormStatus.WAITING
@dataclasses.dataclass
class _DummyRecipient:
id: str
form_id: str
recipient_type: RecipientType
access_token: str | None
class _FakeScalarResult:
def __init__(self, obj: Any):
self._obj = obj
def first(self) -> Any:
if isinstance(self._obj, list):
return self._obj[0] if self._obj else None
return self._obj
def all(self) -> list[Any]:
if self._obj is None:
return []
if isinstance(self._obj, list):
return list(self._obj)
return [self._obj]
class _FakeExecuteResult:
def __init__(self, rows: Sequence[tuple[Any, ...]]):
self._rows = list(rows)
def all(self) -> list[tuple[Any, ...]]:
return list(self._rows)
class _FakeSession:
def __init__(
self,
*,
scalars_result: Any = None,
scalars_results: list[Any] | None = None,
forms: dict[str, _DummyForm] | None = None,
recipients: dict[str, _DummyRecipient] | None = None,
execute_rows: Sequence[tuple[Any, ...]] = (),
):
if scalars_results is not None:
self._scalars_queue = list(scalars_results)
else:
self._scalars_queue = [scalars_result]
self._forms = forms or {}
self._recipients = recipients or {}
self._execute_rows = list(execute_rows)
self.added: list[Any] = []
def scalars(self, _query: Any) -> _FakeScalarResult:
if self._scalars_queue:
value = self._scalars_queue.pop(0)
else:
value = None
return _FakeScalarResult(value)
def execute(self, _stmt: Any) -> _FakeExecuteResult:
return _FakeExecuteResult(self._execute_rows)
def get(self, model_cls: Any, obj_id: str) -> Any:
name = getattr(model_cls, "__name__", "")
if name == "HumanInputForm":
return self._forms.get(obj_id)
if name == "HumanInputFormRecipient":
return self._recipients.get(obj_id)
return None
def add(self, obj: Any) -> None:
self.added.append(obj)
def add_all(self, objs: Sequence[Any]) -> None:
self.added.extend(list(objs))
def flush(self) -> None:
# Simulate DB default population for attributes referenced in entity wrappers.
for obj in self.added:
if hasattr(obj, "id") and obj.id in (None, ""):
obj.id = f"gen-{len(str(self.added))}"
if isinstance(obj, HumanInputFormRecipient) and obj.access_token is None:
if obj.recipient_type == RecipientType.CONSOLE:
obj.access_token = "token-console"
elif obj.recipient_type == RecipientType.BACKSTAGE:
obj.access_token = "token-backstage"
else:
obj.access_token = "token-webapp"
def refresh(self, _obj: Any) -> None:
return None
def begin(self) -> _FakeSession:
return self
def __enter__(self) -> _FakeSession:
return self
def __exit__(self, exc_type, exc, tb) -> None:
return None
class _SessionFactoryStub:
def __init__(self, session: _FakeSession):
self._session = session
def create_session(self) -> _FakeSession:
return self._session
def _patch_session_factory(monkeypatch: pytest.MonkeyPatch, session: _FakeSession) -> None:
monkeypatch.setattr("core.repositories.human_input_repository.session_factory", _SessionFactoryStub(session))
def test_recipient_entity_token_raises_when_missing() -> None:
recipient = SimpleNamespace(id="r1", access_token=None)
entity = _HumanInputFormRecipientEntityImpl(recipient) # type: ignore[arg-type]
with pytest.raises(AssertionError, match="access_token should not be None"):
_ = entity.token
def test_recipient_entity_id_and_token_success() -> None:
recipient = SimpleNamespace(id="r1", access_token="tok")
entity = _HumanInputFormRecipientEntityImpl(recipient) # type: ignore[arg-type]
assert entity.id == "r1"
assert entity.token == "tok"
def test_form_entity_web_app_token_prefers_console_then_webapp_then_none() -> None:
form = _DummyForm(
id="f1",
workflow_run_id="run",
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
)
console = _DummyRecipient(id="c1", form_id=form.id, recipient_type=RecipientType.CONSOLE, access_token="ctok")
webapp = _DummyRecipient(
id="w1", form_id=form.id, recipient_type=RecipientType.STANDALONE_WEB_APP, access_token="wtok"
)
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp, console]) # type: ignore[arg-type]
assert entity.web_app_token == "ctok"
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp]) # type: ignore[arg-type]
assert entity.web_app_token == "wtok"
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type]
assert entity.web_app_token is None
def test_form_entity_submitted_data_parsed() -> None:
form = _DummyForm(
id="f1",
workflow_run_id="run",
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
submitted_data='{"a": 1}',
submitted_at=naive_utc_now(),
)
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type]
assert entity.submitted is True
assert entity.submitted_data == {"a": 1}
assert entity.rendered_content == "<p>x</p>"
assert entity.selected_action_id is None
assert entity.status == HumanInputFormStatus.WAITING
def test_form_record_from_models_injects_expiration_time_when_missing() -> None:
expiration = naive_utc_now()
form = _DummyForm(
id="f1",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=False),
rendered_content="<p>x</p>",
expiration_time=expiration,
submitted_data='{"k": "v"}',
)
record = HumanInputFormRecord.from_models(form, None) # type: ignore[arg-type]
assert record.definition.expiration_time == expiration
assert record.submitted_data == {"k": "v"}
assert record.submitted is False
def test_create_email_recipients_from_resolved_dedupes_and_skips_blank(monkeypatch: pytest.MonkeyPatch) -> None:
created: list[SimpleNamespace] = []
def fake_new(cls, form_id: str, delivery_id: str, payload: Any): # type: ignore[no-untyped-def]
recipient = SimpleNamespace(
id=f"{payload.TYPE}-{len(created)}",
form_id=form_id,
delivery_id=delivery_id,
recipient_type=payload.TYPE,
recipient_payload=payload.model_dump_json(),
access_token="tok",
)
created.append(recipient)
return recipient
monkeypatch.setattr("core.repositories.human_input_repository.HumanInputFormRecipient.new", classmethod(fake_new))
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
recipients = repo._create_email_recipients_from_resolved( # type: ignore[attr-defined]
form_id="f",
delivery_id="d",
members=[
_WorkspaceMemberInfo(user_id="u1", email=""),
_WorkspaceMemberInfo(user_id="u2", email="a@example.com"),
_WorkspaceMemberInfo(user_id="u3", email="a@example.com"),
],
external_emails=["", "a@example.com", "b@example.com", "b@example.com"],
)
assert [r.recipient_type for r in recipients] == [RecipientType.EMAIL_MEMBER, RecipientType.EMAIL_EXTERNAL]
def test_query_workspace_members_by_ids_empty_returns_empty() -> None:
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
assert repo._query_workspace_members_by_ids(session=MagicMock(), restrict_to_user_ids=["", ""]) == []
def test_query_workspace_members_by_ids_maps_rows() -> None:
session = _FakeSession(execute_rows=[("u1", "a@example.com"), ("u2", "b@example.com")])
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
rows = repo._query_workspace_members_by_ids(session=session, restrict_to_user_ids=["u1", "u2"])
assert rows == [
_WorkspaceMemberInfo(user_id="u1", email="a@example.com"),
_WorkspaceMemberInfo(user_id="u2", email="b@example.com"),
]
def test_query_all_workspace_members_maps_rows() -> None:
session = _FakeSession(execute_rows=[("u1", "a@example.com")])
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
rows = repo._query_all_workspace_members(session=session)
assert rows == [_WorkspaceMemberInfo(user_id="u1", email="a@example.com")]
def test_repository_init_sets_tenant_id() -> None:
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
assert repo._tenant_id == "tenant"
def test_delivery_method_to_model_webapp_creates_delivery_and_recipient(monkeypatch: pytest.MonkeyPatch) -> None:
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: "del-1")
result = repo._delivery_method_to_model(
session=MagicMock(), form_id="form-1", delivery_method=WebAppDeliveryMethod()
)
assert result.delivery.id == "del-1"
assert result.delivery.form_id == "form-1"
assert len(result.recipients) == 1
assert result.recipients[0].recipient_type == RecipientType.STANDALONE_WEB_APP
def test_delivery_method_to_model_email_uses_build_email_recipients(monkeypatch: pytest.MonkeyPatch) -> None:
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: "del-1")
called: dict[str, Any] = {}
def fake_build(*, session: Any, form_id: str, delivery_id: str, recipients_config: Any) -> list[Any]:
called.update(
{"session": session, "form_id": form_id, "delivery_id": delivery_id, "recipients_config": recipients_config}
)
return ["r"]
monkeypatch.setattr(repo, "_build_email_recipients", fake_build)
method = EmailDeliveryMethod(
config=EmailDeliveryConfig(
recipients=EmailRecipients(
whole_workspace=False,
items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")],
),
subject="s",
body="b",
)
)
result = repo._delivery_method_to_model(session="sess", form_id="form-1", delivery_method=method)
assert result.recipients == ["r"]
assert called["delivery_id"] == "del-1"
def test_build_email_recipients_uses_all_members_when_whole_workspace(monkeypatch: pytest.MonkeyPatch) -> None:
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
monkeypatch.setattr(
repo,
"_query_all_workspace_members",
lambda *, session: [_WorkspaceMemberInfo(user_id="u", email="a@example.com")],
)
monkeypatch.setattr(repo, "_create_email_recipients_from_resolved", lambda **_: ["ok"])
recipients = repo._build_email_recipients(
session=MagicMock(),
form_id="f",
delivery_id="d",
recipients_config=EmailRecipients(whole_workspace=True, items=[ExternalRecipient(email="e@example.com")]),
)
assert recipients == ["ok"]
def test_build_email_recipients_uses_selected_members_when_not_whole_workspace(monkeypatch: pytest.MonkeyPatch) -> None:
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
def fake_query(*, session: Any, restrict_to_user_ids: Sequence[str]) -> list[_WorkspaceMemberInfo]:
assert restrict_to_user_ids == ["u1"]
return [_WorkspaceMemberInfo(user_id="u1", email="a@example.com")]
monkeypatch.setattr(repo, "_query_workspace_members_by_ids", fake_query)
monkeypatch.setattr(repo, "_create_email_recipients_from_resolved", lambda **_: ["ok"])
recipients = repo._build_email_recipients(
session=MagicMock(),
form_id="f",
delivery_id="d",
recipients_config=EmailRecipients(
whole_workspace=False,
items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")],
),
)
assert recipients == ["ok"]
def test_get_form_returns_entity_and_none_when_missing(monkeypatch: pytest.MonkeyPatch) -> None:
_patch_session_factory(monkeypatch, _FakeSession(scalars_results=[None]))
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
assert repo.get_form("run", "node") is None
form = _DummyForm(
id="f1",
workflow_run_id="run",
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
)
recipient = _DummyRecipient(
id="r1",
form_id=form.id,
recipient_type=RecipientType.STANDALONE_WEB_APP,
access_token="tok",
)
session = _FakeSession(scalars_results=[form, [recipient]])
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
entity = repo.get_form("run", "node")
assert entity is not None
assert entity.id == "f1"
assert entity.recipients[0].id == "r1"
assert entity.recipients[0].token == "tok"
def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.MonkeyPatch) -> None:
fixed_now = datetime(2024, 1, 1, 0, 0, 0)
monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now)
ids = iter(["form-id", "del-web", "del-console", "del-backstage"])
monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: next(ids))
session = _FakeSession()
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
form_config = HumanInputNodeData(
title="Title",
delivery_methods=[],
form_content="hello",
inputs=[],
user_actions=[UserAction(id="submit", title="Submit")],
)
params = FormCreateParams(
app_id="app",
workflow_execution_id="run",
node_id="node",
form_config=form_config,
rendered_content="<p>hello</p>",
delivery_methods=[WebAppDeliveryMethod()],
display_in_ui=True,
resolved_default_values={},
form_kind=HumanInputFormKind.RUNTIME,
console_recipient_required=True,
console_creator_account_id="acc-1",
backstage_recipient_required=True,
)
entity = repo.create_form(params)
assert entity.id == "form-id"
assert entity.expiration_time == fixed_now + timedelta(hours=form_config.timeout)
# Console token should take precedence when console recipient is present.
assert entity.web_app_token == "token-console"
assert len(entity.recipients) == 3
def test_submission_get_by_token_returns_none_when_missing_or_form_missing(monkeypatch: pytest.MonkeyPatch) -> None:
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=None))
repo = HumanInputFormSubmissionRepository()
assert repo.get_by_token("tok") is None
recipient = SimpleNamespace(form=None)
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient))
repo = HumanInputFormSubmissionRepository()
assert repo.get_by_token("tok") is None
def test_submission_repository_init_no_args() -> None:
repo = HumanInputFormSubmissionRepository()
assert isinstance(repo, HumanInputFormSubmissionRepository)
def test_submission_get_by_token_and_get_by_form_id_success_paths(monkeypatch: pytest.MonkeyPatch) -> None:
form = _DummyForm(
id="f1",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
)
recipient = SimpleNamespace(
id="r1",
form_id=form.id,
recipient_type=RecipientType.STANDALONE_WEB_APP,
access_token="tok",
form=form,
)
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient))
repo = HumanInputFormSubmissionRepository()
record = repo.get_by_token("tok")
assert record is not None
assert record.access_token == "tok"
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient))
repo = HumanInputFormSubmissionRepository()
record = repo.get_by_form_id_and_recipient_type(form_id=form.id, recipient_type=RecipientType.STANDALONE_WEB_APP)
assert record is not None
assert record.recipient_id == "r1"
def test_submission_get_by_form_id_returns_none_on_missing(monkeypatch: pytest.MonkeyPatch) -> None:
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=None))
repo = HumanInputFormSubmissionRepository()
assert repo.get_by_form_id_and_recipient_type(form_id="f", recipient_type=RecipientType.CONSOLE) is None
def test_mark_submitted_updates_and_raises_when_missing(monkeypatch: pytest.MonkeyPatch) -> None:
fixed_now = datetime(2024, 1, 1, 0, 0, 0)
monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now)
missing_session = _FakeSession(forms={})
_patch_session_factory(monkeypatch, missing_session)
repo = HumanInputFormSubmissionRepository()
with pytest.raises(FormNotFoundError, match="form not found"):
repo.mark_submitted(
form_id="missing",
recipient_id=None,
selected_action_id="a",
form_data={},
submission_user_id=None,
submission_end_user_id=None,
)
form = _DummyForm(
id="f",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=fixed_now,
)
recipient = _DummyRecipient(id="r", form_id=form.id, recipient_type=RecipientType.CONSOLE, access_token="tok")
session = _FakeSession(forms={form.id: form}, recipients={recipient.id: recipient})
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormSubmissionRepository()
record = repo.mark_submitted(
form_id=form.id,
recipient_id=recipient.id,
selected_action_id="approve",
form_data={"k": "v"},
submission_user_id="u",
submission_end_user_id="eu",
)
assert form.status == HumanInputFormStatus.SUBMITTED
assert form.submitted_at == fixed_now
assert record.submitted_data == {"k": "v"}
def test_mark_timeout_invalid_status_raises(monkeypatch: pytest.MonkeyPatch) -> None:
form = _DummyForm(
id="f",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
)
session = _FakeSession(forms={form.id: form})
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormSubmissionRepository()
with pytest.raises(_InvalidTimeoutStatusError, match="invalid timeout status"):
repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.SUBMITTED) # type: ignore[arg-type]
def test_mark_timeout_already_timed_out_returns_record(monkeypatch: pytest.MonkeyPatch) -> None:
form = _DummyForm(
id="f",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
status=HumanInputFormStatus.TIMEOUT,
)
session = _FakeSession(forms={form.id: form})
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormSubmissionRepository()
record = repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.TIMEOUT, reason="r")
assert record.status == HumanInputFormStatus.TIMEOUT
def test_mark_timeout_submitted_raises_form_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
form = _DummyForm(
id="f",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
status=HumanInputFormStatus.SUBMITTED,
)
session = _FakeSession(forms={form.id: form})
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormSubmissionRepository()
with pytest.raises(FormNotFoundError, match="form already submitted"):
repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.EXPIRED)
def test_mark_timeout_updates_fields(monkeypatch: pytest.MonkeyPatch) -> None:
form = _DummyForm(
id="f",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
selected_action_id="a",
submitted_data="{}",
submission_user_id="u",
submission_end_user_id="eu",
completed_by_recipient_id="r",
status=HumanInputFormStatus.WAITING,
)
session = _FakeSession(forms={form.id: form})
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormSubmissionRepository()
record = repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.EXPIRED)
assert form.status == HumanInputFormStatus.EXPIRED
assert form.selected_action_id is None
assert form.submitted_data is None
assert form.submission_user_id is None
assert form.submission_end_user_id is None
assert form.completed_by_recipient_id is None
assert record.status == HumanInputFormStatus.EXPIRED
def test_mark_timeout_raises_when_form_missing(monkeypatch: pytest.MonkeyPatch) -> None:
_patch_session_factory(monkeypatch, _FakeSession(forms={}))
repo = HumanInputFormSubmissionRepository()
with pytest.raises(FormNotFoundError, match="form not found"):
repo.mark_timeout(form_id="missing", timeout_status=HumanInputFormStatus.TIMEOUT)

View File

@@ -1,84 +1,291 @@
from datetime import datetime
from datetime import UTC, datetime
from unittest.mock import MagicMock
from uuid import uuid4
from sqlalchemy import create_engine
import pytest
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowType
from models import Account, WorkflowRun
from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
from models import Account, CreatorUserRole, EndUser, WorkflowRun
from models.enums import WorkflowRunTriggeredFrom
def _build_repository_with_mocked_session(session: MagicMock) -> SQLAlchemyWorkflowExecutionRepository:
engine = create_engine("sqlite:///:memory:")
real_session_factory = sessionmaker(bind=engine, expire_on_commit=False)
user = MagicMock(spec=Account)
user.id = str(uuid4())
user.current_tenant_id = str(uuid4())
repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=real_session_factory,
user=user,
app_id="app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
session_context = MagicMock()
session_context.__enter__.return_value = session
session_context.__exit__.return_value = False
repository._session_factory = MagicMock(return_value=session_context)
return repository
def _build_execution(*, execution_id: str, started_at: datetime) -> WorkflowExecution:
return WorkflowExecution.new(
id_=execution_id,
workflow_id="workflow-id",
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0.0",
graph={"nodes": [], "edges": []},
inputs={"query": "hello"},
started_at=started_at,
)
def test_save_uses_execution_started_at_when_record_does_not_exist():
@pytest.fixture
def mock_session_factory():
"""Mock SQLAlchemy session factory."""
session_factory = MagicMock(spec=sessionmaker)
session = MagicMock()
session.get.return_value = None
repository = _build_repository_with_mocked_session(session)
started_at = datetime(2026, 1, 1, 12, 0, 0)
execution = _build_execution(execution_id=str(uuid4()), started_at=started_at)
repository.save(execution)
saved_model = session.merge.call_args.args[0]
assert saved_model.created_at == started_at
session.commit.assert_called_once()
session_factory.return_value.__enter__.return_value = session
return session_factory
def test_save_preserves_existing_created_at_when_record_already_exists():
session = MagicMock()
repository = _build_repository_with_mocked_session(session)
@pytest.fixture
def mock_engine():
"""Mock SQLAlchemy Engine."""
return MagicMock(spec=Engine)
execution_id = str(uuid4())
existing_created_at = datetime(2026, 1, 1, 12, 0, 0)
existing_run = WorkflowRun()
existing_run.id = execution_id
existing_run.tenant_id = repository._tenant_id
existing_run.created_at = existing_created_at
session.get.return_value = existing_run
execution = _build_execution(
execution_id=execution_id,
started_at=datetime(2026, 1, 1, 12, 30, 0),
@pytest.fixture
def mock_account():
"""Mock Account user."""
account = MagicMock(spec=Account)
account.id = str(uuid4())
account.current_tenant_id = str(uuid4())
return account
@pytest.fixture
def mock_end_user():
"""Mock EndUser."""
user = MagicMock(spec=EndUser)
user.id = str(uuid4())
user.tenant_id = str(uuid4())
return user
@pytest.fixture
def sample_workflow_execution():
"""Sample WorkflowExecution for testing."""
return WorkflowExecution(
id_=str(uuid4()),
workflow_id=str(uuid4()),
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input1": "value1"},
outputs={"output1": "result1"},
status=WorkflowExecutionStatus.SUCCEEDED,
error_message="",
total_tokens=100,
total_steps=5,
exceptions_count=0,
started_at=datetime.now(UTC),
finished_at=datetime.now(UTC),
)
repository.save(execution)
saved_model = session.merge.call_args.args[0]
assert saved_model.created_at == existing_created_at
session.commit.assert_called_once()
class TestSQLAlchemyWorkflowExecutionRepository:
def test_init_with_sessionmaker(self, mock_session_factory, mock_account):
app_id = "test_app_id"
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory, user=mock_account, app_id=app_id, triggered_from=triggered_from
)
assert repo._session_factory == mock_session_factory
assert repo._tenant_id == mock_account.current_tenant_id
assert repo._app_id == app_id
assert repo._triggered_from == triggered_from
assert repo._creator_user_id == mock_account.id
assert repo._creator_user_role == CreatorUserRole.ACCOUNT
def test_init_with_engine(self, mock_engine, mock_account):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_engine,
user=mock_account,
app_id="test_app_id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert isinstance(repo._session_factory, sessionmaker)
assert repo._session_factory.kw["bind"] == mock_engine
def test_init_invalid_session_factory(self, mock_account):
with pytest.raises(ValueError, match="Invalid session_factory type"):
SQLAlchemyWorkflowExecutionRepository(
session_factory="invalid", user=mock_account, app_id=None, triggered_from=None
)
def test_init_no_tenant_id(self, mock_session_factory):
user = MagicMock(spec=Account)
user.current_tenant_id = None
with pytest.raises(ValueError, match="User must have a tenant_id"):
SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory, user=user, app_id=None, triggered_from=None
)
def test_init_with_end_user(self, mock_session_factory, mock_end_user):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory, user=mock_end_user, app_id=None, triggered_from=None
)
assert repo._tenant_id == mock_end_user.tenant_id
assert repo._creator_user_role == CreatorUserRole.END_USER
def test_to_domain_model(self, mock_session_factory, mock_account):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory, user=mock_account, app_id=None, triggered_from=None
)
db_model = MagicMock(spec=WorkflowRun)
db_model.id = str(uuid4())
db_model.workflow_id = str(uuid4())
db_model.type = "workflow"
db_model.version = "1.0"
db_model.inputs_dict = {"in": "val"}
db_model.outputs_dict = {"out": "val"}
db_model.graph_dict = {"nodes": []}
db_model.status = "succeeded"
db_model.error = "some error"
db_model.total_tokens = 50
db_model.total_steps = 3
db_model.exceptions_count = 1
db_model.created_at = datetime.now(UTC)
db_model.finished_at = datetime.now(UTC)
domain_model = repo._to_domain_model(db_model)
assert domain_model.id_ == db_model.id
assert domain_model.workflow_id == db_model.workflow_id
assert domain_model.status == WorkflowExecutionStatus.SUCCEEDED
assert domain_model.inputs == db_model.inputs_dict
assert domain_model.error_message == "some error"
def test_to_db_model(self, mock_session_factory, mock_account, sample_workflow_execution):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test_app",
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Make elapsed time deterministic to avoid flaky tests
sample_workflow_execution.started_at = datetime(2023, 1, 1, 0, 0, 0, tzinfo=UTC)
sample_workflow_execution.finished_at = datetime(2023, 1, 1, 0, 0, 10, tzinfo=UTC)
db_model = repo._to_db_model(sample_workflow_execution)
assert db_model.id == sample_workflow_execution.id_
assert db_model.tenant_id == repo._tenant_id
assert db_model.app_id == "test_app"
assert db_model.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING
assert db_model.status == sample_workflow_execution.status.value
assert db_model.total_tokens == sample_workflow_execution.total_tokens
assert db_model.elapsed_time == 10.0
def test_to_db_model_edge_cases(self, mock_session_factory, mock_account, sample_workflow_execution):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test_app",
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Test with empty/None fields
sample_workflow_execution.graph = None
sample_workflow_execution.inputs = None
sample_workflow_execution.outputs = None
sample_workflow_execution.error_message = None
sample_workflow_execution.finished_at = None
db_model = repo._to_db_model(sample_workflow_execution)
assert db_model.graph is None
assert db_model.inputs is None
assert db_model.outputs is None
assert db_model.error is None
assert db_model.elapsed_time == 0
def test_to_db_model_app_id_none(self, mock_session_factory, mock_account, sample_workflow_execution):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id=None,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
db_model = repo._to_db_model(sample_workflow_execution)
assert not hasattr(db_model, "app_id") or db_model.app_id is None
assert db_model.tenant_id == repo._tenant_id
def test_to_db_model_missing_context(self, mock_session_factory, mock_account, sample_workflow_execution):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory, user=mock_account, app_id=None, triggered_from=None
)
# Test triggered_from missing
with pytest.raises(ValueError, match="triggered_from is required"):
repo._to_db_model(sample_workflow_execution)
repo._triggered_from = WorkflowRunTriggeredFrom.APP_RUN
repo._creator_user_id = None
with pytest.raises(ValueError, match="created_by is required"):
repo._to_db_model(sample_workflow_execution)
repo._creator_user_id = "some_id"
repo._creator_user_role = None
with pytest.raises(ValueError, match="created_by_role is required"):
repo._to_db_model(sample_workflow_execution)
def test_save(self, mock_session_factory, mock_account, sample_workflow_execution):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test_app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
repo.save(sample_workflow_execution)
session = mock_session_factory.return_value.__enter__.return_value
session.merge.assert_called_once()
session.commit.assert_called_once()
# Check cache
assert sample_workflow_execution.id_ in repo._execution_cache
cached_model = repo._execution_cache[sample_workflow_execution.id_]
assert cached_model.id == sample_workflow_execution.id_
def test_save_uses_execution_started_at_when_record_does_not_exist(
self, mock_session_factory, mock_account, sample_workflow_execution
):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test_app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
started_at = datetime(2026, 1, 1, 12, 0, 0, tzinfo=UTC)
sample_workflow_execution.started_at = started_at
session = mock_session_factory.return_value.__enter__.return_value
session.get.return_value = None
repo.save(sample_workflow_execution)
saved_model = session.merge.call_args.args[0]
assert saved_model.created_at == started_at
session.commit.assert_called_once()
def test_save_preserves_existing_created_at_when_record_already_exists(
self, mock_session_factory, mock_account, sample_workflow_execution
):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test_app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
execution_id = sample_workflow_execution.id_
existing_created_at = datetime(2026, 1, 1, 12, 0, 0, tzinfo=UTC)
existing_run = WorkflowRun()
existing_run.id = execution_id
existing_run.tenant_id = repo._tenant_id
existing_run.created_at = existing_created_at
session = mock_session_factory.return_value.__enter__.return_value
session.get.return_value = existing_run
sample_workflow_execution.started_at = datetime(2026, 1, 1, 12, 30, 0, tzinfo=UTC)
repo.save(sample_workflow_execution)
saved_model = session.merge.call_args.args[0]
assert saved_model.created_at == existing_created_at
session.commit.assert_called_once()

View File

@@ -0,0 +1,772 @@
from __future__ import annotations
import json
import logging
from collections.abc import Mapping
from datetime import UTC, datetime
from types import SimpleNamespace
from typing import Any
from unittest.mock import MagicMock, Mock
import psycopg2.errors
import pytest
from sqlalchemy import Engine, create_engine
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from core.repositories.sqlalchemy_workflow_node_execution_repository import (
SQLAlchemyWorkflowNodeExecutionRepository,
_deterministic_json_dump,
_filter_by_offload_type,
_find_first,
_replace_or_append_offload,
)
from dify_graph.entities import WorkflowNodeExecution
from dify_graph.enums import (
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from dify_graph.repositories.workflow_node_execution_repository import OrderConfig
from models import Account, EndUser
from models.enums import ExecutionOffLoadType
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom
def _mock_account(*, tenant_id: str = "tenant", user_id: str = "user") -> Account:
user = Mock(spec=Account)
user.id = user_id
user.current_tenant_id = tenant_id
return user
def _mock_end_user(*, tenant_id: str = "tenant", user_id: str = "user") -> EndUser:
user = Mock(spec=EndUser)
user.id = user_id
user.tenant_id = tenant_id
return user
def _execution(
*,
execution_id: str = "exec-id",
node_execution_id: str = "node-exec-id",
workflow_run_id: str = "run-id",
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.SUCCEEDED,
inputs: Mapping[str, Any] | None = None,
outputs: Mapping[str, Any] | None = None,
process_data: Mapping[str, Any] | None = None,
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None,
) -> WorkflowNodeExecution:
return WorkflowNodeExecution(
id=execution_id,
node_execution_id=node_execution_id,
workflow_id="workflow-id",
workflow_execution_id=workflow_run_id,
index=1,
predecessor_node_id=None,
node_id="node-id",
node_type=NodeType.LLM,
title="Title",
inputs=inputs,
outputs=outputs,
process_data=process_data,
status=status,
error=None,
elapsed_time=1.0,
metadata=metadata,
created_at=datetime.now(UTC),
finished_at=None,
)
class _SessionCtx:
def __init__(self, session: Any):
self._session = session
def __enter__(self) -> Any:
return self._session
def __exit__(self, exc_type, exc, tb) -> None:
return None
def _session_factory(session: Any) -> sessionmaker:
factory = Mock(spec=sessionmaker)
factory.return_value = _SessionCtx(session)
return factory
def test_init_accepts_engine_and_sessionmaker_and_sets_role(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
engine: Engine = create_engine("sqlite:///:memory:")
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=engine,
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
assert isinstance(repo._session_factory, sessionmaker)
sm = Mock(spec=sessionmaker)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=sm,
user=_mock_end_user(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
assert repo._creator_user_role.value == "end_user"
def test_init_rejects_invalid_session_factory_type(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
with pytest.raises(ValueError, match="Invalid session_factory type"):
SQLAlchemyWorkflowNodeExecutionRepository( # type: ignore[arg-type]
session_factory=object(),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
def test_init_requires_tenant_id(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
user = _mock_account()
user.current_tenant_id = None
with pytest.raises(ValueError, match="User must have a tenant_id"):
SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=user,
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
def test_create_truncator_uses_config(monkeypatch: pytest.MonkeyPatch) -> None:
created: dict[str, Any] = {}
class FakeTruncator:
def __init__(self, *, max_size_bytes: int, array_element_limit: int, string_length_limit: int):
created.update(
{
"max_size_bytes": max_size_bytes,
"array_element_limit": array_element_limit,
"string_length_limit": string_length_limit,
}
)
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.VariableTruncator",
FakeTruncator,
)
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
_ = repo._create_truncator()
assert created["max_size_bytes"] == dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE
def test_helpers_find_first_and_replace_or_append_and_filter() -> None:
assert _deterministic_json_dump({"b": 1, "a": 2}) == '{"a": 2, "b": 1}'
assert _find_first([], lambda _: True) is None
assert _find_first([1, 2, 3], lambda x: x > 1) == 2
off1 = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)
off2 = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS)
assert _find_first([off1, off2], _filter_by_offload_type(ExecutionOffLoadType.OUTPUTS)) is off2
replaced = _replace_or_append_offload([off1, off2], WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS))
assert len(replaced) == 2
assert [o.type_ for o in replaced] == [ExecutionOffLoadType.OUTPUTS, ExecutionOffLoadType.INPUTS]
def test_to_db_model_requires_constructor_context(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
execution = _execution(inputs={"b": 1, "a": 2}, metadata={WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 1})
# Happy path: deterministic json dump should be sorted
db_model = repo._to_db_model(execution)
assert json.loads(db_model.inputs or "{}") == {"a": 2, "b": 1}
assert json.loads(db_model.execution_metadata or "{}")["total_tokens"] == 1
repo._triggered_from = None
with pytest.raises(ValueError, match="triggered_from is required"):
repo._to_db_model(execution)
def test_to_db_model_requires_creator_user_id_and_role(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
execution = _execution()
db_model = repo._to_db_model(execution)
assert db_model.app_id == "app"
repo._creator_user_id = None
with pytest.raises(ValueError, match="created_by is required"):
repo._to_db_model(execution)
repo._creator_user_id = "user"
repo._creator_user_role = None
with pytest.raises(ValueError, match="created_by_role is required"):
repo._to_db_model(execution)
def test_is_duplicate_key_error_and_regenerate_id(
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
unique = Mock(spec=psycopg2.errors.UniqueViolation)
duplicate_error = IntegrityError("dup", params=None, orig=unique)
assert repo._is_duplicate_key_error(duplicate_error) is True
assert repo._is_duplicate_key_error(IntegrityError("other", params=None, orig=None)) is False
execution = _execution(execution_id="old-id")
db_model = WorkflowNodeExecutionModel()
db_model.id = "old-id"
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "new-id")
caplog.set_level(logging.WARNING)
repo._regenerate_id_on_duplicate(execution, db_model)
assert execution.id == "new-id"
assert db_model.id == "new-id"
assert any("Duplicate key conflict" in r.message for r in caplog.records)
def test_persist_to_database_updates_existing_and_inserts_new(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
session = MagicMock()
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=_session_factory(session),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
db_model = WorkflowNodeExecutionModel()
db_model.id = "id1"
db_model.node_execution_id = "node1"
db_model.foo = "bar" # type: ignore[attr-defined]
db_model.__dict__["_private"] = "x"
existing = SimpleNamespace()
session.get.return_value = existing
repo._persist_to_database(db_model)
assert existing.foo == "bar"
session.add.assert_not_called()
assert repo._node_execution_cache["node1"] is db_model
session.reset_mock()
session.get.return_value = None
repo._node_execution_cache.clear()
repo._persist_to_database(db_model)
session.add.assert_called_once_with(db_model)
assert repo._node_execution_cache["node1"] is db_model
def test_truncate_and_upload_returns_none_when_no_values_or_not_truncated(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
assert repo._truncate_and_upload(None, "e", ExecutionOffLoadType.INPUTS) is None
class FakeTruncator:
def truncate_variable_mapping(self, value: Any): # type: ignore[no-untyped-def]
return value, False
monkeypatch.setattr(repo, "_create_truncator", lambda: FakeTruncator())
assert repo._truncate_and_upload({"a": 1}, "e", ExecutionOffLoadType.INPUTS) is None
def test_truncate_and_upload_uploads_and_builds_offload(monkeypatch: pytest.MonkeyPatch) -> None:
uploaded: dict[str, Any] = {}
class FakeFileService:
def upload_file(self, *, filename: str, content: bytes, mimetype: str, user: Any): # type: ignore[no-untyped-def]
uploaded.update({"filename": filename, "content": content, "mimetype": mimetype, "user": user})
return SimpleNamespace(id="file-id", key="file-key")
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", lambda *_: FakeFileService()
)
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "offload-id")
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
class FakeTruncator:
def truncate_variable_mapping(self, value: Any): # type: ignore[no-untyped-def]
return {"truncated": True}, True
monkeypatch.setattr(repo, "_create_truncator", lambda: FakeTruncator())
result = repo._truncate_and_upload({"a": 1}, "exec", ExecutionOffLoadType.INPUTS)
assert result is not None
assert result.truncated_value == {"truncated": True}
assert uploaded["filename"].startswith("node_execution_exec_inputs.json")
assert result.offload.file_id == "file-id"
assert result.offload.type_ == ExecutionOffLoadType.INPUTS
def test_to_domain_model_loads_offloaded_files(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
db_model = WorkflowNodeExecutionModel()
db_model.id = "id"
db_model.node_execution_id = "node-exec"
db_model.workflow_id = "wf"
db_model.workflow_run_id = "run"
db_model.index = 1
db_model.predecessor_node_id = None
db_model.node_id = "node"
db_model.node_type = NodeType.LLM
db_model.title = "t"
db_model.inputs = json.dumps({"trunc": "i"})
db_model.process_data = json.dumps({"trunc": "p"})
db_model.outputs = json.dumps({"trunc": "o"})
db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED
db_model.error = None
db_model.elapsed_time = 0.1
db_model.execution_metadata = json.dumps({"total_tokens": 3})
db_model.created_at = datetime.now(UTC)
db_model.finished_at = None
off_in = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)
off_out = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS)
off_proc = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.PROCESS_DATA)
off_in.file = SimpleNamespace(key="k-in")
off_out.file = SimpleNamespace(key="k-out")
off_proc.file = SimpleNamespace(key="k-proc")
db_model.offload_data = [off_out, off_in, off_proc]
def fake_load(key: str) -> bytes:
return json.dumps({"full": key}).encode()
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.storage.load", fake_load)
domain = repo._to_domain_model(db_model)
assert domain.inputs == {"full": "k-in"}
assert domain.outputs == {"full": "k-out"}
assert domain.process_data == {"full": "k-proc"}
assert domain.get_truncated_inputs() == {"trunc": "i"}
assert domain.get_truncated_outputs() == {"trunc": "o"}
assert domain.get_truncated_process_data() == {"trunc": "p"}
def test_to_domain_model_returns_early_when_no_offload_data(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
db_model = WorkflowNodeExecutionModel()
db_model.id = "id"
db_model.node_execution_id = "node-exec"
db_model.workflow_id = "wf"
db_model.workflow_run_id = "run"
db_model.index = 1
db_model.predecessor_node_id = None
db_model.node_id = "node"
db_model.node_type = NodeType.LLM
db_model.title = "t"
db_model.inputs = json.dumps({"i": 1})
db_model.process_data = json.dumps({"p": 2})
db_model.outputs = json.dumps({"o": 3})
db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED
db_model.error = None
db_model.elapsed_time = 0.1
db_model.execution_metadata = "{}"
db_model.created_at = datetime.now(UTC)
db_model.finished_at = None
db_model.offload_data = []
domain = repo._to_domain_model(db_model)
assert domain.inputs == {"i": 1}
assert domain.outputs == {"o": 3}
def test_json_encode_uses_runtime_converter(monkeypatch: pytest.MonkeyPatch) -> None:
class FakeConverter:
def to_json_encodable(self, values: Mapping[str, Any]) -> Mapping[str, Any]:
return {"wrapped": values["a"]}
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowRuntimeTypeConverter",
FakeConverter,
)
assert SQLAlchemyWorkflowNodeExecutionRepository._json_encode({"a": 1}) == '{"wrapped": 1}'
def test_save_execution_data_handles_existing_db_model_and_truncation(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
session = MagicMock()
session.execute.return_value.scalars.return_value.first.return_value = SimpleNamespace(
id="id",
offload_data=[WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)],
inputs=None,
outputs=None,
process_data=None,
)
session.merge = Mock()
session.flush = Mock()
session.begin.return_value.__enter__ = Mock(return_value=session)
session.begin.return_value.__exit__ = Mock(return_value=None)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=_session_factory(session),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
execution = _execution(inputs={"a": 1}, outputs={"b": 2}, process_data={"c": 3})
trunc_result = SimpleNamespace(
truncated_value={"trunc": True},
offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS, file_id="f1"),
)
monkeypatch.setattr(
repo, "_truncate_and_upload", lambda values, *_args, **_kwargs: trunc_result if values == {"a": 1} else None
)
monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values, sort_keys=True))
repo.save_execution_data(execution)
# Inputs should be truncated, outputs/process_data encoded directly
db_model = session.merge.call_args.args[0]
assert json.loads(db_model.inputs) == {"trunc": True}
assert json.loads(db_model.outputs) == {"b": 2}
assert json.loads(db_model.process_data) == {"c": 3}
assert any(off.type_ == ExecutionOffLoadType.INPUTS for off in db_model.offload_data)
assert execution.get_truncated_inputs() == {"trunc": True}
def test_save_execution_data_truncates_outputs_and_process_data(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
existing = SimpleNamespace(
id="id",
offload_data=[],
inputs=None,
outputs=None,
process_data=None,
)
session = MagicMock()
session.execute.return_value.scalars.return_value.first.return_value = existing
session.merge = Mock()
session.flush = Mock()
session.begin.return_value.__enter__ = Mock(return_value=session)
session.begin.return_value.__exit__ = Mock(return_value=None)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=_session_factory(session),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
execution = _execution(inputs={"a": 1}, outputs={"b": 2}, process_data={"c": 3})
def trunc(values: Mapping[str, Any], *_args: Any, **_kwargs: Any) -> Any:
if values == {"b": 2}:
return SimpleNamespace(
truncated_value={"b": "trunc"},
offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS, file_id="f2"),
)
if values == {"c": 3}:
return SimpleNamespace(
truncated_value={"c": "trunc"},
offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.PROCESS_DATA, file_id="f3"),
)
return None
monkeypatch.setattr(repo, "_truncate_and_upload", trunc)
monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values, sort_keys=True))
repo.save_execution_data(execution)
db_model = session.merge.call_args.args[0]
assert json.loads(db_model.outputs) == {"b": "trunc"}
assert json.loads(db_model.process_data) == {"c": "trunc"}
assert execution.get_truncated_outputs() == {"b": "trunc"}
assert execution.get_truncated_process_data() == {"c": "trunc"}
def test_save_execution_data_handles_missing_db_model(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
session = MagicMock()
session.execute.return_value.scalars.return_value.first.return_value = None
session.merge = Mock()
session.flush = Mock()
session.begin.return_value.__enter__ = Mock(return_value=session)
session.begin.return_value.__exit__ = Mock(return_value=None)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=_session_factory(session),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
execution = _execution(inputs={"a": 1})
fake_db_model = SimpleNamespace(id=execution.id, offload_data=[], inputs=None, outputs=None, process_data=None)
monkeypatch.setattr(repo, "_to_db_model", lambda *_: fake_db_model)
monkeypatch.setattr(repo, "_truncate_and_upload", lambda *_args, **_kwargs: None)
monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values))
repo.save_execution_data(execution)
merged = session.merge.call_args.args[0]
assert merged.inputs == '{"a": 1}'
def test_save_retries_duplicate_and_logs_non_duplicate(
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
execution = _execution(execution_id="id")
unique = Mock(spec=psycopg2.errors.UniqueViolation)
duplicate_error = IntegrityError("dup", params=None, orig=unique)
other_error = IntegrityError("other", params=None, orig=None)
calls = {"n": 0}
def persist(_db_model: Any) -> None:
calls["n"] += 1
if calls["n"] == 1:
raise duplicate_error
monkeypatch.setattr(repo, "_persist_to_database", persist)
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "new-id")
repo.save(execution)
assert execution.id == "new-id"
assert repo._node_execution_cache[execution.node_execution_id] is not None
caplog.set_level(logging.ERROR)
monkeypatch.setattr(repo, "_persist_to_database", lambda _db: (_ for _ in ()).throw(other_error))
with pytest.raises(IntegrityError):
repo.save(_execution(execution_id="id2", node_execution_id="node2"))
assert any("Non-duplicate key integrity error" in r.message for r in caplog.records)
def test_save_logs_and_reraises_on_unexpected_error(
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
caplog.set_level(logging.ERROR)
monkeypatch.setattr(repo, "_persist_to_database", lambda _db: (_ for _ in ()).throw(RuntimeError("boom")))
with pytest.raises(RuntimeError, match="boom"):
repo.save(_execution(execution_id="id3", node_execution_id="node3"))
assert any("Failed to save workflow node execution" in r.message for r in caplog.records)
def test_get_db_models_by_workflow_run_orders_and_caches(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
class FakeStmt:
def __init__(self) -> None:
self.where_calls = 0
self.order_by_args: tuple[Any, ...] | None = None
def where(self, *_args: Any) -> FakeStmt:
self.where_calls += 1
return self
def order_by(self, *args: Any) -> FakeStmt:
self.order_by_args = args
return self
stmt = FakeStmt()
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel.preload_offload_data_and_files",
lambda _q: stmt,
)
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.select", lambda *_: "select")
model1 = SimpleNamespace(node_execution_id="n1")
model2 = SimpleNamespace(node_execution_id=None)
session = MagicMock()
session.scalars.return_value.all.return_value = [model1, model2]
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=_session_factory(session),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
order = OrderConfig(order_by=["index", "missing"], order_direction="desc")
db_models = repo.get_db_models_by_workflow_run("run", order)
assert db_models == [model1, model2]
assert repo._node_execution_cache["n1"] is model1
assert stmt.order_by_args is not None
def test_get_db_models_by_workflow_run_uses_asc_order(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
class FakeStmt:
def where(self, *_args: Any) -> FakeStmt:
return self
def order_by(self, *args: Any) -> FakeStmt:
self.args = args # type: ignore[attr-defined]
return self
stmt = FakeStmt()
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel.preload_offload_data_and_files",
lambda _q: stmt,
)
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.select", lambda *_: "select")
session = MagicMock()
session.scalars.return_value.all.return_value = []
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=_session_factory(session),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
repo.get_db_models_by_workflow_run("run", OrderConfig(order_by=["index"], order_direction="asc"))
def test_get_by_workflow_run_maps_to_domain(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
db_models = [SimpleNamespace(id="db1"), SimpleNamespace(id="db2")]
monkeypatch.setattr(repo, "get_db_models_by_workflow_run", lambda *_args, **_kwargs: db_models)
monkeypatch.setattr(repo, "_to_domain_model", lambda m: f"domain:{m.id}")
class FakeExecutor:
def __enter__(self) -> FakeExecutor:
return self
def __exit__(self, exc_type, exc, tb) -> None:
return None
def map(self, func, items, timeout: int): # type: ignore[no-untyped-def]
assert timeout == 30
return list(map(func, items))
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.ThreadPoolExecutor",
lambda max_workers: FakeExecutor(),
)
result = repo.get_by_workflow_run("run", order_config=None)
assert result == ["domain:db1", "domain:db2"]

View File

@@ -0,0 +1,137 @@
import json
from unittest.mock import patch
from core.schemas.registry import SchemaRegistry
class TestSchemaRegistry:
def test_initialization(self, tmp_path):
base_dir = tmp_path / "schemas"
base_dir.mkdir()
registry = SchemaRegistry(str(base_dir))
assert registry.base_dir == base_dir
assert registry.versions == {}
assert registry.metadata == {}
def test_default_registry_singleton(self):
registry1 = SchemaRegistry.default_registry()
registry2 = SchemaRegistry.default_registry()
assert registry1 is registry2
assert isinstance(registry1, SchemaRegistry)
def test_load_all_versions_non_existent_dir(self, tmp_path):
base_dir = tmp_path / "non_existent"
registry = SchemaRegistry(str(base_dir))
registry.load_all_versions()
assert registry.versions == {}
def test_load_all_versions_filtering(self, tmp_path):
base_dir = tmp_path / "schemas"
base_dir.mkdir()
(base_dir / "not_a_version_dir").mkdir()
(base_dir / "v1").mkdir()
(base_dir / "some_file.txt").write_text("content")
registry = SchemaRegistry(str(base_dir))
with patch.object(registry, "_load_version_dir") as mock_load:
registry.load_all_versions()
mock_load.assert_called_once()
assert mock_load.call_args[0][0] == "v1"
def test_load_version_dir_filtering(self, tmp_path):
version_dir = tmp_path / "v1"
version_dir.mkdir()
(version_dir / "schema1.json").write_text("{}")
(version_dir / "not_a_schema.txt").write_text("content")
registry = SchemaRegistry(str(tmp_path))
with patch.object(registry, "_load_schema") as mock_load:
registry._load_version_dir("v1", version_dir)
mock_load.assert_called_once()
assert mock_load.call_args[0][1] == "schema1"
def test_load_version_dir_non_existent(self, tmp_path):
version_dir = tmp_path / "non_existent"
registry = SchemaRegistry(str(tmp_path))
registry._load_version_dir("v1", version_dir)
assert "v1" not in registry.versions
def test_load_schema_success(self, tmp_path):
schema_path = tmp_path / "test.json"
schema_content = {"title": "Test Schema", "description": "A test schema"}
schema_path.write_text(json.dumps(schema_content))
registry = SchemaRegistry(str(tmp_path))
registry.versions["v1"] = {}
registry._load_schema("v1", "test", schema_path)
assert registry.versions["v1"]["test"] == schema_content
uri = "https://dify.ai/schemas/v1/test.json"
assert registry.metadata[uri]["title"] == "Test Schema"
assert registry.metadata[uri]["version"] == "v1"
def test_load_schema_invalid_json(self, tmp_path, caplog):
schema_path = tmp_path / "invalid.json"
schema_path.write_text("invalid json")
registry = SchemaRegistry(str(tmp_path))
registry.versions["v1"] = {}
registry._load_schema("v1", "invalid", schema_path)
assert "Failed to load schema v1/invalid" in caplog.text
def test_load_schema_os_error(self, tmp_path, caplog):
schema_path = tmp_path / "error.json"
schema_path.write_text("{}")
registry = SchemaRegistry(str(tmp_path))
registry.versions["v1"] = {}
with patch("builtins.open", side_effect=OSError("Read error")):
registry._load_schema("v1", "error", schema_path)
assert "Failed to load schema v1/error" in caplog.text
def test_get_schema(self):
registry = SchemaRegistry("/tmp")
registry.versions = {"v1": {"test": {"type": "object"}}}
# Valid URI
assert registry.get_schema("https://dify.ai/schemas/v1/test.json") == {"type": "object"}
# Invalid URI
assert registry.get_schema("invalid-uri") is None
# Missing version
assert registry.get_schema("https://dify.ai/schemas/v2/test.json") is None
def test_list_versions(self):
registry = SchemaRegistry("/tmp")
registry.versions = {"v2": {}, "v1": {}}
assert registry.list_versions() == ["v1", "v2"]
def test_list_schemas(self):
registry = SchemaRegistry("/tmp")
registry.versions = {"v1": {"b": {}, "a": {}}}
assert registry.list_schemas("v1") == ["a", "b"]
assert registry.list_schemas("v2") == []
def test_get_all_schemas_for_version(self):
registry = SchemaRegistry("/tmp")
registry.versions = {"v1": {"test": {"title": "Test Label"}}}
results = registry.get_all_schemas_for_version("v1")
assert len(results) == 1
assert results[0]["name"] == "test"
assert results[0]["label"] == "Test Label"
assert results[0]["schema"] == {"title": "Test Label"}
# Default label if title missing
registry.versions["v1"]["no_title"] = {}
results = registry.get_all_schemas_for_version("v1")
item = next(r for r in results if r["name"] == "no_title")
assert item["label"] == "no_title"
# Empty if version missing
assert registry.get_all_schemas_for_version("v2") == []

View File

@@ -0,0 +1,80 @@
from unittest.mock import MagicMock, patch
from core.schemas.registry import SchemaRegistry
from core.schemas.schema_manager import SchemaManager
def test_init_with_provided_registry():
mock_registry = MagicMock(spec=SchemaRegistry)
manager = SchemaManager(registry=mock_registry)
assert manager.registry == mock_registry
@patch("core.schemas.schema_manager.SchemaRegistry.default_registry")
def test_init_with_default_registry(mock_default_registry):
mock_registry = MagicMock(spec=SchemaRegistry)
mock_default_registry.return_value = mock_registry
manager = SchemaManager()
mock_default_registry.assert_called_once()
assert manager.registry == mock_registry
def test_get_all_schema_definitions():
mock_registry = MagicMock(spec=SchemaRegistry)
expected_definitions = [{"name": "schema1", "schema": {}}, {"name": "schema2", "schema": {}}]
mock_registry.get_all_schemas_for_version.return_value = expected_definitions
manager = SchemaManager(registry=mock_registry)
result = manager.get_all_schema_definitions(version="v2")
mock_registry.get_all_schemas_for_version.assert_called_once_with("v2")
assert result == expected_definitions
def test_get_schema_by_name_success():
mock_registry = MagicMock(spec=SchemaRegistry)
mock_schema = {"type": "object"}
mock_registry.get_schema.return_value = mock_schema
manager = SchemaManager(registry=mock_registry)
result = manager.get_schema_by_name("my_schema", version="v1")
expected_uri = "https://dify.ai/schemas/v1/my_schema.json"
mock_registry.get_schema.assert_called_once_with(expected_uri)
assert result == {"name": "my_schema", "schema": mock_schema}
def test_get_schema_by_name_not_found():
mock_registry = MagicMock(spec=SchemaRegistry)
mock_registry.get_schema.return_value = None
manager = SchemaManager(registry=mock_registry)
result = manager.get_schema_by_name("non_existent", version="v1")
assert result is None
def test_list_available_schemas():
mock_registry = MagicMock(spec=SchemaRegistry)
expected_schemas = ["schema1", "schema2"]
mock_registry.list_schemas.return_value = expected_schemas
manager = SchemaManager(registry=mock_registry)
result = manager.list_available_schemas(version="v1")
mock_registry.list_schemas.assert_called_once_with("v1")
assert result == expected_schemas
def test_list_available_versions():
mock_registry = MagicMock(spec=SchemaRegistry)
expected_versions = ["v1", "v2"]
mock_registry.list_versions.return_value = expected_versions
manager = SchemaManager(registry=mock_registry)
result = manager.list_available_versions()
mock_registry.list_versions.assert_called_once()
assert result == expected_versions

View File

@@ -1303,6 +1303,24 @@ class TestBillingServiceSubscriptionOperations:
# Assert
assert result == {}
def test_get_plan_bulk_converts_string_expiration_date_to_int(self, mock_send_request):
"""Test bulk plan retrieval converts string expiration_date to int."""
# Arrange
tenant_ids = ["tenant-1"]
mock_send_request.return_value = {
"data": {
"tenant-1": {"plan": "sandbox", "expiration_date": "1735689600"},
}
}
# Act
result = BillingService.get_plan_bulk(tenant_ids)
# Assert
assert "tenant-1" in result
assert isinstance(result["tenant-1"]["expiration_date"], int)
assert result["tenant-1"]["expiration_date"] == 1735689600
def test_get_plan_bulk_with_invalid_tenant_plan_skipped(self, mock_send_request):
"""Test bulk plan retrieval when one tenant has invalid plan data (should skip that tenant)."""
# Arrange

View File

@@ -0,0 +1,558 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import UTC, datetime
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
from models.dataset import Dataset
from services.entities.knowledge_entities.knowledge_entities import (
DocumentMetadataOperation,
MetadataArgs,
MetadataDetail,
MetadataOperationData,
)
from services.metadata_service import MetadataService
@dataclass
class _DocumentStub:
id: str
name: str
uploader: str
upload_date: datetime
last_update_date: datetime
data_source_type: str
doc_metadata: dict[str, object] | None
@pytest.fixture
def mock_db(mocker: MockerFixture) -> MagicMock:
mocked_db = mocker.patch("services.metadata_service.db")
mocked_db.session = MagicMock()
return mocked_db
@pytest.fixture
def mock_redis_client(mocker: MockerFixture) -> MagicMock:
return mocker.patch("services.metadata_service.redis_client")
@pytest.fixture
def mock_current_account(mocker: MockerFixture) -> MagicMock:
mock_user = SimpleNamespace(id="user-1")
return mocker.patch("services.metadata_service.current_account_with_tenant", return_value=(mock_user, "tenant-1"))
def _build_document(document_id: str, doc_metadata: dict[str, object] | None = None) -> _DocumentStub:
now = datetime(2025, 1, 1, 10, 30, tzinfo=UTC)
return _DocumentStub(
id=document_id,
name=f"doc-{document_id}",
uploader="qa@example.com",
upload_date=now,
last_update_date=now,
data_source_type="upload_file",
doc_metadata=doc_metadata,
)
def _dataset(**kwargs: Any) -> Dataset:
return cast(Dataset, SimpleNamespace(**kwargs))
def test_create_metadata_should_raise_value_error_when_name_exceeds_limit() -> None:
# Arrange
metadata_args = MetadataArgs(type="string", name="x" * 256)
# Act + Assert
with pytest.raises(ValueError, match="cannot exceed 255"):
MetadataService.create_metadata("dataset-1", metadata_args)
def test_create_metadata_should_raise_value_error_when_metadata_name_already_exists(
mock_db: MagicMock,
mock_current_account: MagicMock,
) -> None:
# Arrange
metadata_args = MetadataArgs(type="string", name="priority")
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
# Act + Assert
with pytest.raises(ValueError, match="already exists"):
MetadataService.create_metadata("dataset-1", metadata_args)
# Assert
mock_current_account.assert_called_once()
def test_create_metadata_should_raise_value_error_when_name_collides_with_builtin(
mock_db: MagicMock, mock_current_account: MagicMock
) -> None:
# Arrange
metadata_args = MetadataArgs(type="string", name=BuiltInField.document_name)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="Built-in fields"):
MetadataService.create_metadata("dataset-1", metadata_args)
def test_create_metadata_should_persist_metadata_when_input_is_valid(
mock_db: MagicMock, mock_current_account: MagicMock
) -> None:
# Arrange
metadata_args = MetadataArgs(type="number", name="score")
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
# Act
result = MetadataService.create_metadata("dataset-1", metadata_args)
# Assert
assert result.tenant_id == "tenant-1"
assert result.dataset_id == "dataset-1"
assert result.type == "number"
assert result.name == "score"
assert result.created_by == "user-1"
mock_db.session.add.assert_called_once_with(result)
mock_db.session.commit.assert_called_once()
mock_current_account.assert_called_once()
def test_update_metadata_name_should_raise_value_error_when_name_exceeds_limit() -> None:
# Arrange
too_long_name = "x" * 256
# Act + Assert
with pytest.raises(ValueError, match="cannot exceed 255"):
MetadataService.update_metadata_name("dataset-1", "metadata-1", too_long_name)
def test_update_metadata_name_should_raise_value_error_when_duplicate_name_exists(
mock_db: MagicMock, mock_current_account: MagicMock
) -> None:
# Arrange
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
# Act + Assert
with pytest.raises(ValueError, match="already exists"):
MetadataService.update_metadata_name("dataset-1", "metadata-1", "duplicate")
# Assert
mock_current_account.assert_called_once()
def test_update_metadata_name_should_raise_value_error_when_name_collides_with_builtin(
mock_db: MagicMock,
mock_current_account: MagicMock,
) -> None:
# Arrange
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="Built-in fields"):
MetadataService.update_metadata_name("dataset-1", "metadata-1", BuiltInField.source)
# Assert
mock_current_account.assert_called_once()
def test_update_metadata_name_should_update_bound_documents_and_return_metadata(
mock_db: MagicMock,
mock_redis_client: MagicMock,
mock_current_account: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
fixed_now = datetime(2025, 2, 1, 0, 0, tzinfo=UTC)
mocker.patch("services.metadata_service.naive_utc_now", return_value=fixed_now)
metadata = SimpleNamespace(id="metadata-1", name="old_name", updated_by=None, updated_at=None)
bindings = [SimpleNamespace(document_id="doc-1"), SimpleNamespace(document_id="doc-2")]
query_duplicate = MagicMock()
query_duplicate.filter_by.return_value.first.return_value = None
query_metadata = MagicMock()
query_metadata.filter_by.return_value.first.return_value = metadata
query_bindings = MagicMock()
query_bindings.filter_by.return_value.all.return_value = bindings
mock_db.session.query.side_effect = [query_duplicate, query_metadata, query_bindings]
doc_1 = _build_document("1", {"old_name": "value", "other": "keep"})
doc_2 = _build_document("2", None)
mock_get_documents = mocker.patch("services.metadata_service.DocumentService.get_document_by_ids")
mock_get_documents.return_value = [doc_1, doc_2]
# Act
result = MetadataService.update_metadata_name("dataset-1", "metadata-1", "new_name")
# Assert
assert result is metadata
assert metadata.name == "new_name"
assert metadata.updated_by == "user-1"
assert metadata.updated_at == fixed_now
assert doc_1.doc_metadata == {"other": "keep", "new_name": "value"}
assert doc_2.doc_metadata == {"new_name": None}
mock_get_documents.assert_called_once_with(["doc-1", "doc-2"])
mock_db.session.commit.assert_called_once()
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
mock_current_account.assert_called_once()
def test_update_metadata_name_should_return_none_when_metadata_does_not_exist(
mock_db: MagicMock,
mock_redis_client: MagicMock,
mock_current_account: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
mock_logger = mocker.patch("services.metadata_service.logger")
query_duplicate = MagicMock()
query_duplicate.filter_by.return_value.first.return_value = None
query_metadata = MagicMock()
query_metadata.filter_by.return_value.first.return_value = None
mock_db.session.query.side_effect = [query_duplicate, query_metadata]
# Act
result = MetadataService.update_metadata_name("dataset-1", "missing-id", "new_name")
# Assert
assert result is None
mock_logger.exception.assert_called_once()
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
mock_current_account.assert_called_once()
def test_delete_metadata_should_remove_metadata_and_related_document_fields(
mock_db: MagicMock,
mock_redis_client: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
metadata = SimpleNamespace(id="metadata-1", name="obsolete")
bindings = [SimpleNamespace(document_id="doc-1")]
query_metadata = MagicMock()
query_metadata.filter_by.return_value.first.return_value = metadata
query_bindings = MagicMock()
query_bindings.filter_by.return_value.all.return_value = bindings
mock_db.session.query.side_effect = [query_metadata, query_bindings]
document = _build_document("1", {"obsolete": "legacy", "remaining": "value"})
mocker.patch("services.metadata_service.DocumentService.get_document_by_ids", return_value=[document])
# Act
result = MetadataService.delete_metadata("dataset-1", "metadata-1")
# Assert
assert result is metadata
assert document.doc_metadata == {"remaining": "value"}
mock_db.session.delete.assert_called_once_with(metadata)
mock_db.session.commit.assert_called_once()
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
def test_delete_metadata_should_return_none_when_metadata_is_missing(
mock_db: MagicMock,
mock_redis_client: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
mock_logger = mocker.patch("services.metadata_service.logger")
# Act
result = MetadataService.delete_metadata("dataset-1", "missing-id")
# Assert
assert result is None
mock_logger.exception.assert_called_once()
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
def test_get_built_in_fields_should_return_all_expected_fields() -> None:
# Arrange
expected_names = {
BuiltInField.document_name,
BuiltInField.uploader,
BuiltInField.upload_date,
BuiltInField.last_update_date,
BuiltInField.source,
}
# Act
result = MetadataService.get_built_in_fields()
# Assert
assert {item["name"] for item in result} == expected_names
assert [item["type"] for item in result] == ["string", "string", "time", "time", "string"]
def test_enable_built_in_field_should_return_immediately_when_already_enabled(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
dataset = _dataset(id="dataset-1", built_in_field_enabled=True)
get_docs = mocker.patch("services.metadata_service.DocumentService.get_working_documents_by_dataset_id")
# Act
MetadataService.enable_built_in_field(dataset)
# Assert
get_docs.assert_not_called()
mock_db.session.commit.assert_not_called()
def test_enable_built_in_field_should_populate_documents_and_enable_flag(
mock_db: MagicMock,
mock_redis_client: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
dataset = _dataset(id="dataset-1", built_in_field_enabled=False)
doc_1 = _build_document("1", {"custom": "value"})
doc_2 = _build_document("2", None)
mocker.patch(
"services.metadata_service.DocumentService.get_working_documents_by_dataset_id",
return_value=[doc_1, doc_2],
)
# Act
MetadataService.enable_built_in_field(dataset)
# Assert
assert dataset.built_in_field_enabled is True
assert doc_1.doc_metadata is not None
assert doc_1.doc_metadata[BuiltInField.document_name] == "doc-1"
assert doc_1.doc_metadata[BuiltInField.source] == MetadataDataSource.upload_file
assert doc_2.doc_metadata is not None
assert doc_2.doc_metadata[BuiltInField.uploader] == "qa@example.com"
mock_db.session.commit.assert_called_once()
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
def test_disable_built_in_field_should_return_immediately_when_already_disabled(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
dataset = _dataset(id="dataset-1", built_in_field_enabled=False)
get_docs = mocker.patch("services.metadata_service.DocumentService.get_working_documents_by_dataset_id")
# Act
MetadataService.disable_built_in_field(dataset)
# Assert
get_docs.assert_not_called()
mock_db.session.commit.assert_not_called()
def test_disable_built_in_field_should_remove_builtin_keys_and_disable_flag(
mock_db: MagicMock,
mock_redis_client: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
dataset = _dataset(id="dataset-1", built_in_field_enabled=True)
document = _build_document(
"1",
{
BuiltInField.document_name: "doc",
BuiltInField.uploader: "user",
BuiltInField.upload_date: 1.0,
BuiltInField.last_update_date: 2.0,
BuiltInField.source: MetadataDataSource.upload_file,
"custom": "keep",
},
)
mocker.patch(
"services.metadata_service.DocumentService.get_working_documents_by_dataset_id",
return_value=[document],
)
# Act
MetadataService.disable_built_in_field(dataset)
# Assert
assert dataset.built_in_field_enabled is False
assert document.doc_metadata == {"custom": "keep"}
mock_db.session.commit.assert_called_once()
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
def test_update_documents_metadata_should_replace_metadata_and_create_bindings_on_full_update(
mock_db: MagicMock,
mock_redis_client: MagicMock,
mock_current_account: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
dataset = _dataset(id="dataset-1", built_in_field_enabled=False)
document = _build_document("1", {"legacy": "value"})
mocker.patch("services.metadata_service.DocumentService.get_document", return_value=document)
delete_chain = mock_db.session.query.return_value.filter_by.return_value
delete_chain.delete.return_value = 1
operation = DocumentMetadataOperation(
document_id="1",
metadata_list=[MetadataDetail(id="meta-1", name="priority", value="high")],
partial_update=False,
)
metadata_args = MetadataOperationData(operation_data=[operation])
# Act
MetadataService.update_documents_metadata(dataset, metadata_args)
# Assert
assert document.doc_metadata == {"priority": "high"}
delete_chain.delete.assert_called_once()
assert mock_db.session.commit.call_count == 1
mock_redis_client.delete.assert_called_once_with("document_metadata_lock_1")
mock_current_account.assert_called_once()
def test_update_documents_metadata_should_skip_existing_binding_and_preserve_existing_fields_on_partial_update(
mock_db: MagicMock,
mock_redis_client: MagicMock,
mock_current_account: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
dataset = _dataset(id="dataset-1", built_in_field_enabled=True)
document = _build_document("1", {"existing": "value"})
mocker.patch("services.metadata_service.DocumentService.get_document", return_value=document)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
operation = DocumentMetadataOperation(
document_id="1",
metadata_list=[MetadataDetail(id="meta-1", name="new_key", value="new_value")],
partial_update=True,
)
metadata_args = MetadataOperationData(operation_data=[operation])
# Act
MetadataService.update_documents_metadata(dataset, metadata_args)
# Assert
assert document.doc_metadata is not None
assert document.doc_metadata["existing"] == "value"
assert document.doc_metadata["new_key"] == "new_value"
assert document.doc_metadata[BuiltInField.source] == MetadataDataSource.upload_file
assert mock_db.session.commit.call_count == 1
assert mock_db.session.add.call_count == 1
mock_redis_client.delete.assert_called_once_with("document_metadata_lock_1")
mock_current_account.assert_called_once()
def test_update_documents_metadata_should_raise_and_rollback_when_document_not_found(
mock_db: MagicMock,
mock_redis_client: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
dataset = _dataset(id="dataset-1", built_in_field_enabled=False)
mocker.patch("services.metadata_service.DocumentService.get_document", return_value=None)
operation = DocumentMetadataOperation(document_id="404", metadata_list=[], partial_update=True)
metadata_args = MetadataOperationData(operation_data=[operation])
# Act + Assert
with pytest.raises(ValueError, match="Document not found"):
MetadataService.update_documents_metadata(dataset, metadata_args)
# Assert
mock_db.session.rollback.assert_called_once()
mock_redis_client.delete.assert_called_once_with("document_metadata_lock_404")
@pytest.mark.parametrize(
("dataset_id", "document_id", "expected_key"),
[
("dataset-1", None, "dataset_metadata_lock_dataset-1"),
(None, "doc-1", "document_metadata_lock_doc-1"),
],
)
def test_knowledge_base_metadata_lock_check_should_set_lock_when_not_already_locked(
dataset_id: str | None,
document_id: str | None,
expected_key: str,
mock_redis_client: MagicMock,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
# Act
MetadataService.knowledge_base_metadata_lock_check(dataset_id, document_id)
# Assert
mock_redis_client.set.assert_called_once_with(expected_key, 1, ex=3600)
def test_knowledge_base_metadata_lock_check_should_raise_when_dataset_lock_exists(
mock_redis_client: MagicMock,
) -> None:
# Arrange
mock_redis_client.get.return_value = 1
# Act + Assert
with pytest.raises(ValueError, match="knowledge base metadata operation is running"):
MetadataService.knowledge_base_metadata_lock_check("dataset-1", None)
def test_knowledge_base_metadata_lock_check_should_raise_when_document_lock_exists(
mock_redis_client: MagicMock,
) -> None:
# Arrange
mock_redis_client.get.return_value = 1
# Act + Assert
with pytest.raises(ValueError, match="document metadata operation is running"):
MetadataService.knowledge_base_metadata_lock_check(None, "doc-1")
def test_get_dataset_metadatas_should_exclude_builtin_and_include_binding_counts(mock_db: MagicMock) -> None:
# Arrange
dataset = _dataset(
id="dataset-1",
built_in_field_enabled=True,
doc_metadata=[
{"id": "meta-1", "name": "priority", "type": "string"},
{"id": "built-in", "name": "ignored", "type": "string"},
{"id": "meta-2", "name": "score", "type": "number"},
],
)
count_chain = mock_db.session.query.return_value.filter_by.return_value
count_chain.count.side_effect = [3, 1]
# Act
result = MetadataService.get_dataset_metadatas(dataset)
# Assert
assert result["built_in_field_enabled"] is True
assert result["doc_metadata"] == [
{"id": "meta-1", "name": "priority", "type": "string", "count": 3},
{"id": "meta-2", "name": "score", "type": "number", "count": 1},
]
def test_get_dataset_metadatas_should_return_empty_list_when_no_metadata(mock_db: MagicMock) -> None:
# Arrange
dataset = _dataset(id="dataset-1", built_in_field_enabled=False, doc_metadata=None)
# Act
result = MetadataService.get_dataset_metadatas(dataset)
# Assert
assert result == {"doc_metadata": [], "built_in_field_enabled": False}
mock_db.session.query.assert_not_called()

View File

@@ -0,0 +1,808 @@
from __future__ import annotations
import json
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from constants import HIDDEN_VALUE
from dify_graph.model_runtime.entities.common_entities import I18nObject
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.entities.provider_entities import (
CredentialFormSchema,
FieldModelSchema,
FormType,
ModelCredentialSchema,
ProviderCredentialSchema,
)
from models.provider import LoadBalancingModelConfig
from services.model_load_balancing_service import ModelLoadBalancingService
def _build_provider_credential_schema() -> ProviderCredentialSchema:
return ProviderCredentialSchema(
credential_form_schemas=[
CredentialFormSchema(variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.SECRET_INPUT)
]
)
def _build_model_credential_schema() -> ModelCredentialSchema:
return ModelCredentialSchema(
model=FieldModelSchema(label=I18nObject(en_US="Model")),
credential_form_schemas=[
CredentialFormSchema(variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.SECRET_INPUT)
],
)
def _build_provider_configuration(
*,
custom_provider: bool = False,
load_balancing_enabled: bool | None = None,
model_schema: ModelCredentialSchema | None = None,
provider_schema: ProviderCredentialSchema | None = None,
) -> MagicMock:
provider_configuration = MagicMock()
provider_configuration.provider = SimpleNamespace(
provider="openai",
model_credential_schema=model_schema,
provider_credential_schema=provider_schema,
)
provider_configuration.custom_configuration = SimpleNamespace(provider=custom_provider)
provider_configuration.extract_secret_variables.return_value = ["api_key"]
provider_configuration.obfuscated_credentials.side_effect = lambda credentials, credential_form_schemas: credentials
provider_configuration.get_provider_model_setting.return_value = (
None if load_balancing_enabled is None else SimpleNamespace(load_balancing_enabled=load_balancing_enabled)
)
return provider_configuration
def _load_balancing_model_config(**kwargs: Any) -> LoadBalancingModelConfig:
return cast(LoadBalancingModelConfig, SimpleNamespace(**kwargs))
@pytest.fixture
def service(mocker: MockerFixture) -> ModelLoadBalancingService:
# Arrange
provider_manager = MagicMock()
mocker.patch("services.model_load_balancing_service.ProviderManager", return_value=provider_manager)
svc = ModelLoadBalancingService()
svc.provider_manager = provider_manager
return svc
@pytest.fixture
def mock_db(mocker: MockerFixture) -> MagicMock:
# Arrange
mocked_db = mocker.patch("services.model_load_balancing_service.db")
mocked_db.session = MagicMock()
return mocked_db
@pytest.mark.parametrize(
("method_name", "expected_provider_method"),
[
("enable_model_load_balancing", "enable_model_load_balancing"),
("disable_model_load_balancing", "disable_model_load_balancing"),
],
)
def test_enable_disable_model_load_balancing_should_call_provider_configuration_method_when_provider_exists(
method_name: str,
expected_provider_method: str,
service: ModelLoadBalancingService,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
# Act
getattr(service, method_name)("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value)
# Assert
getattr(provider_configuration, expected_provider_method).assert_called_once_with(
model="gpt-4o-mini", model_type=ModelType.LLM
)
@pytest.mark.parametrize(
"method_name",
["enable_model_load_balancing", "disable_model_load_balancing"],
)
def test_enable_disable_model_load_balancing_should_raise_value_error_when_provider_missing(
method_name: str,
service: ModelLoadBalancingService,
) -> None:
# Arrange
service.provider_manager.get_configurations.return_value = {}
# Act + Assert
with pytest.raises(ValueError, match="Provider openai does not exist"):
getattr(service, method_name)("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value)
def test_get_load_balancing_configs_should_raise_value_error_when_provider_missing(
service: ModelLoadBalancingService,
) -> None:
# Arrange
service.provider_manager.get_configurations.return_value = {}
# Act + Assert
with pytest.raises(ValueError, match="Provider openai does not exist"):
service.get_load_balancing_configs("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value)
def test_get_load_balancing_configs_should_insert_inherit_config_when_missing_for_custom_provider(
service: ModelLoadBalancingService,
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(
custom_provider=True,
load_balancing_enabled=True,
provider_schema=_build_provider_credential_schema(),
)
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
config = SimpleNamespace(
id="cfg-1",
name="primary",
encrypted_config=json.dumps({"api_key": "encrypted-key"}),
credential_id="cred-1",
enabled=True,
)
mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [config]
mocker.patch(
"services.model_load_balancing_service.encrypter.get_decrypt_decoding",
return_value=("rsa", "cipher"),
)
mocker.patch(
"services.model_load_balancing_service.encrypter.decrypt_token_with_decoding",
return_value="plain-key",
)
mocker.patch(
"services.model_load_balancing_service.LBModelManager.get_config_in_cooldown_and_ttl",
return_value=(False, 0),
)
# Act
is_enabled, configs = service.get_load_balancing_configs(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
)
# Assert
assert is_enabled is True
assert len(configs) == 2
assert configs[0]["name"] == "__inherit__"
assert configs[1]["name"] == "primary"
assert configs[1]["credentials"] == {"api_key": "plain-key"}
assert mock_db.session.add.call_count == 1
assert mock_db.session.commit.call_count == 1
def test_get_load_balancing_configs_should_reorder_existing_inherit_and_tolerate_json_or_decrypt_errors(
service: ModelLoadBalancingService,
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(
custom_provider=True,
load_balancing_enabled=None,
provider_schema=_build_provider_credential_schema(),
)
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
normal_config = SimpleNamespace(
id="cfg-1",
name="normal",
encrypted_config=json.dumps({"api_key": "bad-encrypted"}),
credential_id="cred-1",
enabled=True,
)
inherit_config = SimpleNamespace(
id="cfg-2",
name="__inherit__",
encrypted_config="not-json",
credential_id=None,
enabled=False,
)
mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [
normal_config,
inherit_config,
]
mocker.patch(
"services.model_load_balancing_service.encrypter.get_decrypt_decoding",
return_value=("rsa", "cipher"),
)
mocker.patch(
"services.model_load_balancing_service.encrypter.decrypt_token_with_decoding",
side_effect=ValueError("cannot decrypt"),
)
mocker.patch(
"services.model_load_balancing_service.LBModelManager.get_config_in_cooldown_and_ttl",
return_value=(True, 15),
)
# Act
is_enabled, configs = service.get_load_balancing_configs(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
config_from="predefined-model",
)
# Assert
assert is_enabled is False
assert configs[0]["name"] == "__inherit__"
assert configs[0]["credentials"] == {}
assert configs[1]["credentials"] == {"api_key": "bad-encrypted"}
assert configs[1]["in_cooldown"] is True
assert configs[1]["ttl"] == 15
def test_get_load_balancing_config_should_raise_value_error_when_provider_missing(
service: ModelLoadBalancingService,
) -> None:
# Arrange
service.provider_manager.get_configurations.return_value = {}
# Act + Assert
with pytest.raises(ValueError, match="Provider openai does not exist"):
service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1")
def test_get_load_balancing_config_should_return_none_when_config_not_found(
service: ModelLoadBalancingService,
mock_db: MagicMock,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
mock_db.session.query.return_value.where.return_value.first.return_value = None
# Act
result = service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1")
# Assert
assert result is None
def test_get_load_balancing_config_should_return_obfuscated_payload_when_config_exists(
service: ModelLoadBalancingService,
mock_db: MagicMock,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
provider_configuration.obfuscated_credentials.side_effect = lambda credentials, credential_form_schemas: {
"masked": credentials.get("api_key", "")
}
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
config = SimpleNamespace(id="cfg-1", name="primary", encrypted_config="not-json", enabled=True)
mock_db.session.query.return_value.where.return_value.first.return_value = config
# Act
result = service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1")
# Assert
assert result == {
"id": "cfg-1",
"name": "primary",
"credentials": {"masked": ""},
"enabled": True,
}
def test_init_inherit_config_should_create_and_persist_inherit_configuration(
service: ModelLoadBalancingService,
mock_db: MagicMock,
) -> None:
# Arrange
model_type = ModelType.LLM
# Act
inherit_config = service._init_inherit_config("tenant-1", "openai", "gpt-4o-mini", model_type)
# Assert
assert inherit_config.tenant_id == "tenant-1"
assert inherit_config.provider_name == "openai"
assert inherit_config.model_name == "gpt-4o-mini"
assert inherit_config.model_type == "text-generation"
assert inherit_config.name == "__inherit__"
mock_db.session.add.assert_called_once_with(inherit_config)
mock_db.session.commit.assert_called_once()
def test_update_load_balancing_configs_should_raise_value_error_when_provider_missing(
service: ModelLoadBalancingService,
) -> None:
# Arrange
service.provider_manager.get_configurations.return_value = {}
# Act + Assert
with pytest.raises(ValueError, match="Provider openai does not exist"):
service.update_load_balancing_configs(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
[],
"custom-model",
)
def test_update_load_balancing_configs_should_raise_value_error_when_configs_is_not_list(
service: ModelLoadBalancingService,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
# Act + Assert
with pytest.raises(ValueError, match="Invalid load balancing configs"):
service.update_load_balancing_configs( # type: ignore[arg-type]
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
cast(list[dict[str, object]], "invalid-configs"),
"custom-model",
)
def test_update_load_balancing_configs_should_raise_value_error_when_config_item_is_not_dict(
service: ModelLoadBalancingService,
mock_db: MagicMock,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
mock_db.session.scalars.return_value.all.return_value = []
# Act + Assert
with pytest.raises(ValueError, match="Invalid load balancing config"):
service.update_load_balancing_configs( # type: ignore[list-item]
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
cast(list[dict[str, object]], ["bad-item"]),
"custom-model",
)
def test_update_load_balancing_configs_should_raise_value_error_when_credential_id_not_found(
service: ModelLoadBalancingService,
mock_db: MagicMock,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
mock_db.session.scalars.return_value.all.return_value = []
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="Provider credential with id cred-1 not found"):
service.update_load_balancing_configs(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
[{"credential_id": "cred-1", "enabled": True}],
"predefined-model",
)
def test_update_load_balancing_configs_should_raise_value_error_when_name_or_enabled_is_invalid(
service: ModelLoadBalancingService,
mock_db: MagicMock,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
mock_db.session.scalars.return_value.all.return_value = []
# Act + Assert
with pytest.raises(ValueError, match="Invalid load balancing config name"):
service.update_load_balancing_configs(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
[{"enabled": True}],
"custom-model",
)
with pytest.raises(ValueError, match="Invalid load balancing config enabled"):
service.update_load_balancing_configs(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
[{"name": "cfg-without-enabled"}],
"custom-model",
)
def test_update_load_balancing_configs_should_raise_value_error_when_existing_config_id_is_invalid(
service: ModelLoadBalancingService,
mock_db: MagicMock,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
current_config = SimpleNamespace(id="cfg-1")
mock_db.session.scalars.return_value.all.return_value = [current_config]
# Act + Assert
with pytest.raises(ValueError, match="Invalid load balancing config id: cfg-2"):
service.update_load_balancing_configs(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
[{"id": "cfg-2", "name": "invalid", "enabled": True}],
"custom-model",
)
def test_update_load_balancing_configs_should_raise_value_error_when_credentials_are_invalid_for_update_or_create(
service: ModelLoadBalancingService,
mock_db: MagicMock,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
existing_config = SimpleNamespace(id="cfg-1", name="old", enabled=True, encrypted_config=None, updated_at=None)
mock_db.session.scalars.return_value.all.return_value = [existing_config]
# Act + Assert
with pytest.raises(ValueError, match="Invalid load balancing config credentials"):
service.update_load_balancing_configs(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
[{"id": "cfg-1", "name": "new", "enabled": True, "credentials": "bad"}],
"custom-model",
)
with pytest.raises(ValueError, match="Invalid load balancing config credentials"):
service.update_load_balancing_configs(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
[{"name": "new-config", "enabled": True, "credentials": "bad"}],
"custom-model",
)
def test_update_load_balancing_configs_should_update_existing_create_new_and_delete_removed_configs(
service: ModelLoadBalancingService,
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
existing_config_1 = SimpleNamespace(
id="cfg-1",
name="existing-one",
enabled=True,
encrypted_config=json.dumps({"api_key": "old"}),
updated_at=None,
)
existing_config_2 = SimpleNamespace(
id="cfg-2",
name="existing-two",
enabled=True,
encrypted_config=None,
updated_at=None,
)
mock_db.session.scalars.return_value.all.return_value = [existing_config_1, existing_config_2]
mocker.patch.object(service, "_custom_credentials_validate", return_value={"api_key": "encrypted"})
mock_clear_cache = mocker.patch.object(service, "_clear_credentials_cache")
# Act
service.update_load_balancing_configs(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
[
{"id": "cfg-1", "name": "updated-name", "enabled": False, "credentials": {"api_key": "plain"}},
{"name": "new-config", "enabled": True, "credentials": {"api_key": "plain"}},
],
"custom-model",
)
# Assert
assert existing_config_1.name == "updated-name"
assert existing_config_1.enabled is False
assert json.loads(existing_config_1.encrypted_config) == {"api_key": "encrypted"}
assert mock_db.session.add.call_count == 1
mock_db.session.delete.assert_called_once_with(existing_config_2)
assert mock_db.session.commit.call_count >= 3
mock_clear_cache.assert_any_call("tenant-1", "cfg-1")
mock_clear_cache.assert_any_call("tenant-1", "cfg-2")
def test_update_load_balancing_configs_should_raise_value_error_for_invalid_new_config_name_or_missing_credentials(
service: ModelLoadBalancingService,
mock_db: MagicMock,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
mock_db.session.scalars.return_value.all.return_value = []
# Act + Assert
with pytest.raises(ValueError, match="Invalid load balancing config name"):
service.update_load_balancing_configs(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
[{"name": "__inherit__", "enabled": True, "credentials": {"api_key": "x"}}],
"custom-model",
)
with pytest.raises(ValueError, match="Invalid load balancing config credentials"):
service.update_load_balancing_configs(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
[{"name": "new", "enabled": True}],
"custom-model",
)
def test_update_load_balancing_configs_should_create_from_existing_provider_credential_when_credential_id_provided(
service: ModelLoadBalancingService,
mock_db: MagicMock,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
mock_db.session.scalars.return_value.all.return_value = []
credential_record = SimpleNamespace(credential_name="Main Credential", encrypted_config='{"api_key":"enc"}')
mock_db.session.query.return_value.filter_by.return_value.first.return_value = credential_record
# Act
service.update_load_balancing_configs(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
[{"credential_id": "cred-1", "enabled": True}],
"predefined-model",
)
# Assert
created_config = mock_db.session.add.call_args.args[0]
assert created_config.name == "Main Credential"
assert created_config.credential_id == "cred-1"
assert created_config.credential_source_type == "provider"
assert created_config.encrypted_config == '{"api_key":"enc"}'
mock_db.session.commit.assert_called()
def test_validate_load_balancing_credentials_should_raise_value_error_when_provider_missing(
service: ModelLoadBalancingService,
) -> None:
# Arrange
service.provider_manager.get_configurations.return_value = {}
# Act + Assert
with pytest.raises(ValueError, match="Provider openai does not exist"):
service.validate_load_balancing_credentials(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
{"api_key": "plain"},
)
def test_validate_load_balancing_credentials_should_raise_value_error_when_config_id_is_invalid(
service: ModelLoadBalancingService,
mock_db: MagicMock,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
mock_db.session.query.return_value.where.return_value.first.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="Load balancing config cfg-1 does not exist"):
service.validate_load_balancing_credentials(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
{"api_key": "plain"},
config_id="cfg-1",
)
def test_validate_load_balancing_credentials_should_delegate_to_custom_validate_with_or_without_config(
service: ModelLoadBalancingService,
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
existing_config = SimpleNamespace(id="cfg-1")
mock_db.session.query.return_value.where.return_value.first.return_value = existing_config
mock_validate = mocker.patch.object(service, "_custom_credentials_validate")
# Act
service.validate_load_balancing_credentials(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
{"api_key": "plain"},
config_id="cfg-1",
)
service.validate_load_balancing_credentials(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
{"api_key": "plain"},
)
# Assert
assert mock_validate.call_count == 2
assert mock_validate.call_args_list[0].kwargs["load_balancing_model_config"] is existing_config
assert mock_validate.call_args_list[1].kwargs["load_balancing_model_config"] is None
def test_custom_credentials_validate_should_replace_hidden_secret_with_original_value_and_encrypt(
service: ModelLoadBalancingService,
mocker: MockerFixture,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
load_balancing_model_config = _load_balancing_model_config(
encrypted_config=json.dumps({"api_key": "old-encrypted-token"})
)
mocker.patch("services.model_load_balancing_service.encrypter.decrypt_token", return_value="old-plain-value")
mock_encrypt = mocker.patch(
"services.model_load_balancing_service.encrypter.encrypt_token",
side_effect=lambda tenant_id, value: f"enc:{value}",
)
# Act
result = service._custom_credentials_validate(
tenant_id="tenant-1",
provider_configuration=provider_configuration,
model_type=ModelType.LLM,
model="gpt-4o-mini",
credentials={"api_key": HIDDEN_VALUE, "region": "us"},
load_balancing_model_config=load_balancing_model_config,
validate=False,
)
# Assert
assert result == {"api_key": "enc:old-plain-value", "region": "us"}
mock_encrypt.assert_called_once_with("tenant-1", "old-plain-value")
def test_custom_credentials_validate_should_handle_invalid_original_json_and_validate_with_model_schema(
service: ModelLoadBalancingService,
mocker: MockerFixture,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(model_schema=_build_model_credential_schema())
load_balancing_model_config = _load_balancing_model_config(encrypted_config="not-json")
mock_factory = MagicMock()
mock_factory.model_credentials_validate.return_value = {"api_key": "validated"}
mocker.patch("services.model_load_balancing_service.ModelProviderFactory", return_value=mock_factory)
mock_encrypt = mocker.patch(
"services.model_load_balancing_service.encrypter.encrypt_token",
side_effect=lambda tenant_id, value: f"enc:{value}",
)
# Act
result = service._custom_credentials_validate(
tenant_id="tenant-1",
provider_configuration=provider_configuration,
model_type=ModelType.LLM,
model="gpt-4o-mini",
credentials={"api_key": "plain"},
load_balancing_model_config=load_balancing_model_config,
validate=True,
)
# Assert
assert result == {"api_key": "enc:validated"}
mock_factory.model_credentials_validate.assert_called_once()
mock_factory.provider_credentials_validate.assert_not_called()
mock_encrypt.assert_called_once_with("tenant-1", "validated")
def test_custom_credentials_validate_should_validate_with_provider_schema_when_model_schema_absent(
service: ModelLoadBalancingService,
mocker: MockerFixture,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
mock_factory = MagicMock()
mock_factory.provider_credentials_validate.return_value = {"api_key": "provider-validated"}
mocker.patch("services.model_load_balancing_service.ModelProviderFactory", return_value=mock_factory)
mocker.patch(
"services.model_load_balancing_service.encrypter.encrypt_token",
side_effect=lambda tenant_id, value: f"enc:{value}",
)
# Act
result = service._custom_credentials_validate(
tenant_id="tenant-1",
provider_configuration=provider_configuration,
model_type=ModelType.LLM,
model="gpt-4o-mini",
credentials={"api_key": "plain"},
validate=True,
)
# Assert
assert result == {"api_key": "enc:provider-validated"}
mock_factory.provider_credentials_validate.assert_called_once()
mock_factory.model_credentials_validate.assert_not_called()
def test_get_credential_schema_should_return_model_schema_or_provider_schema_or_raise(
service: ModelLoadBalancingService,
) -> None:
# Arrange
model_schema = _build_model_credential_schema()
provider_schema = _build_provider_credential_schema()
provider_configuration_with_model = _build_provider_configuration(model_schema=model_schema)
provider_configuration_with_provider = _build_provider_configuration(provider_schema=provider_schema)
provider_configuration_without_schema = _build_provider_configuration()
# Act
schema_from_model = service._get_credential_schema(provider_configuration_with_model)
schema_from_provider = service._get_credential_schema(provider_configuration_with_provider)
# Assert
assert schema_from_model is model_schema
assert schema_from_provider is provider_schema
with pytest.raises(ValueError, match="No credential schema found"):
service._get_credential_schema(provider_configuration_without_schema)
def test_clear_credentials_cache_should_delete_load_balancing_cache_entry(
service: ModelLoadBalancingService,
mocker: MockerFixture,
) -> None:
# Arrange
mock_cache_instance = MagicMock()
mock_cache_cls = mocker.patch(
"services.model_load_balancing_service.ProviderCredentialsCache",
return_value=mock_cache_instance,
)
# Act
service._clear_credentials_cache("tenant-1", "cfg-1")
# Assert
mock_cache_cls.assert_called_once()
assert mock_cache_cls.call_args.kwargs == {
"tenant_id": "tenant-1",
"identity_id": "cfg-1",
"cache_type": mocker.ANY,
}
assert mock_cache_cls.call_args.kwargs["cache_type"].name == "LOAD_BALANCING_MODEL"
mock_cache_instance.delete.assert_called_once()

View File

@@ -0,0 +1,224 @@
from __future__ import annotations
import uuid
from types import SimpleNamespace
from typing import cast
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from werkzeug.exceptions import BadRequest
from services.oauth_server import (
OAUTH_ACCESS_TOKEN_EXPIRES_IN,
OAUTH_ACCESS_TOKEN_REDIS_KEY,
OAUTH_AUTHORIZATION_CODE_REDIS_KEY,
OAUTH_REFRESH_TOKEN_EXPIRES_IN,
OAUTH_REFRESH_TOKEN_REDIS_KEY,
OAuthGrantType,
OAuthServerService,
)
@pytest.fixture
def mock_redis_client(mocker: MockerFixture) -> MagicMock:
return mocker.patch("services.oauth_server.redis_client")
@pytest.fixture
def mock_session(mocker: MockerFixture) -> MagicMock:
"""Mock the OAuth server Session context manager."""
mocker.patch("services.oauth_server.db", SimpleNamespace(engine=object()))
session = MagicMock()
session_cm = MagicMock()
session_cm.__enter__.return_value = session
mocker.patch("services.oauth_server.Session", return_value=session_cm)
return session
def test_get_oauth_provider_app_should_return_app_when_record_exists(mock_session: MagicMock) -> None:
# Arrange
mock_execute_result = MagicMock()
expected_app = MagicMock()
mock_execute_result.scalar_one_or_none.return_value = expected_app
mock_session.execute.return_value = mock_execute_result
# Act
result = OAuthServerService.get_oauth_provider_app("client-1")
# Assert
assert result is expected_app
mock_session.execute.assert_called_once()
mock_execute_result.scalar_one_or_none.assert_called_once()
def test_sign_oauth_authorization_code_should_store_code_and_return_value(
mocker: MockerFixture, mock_redis_client: MagicMock
) -> None:
# Arrange
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111")
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
# Act
code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1")
# Assert
expected_code = str(deterministic_uuid)
assert code == expected_code
mock_redis_client.set.assert_called_once_with(
OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=expected_code),
"user-1",
ex=600,
)
def test_sign_oauth_access_token_should_raise_bad_request_when_authorization_code_is_invalid(
mock_redis_client: MagicMock,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
# Act + Assert
with pytest.raises(BadRequest, match="invalid code"):
OAuthServerService.sign_oauth_access_token(
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
code="bad-code",
client_id="client-1",
)
def test_sign_oauth_access_token_should_issue_access_and_refresh_token_when_authorization_code_is_valid(
mocker: MockerFixture, mock_redis_client: MagicMock
) -> None:
# Arrange
token_uuids = [
uuid.UUID("00000000-0000-0000-0000-000000000201"),
uuid.UUID("00000000-0000-0000-0000-000000000202"),
]
mocker.patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids)
mock_redis_client.get.return_value = b"user-1"
code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1")
# Act
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
code="code-1",
client_id="client-1",
)
# Assert
assert access_token == str(token_uuids[0])
assert refresh_token == str(token_uuids[1])
mock_redis_client.delete.assert_called_once_with(code_key)
mock_redis_client.set.assert_any_call(
OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token),
b"user-1",
ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN,
)
mock_redis_client.set.assert_any_call(
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token),
b"user-1",
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
)
def test_sign_oauth_access_token_should_raise_bad_request_when_refresh_token_is_invalid(
mock_redis_client: MagicMock,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
# Act + Assert
with pytest.raises(BadRequest, match="invalid refresh token"):
OAuthServerService.sign_oauth_access_token(
grant_type=OAuthGrantType.REFRESH_TOKEN,
refresh_token="stale-token",
client_id="client-1",
)
def test_sign_oauth_access_token_should_issue_new_access_token_when_refresh_token_is_valid(
mocker: MockerFixture, mock_redis_client: MagicMock
) -> None:
# Arrange
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301")
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
mock_redis_client.get.return_value = b"user-1"
# Act
access_token, returned_refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type=OAuthGrantType.REFRESH_TOKEN,
refresh_token="refresh-1",
client_id="client-1",
)
# Assert
assert access_token == str(deterministic_uuid)
assert returned_refresh_token == "refresh-1"
mock_redis_client.set.assert_called_once_with(
OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token),
b"user-1",
ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN,
)
def test_sign_oauth_access_token_with_unknown_grant_type_should_return_none() -> None:
# Arrange
grant_type = cast(OAuthGrantType, "invalid-grant-type")
# Act
result = OAuthServerService.sign_oauth_access_token(
grant_type=grant_type,
client_id="client-1",
)
# Assert
assert result is None
def test_sign_oauth_refresh_token_should_store_token_with_expected_expiry(
mocker: MockerFixture, mock_redis_client: MagicMock
) -> None:
# Arrange
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401")
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
# Act
refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2")
# Assert
assert refresh_token == str(deterministic_uuid)
mock_redis_client.set.assert_called_once_with(
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token),
"user-2",
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
)
def test_validate_oauth_access_token_should_return_none_when_token_not_found(
mock_redis_client: MagicMock,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
# Act
result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token")
# Assert
assert result is None
def test_validate_oauth_access_token_should_load_user_when_token_exists(
mocker: MockerFixture, mock_redis_client: MagicMock
) -> None:
# Arrange
mock_redis_client.get.return_value = b"user-88"
expected_user = MagicMock()
mock_load_user = mocker.patch("services.oauth_server.AccountService.load_user", return_value=expected_user)
# Act
result = OAuthServerService.validate_oauth_access_token("client-1", "access-token")
# Assert
assert result is expected_user
mock_load_user.assert_called_once_with("user-88")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,259 @@
from __future__ import annotations
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from core.app.entities.app_invoke_entities import InvokeFrom
from models import Account
from models.model import App, EndUser
from services.web_conversation_service import WebConversationService
@pytest.fixture
def app_model() -> App:
return cast(App, SimpleNamespace(id="app-1"))
def _account(**kwargs: Any) -> Account:
return cast(Account, SimpleNamespace(**kwargs))
def _end_user(**kwargs: Any) -> EndUser:
return cast(EndUser, SimpleNamespace(**kwargs))
def test_pagination_by_last_id_should_raise_error_when_user_is_none(
app_model: App,
mocker: MockerFixture,
) -> None:
# Arrange
session = MagicMock()
mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id")
# Act + Assert
with pytest.raises(ValueError, match="User is required"):
WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=None,
last_id=None,
limit=20,
invoke_from=InvokeFrom.WEB_APP,
)
def test_pagination_by_last_id_should_forward_without_pin_filter_when_pinned_is_none(
app_model: App,
mocker: MockerFixture,
) -> None:
# Arrange
session = MagicMock()
fake_user = _account(id="user-1")
mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id")
mock_pagination.return_value = MagicMock()
# Act
WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=fake_user,
last_id="conv-9",
limit=10,
invoke_from=InvokeFrom.WEB_APP,
pinned=None,
)
# Assert
call_kwargs = mock_pagination.call_args.kwargs
assert call_kwargs["include_ids"] is None
assert call_kwargs["exclude_ids"] is None
assert call_kwargs["last_id"] == "conv-9"
assert call_kwargs["sort_by"] == "-updated_at"
def test_pagination_by_last_id_should_include_only_pinned_ids_when_pinned_true(
app_model: App,
mocker: MockerFixture,
) -> None:
# Arrange
session = MagicMock()
fake_account_cls = type("FakeAccount", (), {})
fake_user = cast(Account, fake_account_cls())
fake_user.id = "account-1"
mocker.patch("services.web_conversation_service.Account", fake_account_cls)
mocker.patch("services.web_conversation_service.EndUser", type("FakeEndUser", (), {}))
session.scalars.return_value.all.return_value = ["conv-1", "conv-2"]
mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id")
mock_pagination.return_value = MagicMock()
# Act
WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=fake_user,
last_id=None,
limit=20,
invoke_from=InvokeFrom.WEB_APP,
pinned=True,
)
# Assert
call_kwargs = mock_pagination.call_args.kwargs
assert call_kwargs["include_ids"] == ["conv-1", "conv-2"]
assert call_kwargs["exclude_ids"] is None
def test_pagination_by_last_id_should_exclude_pinned_ids_when_pinned_false(
app_model: App,
mocker: MockerFixture,
) -> None:
# Arrange
session = MagicMock()
fake_end_user_cls = type("FakeEndUser", (), {})
fake_user = cast(EndUser, fake_end_user_cls())
fake_user.id = "end-user-1"
mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {}))
mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls)
session.scalars.return_value.all.return_value = ["conv-3"]
mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id")
mock_pagination.return_value = MagicMock()
# Act
WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=fake_user,
last_id=None,
limit=20,
invoke_from=InvokeFrom.WEB_APP,
pinned=False,
)
# Assert
call_kwargs = mock_pagination.call_args.kwargs
assert call_kwargs["include_ids"] is None
assert call_kwargs["exclude_ids"] == ["conv-3"]
def test_pin_should_return_early_when_user_is_none(app_model: App, mocker: MockerFixture) -> None:
# Arrange
mock_db = mocker.patch("services.web_conversation_service.db")
mocker.patch("services.web_conversation_service.ConversationService.get_conversation")
# Act
WebConversationService.pin(app_model, "conv-1", None)
# Assert
mock_db.session.add.assert_not_called()
mock_db.session.commit.assert_not_called()
def test_pin_should_return_early_when_conversation_is_already_pinned(
app_model: App,
mocker: MockerFixture,
) -> None:
# Arrange
fake_account_cls = type("FakeAccount", (), {})
fake_user = cast(Account, fake_account_cls())
fake_user.id = "account-1"
mocker.patch("services.web_conversation_service.Account", fake_account_cls)
mock_db = mocker.patch("services.web_conversation_service.db")
mock_db.session.query.return_value.where.return_value.first.return_value = object()
mock_get_conversation = mocker.patch("services.web_conversation_service.ConversationService.get_conversation")
# Act
WebConversationService.pin(app_model, "conv-1", fake_user)
# Assert
mock_get_conversation.assert_not_called()
mock_db.session.add.assert_not_called()
mock_db.session.commit.assert_not_called()
def test_pin_should_create_pinned_conversation_when_not_already_pinned(
app_model: App,
mocker: MockerFixture,
) -> None:
# Arrange
fake_account_cls = type("FakeAccount", (), {})
fake_user = cast(Account, fake_account_cls())
fake_user.id = "account-2"
mocker.patch("services.web_conversation_service.Account", fake_account_cls)
mock_db = mocker.patch("services.web_conversation_service.db")
mock_db.session.query.return_value.where.return_value.first.return_value = None
mock_conversation = SimpleNamespace(id="conv-2")
mock_get_conversation = mocker.patch(
"services.web_conversation_service.ConversationService.get_conversation",
return_value=mock_conversation,
)
# Act
WebConversationService.pin(app_model, "conv-2", fake_user)
# Assert
mock_get_conversation.assert_called_once_with(app_model=app_model, conversation_id="conv-2", user=fake_user)
added_obj = mock_db.session.add.call_args.args[0]
assert added_obj.app_id == "app-1"
assert added_obj.conversation_id == "conv-2"
assert added_obj.created_by_role == "account"
assert added_obj.created_by == "account-2"
mock_db.session.commit.assert_called_once()
def test_unpin_should_return_early_when_user_is_none(app_model: App, mocker: MockerFixture) -> None:
# Arrange
mock_db = mocker.patch("services.web_conversation_service.db")
# Act
WebConversationService.unpin(app_model, "conv-1", None)
# Assert
mock_db.session.delete.assert_not_called()
mock_db.session.commit.assert_not_called()
def test_unpin_should_return_early_when_conversation_is_not_pinned(
app_model: App,
mocker: MockerFixture,
) -> None:
# Arrange
fake_end_user_cls = type("FakeEndUser", (), {})
fake_user = cast(EndUser, fake_end_user_cls())
fake_user.id = "end-user-3"
mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {}))
mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls)
mock_db = mocker.patch("services.web_conversation_service.db")
mock_db.session.query.return_value.where.return_value.first.return_value = None
# Act
WebConversationService.unpin(app_model, "conv-7", fake_user)
# Assert
mock_db.session.delete.assert_not_called()
mock_db.session.commit.assert_not_called()
def test_unpin_should_delete_pinned_conversation_when_exists(
app_model: App,
mocker: MockerFixture,
) -> None:
# Arrange
fake_end_user_cls = type("FakeEndUser", (), {})
fake_user = cast(EndUser, fake_end_user_cls())
fake_user.id = "end-user-4"
mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {}))
mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls)
mock_db = mocker.patch("services.web_conversation_service.db")
pinned_obj = SimpleNamespace(id="pin-1")
mock_db.session.query.return_value.where.return_value.first.return_value = pinned_obj
# Act
WebConversationService.unpin(app_model, "conv-8", fake_user)
# Assert
mock_db.session.delete.assert_called_once_with(pinned_obj)
mock_db.session.commit.assert_called_once()

View File

@@ -0,0 +1,379 @@
from __future__ import annotations
from datetime import UTC, datetime
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from werkzeug.exceptions import NotFound, Unauthorized
from models import Account, AccountStatus
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
ACCOUNT_LOOKUP_PATH = "services.webapp_auth_service.AccountService.get_account_by_email_with_case_fallback"
TOKEN_GENERATE_PATH = "services.webapp_auth_service.TokenManager.generate_token"
TOKEN_GET_DATA_PATH = "services.webapp_auth_service.TokenManager.get_token_data"
def _account(**kwargs: Any) -> Account:
return cast(Account, SimpleNamespace(**kwargs))
@pytest.fixture
def mock_db(mocker: MockerFixture) -> MagicMock:
# Arrange
mocked_db = mocker.patch("services.webapp_auth_service.db")
mocked_db.session = MagicMock()
return mocked_db
def test_authenticate_should_raise_account_not_found_when_email_does_not_exist(mocker: MockerFixture) -> None:
# Arrange
mocker.patch(ACCOUNT_LOOKUP_PATH, return_value=None)
# Act + Assert
with pytest.raises(AccountNotFoundError):
WebAppAuthService.authenticate("user@example.com", "pwd")
def test_authenticate_should_raise_account_login_error_when_account_is_banned(mocker: MockerFixture) -> None:
# Arrange
account = SimpleNamespace(status=AccountStatus.BANNED, password="hash", password_salt="salt")
mocker.patch(
ACCOUNT_LOOKUP_PATH,
return_value=account,
)
# Act + Assert
with pytest.raises(AccountLoginError, match="Account is banned"):
WebAppAuthService.authenticate("user@example.com", "pwd")
@pytest.mark.parametrize("password_value", [None, "hash"])
def test_authenticate_should_raise_password_error_when_password_is_invalid(
password_value: str | None,
mocker: MockerFixture,
) -> None:
# Arrange
account = SimpleNamespace(status=AccountStatus.ACTIVE, password=password_value, password_salt="salt")
mocker.patch(
ACCOUNT_LOOKUP_PATH,
return_value=account,
)
mocker.patch("services.webapp_auth_service.compare_password", return_value=False)
# Act + Assert
with pytest.raises(AccountPasswordError, match="Invalid email or password"):
WebAppAuthService.authenticate("user@example.com", "pwd")
def test_authenticate_should_return_account_when_credentials_are_valid(mocker: MockerFixture) -> None:
# Arrange
account = SimpleNamespace(status=AccountStatus.ACTIVE, password="hash", password_salt="salt")
mocker.patch(
ACCOUNT_LOOKUP_PATH,
return_value=account,
)
mocker.patch("services.webapp_auth_service.compare_password", return_value=True)
# Act
result = WebAppAuthService.authenticate("user@example.com", "pwd")
# Assert
assert result is account
def test_login_should_return_token_from_internal_token_builder(mocker: MockerFixture) -> None:
# Arrange
account = _account(id="a1", email="u@example.com")
mock_get_token = mocker.patch.object(WebAppAuthService, "_get_account_jwt_token", return_value="jwt-token")
# Act
result = WebAppAuthService.login(account)
# Assert
assert result == "jwt-token"
mock_get_token.assert_called_once_with(account=account)
def test_get_user_through_email_should_return_none_when_account_not_found(mocker: MockerFixture) -> None:
# Arrange
mocker.patch(ACCOUNT_LOOKUP_PATH, return_value=None)
# Act
result = WebAppAuthService.get_user_through_email("missing@example.com")
# Assert
assert result is None
def test_get_user_through_email_should_raise_unauthorized_when_account_banned(mocker: MockerFixture) -> None:
# Arrange
account = SimpleNamespace(status=AccountStatus.BANNED)
mocker.patch(
ACCOUNT_LOOKUP_PATH,
return_value=account,
)
# Act + Assert
with pytest.raises(Unauthorized, match="Account is banned"):
WebAppAuthService.get_user_through_email("user@example.com")
def test_get_user_through_email_should_return_account_when_active(mocker: MockerFixture) -> None:
# Arrange
account = SimpleNamespace(status=AccountStatus.ACTIVE)
mocker.patch(
ACCOUNT_LOOKUP_PATH,
return_value=account,
)
# Act
result = WebAppAuthService.get_user_through_email("user@example.com")
# Assert
assert result is account
def test_send_email_code_login_email_should_raise_error_when_email_not_provided() -> None:
# Arrange
# Act + Assert
with pytest.raises(ValueError, match="Email must be provided"):
WebAppAuthService.send_email_code_login_email(account=None, email=None)
def test_send_email_code_login_email_should_generate_token_and_send_mail_for_account(
mocker: MockerFixture,
) -> None:
# Arrange
account = _account(email="user@example.com")
mocker.patch("services.webapp_auth_service.secrets.randbelow", side_effect=[1, 2, 3, 4, 5, 6])
mock_generate_token = mocker.patch(TOKEN_GENERATE_PATH, return_value="token-1")
mock_delay = mocker.patch("services.webapp_auth_service.send_email_code_login_mail_task.delay")
# Act
result = WebAppAuthService.send_email_code_login_email(account=account, language="en-US")
# Assert
assert result == "token-1"
mock_generate_token.assert_called_once()
assert mock_generate_token.call_args.kwargs["additional_data"] == {"code": "123456"}
mock_delay.assert_called_once_with(language="en-US", to="user@example.com", code="123456")
def test_send_email_code_login_email_should_send_mail_for_email_without_account(
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch("services.webapp_auth_service.secrets.randbelow", side_effect=[0, 0, 0, 0, 0, 0])
mocker.patch(TOKEN_GENERATE_PATH, return_value="token-2")
mock_delay = mocker.patch("services.webapp_auth_service.send_email_code_login_mail_task.delay")
# Act
result = WebAppAuthService.send_email_code_login_email(account=None, email="alt@example.com", language="zh-Hans")
# Assert
assert result == "token-2"
mock_delay.assert_called_once_with(language="zh-Hans", to="alt@example.com", code="000000")
def test_get_email_code_login_data_should_delegate_to_token_manager(mocker: MockerFixture) -> None:
# Arrange
mock_get_data = mocker.patch(TOKEN_GET_DATA_PATH, return_value={"code": "123"})
# Act
result = WebAppAuthService.get_email_code_login_data("token-abc")
# Assert
assert result == {"code": "123"}
mock_get_data.assert_called_once_with("token-abc", "email_code_login")
def test_revoke_email_code_login_token_should_delegate_to_token_manager(mocker: MockerFixture) -> None:
# Arrange
mock_revoke = mocker.patch("services.webapp_auth_service.TokenManager.revoke_token")
# Act
WebAppAuthService.revoke_email_code_login_token("token-xyz")
# Assert
mock_revoke.assert_called_once_with("token-xyz", "email_code_login")
def test_create_end_user_should_raise_not_found_when_site_does_not_exist(mock_db: MagicMock) -> None:
# Arrange
mock_db.session.query.return_value.where.return_value.first.return_value = None
# Act + Assert
with pytest.raises(NotFound, match="Site not found"):
WebAppAuthService.create_end_user("app-code", "user@example.com")
def test_create_end_user_should_raise_not_found_when_app_does_not_exist(mock_db: MagicMock) -> None:
# Arrange
site = SimpleNamespace(app_id="app-1")
app_query = MagicMock()
app_query.where.return_value.first.return_value = None
mock_db.session.query.return_value.where.return_value.first.side_effect = [site, None]
# Act + Assert
with pytest.raises(NotFound, match="App not found"):
WebAppAuthService.create_end_user("app-code", "user@example.com")
def test_create_end_user_should_create_and_commit_end_user_when_data_is_valid(mock_db: MagicMock) -> None:
# Arrange
site = SimpleNamespace(app_id="app-1")
app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1")
mock_db.session.query.return_value.where.return_value.first.side_effect = [site, app_model]
# Act
result = WebAppAuthService.create_end_user("app-code", "user@example.com")
# Assert
assert result.tenant_id == "tenant-1"
assert result.app_id == "app-1"
assert result.session_id == "user@example.com"
mock_db.session.add.assert_called_once()
mock_db.session.commit.assert_called_once()
def test_get_account_jwt_token_should_build_payload_and_issue_token(mocker: MockerFixture) -> None:
# Arrange
account = _account(id="a1", email="user@example.com")
mocker.patch("services.webapp_auth_service.dify_config.ACCESS_TOKEN_EXPIRE_MINUTES", 60)
mock_issue = mocker.patch("services.webapp_auth_service.PassportService.issue", return_value="jwt-1")
# Act
token = WebAppAuthService._get_account_jwt_token(account)
# Assert
assert token == "jwt-1"
payload = mock_issue.call_args.args[0]
assert payload["user_id"] == "a1"
assert payload["session_id"] == "user@example.com"
assert payload["token_source"] == "webapp_login_token"
assert payload["auth_type"] == "internal"
assert payload["exp"] > int(datetime.now(UTC).timestamp())
@pytest.mark.parametrize(
("access_mode", "expected"),
[
("private", True),
("private_all", True),
("public", False),
],
)
def test_is_app_require_permission_check_should_use_access_mode_when_provided(
access_mode: str,
expected: bool,
) -> None:
# Arrange
# Act
result = WebAppAuthService.is_app_require_permission_check(access_mode=access_mode)
# Assert
assert result is expected
def test_is_app_require_permission_check_should_raise_when_no_identifier_provided() -> None:
# Arrange
# Act + Assert
with pytest.raises(ValueError, match="Either app_code or app_id must be provided"):
WebAppAuthService.is_app_require_permission_check()
def test_is_app_require_permission_check_should_raise_when_app_id_cannot_be_determined(mocker: MockerFixture) -> None:
# Arrange
mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value=None)
# Act + Assert
with pytest.raises(ValueError, match="App ID could not be determined"):
WebAppAuthService.is_app_require_permission_check(app_code="app-code")
def test_is_app_require_permission_check_should_return_true_when_enterprise_mode_requires_it(
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value="app-1")
mocker.patch(
"services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id",
return_value=SimpleNamespace(access_mode="private"),
)
# Act
result = WebAppAuthService.is_app_require_permission_check(app_code="app-code")
# Assert
assert result is True
def test_is_app_require_permission_check_should_return_false_when_enterprise_settings_do_not_require_it(
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch(
"services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id",
return_value=SimpleNamespace(access_mode="public"),
)
# Act
result = WebAppAuthService.is_app_require_permission_check(app_id="app-1")
# Assert
assert result is False
@pytest.mark.parametrize(
("access_mode", "expected"),
[
("public", WebAppAuthType.PUBLIC),
("private", WebAppAuthType.INTERNAL),
("private_all", WebAppAuthType.INTERNAL),
("sso_verified", WebAppAuthType.EXTERNAL),
],
)
def test_get_app_auth_type_should_map_access_modes_correctly(
access_mode: str,
expected: WebAppAuthType,
) -> None:
# Arrange
# Act
result = WebAppAuthService.get_app_auth_type(access_mode=access_mode)
# Assert
assert result == expected
def test_get_app_auth_type_should_resolve_from_app_code(mocker: MockerFixture) -> None:
# Arrange
mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value="app-1")
mocker.patch(
"services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id",
return_value=SimpleNamespace(access_mode="private_all"),
)
# Act
result = WebAppAuthService.get_app_auth_type(app_code="app-code")
# Assert
assert result == WebAppAuthType.INTERNAL
def test_get_app_auth_type_should_raise_when_no_input_provided() -> None:
# Arrange
# Act + Assert
with pytest.raises(ValueError, match="Either app_code or access_mode must be provided"):
WebAppAuthService.get_app_auth_type()
def test_get_app_auth_type_should_raise_when_cannot_determine_type_from_invalid_mode() -> None:
# Arrange
# Act + Assert
with pytest.raises(ValueError, match="Could not determine app authentication type"):
WebAppAuthService.get_app_auth_type(access_mode="unknown")

View File

@@ -0,0 +1,300 @@
from __future__ import annotations
import json
import uuid
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from dify_graph.enums import WorkflowExecutionStatus
from models import App, WorkflowAppLog
from models.enums import AppTriggerType, CreatorUserRole
from services.workflow_app_service import LogView, WorkflowAppService
@pytest.fixture
def service() -> WorkflowAppService:
# Arrange
return WorkflowAppService()
@pytest.fixture
def app_model() -> App:
# Arrange
return cast(App, SimpleNamespace(id="app-1", tenant_id="tenant-1"))
def _workflow_app_log(**kwargs: Any) -> WorkflowAppLog:
return cast(WorkflowAppLog, SimpleNamespace(**kwargs))
def test_log_view_details_should_return_wrapped_details_and_proxy_attributes() -> None:
# Arrange
log = _workflow_app_log(id="log-1", status="succeeded")
view = LogView(log=log, details={"trigger_metadata": {"type": "plugin"}})
# Act
details = view.details
proxied_status = view.status
# Assert
assert details == {"trigger_metadata": {"type": "plugin"}}
assert proxied_status == "succeeded"
def test_get_paginate_workflow_app_logs_should_return_paginated_summary_when_detail_false(
service: WorkflowAppService,
app_model: App,
) -> None:
# Arrange
session = MagicMock()
log_1 = SimpleNamespace(id="log-1")
log_2 = SimpleNamespace(id="log-2")
session.scalar.return_value = 3
session.scalars.return_value.all.return_value = [log_1, log_2]
# Act
result = service.get_paginate_workflow_app_logs(
session=session,
app_model=app_model,
page=1,
limit=2,
detail=False,
)
# Assert
assert result["page"] == 1
assert result["limit"] == 2
assert result["total"] == 3
assert result["has_more"] is True
assert len(result["data"]) == 2
assert isinstance(result["data"][0], LogView)
assert result["data"][0].details is None
def test_get_paginate_workflow_app_logs_should_return_detailed_rows_when_detail_true(
service: WorkflowAppService,
app_model: App,
mocker: MockerFixture,
) -> None:
# Arrange
session = MagicMock()
session.scalar.side_effect = [1]
log_1 = SimpleNamespace(id="log-1")
session.execute.return_value.all.return_value = [(log_1, '{"type":"trigger_plugin"}')]
mock_handle = mocker.patch.object(
service,
"handle_trigger_metadata",
return_value={"type": "trigger_plugin", "icon": "url"},
)
# Act
result = service.get_paginate_workflow_app_logs(
session=session,
app_model=app_model,
keyword="run-1",
status=WorkflowExecutionStatus.SUCCEEDED,
created_at_before=None,
created_at_after=None,
page=1,
limit=20,
detail=True,
)
# Assert
assert result["total"] == 1
assert len(result["data"]) == 1
assert result["data"][0].details == {"trigger_metadata": {"type": "trigger_plugin", "icon": "url"}}
mock_handle.assert_called_once()
def test_get_paginate_workflow_app_logs_should_raise_when_account_filter_email_not_found(
service: WorkflowAppService,
app_model: App,
) -> None:
# Arrange
session = MagicMock()
session.scalar.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="Account not found: account@example.com"):
service.get_paginate_workflow_app_logs(
session=session,
app_model=app_model,
created_by_account="account@example.com",
)
def test_get_paginate_workflow_app_logs_should_filter_by_account_when_account_exists(
service: WorkflowAppService,
app_model: App,
) -> None:
# Arrange
session = MagicMock()
session.scalar.side_effect = [SimpleNamespace(id="account-1"), 0]
session.scalars.return_value.all.return_value = []
# Act
result = service.get_paginate_workflow_app_logs(
session=session,
app_model=app_model,
created_by_account="account@example.com",
)
# Assert
assert result["total"] == 0
assert result["data"] == []
def test_get_paginate_workflow_archive_logs_should_return_paginated_archive_items(
service: WorkflowAppService,
app_model: App,
) -> None:
# Arrange
session = MagicMock()
log_account = SimpleNamespace(
id="log-1",
created_by="acc-1",
created_by_role=CreatorUserRole.ACCOUNT,
workflow_run_summary={"run": "1"},
trigger_metadata='{"type":"trigger-webhook"}',
log_created_at="2026-01-01",
)
log_end_user = SimpleNamespace(
id="log-2",
created_by="end-1",
created_by_role=CreatorUserRole.END_USER,
workflow_run_summary={"run": "2"},
trigger_metadata='{"type":"trigger-webhook"}',
log_created_at="2026-01-02",
)
log_unknown = SimpleNamespace(
id="log-3",
created_by="other",
created_by_role="system",
workflow_run_summary={"run": "3"},
trigger_metadata='{"type":"trigger-webhook"}',
log_created_at="2026-01-03",
)
session.scalar.return_value = 3
session.scalars.side_effect = [
SimpleNamespace(all=lambda: [log_account, log_end_user, log_unknown]),
SimpleNamespace(all=lambda: [SimpleNamespace(id="acc-1", email="a@example.com")]),
SimpleNamespace(all=lambda: [SimpleNamespace(id="end-1", session_id="session-1")]),
]
# Act
result = service.get_paginate_workflow_archive_logs(
session=session,
app_model=app_model,
page=1,
limit=20,
)
# Assert
assert result["total"] == 3
assert len(result["data"]) == 3
assert result["data"][0]["created_by_account"].id == "acc-1"
assert result["data"][1]["created_by_end_user"].id == "end-1"
assert result["data"][2]["created_by_account"] is None
assert result["data"][2]["created_by_end_user"] is None
def test_handle_trigger_metadata_should_return_empty_dict_when_metadata_missing(
service: WorkflowAppService,
) -> None:
# Arrange
# Act
result = service.handle_trigger_metadata("tenant-1", None)
# Assert
assert result == {}
def test_handle_trigger_metadata_should_enrich_plugin_icons_for_trigger_plugin(
service: WorkflowAppService,
mocker: MockerFixture,
) -> None:
# Arrange
meta = {
"type": AppTriggerType.TRIGGER_PLUGIN.value,
"icon_filename": "light.png",
"icon_dark_filename": "dark.png",
}
mock_icon = mocker.patch(
"services.workflow_app_service.PluginService.get_plugin_icon_url",
side_effect=["https://cdn/light.png", "https://cdn/dark.png"],
)
# Act
result = service.handle_trigger_metadata("tenant-1", json.dumps(meta))
# Assert
assert result["icon"] == "https://cdn/light.png"
assert result["icon_dark"] == "https://cdn/dark.png"
assert mock_icon.call_count == 2
def test_handle_trigger_metadata_should_return_non_plugin_metadata_without_icon_lookup(
service: WorkflowAppService,
mocker: MockerFixture,
) -> None:
# Arrange
meta = {"type": AppTriggerType.TRIGGER_WEBHOOK.value}
mock_icon = mocker.patch("services.workflow_app_service.PluginService.get_plugin_icon_url")
# Act
result = service.handle_trigger_metadata("tenant-1", json.dumps(meta))
# Assert
assert result["type"] == AppTriggerType.TRIGGER_WEBHOOK.value
mock_icon.assert_not_called()
@pytest.mark.parametrize(
("value", "expected"),
[
(None, None),
("", None),
('{"k":"v"}', {"k": "v"}),
("not-json", None),
({"raw": True}, {"raw": True}),
],
)
def test_safe_json_loads_should_handle_various_inputs(
value: object,
expected: object,
service: WorkflowAppService,
) -> None:
# Arrange
# Act
result = service._safe_json_loads(value)
# Assert
assert result == expected
def test_safe_parse_uuid_should_return_none_for_short_or_invalid_values(service: WorkflowAppService) -> None:
# Arrange
# Act
short_result = service._safe_parse_uuid("short")
invalid_result = service._safe_parse_uuid("x" * 40)
# Assert
assert short_result is None
assert invalid_result is None
def test_safe_parse_uuid_should_return_uuid_for_valid_uuid_string(service: WorkflowAppService) -> None:
# Arrange
raw_uuid = str(uuid.uuid4())
# Act
result = service._safe_parse_uuid(raw_uuid)
# Assert
assert result is not None
assert str(result) == raw_uuid

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,576 @@
from __future__ import annotations
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from models.account import Tenant
# ---------------------------------------------------------------------------
# Constants used throughout the tests
# ---------------------------------------------------------------------------
TENANT_ID = "tenant-abc"
ACCOUNT_ID = "account-xyz"
FILES_BASE_URL = "https://files.example.com"
DB_PATH = "services.workspace_service.db"
FEATURE_SERVICE_PATH = "services.workspace_service.FeatureService.get_features"
TENANT_SERVICE_PATH = "services.workspace_service.TenantService.has_roles"
DIFY_CONFIG_PATH = "services.workspace_service.dify_config"
CURRENT_USER_PATH = "services.workspace_service.current_user"
CREDIT_POOL_SERVICE_PATH = "services.credit_pool_service.CreditPoolService.get_pool"
# ---------------------------------------------------------------------------
# Helpers / factories
# ---------------------------------------------------------------------------
def _make_tenant(
tenant_id: str = TENANT_ID,
name: str = "My Workspace",
plan: str = "sandbox",
status: str = "active",
custom_config: dict | None = None,
) -> Tenant:
"""Create a minimal Tenant-like namespace."""
return cast(
Tenant,
SimpleNamespace(
id=tenant_id,
name=name,
plan=plan,
status=status,
created_at="2024-01-01T00:00:00Z",
custom_config_dict=custom_config or {},
),
)
def _make_feature(
can_replace_logo: bool = False,
next_credit_reset_date: str | None = None,
billing_plan: str = "sandbox",
) -> MagicMock:
"""Create a feature namespace matching what FeatureService.get_features returns."""
feature = MagicMock()
feature.can_replace_logo = can_replace_logo
feature.next_credit_reset_date = next_credit_reset_date
feature.billing.subscription.plan = billing_plan
return feature
def _make_pool(quota_limit: int, quota_used: int) -> MagicMock:
pool = MagicMock()
pool.quota_limit = quota_limit
pool.quota_used = quota_used
return pool
def _make_tenant_account_join(role: str = "normal") -> SimpleNamespace:
return SimpleNamespace(role=role)
def _tenant_info(result: object) -> dict[str, Any] | None:
return cast(dict[str, Any] | None, result)
# ---------------------------------------------------------------------------
# Shared fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def mock_current_user() -> SimpleNamespace:
"""Return a lightweight current_user stand-in."""
return SimpleNamespace(id=ACCOUNT_ID)
@pytest.fixture
def basic_mocks(mocker: MockerFixture, mock_current_user: SimpleNamespace) -> dict:
"""
Patch the common external boundaries used by WorkspaceService.get_tenant_info.
Returns a dict of named mocks so individual tests can customise them.
"""
mocker.patch(CURRENT_USER_PATH, mock_current_user)
mock_db_session = mocker.patch(f"{DB_PATH}.session")
mock_query_chain = MagicMock()
mock_db_session.query.return_value = mock_query_chain
mock_query_chain.where.return_value = mock_query_chain
mock_query_chain.first.return_value = _make_tenant_account_join(role="owner")
mock_feature = mocker.patch(FEATURE_SERVICE_PATH, return_value=_make_feature())
mock_has_roles = mocker.patch(TENANT_SERVICE_PATH, return_value=False)
mock_config = mocker.patch(DIFY_CONFIG_PATH)
mock_config.EDITION = "SELF_HOSTED"
mock_config.FILES_URL = FILES_BASE_URL
return {
"db_session": mock_db_session,
"query_chain": mock_query_chain,
"get_features": mock_feature,
"has_roles": mock_has_roles,
"config": mock_config,
}
# ---------------------------------------------------------------------------
# 1. None Tenant Handling
# ---------------------------------------------------------------------------
def test_get_tenant_info_should_return_none_when_tenant_is_none() -> None:
"""get_tenant_info should short-circuit and return None for a falsy tenant."""
from services.workspace_service import WorkspaceService
# Arrange
tenant = None
# Act
result = WorkspaceService.get_tenant_info(cast(Tenant, tenant))
# Assert
assert result is None
def test_get_tenant_info_should_return_none_when_tenant_is_falsy() -> None:
"""get_tenant_info treats any falsy value as absent (e.g. empty string, 0)."""
from services.workspace_service import WorkspaceService
# Arrange / Act / Assert
assert WorkspaceService.get_tenant_info("") is None # type: ignore[arg-type]
# ---------------------------------------------------------------------------
# 2. Basic Tenant Info — happy path
# ---------------------------------------------------------------------------
def test_get_tenant_info_should_return_base_fields(
mocker: MockerFixture,
basic_mocks: dict,
) -> None:
"""get_tenant_info should always return the six base scalar fields."""
from services.workspace_service import WorkspaceService
# Arrange
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert result["id"] == TENANT_ID
assert result["name"] == "My Workspace"
assert result["plan"] == "sandbox"
assert result["status"] == "active"
assert result["created_at"] == "2024-01-01T00:00:00Z"
assert result["trial_end_reason"] is None
def test_get_tenant_info_should_populate_role_from_tenant_account_join(
mocker: MockerFixture,
basic_mocks: dict,
) -> None:
"""The 'role' field should be taken from TenantAccountJoin, not the default."""
from services.workspace_service import WorkspaceService
# Arrange
basic_mocks["query_chain"].first.return_value = _make_tenant_account_join(role="admin")
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert result["role"] == "admin"
def test_get_tenant_info_should_raise_assertion_when_tenant_account_join_missing(
mocker: MockerFixture,
basic_mocks: dict,
) -> None:
"""
The service asserts that TenantAccountJoin exists.
Missing join should raise AssertionError.
"""
from services.workspace_service import WorkspaceService
# Arrange
basic_mocks["query_chain"].first.return_value = None
tenant = _make_tenant()
# Act + Assert
with pytest.raises(AssertionError, match="TenantAccountJoin not found"):
WorkspaceService.get_tenant_info(tenant)
# ---------------------------------------------------------------------------
# 3. Logo Customisation
# ---------------------------------------------------------------------------
def test_get_tenant_info_should_include_custom_config_when_logo_allowed_and_admin(
mocker: MockerFixture,
basic_mocks: dict,
) -> None:
"""custom_config block should appear for OWNER/ADMIN when can_replace_logo is True."""
from services.workspace_service import WorkspaceService
# Arrange
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True)
basic_mocks["has_roles"].return_value = True
tenant = _make_tenant(
custom_config={
"replace_webapp_logo": True,
"remove_webapp_brand": True,
}
)
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert "custom_config" in result
assert result["custom_config"]["remove_webapp_brand"] is True
expected_logo_url = f"{FILES_BASE_URL}/files/workspaces/{TENANT_ID}/webapp-logo"
assert result["custom_config"]["replace_webapp_logo"] == expected_logo_url
def test_get_tenant_info_should_set_replace_webapp_logo_to_none_when_flag_absent(
mocker: MockerFixture,
basic_mocks: dict,
) -> None:
"""replace_webapp_logo should be None when custom_config_dict does not have the key."""
from services.workspace_service import WorkspaceService
# Arrange
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True)
basic_mocks["has_roles"].return_value = True
tenant = _make_tenant(custom_config={}) # no replace_webapp_logo key
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert result["custom_config"]["replace_webapp_logo"] is None
def test_get_tenant_info_should_not_include_custom_config_when_logo_not_allowed(
mocker: MockerFixture,
basic_mocks: dict,
) -> None:
"""custom_config should be absent when can_replace_logo is False."""
from services.workspace_service import WorkspaceService
# Arrange
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=False)
basic_mocks["has_roles"].return_value = True
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert "custom_config" not in result
def test_get_tenant_info_should_not_include_custom_config_when_user_not_admin(
mocker: MockerFixture,
basic_mocks: dict,
) -> None:
"""custom_config block is gated on OWNER or ADMIN role."""
from services.workspace_service import WorkspaceService
# Arrange
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True)
basic_mocks["has_roles"].return_value = False # regular member
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert "custom_config" not in result
def test_get_tenant_info_should_use_files_url_for_logo_url(
mocker: MockerFixture,
basic_mocks: dict,
) -> None:
"""The logo URL should use dify_config.FILES_URL as the base."""
from services.workspace_service import WorkspaceService
# Arrange
custom_base = "https://cdn.mycompany.io"
basic_mocks["config"].FILES_URL = custom_base
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True)
basic_mocks["has_roles"].return_value = True
tenant = _make_tenant(custom_config={"replace_webapp_logo": True})
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert result["custom_config"]["replace_webapp_logo"].startswith(custom_base)
# ---------------------------------------------------------------------------
# 4. Cloud-Edition Credit Features
# ---------------------------------------------------------------------------
CLOUD_BILLING_PLAN_NON_SANDBOX = "professional" # any plan that is not SANDBOX
@pytest.fixture
def cloud_mocks(mocker: MockerFixture, mock_current_user: SimpleNamespace) -> dict:
"""Patches for CLOUD edition tests, billing plan = professional by default."""
mocker.patch(CURRENT_USER_PATH, mock_current_user)
mock_db_session = mocker.patch(f"{DB_PATH}.session")
mock_query_chain = MagicMock()
mock_db_session.query.return_value = mock_query_chain
mock_query_chain.where.return_value = mock_query_chain
mock_query_chain.first.return_value = _make_tenant_account_join(role="owner")
mock_feature = mocker.patch(
FEATURE_SERVICE_PATH,
return_value=_make_feature(
can_replace_logo=False,
next_credit_reset_date="2025-02-01",
billing_plan=CLOUD_BILLING_PLAN_NON_SANDBOX,
),
)
mocker.patch(TENANT_SERVICE_PATH, return_value=False)
mock_config = mocker.patch(DIFY_CONFIG_PATH)
mock_config.EDITION = "CLOUD"
mock_config.FILES_URL = FILES_BASE_URL
return {
"db_session": mock_db_session,
"query_chain": mock_query_chain,
"get_features": mock_feature,
"config": mock_config,
}
def test_get_tenant_info_should_add_next_credit_reset_date_in_cloud_edition(
mocker: MockerFixture,
cloud_mocks: dict,
) -> None:
"""next_credit_reset_date should be present in CLOUD edition."""
from services.workspace_service import WorkspaceService
# Arrange
mocker.patch(
CREDIT_POOL_SERVICE_PATH,
side_effect=[None, None], # both paid and trial pools absent
)
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert result["next_credit_reset_date"] == "2025-02-01"
def test_get_tenant_info_should_use_paid_pool_when_plan_is_not_sandbox_and_pool_not_full(
mocker: MockerFixture,
cloud_mocks: dict,
) -> None:
"""trial_credits/trial_credits_used come from the paid pool when conditions are met."""
from services.workspace_service import WorkspaceService
# Arrange
paid_pool = _make_pool(quota_limit=1000, quota_used=200)
mocker.patch(CREDIT_POOL_SERVICE_PATH, return_value=paid_pool)
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert result["trial_credits"] == 1000
assert result["trial_credits_used"] == 200
def test_get_tenant_info_should_use_paid_pool_when_quota_limit_is_infinite(
mocker: MockerFixture,
cloud_mocks: dict,
) -> None:
"""quota_limit == -1 means unlimited; service should still use the paid pool."""
from services.workspace_service import WorkspaceService
# Arrange
paid_pool = _make_pool(quota_limit=-1, quota_used=999)
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, None])
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert result["trial_credits"] == -1
assert result["trial_credits_used"] == 999
def test_get_tenant_info_should_fall_back_to_trial_pool_when_paid_pool_is_full(
mocker: MockerFixture,
cloud_mocks: dict,
) -> None:
"""When paid pool is exhausted (used >= limit), switch to trial pool."""
from services.workspace_service import WorkspaceService
# Arrange
paid_pool = _make_pool(quota_limit=500, quota_used=500) # exactly full
trial_pool = _make_pool(quota_limit=100, quota_used=10)
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, trial_pool])
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert result["trial_credits"] == 100
assert result["trial_credits_used"] == 10
def test_get_tenant_info_should_fall_back_to_trial_pool_when_paid_pool_is_none(
mocker: MockerFixture,
cloud_mocks: dict,
) -> None:
"""When paid_pool is None, fall back to trial pool."""
from services.workspace_service import WorkspaceService
# Arrange
trial_pool = _make_pool(quota_limit=50, quota_used=5)
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[None, trial_pool])
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert result["trial_credits"] == 50
assert result["trial_credits_used"] == 5
def test_get_tenant_info_should_fall_back_to_trial_pool_for_sandbox_plan(
mocker: MockerFixture,
cloud_mocks: dict,
) -> None:
"""
When the subscription plan IS SANDBOX, the paid pool branch is skipped
entirely and we fall back to the trial pool.
"""
from enums.cloud_plan import CloudPlan
from services.workspace_service import WorkspaceService
# Arrange — override billing plan to SANDBOX
cloud_mocks["get_features"].return_value = _make_feature(
next_credit_reset_date="2025-02-01",
billing_plan=CloudPlan.SANDBOX,
)
paid_pool = _make_pool(quota_limit=1000, quota_used=0)
trial_pool = _make_pool(quota_limit=200, quota_used=20)
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, trial_pool])
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert result["trial_credits"] == 200
assert result["trial_credits_used"] == 20
def test_get_tenant_info_should_omit_trial_credits_when_both_pools_are_none(
mocker: MockerFixture,
cloud_mocks: dict,
) -> None:
"""When both paid and trial pools are absent, trial_credits should not be set."""
from services.workspace_service import WorkspaceService
# Arrange
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[None, None])
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert "trial_credits" not in result
assert "trial_credits_used" not in result
# ---------------------------------------------------------------------------
# 5. Self-hosted / Non-Cloud Edition
# ---------------------------------------------------------------------------
def test_get_tenant_info_should_not_include_cloud_fields_in_self_hosted(
mocker: MockerFixture,
basic_mocks: dict,
) -> None:
"""next_credit_reset_date and trial_credits should NOT appear in SELF_HOSTED mode."""
from services.workspace_service import WorkspaceService
# Arrange (basic_mocks already sets EDITION = "SELF_HOSTED")
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert "next_credit_reset_date" not in result
assert "trial_credits" not in result
assert "trial_credits_used" not in result
# ---------------------------------------------------------------------------
# 6. DB query integrity
# ---------------------------------------------------------------------------
def test_get_tenant_info_should_query_tenant_account_join_with_correct_ids(
mocker: MockerFixture,
basic_mocks: dict,
) -> None:
"""
The DB query for TenantAccountJoin must be scoped to the correct
tenant_id and current_user.id.
"""
from services.workspace_service import WorkspaceService
# Arrange
tenant = _make_tenant(tenant_id="my-special-tenant")
mock_current_user = mocker.patch(CURRENT_USER_PATH)
mock_current_user.id = "special-user-id"
# Act
WorkspaceService.get_tenant_info(tenant)
# Assert — db.session.query was invoked (at least once)
basic_mocks["db_session"].query.assert_called()

View File

@@ -0,0 +1,643 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from core.tools.entities.tool_entities import ApiProviderSchemaType
from services.tools.api_tools_manage_service import ApiToolManageService
@pytest.fixture
def mock_db(mocker: MockerFixture) -> MagicMock:
# Arrange
mocked_db = mocker.patch("services.tools.api_tools_manage_service.db")
mocked_db.session = MagicMock()
return mocked_db
def _tool_bundle(operation_id: str = "tool-1") -> SimpleNamespace:
return SimpleNamespace(operation_id=operation_id)
def test_parser_api_schema_should_return_schema_payload_when_schema_is_valid(mocker: MockerFixture) -> None:
# Arrange
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI.value),
)
# Act
result = ApiToolManageService.parser_api_schema("valid-schema")
# Assert
assert result["schema_type"] == ApiProviderSchemaType.OPENAPI.value
assert len(result["credentials_schema"]) == 3
assert "warning" in result
def test_parser_api_schema_should_raise_value_error_when_parser_raises(mocker: MockerFixture) -> None:
# Arrange
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
side_effect=RuntimeError("bad schema"),
)
# Act + Assert
with pytest.raises(ValueError, match="invalid schema: invalid schema: bad schema"):
ApiToolManageService.parser_api_schema("invalid")
def test_convert_schema_to_tool_bundles_should_return_tool_bundles_when_valid(mocker: MockerFixture) -> None:
# Arrange
expected = ([_tool_bundle("a"), _tool_bundle("b")], ApiProviderSchemaType.SWAGGER)
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
return_value=expected,
)
extra_info: dict[str, str] = {}
# Act
result = ApiToolManageService.convert_schema_to_tool_bundles("schema", extra_info=extra_info)
# Assert
assert result == expected
def test_convert_schema_to_tool_bundles_should_raise_value_error_when_parser_fails(mocker: MockerFixture) -> None:
# Arrange
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
side_effect=ValueError("parse failed"),
)
# Act + Assert
with pytest.raises(ValueError, match="invalid schema: parse failed"):
ApiToolManageService.convert_schema_to_tool_bundles("schema")
def test_create_api_tool_provider_should_raise_error_when_provider_already_exists(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_db.session.query.return_value.where.return_value.first.return_value = object()
# Act + Assert
with pytest.raises(ValueError, match="provider provider-a already exists"):
ApiToolManageService.create_api_tool_provider(
user_id="user-1",
tenant_id="tenant-1",
provider_name=" provider-a ",
icon={"emoji": "X"},
credentials={"auth_type": "none"},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
privacy_policy="privacy",
custom_disclaimer="custom",
labels=[],
)
def test_create_api_tool_provider_should_raise_error_when_tool_count_exceeds_limit(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_db.session.query.return_value.where.return_value.first.return_value = None
many_tools = [_tool_bundle(str(i)) for i in range(101)]
mocker.patch.object(
ApiToolManageService,
"convert_schema_to_tool_bundles",
return_value=(many_tools, ApiProviderSchemaType.OPENAPI),
)
# Act + Assert
with pytest.raises(ValueError, match="the number of apis should be less than 100"):
ApiToolManageService.create_api_tool_provider(
user_id="user-1",
tenant_id="tenant-1",
provider_name="provider-a",
icon={"emoji": "X"},
credentials={"auth_type": "none"},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
privacy_policy="privacy",
custom_disclaimer="custom",
labels=[],
)
def test_create_api_tool_provider_should_raise_error_when_auth_type_is_missing(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_db.session.query.return_value.where.return_value.first.return_value = None
mocker.patch.object(
ApiToolManageService,
"convert_schema_to_tool_bundles",
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI),
)
# Act + Assert
with pytest.raises(ValueError, match="auth_type is required"):
ApiToolManageService.create_api_tool_provider(
user_id="user-1",
tenant_id="tenant-1",
provider_name="provider-a",
icon={"emoji": "X"},
credentials={},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
privacy_policy="privacy",
custom_disclaimer="custom",
labels=[],
)
def test_create_api_tool_provider_should_create_provider_when_input_is_valid(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_db.session.query.return_value.where.return_value.first.return_value = None
mocker.patch.object(
ApiToolManageService,
"convert_schema_to_tool_bundles",
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI),
)
mock_controller = MagicMock()
mocker.patch(
"services.tools.api_tools_manage_service.ApiToolProviderController.from_db",
return_value=mock_controller,
)
mock_encrypter = MagicMock()
mock_encrypter.encrypt.return_value = {"auth_type": "none"}
mocker.patch(
"services.tools.api_tools_manage_service.create_tool_provider_encrypter",
return_value=(mock_encrypter, MagicMock()),
)
mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.update_tool_labels")
# Act
result = ApiToolManageService.create_api_tool_provider(
user_id="user-1",
tenant_id="tenant-1",
provider_name="provider-a",
icon={"emoji": "X"},
credentials={"auth_type": "none"},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
privacy_policy="privacy",
custom_disclaimer="custom",
labels=["news"],
)
# Assert
assert result == {"result": "success"}
mock_controller.load_bundled_tools.assert_called_once()
mock_db.session.add.assert_called_once()
mock_db.session.commit.assert_called_once()
def test_get_api_tool_provider_remote_schema_should_return_schema_when_response_is_valid(
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch(
"services.tools.api_tools_manage_service.get",
return_value=SimpleNamespace(status_code=200, text="schema-content"),
)
mocker.patch.object(ApiToolManageService, "parser_api_schema", return_value={"ok": True})
# Act
result = ApiToolManageService.get_api_tool_provider_remote_schema("user-1", "tenant-1", "https://schema")
# Assert
assert result == {"schema": "schema-content"}
@pytest.mark.parametrize("status_code", [400, 404, 500])
def test_get_api_tool_provider_remote_schema_should_raise_error_when_remote_fetch_is_invalid(
status_code: int,
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch(
"services.tools.api_tools_manage_service.get",
return_value=SimpleNamespace(status_code=status_code, text="schema-content"),
)
mock_logger = mocker.patch("services.tools.api_tools_manage_service.logger")
# Act + Assert
with pytest.raises(ValueError, match="invalid schema, please check the url you provided"):
ApiToolManageService.get_api_tool_provider_remote_schema("user-1", "tenant-1", "https://schema")
mock_logger.exception.assert_called_once()
def test_list_api_tool_provider_tools_should_raise_error_when_provider_not_found(
mock_db: MagicMock,
) -> None:
# Arrange
mock_db.session.query.return_value.where.return_value.first.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="you have not added provider provider-a"):
ApiToolManageService.list_api_tool_provider_tools("user-1", "tenant-1", "provider-a")
def test_list_api_tool_provider_tools_should_return_converted_tools_when_provider_exists(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
provider = SimpleNamespace(tools=[_tool_bundle("tool-a"), _tool_bundle("tool-b")])
mock_db.session.query.return_value.where.return_value.first.return_value = provider
controller = MagicMock()
mocker.patch(
"services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_controller",
return_value=controller,
)
mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.get_tool_labels", return_value=["search"])
mock_convert = mocker.patch(
"services.tools.api_tools_manage_service.ToolTransformService.convert_tool_entity_to_api_entity",
side_effect=[{"name": "tool-a"}, {"name": "tool-b"}],
)
# Act
result = ApiToolManageService.list_api_tool_provider_tools("user-1", "tenant-1", "provider-a")
# Assert
assert result == [{"name": "tool-a"}, {"name": "tool-b"}]
assert mock_convert.call_count == 2
def test_update_api_tool_provider_should_raise_error_when_original_provider_not_found(
mock_db: MagicMock,
) -> None:
# Arrange
mock_db.session.query.return_value.where.return_value.first.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="api provider provider-a does not exists"):
ApiToolManageService.update_api_tool_provider(
user_id="user-1",
tenant_id="tenant-1",
provider_name="provider-a",
original_provider="provider-a",
icon={},
credentials={"auth_type": "none"},
_schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
privacy_policy=None,
custom_disclaimer="custom",
labels=[],
)
def test_update_api_tool_provider_should_raise_error_when_auth_type_missing(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
provider = SimpleNamespace(credentials={}, name="old")
mock_db.session.query.return_value.where.return_value.first.return_value = provider
mocker.patch.object(
ApiToolManageService,
"convert_schema_to_tool_bundles",
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI),
)
# Act + Assert
with pytest.raises(ValueError, match="auth_type is required"):
ApiToolManageService.update_api_tool_provider(
user_id="user-1",
tenant_id="tenant-1",
provider_name="provider-a",
original_provider="provider-a",
icon={},
credentials={},
_schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
privacy_policy=None,
custom_disclaimer="custom",
labels=[],
)
def test_update_api_tool_provider_should_update_provider_and_preserve_masked_credentials(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
provider = SimpleNamespace(
credentials={"auth_type": "none", "api_key_value": "encrypted-old"},
name="old",
icon="",
schema="",
description="",
schema_type_str="",
tools_str="",
privacy_policy="",
custom_disclaimer="",
credentials_str="",
)
mock_db.session.query.return_value.where.return_value.first.return_value = provider
mocker.patch.object(
ApiToolManageService,
"convert_schema_to_tool_bundles",
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI),
)
controller = MagicMock()
mocker.patch(
"services.tools.api_tools_manage_service.ApiToolProviderController.from_db",
return_value=controller,
)
cache = MagicMock()
encrypter = MagicMock()
encrypter.decrypt.return_value = {"auth_type": "none", "api_key_value": "plain-old"}
encrypter.mask_plugin_credentials.return_value = {"api_key_value": "***"}
encrypter.encrypt.return_value = {"auth_type": "none", "api_key_value": "encrypted-new"}
mocker.patch(
"services.tools.api_tools_manage_service.create_tool_provider_encrypter",
return_value=(encrypter, cache),
)
mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.update_tool_labels")
# Act
result = ApiToolManageService.update_api_tool_provider(
user_id="user-1",
tenant_id="tenant-1",
provider_name="provider-new",
original_provider="provider-old",
icon={"emoji": "E"},
credentials={"auth_type": "none", "api_key_value": "***"},
_schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
privacy_policy="privacy",
custom_disclaimer="custom",
labels=["news"],
)
# Assert
assert result == {"result": "success"}
assert provider.name == "provider-new"
assert provider.privacy_policy == "privacy"
assert provider.credentials_str != ""
cache.delete.assert_called_once()
mock_db.session.commit.assert_called_once()
def test_delete_api_tool_provider_should_raise_error_when_provider_missing(mock_db: MagicMock) -> None:
# Arrange
mock_db.session.query.return_value.where.return_value.first.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="you have not added provider provider-a"):
ApiToolManageService.delete_api_tool_provider("user-1", "tenant-1", "provider-a")
def test_delete_api_tool_provider_should_delete_provider_when_exists(mock_db: MagicMock) -> None:
# Arrange
provider = object()
mock_db.session.query.return_value.where.return_value.first.return_value = provider
# Act
result = ApiToolManageService.delete_api_tool_provider("user-1", "tenant-1", "provider-a")
# Assert
assert result == {"result": "success"}
mock_db.session.delete.assert_called_once_with(provider)
mock_db.session.commit.assert_called_once()
def test_get_api_tool_provider_should_delegate_to_tool_manager(mocker: MockerFixture) -> None:
# Arrange
expected = {"provider": "value"}
mock_get = mocker.patch(
"services.tools.api_tools_manage_service.ToolManager.user_get_api_provider",
return_value=expected,
)
# Act
result = ApiToolManageService.get_api_tool_provider("user-1", "tenant-1", "provider-a")
# Assert
assert result == expected
mock_get.assert_called_once_with(provider="provider-a", tenant_id="tenant-1")
def test_test_api_tool_preview_should_raise_error_for_invalid_schema_type() -> None:
# Arrange
schema_type = "bad-schema-type"
# Act + Assert
with pytest.raises(ValueError, match="invalid schema type"):
ApiToolManageService.test_api_tool_preview(
tenant_id="tenant-1",
provider_name="provider-a",
tool_name="tool-a",
credentials={"auth_type": "none"},
parameters={},
schema_type=schema_type, # type: ignore[arg-type]
schema="schema",
)
def test_test_api_tool_preview_should_raise_error_when_schema_parser_fails(mocker: MockerFixture) -> None:
# Arrange
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
side_effect=RuntimeError("invalid"),
)
# Act + Assert
with pytest.raises(ValueError, match="invalid schema"):
ApiToolManageService.test_api_tool_preview(
tenant_id="tenant-1",
provider_name="provider-a",
tool_name="tool-a",
credentials={"auth_type": "none"},
parameters={},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
)
def test_test_api_tool_preview_should_raise_error_when_tool_name_is_invalid(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI),
)
mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace(id="provider-id")
# Act + Assert
with pytest.raises(ValueError, match="invalid tool name tool-b"):
ApiToolManageService.test_api_tool_preview(
tenant_id="tenant-1",
provider_name="provider-a",
tool_name="tool-b",
credentials={"auth_type": "none"},
parameters={},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
)
def test_test_api_tool_preview_should_raise_error_when_auth_type_missing(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI),
)
mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace(id="provider-id")
# Act + Assert
with pytest.raises(ValueError, match="auth_type is required"):
ApiToolManageService.test_api_tool_preview(
tenant_id="tenant-1",
provider_name="provider-a",
tool_name="tool-a",
credentials={},
parameters={},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
)
def test_test_api_tool_preview_should_return_error_payload_when_tool_validation_raises(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
db_provider = SimpleNamespace(id="provider-id", credentials={"auth_type": "none"})
mock_db.session.query.return_value.where.return_value.first.return_value = db_provider
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI),
)
provider_controller = MagicMock()
tool_obj = MagicMock()
tool_obj.fork_tool_runtime.return_value = tool_obj
tool_obj.validate_credentials.side_effect = ValueError("validation failed")
provider_controller.get_tool.return_value = tool_obj
mocker.patch(
"services.tools.api_tools_manage_service.ApiToolProviderController.from_db",
return_value=provider_controller,
)
mock_encrypter = MagicMock()
mock_encrypter.decrypt.return_value = {"auth_type": "none"}
mock_encrypter.mask_plugin_credentials.return_value = {}
mocker.patch(
"services.tools.api_tools_manage_service.create_tool_provider_encrypter",
return_value=(mock_encrypter, MagicMock()),
)
# Act
result = ApiToolManageService.test_api_tool_preview(
tenant_id="tenant-1",
provider_name="provider-a",
tool_name="tool-a",
credentials={"auth_type": "none"},
parameters={},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
)
# Assert
assert result == {"error": "validation failed"}
def test_test_api_tool_preview_should_return_result_payload_when_validation_succeeds(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
db_provider = SimpleNamespace(id="provider-id", credentials={"auth_type": "none"})
mock_db.session.query.return_value.where.return_value.first.return_value = db_provider
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI),
)
provider_controller = MagicMock()
tool_obj = MagicMock()
tool_obj.fork_tool_runtime.return_value = tool_obj
tool_obj.validate_credentials.return_value = {"ok": True}
provider_controller.get_tool.return_value = tool_obj
mocker.patch(
"services.tools.api_tools_manage_service.ApiToolProviderController.from_db",
return_value=provider_controller,
)
mock_encrypter = MagicMock()
mock_encrypter.decrypt.return_value = {"auth_type": "none"}
mock_encrypter.mask_plugin_credentials.return_value = {}
mocker.patch(
"services.tools.api_tools_manage_service.create_tool_provider_encrypter",
return_value=(mock_encrypter, MagicMock()),
)
# Act
result = ApiToolManageService.test_api_tool_preview(
tenant_id="tenant-1",
provider_name="provider-a",
tool_name="tool-a",
credentials={"auth_type": "none"},
parameters={"x": "1"},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
)
# Assert
assert result == {"result": {"ok": True}}
def test_list_api_tools_should_return_all_user_providers_with_converted_tools(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
provider_one = SimpleNamespace(name="p1")
provider_two = SimpleNamespace(name="p2")
mock_db.session.scalars.return_value.all.return_value = [provider_one, provider_two]
controller_one = MagicMock()
controller_one.get_tools.return_value = ["tool-a"]
controller_two = MagicMock()
controller_two.get_tools.return_value = ["tool-b", "tool-c"]
user_provider_one = SimpleNamespace(labels=[], tools=[])
user_provider_two = SimpleNamespace(labels=[], tools=[])
mocker.patch(
"services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_controller",
side_effect=[controller_one, controller_two],
)
mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.get_tool_labels", return_value=["news"])
mocker.patch(
"services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_user_provider",
side_effect=[user_provider_one, user_provider_two],
)
mocker.patch("services.tools.api_tools_manage_service.ToolTransformService.repack_provider")
mock_convert = mocker.patch(
"services.tools.api_tools_manage_service.ToolTransformService.convert_tool_entity_to_api_entity",
side_effect=[{"name": "tool-a"}, {"name": "tool-b"}, {"name": "tool-c"}],
)
# Act
result = ApiToolManageService.list_api_tools("tenant-1")
# Assert
assert len(result) == 2
assert user_provider_one.tools == [{"name": "tool-a"}]
assert user_provider_two.tools == [{"name": "tool-b"}, {"name": "tool-c"}]
assert mock_convert.call_count == 3

File diff suppressed because it is too large Load Diff

View File

@@ -300,7 +300,7 @@
"segment.collapseChunks": "طي القطع",
"segment.contentEmpty": "لا يمكن أن يكون المحتوى فارغًا",
"segment.contentPlaceholder": "أضف المحتوى هنا",
"segment.dateTimeFormat": "MM/DD/YYYY h:mm",
"segment.dateTimeFormat": "MM/DD/YYYY HH:mm",
"segment.delete": "حذف هذه القطعة؟",
"segment.editChildChunk": "تعديل القطعة الفرعية",
"segment.editChunk": "تعديل القطعة",

View File

@@ -300,7 +300,7 @@
"segment.collapseChunks": "Blöcke reduzieren",
"segment.contentEmpty": "Inhalt darf nicht leer sein",
"segment.contentPlaceholder": "Inhalt hier hinzufügen",
"segment.dateTimeFormat": "MM/DD/YYYY h:mm",
"segment.dateTimeFormat": "MM/DD/YYYY HH:mm",
"segment.delete": "Diesen Chunk löschen?",
"segment.editChildChunk": "Untergeordneten Block bearbeiten",
"segment.editChunk": "Chunk bearbeiten",

View File

@@ -300,7 +300,7 @@
"segment.collapseChunks": "Collapse chunks",
"segment.contentEmpty": "Content can not be empty",
"segment.contentPlaceholder": "Add content here",
"segment.dateTimeFormat": "MM/DD/YYYY h:mm",
"segment.dateTimeFormat": "MM/DD/YYYY HH:mm",
"segment.delete": "Delete this chunk ?",
"segment.editChildChunk": "Edit Child Chunk",
"segment.editChunk": "Edit Chunk",

View File

@@ -300,7 +300,7 @@
"segment.collapseChunks": "Contraer fragmentos",
"segment.contentEmpty": "El contenido no puede estar vacío",
"segment.contentPlaceholder": "agregar contenido aquí",
"segment.dateTimeFormat": "MM/DD/YYYY h:mm",
"segment.dateTimeFormat": "MM/DD/YYYY HH:mm",
"segment.delete": "¿Eliminar este fragmento?",
"segment.editChildChunk": "Editar fragmento secundario",
"segment.editChunk": "Editar fragmento",

View File

@@ -300,7 +300,7 @@
"segment.collapseChunks": "جمع کردن تکه ها",
"segment.contentEmpty": "محتوا نمی‌تواند خالی باشد",
"segment.contentPlaceholder": "محتوا را اینجا اضافه کنید",
"segment.dateTimeFormat": "MM/DD/YYYY h:mm",
"segment.dateTimeFormat": "MM/DD/YYYY HH:mm",
"segment.delete": "حذف این قطعه؟",
"segment.editChildChunk": "ویرایش Child Chunk",
"segment.editChunk": "ویرایش تکه",

View File

@@ -300,7 +300,7 @@
"segment.collapseChunks": "चंक्स संक्षिप्त करें",
"segment.contentEmpty": "सामग्री खाली नहीं हो सकती",
"segment.contentPlaceholder": "यहाँ सामग्री जोड़ें",
"segment.dateTimeFormat": "MM/DD/YYYY h:mm",
"segment.dateTimeFormat": "MM/DD/YYYY HH:mm",
"segment.delete": "इस खंड को हटाएँ ?",
"segment.editChildChunk": "संपादित करें बाल चंक",
"segment.editChunk": "चंक संपादित करें",

View File

@@ -300,7 +300,7 @@
"segment.collapseChunks": "Ciutkan potongan",
"segment.contentEmpty": "Konten tidak boleh kosong",
"segment.contentPlaceholder": "Tambahkan konten di sini",
"segment.dateTimeFormat": "MM / DD / YYYY h: mm",
"segment.dateTimeFormat": "MM/DD/YYYY HH:mm",
"segment.delete": "Hapus potongan ini ?",
"segment.editChildChunk": "Edit Potongan Anak",
"segment.editChunk": "Edit Potongan",

View File

@@ -300,7 +300,7 @@
"segment.collapseChunks": "청크 축소",
"segment.contentEmpty": "내용을 비워둘 수 없습니다",
"segment.contentPlaceholder": "내용을 입력하세요",
"segment.dateTimeFormat": "MM/DD/YYYY h:mm",
"segment.dateTimeFormat": "MM/DD/YYYY HH:mm",
"segment.delete": "이 청크를 삭제하시겠습니까?",
"segment.editChildChunk": "자손 청크 편집 (Edit Child Chunk)",
"segment.editChunk": "청크 편집 (Edit Chunk)",

View File

@@ -300,7 +300,7 @@
"segment.collapseChunks": "Collapse chunks",
"segment.contentEmpty": "Content can not be empty",
"segment.contentPlaceholder": "Add content here",
"segment.dateTimeFormat": "MM/DD/YYYY h:mm",
"segment.dateTimeFormat": "MM/DD/YYYY HH:mm",
"segment.delete": "Delete this chunk ?",
"segment.editChildChunk": "Edit Child Chunk",
"segment.editChunk": "Edit Chunk",

View File

@@ -300,7 +300,7 @@
"segment.collapseChunks": "Zwijanie fragmentów",
"segment.contentEmpty": "Treść nie może być pusta",
"segment.contentPlaceholder": "dodaj treść tutaj",
"segment.dateTimeFormat": "MM/DD/YYYY h:mm",
"segment.dateTimeFormat": "MM/DD/YYYY HH:mm",
"segment.delete": "Usunąć ten fragment?",
"segment.editChildChunk": "Edytuj fragment podrzędny",
"segment.editChunk": "Edytuj fragment",

View File

@@ -300,7 +300,7 @@
"segment.collapseChunks": "Restrângerea bucăților",
"segment.contentEmpty": "Conținutul nu poate fi gol",
"segment.contentPlaceholder": "adăugați conținutul aici",
"segment.dateTimeFormat": "MM/DD/YYYY h:mm",
"segment.dateTimeFormat": "MM/DD/YYYY HH:mm",
"segment.delete": "Ștergeți acest fragment?",
"segment.editChildChunk": "Editați fragmentul copil",
"segment.editChunk": "Editați bucata",

View File

@@ -300,7 +300,7 @@
"segment.collapseChunks": "Сворачивание кусков",
"segment.contentEmpty": "Содержимое не может быть пустым",
"segment.contentPlaceholder": "добавьте содержимое здесь",
"segment.dateTimeFormat": "MM/DD/YYYY h:mm",
"segment.dateTimeFormat": "MM/DD/YYYY HH:mm",
"segment.delete": "Удалить этот фрагмент?",
"segment.editChildChunk": "Редактирование дочернего фрагмента",
"segment.editChunk": "Редактировать фрагмент",

View File

@@ -300,7 +300,7 @@
"segment.collapseChunks": "Strniti koščke",
"segment.contentEmpty": "Vsebina ne sme biti prazna",
"segment.contentPlaceholder": "dodajte vsebino tukaj",
"segment.dateTimeFormat": "MM/DD/YYYY h:mm",
"segment.dateTimeFormat": "MM/DD/YYYY HH:mm",
"segment.delete": "Izbriši ta del?",
"segment.editChildChunk": "Urejanje podrejenega kosa",
"segment.editChunk": "Uredi kos",

View File

@@ -300,7 +300,7 @@
"segment.collapseChunks": "ยุบก้อน",
"segment.contentEmpty": "เนื้อหาไม่สามารถว่างเปล่าได้",
"segment.contentPlaceholder": "เพิ่มเนื้อหาที่นี่",
"segment.dateTimeFormat": "MM/DD/YYYY h:mm",
"segment.dateTimeFormat": "MM/DD/YYYY HH:mm",
"segment.delete": "ลบส่วนนี้ ?",
"segment.editChildChunk": "แก้ไข Child Chunk",
"segment.editChunk": "แก้ไขก้อน",

View File

@@ -300,7 +300,7 @@
"segment.collapseChunks": "Parçaları daraltma",
"segment.contentEmpty": "İçerik boş olamaz",
"segment.contentPlaceholder": "içeriği buraya ekleyin",
"segment.dateTimeFormat": "MM/DD/YYYY h:mm",
"segment.dateTimeFormat": "MM/DD/YYYY HH:mm",
"segment.delete": "Bu parçayı silmek istiyor musunuz?",
"segment.editChildChunk": "Alt Parçayı Düzenle",
"segment.editChunk": "Yığını Düzenle",

View File

@@ -10,6 +10,7 @@ const nextConfig: NextConfig = {
basePath: env.NEXT_PUBLIC_BASE_PATH,
transpilePackages: ['@t3-oss/env-core', '@t3-oss/env-nextjs', 'echarts', 'zrender'],
turbopack: {
root: process.cwd(),
rules: codeInspectorPlugin({
bundler: 'turbopack',
}),