chore: fix use select style api in orm (#35531)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: WH-2099 <wh2099@pm.me>
This commit is contained in:
Asuka Minato
2026-04-24 17:35:20 +09:00
committed by GitHub
parent 0baefa6163
commit c3aebb8403
57 changed files with 347 additions and 5757 deletions

View File

@@ -225,8 +225,10 @@ class TestSpanBuilder:
span = builder.build_span(span_data)
assert isinstance(span, ReadableSpan)
assert span.name == "test-span"
assert span.context is not None
assert span.context.trace_id == 123
assert span.context.span_id == 456
assert span.parent is not None
assert span.parent.span_id == 789
assert span.resource == resource
assert span.attributes == {"attr1": "val1"}

View File

@@ -64,12 +64,13 @@ class TestSpanData:
def test_span_data_missing_required_fields(self):
with pytest.raises(ValidationError):
SpanData(
trace_id=123,
# span_id missing
name="test_span",
start_time=1000,
end_time=2000,
SpanData.model_validate(
{
"trace_id": 123,
"name": "test_span",
"start_time": 1000,
"end_time": 2000,
}
)
def test_span_data_arbitrary_types_allowed(self):

View File

@@ -2,12 +2,14 @@ from __future__ import annotations
from datetime import UTC, datetime
from types import SimpleNamespace
from typing import cast
from unittest.mock import MagicMock
import dify_trace_aliyun.aliyun_trace as aliyun_trace_module
import pytest
from dify_trace_aliyun.aliyun_trace import AliyunDataTrace
from dify_trace_aliyun.config import AliyunConfig
from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata
from dify_trace_aliyun.entities.semconv import (
GEN_AI_COMPLETION,
GEN_AI_INPUT_MESSAGE,
@@ -44,7 +46,7 @@ class RecordingTraceClient:
self.endpoint = endpoint
self.added_spans: list[object] = []
def add_span(self, span) -> None:
def add_span(self, span: object) -> None:
self.added_spans.append(span)
def api_check(self) -> bool:
@@ -63,11 +65,35 @@ def _make_link(trace_id: int = 1, span_id: int = 2) -> Link:
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags.SAMPLED,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
)
return Link(context)
def _make_trace_metadata(
trace_id: int = 1,
workflow_span_id: int = 2,
session_id: str = "s",
user_id: str = "u",
links: list[Link] | None = None,
) -> TraceMetadata:
return TraceMetadata(
trace_id=trace_id,
workflow_span_id=workflow_span_id,
session_id=session_id,
user_id=user_id,
links=[] if links is None else links,
)
def _recording_trace_client(trace_instance: AliyunDataTrace) -> RecordingTraceClient:
return cast(RecordingTraceClient, trace_instance.trace_client)
def _recorded_span_data(trace_instance: AliyunDataTrace) -> list[SpanData]:
return cast(list[SpanData], _recording_trace_client(trace_instance).added_spans)
def _make_workflow_trace_info(**overrides) -> WorkflowTraceInfo:
defaults = {
"workflow_id": "workflow-id",
@@ -263,20 +289,20 @@ def test_workflow_trace_adds_workflow_and_node_spans(trace_instance: AliyunDataT
trace_instance.workflow_trace(trace_info)
add_workflow_span.assert_called_once()
passed_trace_metadata = add_workflow_span.call_args.args[1]
passed_trace_metadata = cast(TraceMetadata, add_workflow_span.call_args.args[1])
assert passed_trace_metadata.trace_id == 111
assert passed_trace_metadata.workflow_span_id == 222
assert passed_trace_metadata.session_id == "c"
assert passed_trace_metadata.user_id == "u"
assert passed_trace_metadata.links == []
assert trace_instance.trace_client.added_spans == ["span-1", "span-2"]
assert _recording_trace_client(trace_instance).added_spans == ["span-1", "span-2"]
def test_message_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
trace_info = _make_message_trace_info(message_data=None)
trace_instance.message_trace(trace_info)
assert trace_instance.trace_client.added_spans == []
assert _recording_trace_client(trace_instance).added_spans == []
def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
@@ -302,8 +328,9 @@ def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataT
)
trace_instance.message_trace(trace_info)
assert len(trace_instance.trace_client.added_spans) == 2
message_span, llm_span = trace_instance.trace_client.added_spans
spans = _recorded_span_data(trace_instance)
assert len(spans) == 2
message_span, llm_span = spans
assert message_span.name == "message"
assert message_span.trace_id == 10
@@ -324,7 +351,7 @@ def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataT
def test_dataset_retrieval_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
trace_info = _make_dataset_retrieval_trace_info(message_data=None)
trace_instance.dataset_retrieval_trace(trace_info)
assert trace_instance.trace_client.added_spans == []
assert _recording_trace_client(trace_instance).added_spans == []
def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
@@ -338,8 +365,9 @@ def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, m
monkeypatch.setattr(aliyun_trace_module, "extract_retrieval_documents", lambda _: [{"doc": "d"}])
trace_instance.dataset_retrieval_trace(_make_dataset_retrieval_trace_info(inputs="query"))
assert len(trace_instance.trace_client.added_spans) == 1
span = trace_instance.trace_client.added_spans[0]
spans = _recorded_span_data(trace_instance)
assert len(spans) == 1
span = spans[0]
assert span.name == "dataset_retrieval"
assert span.attributes[RETRIEVAL_QUERY] == "query"
assert span.attributes[RETRIEVAL_DOCUMENT] == '[{"doc": "d"}]'
@@ -348,7 +376,7 @@ def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, m
def test_tool_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
trace_info = _make_tool_trace_info(message_data=None)
trace_instance.tool_trace(trace_info)
assert trace_instance.trace_client.added_spans == []
assert _recording_trace_client(trace_instance).added_spans == []
def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
@@ -371,8 +399,9 @@ def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: p
)
)
assert len(trace_instance.trace_client.added_spans) == 1
span = trace_instance.trace_client.added_spans[0]
spans = _recorded_span_data(trace_instance)
assert len(spans) == 1
span = spans[0]
assert span.name == "my-tool"
assert span.status == status
assert span.attributes[TOOL_NAME] == "my-tool"
@@ -409,7 +438,7 @@ def test_get_workflow_node_executions_builds_repo_and_fetches(
def test_build_workflow_node_span_routes_llm_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_llm_span", MagicMock(return_value="llm"))
@@ -422,7 +451,7 @@ def test_build_workflow_node_span_routes_knowledge_retrieval_type(
):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_retrieval_span", MagicMock(return_value="retrieval"))
@@ -433,7 +462,7 @@ def test_build_workflow_node_span_routes_knowledge_retrieval_type(
def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_tool_span", MagicMock(return_value="tool"))
@@ -444,7 +473,7 @@ def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTra
def test_build_workflow_node_span_routes_code_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(return_value="task"))
@@ -457,7 +486,7 @@ def test_build_workflow_node_span_handles_errors(
):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(side_effect=RuntimeError("boom")))
node_execution.node_type = BuiltinNodeTypes.CODE
@@ -472,7 +501,7 @@ def test_build_workflow_task_span(trace_instance: AliyunDataTrace, monkeypatch:
status = Status(StatusCode.OK)
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
trace_metadata = _make_trace_metadata()
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node-id"
node_execution.title = "title"
@@ -494,7 +523,7 @@ def test_build_workflow_tool_span(trace_instance: AliyunDataTrace, monkeypatch:
status = Status(StatusCode.OK)
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[_make_link()])
trace_metadata = _make_trace_metadata(links=[_make_link()])
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node-id"
node_execution.title = "my-tool"
@@ -527,7 +556,7 @@ def test_build_workflow_retrieval_span(trace_instance: AliyunDataTrace, monkeypa
aliyun_trace_module, "format_retrieval_documents", lambda docs: [{"formatted": True}] if docs else []
)
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
trace_metadata = _make_trace_metadata()
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node-id"
node_execution.title = "retrieval"
@@ -556,7 +585,7 @@ def test_build_workflow_llm_span(trace_instance: AliyunDataTrace, monkeypatch: p
monkeypatch.setattr(aliyun_trace_module, "format_input_messages", lambda _: "in")
monkeypatch.setattr(aliyun_trace_module, "format_output_messages", lambda _: "out")
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
trace_metadata = _make_trace_metadata()
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node-id"
node_execution.title = "llm"
@@ -594,7 +623,7 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
status = Status(StatusCode.OK)
monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status)
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
trace_metadata = _make_trace_metadata()
# CASE 1: With message_id
trace_info = _make_workflow_trace_info(
@@ -602,9 +631,11 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
)
trace_instance.add_workflow_span(trace_info, trace_metadata)
assert len(trace_instance.trace_client.added_spans) == 2
message_span = trace_instance.trace_client.added_spans[0]
workflow_span = trace_instance.trace_client.added_spans[1]
client = _recording_trace_client(trace_instance)
spans = _recorded_span_data(trace_instance)
assert len(spans) == 2
message_span = spans[0]
workflow_span = spans[1]
assert message_span.name == "message"
assert message_span.span_kind == SpanKind.SERVER
@@ -614,13 +645,14 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
assert workflow_span.span_kind == SpanKind.INTERNAL
assert workflow_span.parent_span_id == 20
trace_instance.trace_client.added_spans.clear()
client.added_spans.clear()
# CASE 2: Without message_id
trace_info_no_msg = _make_workflow_trace_info(message_id=None)
trace_instance.add_workflow_span(trace_info_no_msg, trace_metadata)
assert len(trace_instance.trace_client.added_spans) == 1
span = trace_instance.trace_client.added_spans[0]
spans = _recorded_span_data(trace_instance)
assert len(spans) == 1
span = spans[0]
assert span.name == "workflow"
assert span.span_kind == SpanKind.SERVER
assert span.parent_span_id is None
@@ -641,7 +673,8 @@ def test_suggested_question_trace(trace_instance: AliyunDataTrace, monkeypatch:
trace_info = _make_suggested_question_trace_info(suggested_question=["how?"])
trace_instance.suggested_question_trace(trace_info)
assert len(trace_instance.trace_client.added_spans) == 1
span = trace_instance.trace_client.added_spans[0]
spans = _recorded_span_data(trace_instance)
assert len(spans) == 1
span = spans[0]
assert span.name == "suggested_question"
assert span.attributes[GEN_AI_COMPLETION] == '["how?"]'

View File

@@ -1,4 +1,6 @@
import json
from collections.abc import Mapping
from typing import Any, cast
from unittest.mock import MagicMock
from dify_trace_aliyun.entities.semconv import (
@@ -170,7 +172,7 @@ def test_create_common_span_attributes():
def test_format_retrieval_documents():
# Not a list
assert format_retrieval_documents("not a list") == []
assert format_retrieval_documents(cast(list[object], "not a list")) == []
# Valid list
docs = [
@@ -211,7 +213,7 @@ def test_format_retrieval_documents():
def test_format_input_messages():
# Not a dict
assert format_input_messages(None) == serialize_json_data([])
assert format_input_messages(cast(Mapping[str, Any], None)) == serialize_json_data([])
# No prompts
assert format_input_messages({}) == serialize_json_data([])
@@ -244,7 +246,7 @@ def test_format_input_messages():
def test_format_output_messages():
# Not a dict
assert format_output_messages(None) == serialize_json_data([])
assert format_output_messages(cast(Mapping[str, Any], None)) == serialize_json_data([])
# No text
assert format_output_messages({"finish_reason": "stop"}) == serialize_json_data([])

View File

@@ -25,13 +25,13 @@ class TestAliyunConfig:
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
AliyunConfig()
AliyunConfig.model_validate({})
with pytest.raises(ValidationError):
AliyunConfig(license_key="test_license")
AliyunConfig.model_validate({"license_key": "test_license"})
with pytest.raises(ValidationError):
AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
AliyunConfig.model_validate({"endpoint": "https://tracing-analysis-dc-hz.aliyuncs.com"})
def test_app_name_validation_empty(self):
"""Test app_name validation with empty value"""

View File

@@ -1,4 +1,5 @@
from datetime import UTC, datetime, timedelta
from typing import cast
from unittest.mock import MagicMock, patch
import pytest
@@ -129,7 +130,7 @@ def test_set_span_status():
return "SilentErrorRepr"
span.reset_mock()
set_span_status(span, SilentError())
set_span_status(span, cast(Exception | str | None, SilentError()))
assert span.add_event.call_args[1]["attributes"][OTELSpanAttributes.EXCEPTION_MESSAGE] == "SilentErrorRepr"

View File

@@ -28,13 +28,13 @@ class TestLangfuseConfig:
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
LangfuseConfig()
LangfuseConfig.model_validate({})
with pytest.raises(ValidationError):
LangfuseConfig(public_key="public")
LangfuseConfig.model_validate({"public_key": "public"})
with pytest.raises(ValidationError):
LangfuseConfig(secret_key="secret")
LangfuseConfig.model_validate({"secret_key": "secret"})
def test_host_validation_empty(self):
"""Test host validation with empty value"""

View File

@@ -2,6 +2,7 @@
from datetime import datetime, timedelta
from types import SimpleNamespace
from typing import cast
from unittest.mock import MagicMock, patch
from dify_trace_langfuse.config import LangfuseConfig
@@ -134,4 +135,4 @@ class TestLangFuseDataTraceCompletionStartTime:
assert trace._get_completion_start_time(start_time, None) is None
assert trace._get_completion_start_time(start_time, -1) is None
assert trace._get_completion_start_time(start_time, "invalid") is None
assert trace._get_completion_start_time(start_time, cast(float | int | None, "invalid")) is None

View File

@@ -21,13 +21,13 @@ class TestLangSmithConfig:
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
LangSmithConfig()
LangSmithConfig.model_validate({})
with pytest.raises(ValidationError):
LangSmithConfig(api_key="key")
LangSmithConfig.model_validate({"api_key": "key"})
with pytest.raises(ValidationError):
LangSmithConfig(project="project")
LangSmithConfig.model_validate({"project": "project"})
def test_endpoint_validation_https_only(self):
"""Test endpoint validation only allows HTTPS"""

View File

@@ -599,7 +599,6 @@ class TestMessageTrace:
span = MagicMock()
mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token"
mock_db.session.query.return_value.where.return_value.first.return_value = None
trace_instance.message_trace(_make_message_trace_info())
mock_tracing["start"].assert_called_once()
@@ -609,7 +608,6 @@ class TestMessageTrace:
span = MagicMock()
mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token"
mock_db.session.query.return_value.where.return_value.first.return_value = None
trace_info = _make_message_trace_info(error="something broke")
trace_instance.message_trace(trace_info)
@@ -620,7 +618,6 @@ class TestMessageTrace:
span = MagicMock()
mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token"
mock_db.session.query.return_value.where.return_value.first.return_value = None
monkeypatch.setenv("FILES_URL", "http://files.test")
file_data = SimpleNamespace(url="path/to/file.png")
@@ -638,7 +635,6 @@ class TestMessageTrace:
span = MagicMock()
mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token"
mock_db.session.query.return_value.where.return_value.first.return_value = None
trace_info = _make_message_trace_info(file_list=None, message_file_data=None)
trace_instance.message_trace(trace_info)
@@ -651,7 +647,6 @@ class TestMessageTrace:
end_user = MagicMock()
end_user.session_id = "session-xyz"
mock_db.session.query.return_value.where.return_value.first.return_value = end_user
trace_info = _make_message_trace_info(
metadata={"from_end_user_id": "eu-1", "conversation_id": "c1"},
@@ -664,7 +659,6 @@ class TestMessageTrace:
span = MagicMock()
mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token"
mock_db.session.query.return_value.where.return_value.first.return_value = None
trace_info = _make_message_trace_info(
metadata={"from_account_id": "acc-1"},

View File

@@ -12,6 +12,7 @@ from __future__ import annotations
import uuid
from datetime import datetime
from typing import cast
from unittest.mock import MagicMock, patch
from dify_trace_opik.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid
@@ -69,6 +70,14 @@ def _make_opik_trace_instance() -> OpikDataTrace:
return instance
def _add_trace_mock(instance: OpikDataTrace) -> MagicMock:
return cast(MagicMock, instance.add_trace)
def _add_span_mock(instance: OpikDataTrace) -> MagicMock:
return cast(MagicMock, instance.add_span)
# ---------------------------------------------------------------------------
# _seed_to_uuid4
# ---------------------------------------------------------------------------
@@ -155,21 +164,21 @@ class TestWorkflowTraceWithoutMessageId:
def test_root_span_is_created(self):
trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info)
assert instance.add_span.called
assert _add_span_mock(instance).called
def test_root_span_id_matches_expected(self):
trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info)
expected = self._expected_root_span_id(trace_info)
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
assert root_span_kwargs["id"] == expected
def test_root_span_has_no_parent(self):
trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info)
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
assert root_span_kwargs["parent_span_id"] is None
def test_trace_name_is_workflow_trace(self):
@@ -177,21 +186,21 @@ class TestWorkflowTraceWithoutMessageId:
trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info)
trace_kwargs = instance.add_trace.call_args_list[0][0][0]
trace_kwargs = _add_trace_mock(instance).call_args_list[0][0][0]
assert trace_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE
def test_root_span_name_is_workflow_trace(self):
trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info)
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
assert root_span_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE
def test_root_span_has_workflow_tag(self):
trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info)
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
assert "workflow" in root_span_kwargs["tags"]
def test_node_execution_spans_are_parented_to_root(self):
@@ -214,8 +223,9 @@ class TestWorkflowTraceWithoutMessageId:
instance = self._run(trace_info, node_executions=[node_exec])
# call_args_list[0] = root span, [1] = node execution span
assert instance.add_span.call_count == 2
node_span_kwargs = instance.add_span.call_args_list[1][0][0]
add_span = _add_span_mock(instance)
assert add_span.call_count == 2
node_span_kwargs = add_span.call_args_list[1][0][0]
assert node_span_kwargs["parent_span_id"] == expected_root_span_id
def test_node_span_not_parented_to_workflow_app_log_id(self):
@@ -240,7 +250,7 @@ class TestWorkflowTraceWithoutMessageId:
instance = self._run(trace_info, node_executions=[node_exec])
old_parent_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_app_log_id)
node_span_kwargs = instance.add_span.call_args_list[1][0][0]
node_span_kwargs = _add_span_mock(instance).call_args_list[1][0][0]
assert node_span_kwargs["parent_span_id"] != old_parent_id
def test_root_span_id_differs_from_trace_id(self):
@@ -283,7 +293,7 @@ class TestWorkflowTraceWithMessageId:
trace_info = _make_workflow_trace_info(message_id=self._MESSAGE_ID)
instance = self._run(trace_info)
trace_kwargs = instance.add_trace.call_args_list[0][0][0]
trace_kwargs = _add_trace_mock(instance).call_args_list[0][0][0]
assert trace_kwargs["name"] == TraceTaskName.MESSAGE_TRACE
def test_root_span_uses_workflow_run_id_directly(self):
@@ -292,7 +302,7 @@ class TestWorkflowTraceWithMessageId:
instance = self._run(trace_info)
expected_root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id)
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
assert root_span_kwargs["id"] == expected_root_span_id
def test_root_span_id_differs_from_no_message_id_case(self):
@@ -326,5 +336,5 @@ class TestWorkflowTraceWithMessageId:
instance = self._run(trace_info, node_executions=[node_exec])
node_span_kwargs = instance.add_span.call_args_list[1][0][0]
node_span_kwargs = _add_span_mock(instance).call_args_list[1][0][0]
assert node_span_kwargs["parent_span_id"] == expected_root_span_id

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import sys
import types
from types import SimpleNamespace
from typing import Any, TypedDict, cast
from unittest.mock import MagicMock
import pytest
@@ -12,7 +13,7 @@ from dify_trace_tencent import client as client_module
from dify_trace_tencent.client import TencentTraceClient, _get_opentelemetry_sdk_version
from dify_trace_tencent.entities.tencent_trace_entity import SpanData
from opentelemetry.sdk.trace import Event
from opentelemetry.trace import Status, StatusCode
from opentelemetry.trace import SpanContext, Status, StatusCode, TraceFlags
metric_reader_instances: list[DummyMetricReader] = []
meter_provider_instances: list[DummyMeterProvider] = []
@@ -80,6 +81,16 @@ class DummyJsonMetricExporterNoTemporality:
self.kwargs = kwargs
class PatchedCoreComponents(TypedDict):
span_exporter: MagicMock
span_processor: MagicMock
tracer: MagicMock
span: MagicMock
tracer_provider: MagicMock
logger: MagicMock
trace_api: Any
def _add_stub_modules(monkeypatch: pytest.MonkeyPatch) -> None:
"""Drop fake metric modules into sys.modules so the client imports resolve."""
@@ -118,7 +129,7 @@ def stub_metric_modules(monkeypatch: pytest.MonkeyPatch) -> None:
@pytest.fixture(autouse=True)
def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]:
def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> PatchedCoreComponents:
span_exporter = MagicMock(name="span_exporter")
monkeypatch.setattr(client_module, "OTLPSpanExporter", MagicMock(return_value=span_exporter))
@@ -168,6 +179,15 @@ def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]:
}
def _make_span_context(trace_id: int = 1, span_id: int = 2) -> SpanContext:
return SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
)
def _build_client() -> TencentTraceClient:
return TencentTraceClient(
service_name="service",
@@ -208,7 +228,7 @@ def test_resolve_grpc_target_parsable_variants(endpoint: str, expected: tuple[st
def test_resolve_grpc_target_handles_errors() -> None:
assert TencentTraceClient._resolve_grpc_target(123) == ("localhost:4317", True, "localhost", 4317)
assert TencentTraceClient._resolve_grpc_target(cast(str, 123)) == ("localhost:4317", True, "localhost", 4317)
@pytest.mark.parametrize(
@@ -248,7 +268,7 @@ def test_record_methods_skip_when_histogram_missing() -> None:
client.record_trace_duration(0.5)
def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str, object]) -> None:
def test_record_llm_duration_handles_exceptions(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
client.hist_llm_duration = MagicMock(name="hist_llm_duration")
client.hist_llm_duration.record.side_effect = RuntimeError("boom")
@@ -258,10 +278,11 @@ def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str,
logger.debug.assert_called()
def test_create_and_export_span_sets_attributes(patch_core_components: dict[str, object]) -> None:
def test_create_and_export_span_sets_attributes(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
span = patch_core_components["span"]
span.get_span_context.return_value = "ctx"
ctx = _make_span_context(span_id=2)
span.get_span_context.return_value = ctx
data = SpanData(
trace_id=1,
@@ -280,14 +301,15 @@ def test_create_and_export_span_sets_attributes(patch_core_components: dict[str,
span.add_event.assert_called_once()
span.set_status.assert_called_once()
span.end.assert_called_once_with(end_time=20)
assert client.span_contexts[2] == "ctx"
assert client.span_contexts[2] == ctx
def test_create_and_export_span_uses_parent_context(patch_core_components: dict[str, object]) -> None:
def test_create_and_export_span_uses_parent_context(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
client.span_contexts[10] = "existing"
existing_context = _make_span_context(span_id=10)
client.span_contexts[10] = existing_context
span = patch_core_components["span"]
span.get_span_context.return_value = "child"
span.get_span_context.return_value = _make_span_context(span_id=11)
data = SpanData(
trace_id=1,
@@ -302,14 +324,14 @@ def test_create_and_export_span_uses_parent_context(patch_core_components: dict[
client._create_and_export_span(data)
trace_api = patch_core_components["trace_api"]
trace_api.NonRecordingSpan.assert_called_once_with("existing")
trace_api.NonRecordingSpan.assert_called_once_with(existing_context)
trace_api.set_span_in_context.assert_called_once()
def test_create_and_export_span_exception_logs_error(patch_core_components: dict[str, object]) -> None:
def test_create_and_export_span_exception_logs_error(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
span = patch_core_components["span"]
span.get_span_context.return_value = "ctx"
span.get_span_context.return_value = _make_span_context(span_id=2)
client.tracer.start_span.side_effect = RuntimeError("boom")
client._create_and_export_span(
@@ -385,7 +407,7 @@ def test_get_project_url() -> None:
assert client.get_project_url() == "https://console.cloud.tencent.com/apm"
def test_shutdown_flushes_all_components(patch_core_components: dict[str, object]) -> None:
def test_shutdown_flushes_all_components(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
span_processor = patch_core_components["span_processor"]
tracer_provider = patch_core_components["tracer_provider"]
@@ -401,10 +423,11 @@ def test_shutdown_flushes_all_components(patch_core_components: dict[str, object
metric_reader.shutdown.assert_called_once()
def test_shutdown_logs_when_meter_provider_fails(patch_core_components: dict[str, object]) -> None:
def test_shutdown_logs_when_meter_provider_fails(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
meter_provider = meter_provider_instances[-1]
meter_provider.shutdown.side_effect = RuntimeError("boom")
assert client.metric_reader is not None
client.metric_reader.shutdown.side_effect = RuntimeError("boom")
client.shutdown()
@@ -433,7 +456,7 @@ def test_metrics_initialization_failure_sets_histogram_attributes(monkeypatch: p
assert client.metric_reader is None
def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: dict[str, object]) -> None:
def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
monkeypatch.setattr(client, "_create_and_export_span", MagicMock(side_effect=RuntimeError("boom")))
@@ -454,10 +477,10 @@ def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_com
logger.exception.assert_called_once()
def test_create_and_export_span_converts_attribute_types(patch_core_components: dict[str, object]) -> None:
def test_create_and_export_span_converts_attribute_types(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
span = patch_core_components["span"]
span.get_span_context.return_value = "ctx"
span.get_span_context.return_value = _make_span_context(span_id=2)
data = SpanData.model_construct(
trace_id=1,
@@ -485,7 +508,7 @@ def test_record_llm_duration_converts_attributes() -> None:
hist_mock = MagicMock(name="hist_llm_duration")
client.hist_llm_duration = hist_mock
client.record_llm_duration(0.3, {"foo": object(), "bar": 2})
client.record_llm_duration(0.3, cast(dict[str, str], {"foo": object(), "bar": 2}))
_, attrs = hist_mock.record.call_args.args
assert isinstance(attrs["foo"], str)
assert attrs["bar"] == 2
@@ -496,7 +519,7 @@ def test_record_trace_duration_converts_attributes() -> None:
hist_mock = MagicMock(name="hist_trace_duration")
client.hist_trace_duration = hist_mock
client.record_trace_duration(1.0, {"meta": object(), "ok": True})
client.record_trace_duration(1.0, cast(dict[str, str], {"meta": object(), "ok": True}))
_, attrs = hist_mock.record.call_args.args
assert isinstance(attrs["meta"], str)
assert attrs["ok"] is True
@@ -512,7 +535,7 @@ def test_record_trace_duration_converts_attributes() -> None:
],
)
def test_record_methods_handle_exceptions(
method: str, attr_name: str, args: tuple[object, ...], patch_core_components: dict[str, object]
method: str, attr_name: str, args: tuple[object, ...], patch_core_components: PatchedCoreComponents
) -> None:
client = _build_client()
hist_mock = MagicMock(name=attr_name)
@@ -527,35 +550,38 @@ def test_record_methods_handle_exceptions(
def test_metrics_initializes_grpc_metric_exporter() -> None:
client = _build_client()
metric_reader = metric_reader_instances[-1]
exporter = cast(DummyGrpcMetricExporter, metric_reader.exporter)
assert isinstance(metric_reader.exporter, DummyGrpcMetricExporter)
assert isinstance(exporter, DummyGrpcMetricExporter)
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
assert metric_reader.exporter.kwargs["endpoint"] == "trace.example.com:4317"
assert metric_reader.exporter.kwargs["insecure"] is False
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
assert exporter.kwargs["endpoint"] == "trace.example.com:4317"
assert exporter.kwargs["insecure"] is False
assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token"
def test_metrics_initializes_http_protobuf_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/protobuf")
client = _build_client()
metric_reader = metric_reader_instances[-1]
exporter = cast(DummyHttpMetricExporter, metric_reader.exporter)
assert isinstance(metric_reader.exporter, DummyHttpMetricExporter)
assert isinstance(exporter, DummyHttpMetricExporter)
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
assert exporter.kwargs["endpoint"] == client.endpoint
assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token"
def test_metrics_initializes_http_json_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json")
client = _build_client()
metric_reader = metric_reader_instances[-1]
exporter = cast(DummyJsonMetricExporter, metric_reader.exporter)
assert isinstance(metric_reader.exporter, DummyJsonMetricExporter)
assert isinstance(exporter, DummyJsonMetricExporter)
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
assert "preferred_temporality" in metric_reader.exporter.kwargs
assert exporter.kwargs["endpoint"] == client.endpoint
assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token"
assert "preferred_temporality" in exporter.kwargs
def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkeypatch: pytest.MonkeyPatch) -> None:
@@ -564,9 +590,10 @@ def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkey
monkeypatch.setattr(exporter_module, "OTLPMetricExporter", DummyJsonMetricExporterNoTemporality)
_ = _build_client()
metric_reader = metric_reader_instances[-1]
exporter = cast(DummyJsonMetricExporterNoTemporality, metric_reader.exporter)
assert isinstance(metric_reader.exporter, DummyJsonMetricExporterNoTemporality)
assert "preferred_temporality" not in metric_reader.exporter.kwargs
assert isinstance(exporter, DummyJsonMetricExporterNoTemporality)
assert "preferred_temporality" not in exporter.kwargs
def test_metrics_http_json_uses_http_fallback_when_no_json_exporter(monkeypatch: pytest.MonkeyPatch) -> None:

View File

@@ -31,13 +31,13 @@ class TestWeaveConfig:
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
WeaveConfig()
WeaveConfig.model_validate({})
with pytest.raises(ValidationError):
WeaveConfig(api_key="key")
WeaveConfig.model_validate({"api_key": "key"})
with pytest.raises(ValidationError):
WeaveConfig(project="project")
WeaveConfig.model_validate({"project": "project"})
def test_endpoint_validation_https_only(self):
"""Test endpoint validation only allows HTTPS"""

View File

@@ -171,35 +171,13 @@ class TestChatMessageApiPermissions:
parent_message_id=None,
)
class MockQuery:
def __init__(self, model):
self.model = model
def where(self, *args, **kwargs):
return self
def first(self):
if getattr(self.model, "__name__", "") == "Conversation":
return mock_conversation
return None
def order_by(self, *args, **kwargs):
return self
def limit(self, *_):
return self
def all(self):
if getattr(self.model, "__name__", "") == "Message":
return [mock_message]
return []
mock_session = mock.Mock()
mock_session.query.side_effect = MockQuery
mock_session.scalar.return_value = False
mock_session.scalar.return_value = mock_conversation
mock_session.scalars.return_value.all.return_value = [mock_message]
monkeypatch.setattr(message_api, "db", SimpleNamespace(session=mock_session))
monkeypatch.setattr(message_api, "current_user", mock_account)
monkeypatch.setattr(message_api, "attach_message_extra_contents", mock.Mock())
class DummyPagination:
def __init__(self, data, limit, has_more):

View File

@@ -24,7 +24,6 @@ def _patch_wraps():
patch("controllers.console.wraps.dify_config", dify_settings),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
):
mock_db.session.query.return_value.first.return_value = MagicMock()
yield

View File

@@ -13,6 +13,12 @@ from models.model import App, Conversation, Message
from services.feedback_service import FeedbackService
def _execute_result(rows):
result = mock.Mock()
result.all.return_value = rows
return result
class TestFeedbackService:
"""Test FeedbackService methods."""
@@ -81,25 +87,17 @@ class TestFeedbackService:
def test_export_feedbacks_csv_format(self, mock_db_session, sample_data):
"""Test exporting feedback data in CSV format."""
# Setup mock query result
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
sample_data["user_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["user_feedback"].from_account,
)
]
mock_db_session.execute.return_value = mock_query
mock_db_session.execute.return_value = _execute_result(
[
(
sample_data["user_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["user_feedback"].from_account,
)
]
)
# Test CSV export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
@@ -120,25 +118,17 @@ class TestFeedbackService:
def test_export_feedbacks_json_format(self, mock_db_session, sample_data):
"""Test exporting feedback data in JSON format."""
# Setup mock query result
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
sample_data["admin_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["admin_feedback"].from_account,
)
]
mock_db_session.execute.return_value = mock_query
mock_db_session.execute.return_value = _execute_result(
[
(
sample_data["admin_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["admin_feedback"].from_account,
)
]
)
# Test JSON export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
@@ -157,25 +147,17 @@ class TestFeedbackService:
def test_export_feedbacks_with_filters(self, mock_db_session, sample_data):
"""Test exporting feedback with various filters."""
# Setup mock query result
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
sample_data["admin_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["admin_feedback"].from_account,
)
]
mock_db_session.execute.return_value = mock_query
mock_db_session.execute.return_value = _execute_result(
[
(
sample_data["admin_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["admin_feedback"].from_account,
)
]
)
# Test with filters
result = FeedbackService.export_feedbacks(
@@ -193,17 +175,7 @@ class TestFeedbackService:
def test_export_feedbacks_no_data(self, mock_db_session, sample_data):
"""Test exporting feedback when no data exists."""
# Setup mock query result with no data
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = []
mock_db_session.execute.return_value = mock_query
mock_db_session.execute.return_value = _execute_result([])
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
@@ -251,24 +223,17 @@ class TestFeedbackService:
created_at=datetime(2024, 1, 1, 10, 0, 0),
)
# Setup mock query result
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
sample_data["user_feedback"],
long_message,
sample_data["conversation"],
sample_data["app"],
sample_data["user_feedback"].from_account,
)
]
mock_db_session.execute.return_value = mock_query
mock_db_session.execute.return_value = _execute_result(
[
(
sample_data["user_feedback"],
long_message,
sample_data["conversation"],
sample_data["app"],
sample_data["user_feedback"].from_account,
)
]
)
# Test export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
@@ -309,24 +274,17 @@ class TestFeedbackService:
created_at=datetime(2024, 1, 1, 10, 0, 0),
)
# Setup mock query result
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
chinese_feedback,
chinese_message,
sample_data["conversation"],
sample_data["app"],
None, # No account for user feedback
)
]
mock_db_session.execute.return_value = mock_query
mock_db_session.execute.return_value = _execute_result(
[
(
chinese_feedback,
chinese_message,
sample_data["conversation"],
sample_data["app"],
None,
)
]
)
# Test export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
@@ -339,32 +297,24 @@ class TestFeedbackService:
def test_export_feedbacks_emoji_ratings(self, mock_db_session, sample_data):
"""Test that rating emojis are properly formatted in export."""
# Setup mock query result with both like and dislike feedback
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
sample_data["user_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["user_feedback"].from_account,
),
(
sample_data["admin_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["admin_feedback"].from_account,
),
]
mock_db_session.execute.return_value = mock_query
mock_db_session.execute.return_value = _execute_result(
[
(
sample_data["user_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["user_feedback"].from_account,
),
(
sample_data["admin_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["admin_feedback"].from_account,
),
]
)
# Test export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")

View File

@@ -121,33 +121,32 @@ def _configure_session_factory(_unit_test_engine):
configure_session_factory(_unit_test_engine, expire_on_commit=False)
def setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account):
def setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_owner):
"""
Helper to set up the mock DB execute chain for tenant/account authentication.
Helper to stub the tenant-owner execute result for service API app authentication.
This configures the mock to return (tenant, account) for the
db.session.execute(select(...).join().join().where()).one_or_none()
query used by validate_app_token decorator.
The validate_app_token decorator currently resolves the active tenant owner
via db.session.execute(select(Tenant, Account)...).one_or_none().
Args:
mock_db: The mocked db object
mock_tenant: Mock tenant object to return
mock_account: Mock account object to return
mock_owner: Mock owner object to return from the execute result
"""
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_account)
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_owner)
def setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta):
def setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_tenant_account_join):
"""
Helper to set up the mock DB execute chain for dataset tenant authentication.
Helper to stub the tenant-owner execute result for dataset token authentication.
This configures the mock to return (tenant, tenant_account) for the
db.session.execute(select(...).where().where().where().where()).one_or_none()
query used by validate_dataset_token decorator.
The validate_dataset_token decorator currently resolves the owner mapping via
db.session.execute(select(Tenant, TenantAccountJoin)...).one_or_none(), and
then loads the Account separately via db.session.get(...).
Args:
mock_db: The mocked db object
mock_tenant: Mock tenant object to return
mock_ta: Mock tenant account object to return
mock_tenant_account_join: Mock tenant-account join object to return
"""
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta)
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_tenant_account_join)

View File

@@ -208,8 +208,6 @@ class TestAnnotationImportServiceValidation:
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
@@ -230,8 +228,6 @@ class TestAnnotationImportServiceValidation:
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
@@ -248,8 +244,6 @@ class TestAnnotationImportServiceValidation:
csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff'
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
with (
patch("services.annotation_service.current_account_with_tenant") as mock_auth,
patch("services.annotation_service.pd.read_csv", side_effect=ParserError("malformed CSV")),
@@ -269,8 +263,6 @@ class TestAnnotationImportServiceValidation:
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")

View File

@@ -43,7 +43,6 @@ class TestAuthenticationSecurity:
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.")
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
mock_features.return_value.is_allow_register = True
# Act
@@ -76,7 +75,6 @@ class TestAuthenticationSecurity:
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Wrong password")
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
# Act
with self.app.test_request_context(
@@ -109,7 +107,6 @@ class TestAuthenticationSecurity:
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.")
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
mock_features.return_value.is_allow_register = False
# Act
@@ -135,7 +132,6 @@ class TestAuthenticationSecurity:
def test_reset_password_with_existing_account(self, mock_send_email, mock_get_user, mock_features, mock_db):
"""Test that reset password returns success with token for existing accounts."""
# Mock the setup check
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
# Test with existing account
mock_get_user.return_value = MagicMock(email="existing@example.com")

View File

@@ -65,7 +65,6 @@ class TestEmailCodeLoginSendEmailApi:
- IP rate limiting is checked
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_user.return_value = mock_account
mock_send_email.return_value = "email_token_123"
@@ -98,7 +97,6 @@ class TestEmailCodeLoginSendEmailApi:
- Registration is allowed by system features
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_user.return_value = None
mock_get_features.return_value.is_allow_register = True
@@ -130,7 +128,6 @@ class TestEmailCodeLoginSendEmailApi:
- Registration is blocked by system features
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_user.return_value = None
mock_get_features.return_value.is_allow_register = False
@@ -152,7 +149,6 @@ class TestEmailCodeLoginSendEmailApi:
- Prevents spam and abuse
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = True
# Act & Assert
@@ -172,7 +168,6 @@ class TestEmailCodeLoginSendEmailApi:
- AccountInFreezeError is raised for frozen accounts
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_user.side_effect = AccountRegisterError("Account frozen")
@@ -213,7 +208,6 @@ class TestEmailCodeLoginSendEmailApi:
- Defaults to en-US when not specified
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_user.return_value = mock_account
mock_send_email.return_value = "token"
@@ -286,7 +280,6 @@ class TestEmailCodeLoginApi:
- User is logged in with token pair
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
mock_get_user.return_value = mock_account
mock_get_tenants.return_value = [MagicMock()]
@@ -335,7 +328,6 @@ class TestEmailCodeLoginApi:
- User is logged in after account creation
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "newuser@example.com", "code": "123456"}
mock_get_user.return_value = None
mock_create_account.return_value = mock_account
@@ -369,7 +361,6 @@ class TestEmailCodeLoginApi:
- InvalidTokenError is raised for invalid/expired tokens
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = None
# Act & Assert
@@ -392,7 +383,6 @@ class TestEmailCodeLoginApi:
- InvalidEmailError is raised when email doesn't match token
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "original@example.com", "code": "123456"}
# Act & Assert
@@ -415,7 +405,6 @@ class TestEmailCodeLoginApi:
- EmailCodeError is raised for wrong verification code
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
# Act & Assert
@@ -453,7 +442,6 @@ class TestEmailCodeLoginApi:
- User is added as owner of new workspace
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
mock_get_user.return_value = mock_account
mock_get_tenants.return_value = []
@@ -496,7 +484,6 @@ class TestEmailCodeLoginApi:
- WorkspacesLimitExceeded is raised when limit reached
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
mock_get_user.return_value = mock_account
mock_get_tenants.return_value = []
@@ -538,7 +525,6 @@ class TestEmailCodeLoginApi:
- NotAllowedCreateWorkspace is raised when creation disabled
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
mock_get_user.return_value = mock_account
mock_get_tenants.return_value = []

View File

@@ -110,7 +110,6 @@ class TestLoginApi:
- Rate limit is reset after successful login
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.return_value = mock_account
@@ -162,7 +161,6 @@ class TestLoginApi:
- Authentication proceeds with invitation token
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = {"data": {"email": "test@example.com"}}
mock_authenticate.return_value = mock_account
@@ -199,7 +197,6 @@ class TestLoginApi:
- EmailPasswordLoginLimitError is raised when limit exceeded
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = True
mock_get_invitation.return_value = None
@@ -228,7 +225,6 @@ class TestLoginApi:
- AccountInFreezeError is raised for frozen accounts
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_frozen.return_value = True
# Act & Assert
@@ -268,7 +264,6 @@ class TestLoginApi:
- Generic error message prevents user enumeration
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = AccountPasswordError("Invalid password")
@@ -305,7 +300,6 @@ class TestLoginApi:
- Login is prevented even with valid credentials
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = AccountLoginError("Account is banned")
@@ -351,7 +345,6 @@ class TestLoginApi:
- User cannot login without an assigned workspace
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.return_value = mock_account
@@ -383,7 +376,6 @@ class TestLoginApi:
- Security check prevents invitation token abuse
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = {"data": {"email": "invited@example.com"}}
@@ -425,7 +417,6 @@ class TestLoginApi:
mock_token_pair,
):
"""Test that login retries with lowercase email when uppercase lookup fails."""
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = [AccountPasswordError("Invalid"), mock_account]
@@ -459,7 +450,6 @@ class TestLoginApi:
mock_db,
app,
):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"}
mock_get_account.side_effect = Unauthorized("Account is banned.")
@@ -513,7 +503,6 @@ class TestLogoutApi:
- Success response is returned
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_current_account.return_value = (mock_account, MagicMock())
# Act
@@ -539,7 +528,6 @@ class TestLogoutApi:
- Success response is returned
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
# Create a mock anonymous user that will pass isinstance check
anonymous_user = MagicMock()
mock_flask_login.AnonymousUserMixin = type("AnonymousUserMixin", (), {})

View File

@@ -46,7 +46,6 @@ class TestPartnerTenants:
patch("libs.login.dify_config.LOGIN_DISABLED", False),
patch("libs.login.check_csrf_token") as mock_csrf,
):
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
mock_csrf.return_value = None
yield {"db": mock_db, "csrf": mock_csrf}

View File

@@ -24,10 +24,6 @@ def app():
return app
def _mock_wraps_db(mock_db):
mock_db.session.query.return_value.first.return_value = MagicMock()
def _build_account(email: str, account_id: str = "acc", tenant: object | None = None) -> Account:
tenant_obj = tenant if tenant is not None else SimpleNamespace(id="tenant-id")
account = Account(name=account_id, email=email)
@@ -64,7 +60,6 @@ class TestChangeEmailSend:
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("current@example.com", "acc1")
mock_current_account.return_value = (mock_account, None)
@@ -117,7 +112,6 @@ class TestChangeEmailSend:
"""GHSA-4q3w-q5mc-45rq: a phase-1 token must not unlock the new-email send step."""
from controllers.console.auth.error import InvalidTokenError
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("current@example.com", "acc1")
mock_current_account.return_value = (mock_account, None)
@@ -163,7 +157,6 @@ class TestChangeEmailValidity:
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("user@example.com", "acc2")
mock_current_account.return_value = (mock_account, None)
@@ -223,7 +216,6 @@ class TestChangeEmailValidity:
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
mock_is_rate_limit.return_value = False
@@ -280,7 +272,6 @@ class TestChangeEmailValidity:
"""A token whose phase marker is a string but not a known transition must be rejected."""
from controllers.console.auth.error import InvalidTokenError
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
mock_is_rate_limit.return_value = False
@@ -330,7 +321,6 @@ class TestChangeEmailValidity:
"""A token minted without a phase marker (e.g. a hand-crafted token) must not validate."""
from controllers.console.auth.error import InvalidTokenError
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
mock_is_rate_limit.return_value = False
@@ -378,7 +368,6 @@ class TestChangeEmailReset:
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
current_user = _build_account("old@example.com", "acc3")
mock_current_account.return_value = (current_user, None)
@@ -434,7 +423,6 @@ class TestChangeEmailReset:
"""GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset."""
from controllers.console.auth.error import InvalidTokenError
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
current_user = _build_account("old@example.com", "acc3")
mock_current_account.return_value = (current_user, None)
@@ -488,7 +476,6 @@ class TestChangeEmailReset:
"""A verified token for address A must not be replayed to change to address B."""
from controllers.console.auth.error import InvalidTokenError
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
current_user = _build_account("old@example.com", "acc3")
mock_current_account.return_value = (current_user, None)
@@ -561,7 +548,6 @@ class TestAccountDeletionFeedback:
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.BillingService.update_account_deletion_feedback")
def test_should_normalize_feedback_email(self, mock_update, mock_db, app):
_mock_wraps_db(mock_db)
with app.test_request_context(
"/account/delete/feedback",
method="POST",
@@ -578,7 +564,6 @@ class TestCheckEmailUnique:
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app):
_mock_wraps_db(mock_db)
mock_is_freeze.return_value = False
mock_check_unique.return_value = True

View File

@@ -1,5 +1,5 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from unittest.mock import patch
import pytest
from flask import Flask, g
@@ -16,10 +16,6 @@ def app():
return flask_app
def _mock_wraps_db(mock_db):
mock_db.session.query.return_value.first.return_value = MagicMock()
def _build_feature_flags():
placeholder_quota = SimpleNamespace(limit=0, size=0)
workspace_members = SimpleNamespace(is_available=lambda count: True)
@@ -49,7 +45,6 @@ class TestMemberInviteEmailApi:
mock_get_features,
app,
):
_mock_wraps_db(mock_db)
mock_get_features.return_value = _build_feature_flags()
mock_invite_member.return_value = "token-abc"

View File

@@ -310,7 +310,6 @@ class TestSystemSetup:
def test_should_allow_when_setup_complete(self, mock_db):
"""Test that requests are allowed when setup is complete"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock() # Setup exists
@setup_required
def admin_view():

View File

@@ -22,7 +22,7 @@ _WRAPS_MODULE: ModuleType | None = None
@contextmanager
def _mock_db():
mock_session = SimpleNamespace(query=lambda *args, **kwargs: SimpleNamespace(first=lambda: True))
mock_session = SimpleNamespace(scalar=lambda *args, **kwargs: True)
with patch("extensions.ext_database.db.session", mock_session):
yield

View File

@@ -12,7 +12,7 @@ from controllers.service_api.app.app import AppInfoApi, AppMetaApi, AppParameter
from controllers.service_api.app.error import AppUnavailableError
from models.account import TenantStatus
from models.model import App, AppMode
from tests.unit_tests.conftest import setup_mock_tenant_account_query
from tests.unit_tests.conftest import setup_mock_tenant_owner_execute_result
class TestAppParameterApi:
@@ -74,7 +74,7 @@ class TestAppParameterApi:
# Mock tenant owner info for login
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
@@ -120,7 +120,7 @@ class TestAppParameterApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
@@ -161,7 +161,7 @@ class TestAppParameterApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act & Assert
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
@@ -200,7 +200,7 @@ class TestAppParameterApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act & Assert
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
@@ -263,7 +263,7 @@ class TestAppMetaApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/meta", method="GET", headers={"Authorization": "Bearer test_token"}):
@@ -331,7 +331,7 @@ class TestAppInfoApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
@@ -388,7 +388,7 @@ class TestAppInfoApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
@@ -434,7 +434,7 @@ class TestAppInfoApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
@@ -486,7 +486,7 @@ class TestAppInfoApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):

View File

@@ -15,7 +15,10 @@ from flask import Flask
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.account import TenantStatus
from models.model import App, AppMode, EndUser
from tests.unit_tests.conftest import setup_mock_tenant_account_query
from tests.unit_tests.conftest import (
setup_mock_dataset_owner_execute_result,
setup_mock_tenant_owner_execute_result,
)
@pytest.fixture
@@ -123,9 +126,7 @@ class AuthenticationMocker:
mock_db.session.get.side_effect = [mock_app, mock_tenant]
if mock_account:
mock_ta = Mock()
mock_ta.account_id = mock_account.id
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
@staticmethod
def setup_dataset_auth(mock_db, mock_tenant, mock_account):
@@ -133,8 +134,7 @@ class AuthenticationMocker:
mock_ta = Mock()
mock_ta.account_id = mock_account.id
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta)
setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_ta)
mock_db.session.get.return_value = mock_account

View File

@@ -701,8 +701,8 @@ class TestDocumentApiDelete:
``delete`` is wrapped by ``@cloud_edition_billing_rate_limit_check`` which
internally calls ``validate_and_get_api_token``. To bypass the decorator
we call the original function via ``__wrapped__`` (preserved by
``functools.wraps``). ``delete`` queries the dataset via
``db.session.query(Dataset)`` directly, so we patch ``db`` at the
``functools.wraps``). ``delete`` loads the dataset via
``db.session.scalar(select(Dataset)...)``, so we patch ``db`` at the
controller module.
"""

View File

@@ -24,8 +24,8 @@ from enums.cloud_plan import CloudPlan
from models.account import TenantStatus
from models.model import ApiToken
from tests.unit_tests.conftest import (
setup_mock_dataset_tenant_query,
setup_mock_tenant_account_query,
setup_mock_dataset_owner_execute_result,
setup_mock_tenant_owner_execute_result,
)
@@ -141,14 +141,11 @@ class TestValidateAppToken:
mock_account = Mock()
mock_account.id = str(uuid.uuid4())
mock_ta = Mock()
mock_ta.account_id = mock_account.id
# Use side_effect to return app first, then tenant via session.get()
mock_db.session.get.side_effect = [mock_app, mock_tenant]
# Mock the tenant owner query (execute(select(...)).one_or_none())
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta)
# Mock the tenant owner execute result (execute(select(...)).one_or_none())
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
@validate_app_token
def protected_view(app_model):
@@ -471,7 +468,7 @@ class TestValidateDatasetToken:
mock_account.current_tenant = mock_tenant
# Mock the tenant account join query (execute(select(...)).one_or_none())
setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta)
setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_ta)
# Mock the account lookup via session.get()
mock_db.session.get.return_value = mock_account

View File

@@ -22,18 +22,16 @@ class FakeSession:
def __init__(self, mapping: dict[str, Any] | None = None):
self._mapping: dict[str, Any] = mapping or {}
self._model_name: str | None = None
def query(self, model: type) -> FakeSession:
self._model_name = model.__name__
return self
def get(self, model: type, _ident: object) -> Any:
return self._mapping.get(model.__name__)
def where(self, *_args: object, **_kwargs: object) -> FakeSession:
return self
def first(self) -> Any:
assert self._model_name is not None
return self._mapping.get(self._model_name)
def scalar(self, stmt: Any) -> Any:
try:
model = stmt.column_descriptions[0]["entity"]
except (AttributeError, IndexError, KeyError, TypeError):
return None
return self._mapping.get(model.__name__)
class FakeDB:

View File

@@ -36,18 +36,6 @@ class _FakeSession:
def __init__(self, mapping: dict[str, Any]):
self._mapping = mapping
self._model_name: str | None = None
def query(self, model):
self._model_name = model.__name__
return self
def where(self, *args, **kwargs):
return self
def first(self):
assert self._model_name is not None
return self._mapping.get(self._model_name)
def get(self, model, ident):
return self._mapping.get(model.__name__)

View File

@@ -34,7 +34,6 @@ def _patch_wraps():
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
patch("controllers.web.login.dify_config", web_dify),
):
mock_db.session.query.return_value.first.return_value = MagicMock()
yield

View File

@@ -154,7 +154,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
mock_session_class.return_value.__enter__.return_value = MagicMock()
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
mock_db.engine = MagicMock()
# Mock GraphRuntimeState to accept the variable pool
@@ -301,7 +300,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
mock_session_class.return_value.__enter__.return_value = MagicMock()
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
mock_db.engine = MagicMock()
# Mock ConversationVariable.from_variable to return mock objects
@@ -453,7 +451,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
mock_session_class.return_value.__enter__.return_value = MagicMock()
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
mock_db.engine = MagicMock()
# Mock GraphRuntimeState to accept the variable pool

View File

@@ -375,7 +375,7 @@ def test_generate_success_returns_converted(generator, mocker):
workflow = MagicMock(id="wf", tenant_id="tenant", app_id="pipe", graph_dict={})
session = MagicMock()
session.query.return_value.where.return_value.first.return_value = workflow
session.get.return_value = workflow
mocker.patch.object(module.db, "session", session)
queue_manager = MagicMock()

View File

@@ -132,11 +132,8 @@ def test_run_pipeline_not_found(mocker):
app_generate_entity.single_iteration_run = None
app_generate_entity.single_loop_run = None
query = MagicMock()
query.where.return_value.first.return_value = None
session = MagicMock()
session.query.return_value = query
session.get.side_effect = [None, None]
mocker.patch.object(module.db, "session", session)
runner = PipelineRunner(
@@ -157,11 +154,9 @@ def test_run_workflow_not_initialized(mocker):
app_generate_entity = _build_app_generate_entity()
pipeline = MagicMock(id="pipe")
query_pipeline = MagicMock()
query_pipeline.where.return_value.first.return_value = pipeline
session = MagicMock()
session.query.return_value = query_pipeline
session.get.side_effect = [None, pipeline]
mocker.patch.object(module.db, "session", session)
runner = PipelineRunner(

View File

@@ -775,9 +775,6 @@ class TestNotionExtractorLastEditedTime:
"last_edited_time": "2024-11-27T18:00:00.000Z",
}
mock_request.return_value = mock_response
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
# Act
extractor_page.update_last_edited_time(mock_document_model)
@@ -863,9 +860,6 @@ class TestNotionExtractorIntegration:
}
mock_request.side_effect = [last_edited_response, block_response]
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
# Act
documents = extractor.extract()
@@ -919,10 +913,6 @@ class TestNotionExtractorIntegration:
}
mock_post.return_value = database_response
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
# Act
documents = extractor.extract()

View File

@@ -40,11 +40,11 @@ class TestObfuscatedToken:
class TestEncryptToken:
@patch("extensions.ext_database.db.session.get")
@patch("libs.rsa.encrypt")
def test_successful_encryption(self, mock_encrypt, mock_query):
def test_successful_encryption(self, mock_encrypt, mock_get):
"""Test successful token encryption"""
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value = mock_tenant
mock_get.return_value = mock_tenant
mock_encrypt.return_value = b"encrypted_data"
result = encrypt_token("tenant-123", "test_token")
@@ -53,9 +53,9 @@ class TestEncryptToken:
mock_encrypt.assert_called_with("test_token", "mock_public_key")
@patch("extensions.ext_database.db.session.get")
def test_tenant_not_found(self, mock_query):
def test_tenant_not_found(self, mock_get):
"""Test error when tenant doesn't exist"""
mock_query.return_value = None
mock_get.return_value = None
with pytest.raises(ValueError) as exc_info:
encrypt_token("invalid-tenant", "test_token")
@@ -122,12 +122,12 @@ class TestEncryptDecryptIntegration:
@patch("extensions.ext_database.db.session.get")
@patch("libs.rsa.encrypt")
@patch("libs.rsa.decrypt")
def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_query):
def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_get):
"""Test that encryption and decryption are consistent"""
# Setup mock tenant
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value = mock_tenant
mock_get.return_value = mock_tenant
# Setup mock encryption/decryption
original_token = "test_token_123"
@@ -148,12 +148,12 @@ class TestSecurity:
@patch("extensions.ext_database.db.session.get")
@patch("libs.rsa.encrypt")
def test_cross_tenant_isolation(self, mock_encrypt, mock_query):
def test_cross_tenant_isolation(self, mock_encrypt, mock_get):
"""Ensure tokens encrypted for one tenant cannot be used by another"""
# Setup mock tenant
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "tenant1_public_key"
mock_query.return_value = mock_tenant
mock_get.return_value = mock_tenant
mock_encrypt.return_value = b"encrypted_for_tenant1"
# Encrypt token for tenant1
@@ -183,10 +183,10 @@ class TestSecurity:
@patch("extensions.ext_database.db.session.get")
@patch("libs.rsa.encrypt")
def test_encryption_randomness(self, mock_encrypt, mock_query):
def test_encryption_randomness(self, mock_encrypt, mock_get):
"""Ensure same plaintext produces different ciphertext"""
mock_tenant = MagicMock(encrypt_public_key="key")
mock_query.return_value = mock_tenant
mock_get.return_value = mock_tenant
# Different outputs for same input
mock_encrypt.side_effect = [b"enc1", b"enc2", b"enc3"]
@@ -207,11 +207,11 @@ class TestEdgeCases:
@patch("extensions.ext_database.db.session.get")
@patch("libs.rsa.encrypt")
def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_query):
def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_get):
"""Test encryption of empty token"""
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value = mock_tenant
mock_get.return_value = mock_tenant
mock_encrypt.return_value = b"encrypted_empty"
result = encrypt_token("tenant-123", "")
@@ -221,11 +221,11 @@ class TestEdgeCases:
@patch("extensions.ext_database.db.session.get")
@patch("libs.rsa.encrypt")
def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_query):
def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_get):
"""Test tokens containing special/unicode characters"""
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value = mock_tenant
mock_get.return_value = mock_tenant
mock_encrypt.return_value = b"encrypted_special"
# Test various special characters
@@ -244,11 +244,11 @@ class TestEdgeCases:
@patch("extensions.ext_database.db.session.get")
@patch("libs.rsa.encrypt")
def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_query):
def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_get):
"""Test behavior when token exceeds RSA encryption limits"""
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value = mock_tenant
mock_get.return_value = mock_tenant
# RSA 2048-bit can only encrypt ~245 bytes
# The actual limit depends on padding scheme

View File

@@ -491,7 +491,7 @@ class TestLLMGenerator:
def test_instruction_modify_workflow_no_last_run_fallback(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session") as mock_session:
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
mock_session.return_value.scalar.return_value = MagicMock()
workflow = MagicMock()
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "code"}}]}}
@@ -517,7 +517,7 @@ class TestLLMGenerator:
def test_instruction_modify_workflow_node_type_fallback(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session") as mock_session:
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
mock_session.return_value.scalar.return_value = MagicMock()
workflow = MagicMock()
# Cause exception in node_type logic
workflow.graph_dict = {"graph": {"nodes": []}}
@@ -544,7 +544,7 @@ class TestLLMGenerator:
def test_instruction_modify_workflow_empty_agent_log(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session") as mock_session:
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
mock_session.return_value.scalar.return_value = MagicMock()
workflow = MagicMock()
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "llm"}}]}}
@@ -632,7 +632,7 @@ class TestLLMGenerator:
instance.invoke_llm.return_value = mock_response
with patch("extensions.ext_database.db.session") as mock_session:
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
mock_session.return_value.scalar.return_value = MagicMock()
workflow = MagicMock()
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "other"}}]}}

View File

@@ -29,15 +29,6 @@ class _Field:
return ("in", self._name, tuple(values))
class _FakeQuery:
def __init__(self):
self.where_calls: list[tuple] = []
def where(self, *conditions):
self.where_calls.append(conditions)
return self
class _FakeExecuteResult:
def __init__(self, segments: list[SimpleNamespace]):
self._segments = segments

View File

@@ -109,17 +109,6 @@ class _FakeExecuteResult:
return _FakeExecuteScalarResult(self._data)
class _FakeSummaryQuery:
def __init__(self, summaries: list) -> None:
self._summaries = summaries
def filter(self, *args, **kwargs):
return self
def all(self) -> list:
return self._summaries
class _FakeScalarsResult:
def __init__(self, data: list) -> None:
self._data = data

View File

@@ -372,19 +372,11 @@ def test_vector_delegation_methods(vector_factory_module):
def test_search_by_file_handles_missing_and_existing_upload(vector_factory_module, monkeypatch):
class _Field:
def __eq__(self, value):
return value
upload_query = MagicMock()
upload_query.where.return_value = upload_query
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
vector._embeddings = MagicMock()
vector._vector_processor = MagicMock()
mock_session = SimpleNamespace(get=lambda _model, _id: None)
monkeypatch.setattr(vector_factory_module, "UploadFile", SimpleNamespace(id=_Field()))
monkeypatch.setattr(vector_factory_module, "db", SimpleNamespace(session=mock_session))
assert vector.search_by_file("file-1") == []

View File

@@ -1484,11 +1484,8 @@ class TestIndexingRunnerProcessChunk:
mock_dependencies["redis"].get.return_value = None
# Mock database query for segment updates
mock_query = MagicMock()
mock_dependencies["db"].session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.update.return_value = None
# Mock database update for segment status
mock_dependencies["db"].session.execute.return_value = None
# Create a proper context manager mock
mock_context = MagicMock()

View File

@@ -2417,12 +2417,11 @@ class TestDatasetRetrievalKnowledgeRetrieval:
mock_document.data_source_type = "upload_file"
mock_document.doc_metadata = {}
mock_session.query.return_value.filter.return_value.all.return_value = [
mock_dataset_from_db
]
mock_session.query.return_value.filter.return_value.all.__iter__ = lambda self: iter(
[mock_dataset_from_db, mock_document]
)
mock_datasets = MagicMock()
mock_datasets.all.return_value = [mock_dataset_from_db]
mock_documents = MagicMock()
mock_documents.all.return_value = [mock_document]
mock_session.scalars.side_effect = [mock_datasets, mock_documents]
# Act
result = dataset_retrieval.knowledge_retrieval(request)

View File

@@ -451,12 +451,11 @@ class TestDatasetRetrievalKnowledgeRetrieval:
mock_document.data_source_type = "upload_file"
mock_document.doc_metadata = {}
mock_session.query.return_value.filter.return_value.all.return_value = [
mock_dataset_from_db
]
mock_session.query.return_value.filter.return_value.all.__iter__ = lambda self: iter(
[mock_dataset_from_db, mock_document]
)
mock_datasets = MagicMock()
mock_datasets.all.return_value = [mock_dataset_from_db]
mock_documents = MagicMock()
mock_documents.all.return_value = [mock_document]
mock_session.scalars.side_effect = [mock_datasets, mock_documents]
# Act
result = dataset_retrieval.knowledge_retrieval(request)

File diff suppressed because it is too large Load Diff

View File

@@ -1,925 +0,0 @@
"""
Extensive unit tests for ``ExternalDatasetService``.
This module focuses on the *external dataset service* surface area, which is responsible
for integrating with **external knowledge APIs** and wiring them into Dify datasets.
The goal of this test suite is twofold:
- Provide **highconfidence regression coverage** for all public helpers on
``ExternalDatasetService``.
- Serve as **executable documentation** for how external API integration is expected
to behave in different scenarios (happy paths, validation failures, and error codes).
The file intentionally contains **rich comments and generous spacing** in order to make
each scenario easy to scan during reviews.
"""
from __future__ import annotations
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock, Mock, patch
import httpx
import pytest
from constants import HIDDEN_VALUE
from models.dataset import Dataset, ExternalKnowledgeApis, ExternalKnowledgeBindings
from services.entities.external_knowledge_entities.external_knowledge_entities import (
Authorization,
AuthorizationConfig,
ExternalKnowledgeApiSetting,
)
from services.errors.dataset import DatasetNameDuplicateError
from services.external_knowledge_service import ExternalDatasetService
class ExternalDatasetTestDataFactory:
"""
Factory helpers for building *lightweight* mocks for external knowledge tests.
These helpers are intentionally small and explicit:
- They avoid pulling in unnecessary fixtures.
- They reflect the minimal contract that the service under test cares about.
"""
@staticmethod
def create_external_api(
api_id: str = "api-123",
tenant_id: str = "tenant-1",
name: str = "Test API",
description: str = "Description",
settings: dict[str, Any] | None = None,
) -> ExternalKnowledgeApis:
"""
Create a concrete ``ExternalKnowledgeApis`` instance with minimal fields.
Using the real SQLAlchemy model (instead of a pure Mock) makes it easier to
exercise ``settings_dict`` and other convenience properties if needed.
"""
instance = ExternalKnowledgeApis(
tenant_id=tenant_id,
name=name,
description=description,
settings=None if settings is None else cast(str, pytest.approx), # type: ignore[assignment]
)
# Overwrite generated id for determinism in assertions.
instance.id = api_id
return instance
@staticmethod
def create_dataset(
dataset_id: str = "ds-1",
tenant_id: str = "tenant-1",
name: str = "External Dataset",
provider: str = "external",
) -> Dataset:
"""
Build a small ``Dataset`` instance representing an external dataset.
"""
dataset = Dataset(
tenant_id=tenant_id,
name=name,
description="",
provider=provider,
created_by="user-1",
)
dataset.id = dataset_id
return dataset
@staticmethod
def create_external_binding(
tenant_id: str = "tenant-1",
dataset_id: str = "ds-1",
api_id: str = "api-1",
external_knowledge_id: str = "knowledge-1",
) -> ExternalKnowledgeBindings:
"""
Small helper for a binding between dataset and external knowledge API.
"""
binding = ExternalKnowledgeBindings(
tenant_id=tenant_id,
dataset_id=dataset_id,
external_knowledge_api_id=api_id,
external_knowledge_id=external_knowledge_id,
created_by="user-1",
)
return binding
# ---------------------------------------------------------------------------
# get_external_knowledge_apis
# ---------------------------------------------------------------------------
class TestExternalDatasetServiceGetExternalKnowledgeApis:
"""
Tests for ``ExternalDatasetService.get_external_knowledge_apis``.
These tests focus on:
- Basic pagination wiring via ``db.paginate``.
- Optional search keyword behaviour.
"""
@pytest.fixture
def mock_db_paginate(self):
"""
Patch ``db.paginate`` so we do not touch the real database layer.
"""
with (
patch("services.external_knowledge_service.db.paginate", autospec=True) as mock_paginate,
patch("services.external_knowledge_service.select", autospec=True),
):
yield mock_paginate
def test_get_external_knowledge_apis_basic_pagination(self, mock_db_paginate: MagicMock):
"""
It should return ``items`` and ``total`` coming from the paginate object.
"""
# Arrange
tenant_id = "tenant-1"
page = 1
per_page = 20
mock_items = [Mock(spec=ExternalKnowledgeApis), Mock(spec=ExternalKnowledgeApis)]
mock_pagination = SimpleNamespace(items=mock_items, total=42)
mock_db_paginate.return_value = mock_pagination
# Act
items, total = ExternalDatasetService.get_external_knowledge_apis(page, per_page, tenant_id)
# Assert
assert items is mock_items
assert total == 42
mock_db_paginate.assert_called_once()
call_kwargs = mock_db_paginate.call_args.kwargs
assert call_kwargs["page"] == page
assert call_kwargs["per_page"] == per_page
assert call_kwargs["max_per_page"] == 100
assert call_kwargs["error_out"] is False
def test_get_external_knowledge_apis_with_search_keyword(self, mock_db_paginate: MagicMock):
"""
When a search keyword is provided, the query should be adjusted
(we simply assert that paginate is still called and does not explode).
"""
# Arrange
tenant_id = "tenant-1"
page = 2
per_page = 10
search = "foo"
mock_pagination = SimpleNamespace(items=[], total=0)
mock_db_paginate.return_value = mock_pagination
# Act
items, total = ExternalDatasetService.get_external_knowledge_apis(page, per_page, tenant_id, search=search)
# Assert
assert items == []
assert total == 0
mock_db_paginate.assert_called_once()
# ---------------------------------------------------------------------------
# validate_api_list
# ---------------------------------------------------------------------------
class TestExternalDatasetServiceValidateApiList:
"""
Lightweight validation tests for ``validate_api_list``.
"""
def test_validate_api_list_success(self):
"""
A minimal valid configuration (endpoint + api_key) should pass.
"""
config = {"endpoint": "https://example.com", "api_key": "secret"}
# Act & Assert no exception expected
ExternalDatasetService.validate_api_list(config)
@pytest.mark.parametrize(
("config", "expected_message"),
[
({}, "api list is empty"),
({"api_key": "k"}, "endpoint is required"),
({"endpoint": "https://example.com"}, "api_key is required"),
],
)
def test_validate_api_list_failures(self, config: dict[str, Any], expected_message: str):
"""
Invalid configs should raise ``ValueError`` with a clear message.
"""
with pytest.raises(ValueError, match=expected_message):
ExternalDatasetService.validate_api_list(config)
# ---------------------------------------------------------------------------
# create_external_knowledge_api & get/update/delete
# ---------------------------------------------------------------------------
class TestExternalDatasetServiceCrudExternalKnowledgeApi:
"""
CRUD tests for external knowledge API templates.
"""
@pytest.fixture
def mock_db_session(self):
"""
Patch ``db.session`` for all CRUD tests in this class.
"""
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
yield mock_session
def test_create_external_knowledge_api_success(self, mock_db_session: MagicMock):
"""
``create_external_knowledge_api`` should persist a new record
when settings are present and valid.
"""
tenant_id = "tenant-1"
user_id = "user-1"
args = {
"name": "API",
"description": "desc",
"settings": {"endpoint": "https://api.example.com", "api_key": "secret"},
}
# We do not want to actually call the remote endpoint here, so we patch the validator.
with patch.object(ExternalDatasetService, "check_endpoint_and_api_key", autospec=True) as mock_check:
result = ExternalDatasetService.create_external_knowledge_api(tenant_id, user_id, args)
assert isinstance(result, ExternalKnowledgeApis)
mock_check.assert_called_once_with(args["settings"])
mock_db_session.add.assert_called_once()
mock_db_session.commit.assert_called_once()
def test_create_external_knowledge_api_missing_settings_raises(self, mock_db_session: MagicMock):
"""
Missing ``settings`` should result in a ``ValueError``.
"""
tenant_id = "tenant-1"
user_id = "user-1"
args = {"name": "API", "description": "desc"}
with pytest.raises(ValueError, match="settings is required"):
ExternalDatasetService.create_external_knowledge_api(tenant_id, user_id, args)
mock_db_session.add.assert_not_called()
mock_db_session.commit.assert_not_called()
def test_get_external_knowledge_api_found(self, mock_db_session: MagicMock):
"""
``get_external_knowledge_api`` should return the first matching record.
"""
api = Mock(spec=ExternalKnowledgeApis)
mock_db_session.scalar.return_value = api
result = ExternalDatasetService.get_external_knowledge_api("api-id", "tenant-id")
assert result is api
def test_get_external_knowledge_api_not_found_raises(self, mock_db_session: MagicMock):
"""
When the record is absent, a ``ValueError`` is raised.
"""
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="api template not found"):
ExternalDatasetService.get_external_knowledge_api("missing-id", "tenant-id")
def test_update_external_knowledge_api_success_with_hidden_api_key(self, mock_db_session: MagicMock):
"""
Updating an API should keep the existing API key when the special hidden
value placeholder is sent from the UI.
"""
tenant_id = "tenant-1"
user_id = "user-1"
api_id = "api-1"
existing_api = Mock(spec=ExternalKnowledgeApis)
existing_api.settings_dict = {"api_key": "stored-key"}
existing_api.settings = '{"api_key":"stored-key"}'
mock_db_session.scalar.return_value = existing_api
args = {
"name": "New Name",
"description": "New Desc",
"settings": {"endpoint": "https://api.example.com", "api_key": HIDDEN_VALUE},
}
result = ExternalDatasetService.update_external_knowledge_api(tenant_id, user_id, api_id, args)
assert result is existing_api
# The placeholder should be replaced with stored key.
assert args["settings"]["api_key"] == "stored-key"
mock_db_session.commit.assert_called_once()
def test_update_external_knowledge_api_not_found_raises(self, mock_db_session: MagicMock):
"""
Updating a nonexistent API template should raise ``ValueError``.
"""
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="api template not found"):
ExternalDatasetService.update_external_knowledge_api(
tenant_id="tenant-1",
user_id="user-1",
external_knowledge_api_id="missing-id",
args={"name": "n", "description": "d", "settings": {}},
)
def test_delete_external_knowledge_api_success(self, mock_db_session: MagicMock):
"""
``delete_external_knowledge_api`` should delete and commit when found.
"""
api = Mock(spec=ExternalKnowledgeApis)
mock_db_session.scalar.return_value = api
ExternalDatasetService.delete_external_knowledge_api("tenant-1", "api-1")
mock_db_session.delete.assert_called_once_with(api)
mock_db_session.commit.assert_called_once()
def test_delete_external_knowledge_api_not_found_raises(self, mock_db_session: MagicMock):
"""
Deletion of a missing template should raise ``ValueError``.
"""
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="api template not found"):
ExternalDatasetService.delete_external_knowledge_api("tenant-1", "missing")
# ---------------------------------------------------------------------------
# external_knowledge_api_use_check & binding lookups
# ---------------------------------------------------------------------------
class TestExternalDatasetServiceUsageAndBindings:
"""
Tests for usage checks and dataset binding retrieval.
"""
@pytest.fixture
def mock_db_session(self):
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
yield mock_session
def test_external_knowledge_api_use_check_in_use(self, mock_db_session: MagicMock):
"""
When there are bindings, ``external_knowledge_api_use_check`` returns True and count.
"""
mock_db_session.scalar.return_value = 3
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1", "tenant-1")
assert in_use is True
assert count == 3
assert "tenant_id" in str(mock_db_session.scalar.call_args.args[0])
def test_external_knowledge_api_use_check_not_in_use(self, mock_db_session: MagicMock):
"""
Zero bindings should return ``(False, 0)``.
"""
mock_db_session.scalar.return_value = 0
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1", "tenant-1")
assert in_use is False
assert count == 0
def test_get_external_knowledge_binding_with_dataset_id_found(self, mock_db_session: MagicMock):
"""
Binding lookup should return the first record when present.
"""
binding = Mock(spec=ExternalKnowledgeBindings)
mock_db_session.scalar.return_value = binding
result = ExternalDatasetService.get_external_knowledge_binding_with_dataset_id("tenant-1", "ds-1")
assert result is binding
def test_get_external_knowledge_binding_with_dataset_id_not_found_raises(self, mock_db_session: MagicMock):
"""
Missing binding should result in a ``ValueError``.
"""
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="external knowledge binding not found"):
ExternalDatasetService.get_external_knowledge_binding_with_dataset_id("tenant-1", "ds-1")
# ---------------------------------------------------------------------------
# document_create_args_validate
# ---------------------------------------------------------------------------
class TestExternalDatasetServiceDocumentCreateArgsValidate:
"""
Tests for ``document_create_args_validate``.
"""
@pytest.fixture
def mock_db_session(self):
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
yield mock_session
def test_document_create_args_validate_success(self, mock_db_session: MagicMock):
"""
All required custom parameters present validation should pass.
"""
external_api = Mock(spec=ExternalKnowledgeApis)
external_api.settings = json_settings = (
'[{"document_process_setting":[{"name":"foo","required":true},{"name":"bar","required":false}]}]'
)
# Raw string; the service itself calls json.loads on it
mock_db_session.scalar.return_value = external_api
process_parameter = {"foo": "value", "bar": "optional"}
# Act & Assert no exception
ExternalDatasetService.document_create_args_validate("tenant-1", "api-1", process_parameter)
assert json_settings in external_api.settings # simple sanity check on our test data
def test_document_create_args_validate_missing_template_raises(self, mock_db_session: MagicMock):
"""
When the referenced API template is missing, a ``ValueError`` is raised.
"""
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="api template not found"):
ExternalDatasetService.document_create_args_validate("tenant-1", "missing", {})
def test_document_create_args_validate_missing_required_parameter_raises(self, mock_db_session: MagicMock):
"""
Required document process parameters must be supplied.
"""
external_api = Mock(spec=ExternalKnowledgeApis)
external_api.settings = (
'[{"document_process_setting":[{"name":"foo","required":true},{"name":"bar","required":false}]}]'
)
mock_db_session.scalar.return_value = external_api
process_parameter = {"bar": "present"} # missing "foo"
with pytest.raises(ValueError, match="foo is required"):
ExternalDatasetService.document_create_args_validate("tenant-1", "api-1", process_parameter)
# ---------------------------------------------------------------------------
# process_external_api
# ---------------------------------------------------------------------------
class TestExternalDatasetServiceProcessExternalApi:
"""
Tests focused on the HTTP request assembly and method mapping behaviour.
"""
def test_process_external_api_valid_method_post(self):
"""
For a supported HTTP verb we should delegate to the correct ``ssrf_proxy`` function.
"""
settings = ExternalKnowledgeApiSetting(
url="https://example.com/path",
request_method="POST",
headers={"X-Test": "1"},
params={"foo": "bar"},
)
fake_response = httpx.Response(200)
with patch("services.external_knowledge_service.ssrf_proxy.post", autospec=True) as mock_post:
mock_post.return_value = fake_response
result = ExternalDatasetService.process_external_api(settings, files=None)
assert result is fake_response
mock_post.assert_called_once()
kwargs = mock_post.call_args.kwargs
assert kwargs["url"] == settings.url
assert kwargs["headers"] == settings.headers
assert kwargs["follow_redirects"] is True
assert "data" in kwargs
def test_process_external_api_invalid_method_raises(self):
"""
An unsupported HTTP verb should raise ``InvalidHttpMethodError``.
"""
settings = ExternalKnowledgeApiSetting(
url="https://example.com",
request_method="INVALID",
headers=None,
params={},
)
from graphon.nodes.http_request.exc import InvalidHttpMethodError
with pytest.raises(InvalidHttpMethodError):
ExternalDatasetService.process_external_api(settings, files=None)
# ---------------------------------------------------------------------------
# assembling_headers
# ---------------------------------------------------------------------------
class TestExternalDatasetServiceAssemblingHeaders:
"""
Tests for header assembly based on different authentication flavours.
"""
def test_assembling_headers_bearer_token(self):
"""
For bearer auth we expect ``Authorization: Bearer <key>`` by default.
"""
auth = Authorization(
type="api-key",
config=AuthorizationConfig(type="bearer", api_key="secret", header=None),
)
headers = ExternalDatasetService.assembling_headers(auth)
assert headers["Authorization"] == "Bearer secret"
def test_assembling_headers_basic_token_with_custom_header(self):
"""
For basic auth we honour the configured header name.
"""
auth = Authorization(
type="api-key",
config=AuthorizationConfig(type="basic", api_key="abc123", header="X-Auth"),
)
headers = ExternalDatasetService.assembling_headers(auth, headers={"Existing": "1"})
assert headers["Existing"] == "1"
assert headers["X-Auth"] == "Basic abc123"
def test_assembling_headers_custom_type(self):
"""
Custom auth type should inject the raw API key.
"""
auth = Authorization(
type="api-key",
config=AuthorizationConfig(type="custom", api_key="raw-key", header="X-API-KEY"),
)
headers = ExternalDatasetService.assembling_headers(auth, headers=None)
assert headers["X-API-KEY"] == "raw-key"
def test_assembling_headers_missing_config_raises(self):
"""
Missing config object should be rejected.
"""
auth = Authorization(type="api-key", config=None)
with pytest.raises(ValueError, match="authorization config is required"):
ExternalDatasetService.assembling_headers(auth)
def test_assembling_headers_missing_api_key_raises(self):
"""
``api_key`` is required when type is ``api-key``.
"""
auth = Authorization(
type="api-key",
config=AuthorizationConfig(type="bearer", api_key=None, header="Authorization"),
)
with pytest.raises(ValueError, match="api_key is required"):
ExternalDatasetService.assembling_headers(auth)
def test_assembling_headers_no_auth_type_leaves_headers_unchanged(self):
"""
For ``no-auth`` we should not modify the headers mapping.
"""
auth = Authorization(type="no-auth", config=None)
base_headers = {"X": "1"}
result = ExternalDatasetService.assembling_headers(auth, headers=base_headers)
# A copy is returned, original is not mutated.
assert result == base_headers
assert result is not base_headers
# ---------------------------------------------------------------------------
# get_external_knowledge_api_settings
# ---------------------------------------------------------------------------
class TestExternalDatasetServiceGetExternalKnowledgeApiSettings:
"""
Simple shape test for ``get_external_knowledge_api_settings``.
"""
def test_get_external_knowledge_api_settings(self):
settings_dict: dict[str, Any] = {
"url": "https://example.com/retrieval",
"request_method": "post",
"headers": {"Content-Type": "application/json"},
"params": {"foo": "bar"},
}
result = ExternalDatasetService.get_external_knowledge_api_settings(settings_dict)
assert isinstance(result, ExternalKnowledgeApiSetting)
assert result.url == settings_dict["url"]
assert result.request_method == settings_dict["request_method"]
assert result.headers == settings_dict["headers"]
assert result.params == settings_dict["params"]
# ---------------------------------------------------------------------------
# create_external_dataset
# ---------------------------------------------------------------------------
class TestExternalDatasetServiceCreateExternalDataset:
"""
Tests around creating the external dataset and its binding row.
"""
@pytest.fixture
def mock_db_session(self):
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
yield mock_session
def test_create_external_dataset_success(self, mock_db_session: MagicMock):
"""
A brand new dataset name with valid external knowledge references
should create both the dataset and its binding.
"""
tenant_id = "tenant-1"
user_id = "user-1"
args = {
"name": "My Dataset",
"description": "desc",
"external_knowledge_api_id": "api-1",
"external_knowledge_id": "knowledge-1",
"external_retrieval_model": {"top_k": 3},
}
# No existing dataset with same name.
mock_db_session.scalar.side_effect = [
None, # duplicatename check
Mock(spec=ExternalKnowledgeApis), # external knowledge api
]
dataset = ExternalDatasetService.create_external_dataset(tenant_id, user_id, args)
assert isinstance(dataset, Dataset)
assert dataset.provider == "external"
assert dataset.retrieval_model == args["external_retrieval_model"]
assert mock_db_session.add.call_count >= 2 # dataset + binding
mock_db_session.flush.assert_called_once()
mock_db_session.commit.assert_called_once()
def test_create_external_dataset_duplicate_name_raises(self, mock_db_session: MagicMock):
"""
When a dataset with the same name already exists,
``DatasetNameDuplicateError`` is raised.
"""
existing_dataset = Mock(spec=Dataset)
mock_db_session.scalar.return_value = existing_dataset
args = {
"name": "Existing",
"external_knowledge_api_id": "api-1",
"external_knowledge_id": "knowledge-1",
}
with pytest.raises(DatasetNameDuplicateError):
ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args)
mock_db_session.add.assert_not_called()
mock_db_session.commit.assert_not_called()
def test_create_external_dataset_missing_api_template_raises(self, mock_db_session: MagicMock):
"""
If the referenced external knowledge API does not exist, a ``ValueError`` is raised.
"""
# First call: duplicate name check not found.
mock_db_session.scalar.side_effect = [
None,
None, # external knowledge api lookup
]
args = {
"name": "Dataset",
"external_knowledge_api_id": "missing",
"external_knowledge_id": "knowledge-1",
}
with pytest.raises(ValueError, match="api template not found"):
ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args)
def test_create_external_dataset_missing_required_ids_raise(self, mock_db_session: MagicMock):
"""
``external_knowledge_id`` and ``external_knowledge_api_id`` are mandatory.
"""
# duplicate name check — two calls to create_external_dataset, each does 2 scalar calls
mock_db_session.scalar.side_effect = [
None,
Mock(spec=ExternalKnowledgeApis),
None,
Mock(spec=ExternalKnowledgeApis),
]
args_missing_knowledge_id = {
"name": "Dataset",
"external_knowledge_api_id": "api-1",
"external_knowledge_id": None,
}
with pytest.raises(ValueError, match="external_knowledge_id is required"):
ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args_missing_knowledge_id)
args_missing_api_id = {
"name": "Dataset",
"external_knowledge_api_id": None,
"external_knowledge_id": "k-1",
}
with pytest.raises(ValueError, match="external_knowledge_api_id is required"):
ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args_missing_api_id)
# ---------------------------------------------------------------------------
# fetch_external_knowledge_retrieval
# ---------------------------------------------------------------------------
class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
"""
Tests for ``fetch_external_knowledge_retrieval`` which orchestrates
external retrieval requests and normalises the response payload.
"""
@pytest.fixture
def mock_db_session(self):
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
yield mock_session
def test_fetch_external_knowledge_retrieval_success(self, mock_db_session: MagicMock):
"""
With a valid binding and API template, records from the external
service should be returned when the HTTP response is 200.
"""
tenant_id = "tenant-1"
dataset_id = "ds-1"
query = "test query"
external_retrieval_parameters = {"top_k": 3, "score_threshold_enabled": True, "score_threshold": 0.5}
binding = ExternalDatasetTestDataFactory.create_external_binding(
tenant_id=tenant_id,
dataset_id=dataset_id,
api_id="api-1",
external_knowledge_id="knowledge-1",
)
api = Mock(spec=ExternalKnowledgeApis)
api.settings = '{"endpoint":"https://example.com","api_key":"secret"}'
# First query: binding; second query: api.
mock_db_session.scalar.side_effect = [
binding,
api,
]
fake_records = [{"content": "doc", "score": 0.9}]
fake_response = Mock(spec=httpx.Response)
fake_response.status_code = 200
fake_response.json.return_value = {"records": fake_records}
metadata_condition = SimpleNamespace(model_dump=lambda: {"field": "value"})
with patch.object(
ExternalDatasetService, "process_external_api", return_value=fake_response, autospec=True
) as mock_process:
result = ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id=tenant_id,
dataset_id=dataset_id,
query=query,
external_retrieval_parameters=external_retrieval_parameters,
metadata_condition=metadata_condition,
)
assert result == fake_records
mock_process.assert_called_once()
setting_arg = mock_process.call_args.args[0]
assert isinstance(setting_arg, ExternalKnowledgeApiSetting)
assert setting_arg.url.endswith("/retrieval")
def test_fetch_external_knowledge_retrieval_binding_not_found_raises(self, mock_db_session: MagicMock):
"""
Missing binding should raise ``ValueError``.
"""
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="external knowledge binding not found"):
ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id="tenant-1",
dataset_id="missing",
query="q",
external_retrieval_parameters={},
metadata_condition=None,
)
def test_fetch_external_knowledge_retrieval_missing_api_template_raises(self, mock_db_session: MagicMock):
"""
When the API template is missing or has no settings, a ``ValueError`` is raised.
"""
binding = ExternalDatasetTestDataFactory.create_external_binding()
mock_db_session.scalar.side_effect = [
binding,
None,
]
with pytest.raises(ValueError, match="external api template not found"):
ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id="tenant-1",
dataset_id="ds-1",
query="q",
external_retrieval_parameters={},
metadata_condition=None,
)
def test_fetch_external_knowledge_retrieval_non_200_status_returns_empty_list(self, mock_db_session: MagicMock):
"""
Non200 responses should be treated as an empty result set.
"""
binding = ExternalDatasetTestDataFactory.create_external_binding()
api = Mock(spec=ExternalKnowledgeApis)
api.settings = '{"endpoint":"https://example.com","api_key":"secret"}'
mock_db_session.scalar.side_effect = [
binding,
api,
]
fake_response = Mock(spec=httpx.Response)
fake_response.status_code = 500
fake_response.json.return_value = {}
with patch.object(ExternalDatasetService, "process_external_api", return_value=fake_response, autospec=True):
result = ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id="tenant-1",
dataset_id="ds-1",
query="q",
external_retrieval_parameters={},
metadata_condition=None,
)
assert result == []

View File

@@ -374,24 +374,14 @@ def test_publish_workflow_success(mocker, rag_pipeline_service) -> None:
mock_db = mocker.Mock()
mocker.patch("services.rag_pipeline.rag_pipeline.db", mock_db)
mock_dataset_service_class = mocker.patch("services.dataset_service.DatasetService")
mock_dataset_service = mock_dataset_service_class.return_value
# 6. Mock session and its scalar/query methods
# 6. Mock session and dataset lookup
mock_session = mocker.Mock()
mock_session.scalar.return_value = draft_wf
# Mock dataset update query (needed even if service is mocked, as rag_pipeline fetches it first)
dataset = mocker.Mock()
dataset.retrieval_model_dict = {}
dataset_query = mocker.Mock()
dataset_query.where.return_value.first.return_value = dataset
# Mock node execution copy
node_exec_query = mocker.Mock()
node_exec_query.where.return_value.all.return_value = []
# Mocked session query side effects
mock_session.query.side_effect = [node_exec_query, dataset_query]
pipeline.retrieve_dataset.return_value = dataset
# 7. Run test
result = rag_pipeline_service.publish_workflow(session=mock_session, pipeline=pipeline, account=account)
@@ -1524,7 +1514,6 @@ def test_handle_node_run_result_marks_document_error_for_published_invoke(mocker
)
document = SimpleNamespace(indexing_status="waiting", error=None)
query = mocker.Mock()
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=document)
add_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.add")
commit_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
@@ -1595,7 +1584,6 @@ def test_publish_customized_pipeline_template_raises_for_missing_workflow_id(moc
def test_get_pipeline_raises_when_dataset_missing(mocker, rag_pipeline_service) -> None:
query = mocker.Mock()
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
with pytest.raises(ValueError, match="Dataset not found"):
@@ -1604,7 +1592,6 @@ def test_get_pipeline_raises_when_dataset_missing(mocker, rag_pipeline_service)
def test_get_pipeline_raises_when_pipeline_missing(mocker, rag_pipeline_service) -> None:
dataset = SimpleNamespace(pipeline_id="p1")
query = mocker.Mock()
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, None])
with pytest.raises(ValueError, match="Pipeline not found"):
@@ -1644,7 +1631,6 @@ def test_get_pipeline_templates_builtin_en_us_no_fallback(mocker) -> None:
def test_update_customized_pipeline_template_commits_when_name_empty(mocker) -> None:
template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None)
query = mocker.Mock()
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=template)
commit = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
@@ -1871,7 +1857,6 @@ def test_run_free_workflow_node_delegates_to_handle_result(mocker, rag_pipeline_
def test_publish_customized_pipeline_template_raises_when_workflow_missing(mocker, rag_pipeline_service) -> None:
pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id="wf-1")
query = mocker.Mock()
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", side_effect=[pipeline, None])
with pytest.raises(ValueError, match="Workflow not found"):
@@ -1910,7 +1895,6 @@ def test_get_recommended_plugins_skips_manifest_when_missing(mocker, rag_pipelin
def test_retry_error_document_raises_when_pipeline_missing(mocker, rag_pipeline_service) -> None:
exec_log = SimpleNamespace(pipeline_id="p1")
query = mocker.Mock()
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=exec_log)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=None)
@@ -1923,7 +1907,6 @@ def test_retry_error_document_raises_when_pipeline_missing(mocker, rag_pipeline_
def test_retry_error_document_raises_when_workflow_missing(mocker, rag_pipeline_service) -> None:
exec_log = SimpleNamespace(pipeline_id="p1")
pipeline = SimpleNamespace(id="p1")
query = mocker.Mock()
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=exec_log)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline)
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=None)
@@ -1940,7 +1923,6 @@ def test_get_datasource_plugins_returns_empty_for_non_datasource_nodes(mocker, r
workflow = SimpleNamespace(
graph_dict={"nodes": [{"id": "n1", "data": {"type": "start"}}]}, rag_pipeline_variables=[]
)
query = mocker.Mock()
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
@@ -2103,7 +2085,6 @@ def test_get_datasource_plugins_handles_empty_datasource_data_and_non_published(
graph_dict={"nodes": [{"id": "n1", "data": {"type": "datasource", "datasource_parameters": {}}}]},
rag_pipeline_variables=[{"variable": "v1", "belong_to_node_id": "shared"}],
)
query = mocker.Mock()
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
mocker.patch.object(rag_pipeline_service, "get_draft_workflow", return_value=workflow)
mocker.patch(
@@ -2143,7 +2124,6 @@ def test_get_datasource_plugins_extracts_user_inputs_and_credentials(mocker, rag
{"variable": "v3", "belong_to_node_id": "shared"},
],
)
query = mocker.Mock()
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
mocker.patch(
@@ -2161,7 +2141,6 @@ def test_get_datasource_plugins_extracts_user_inputs_and_credentials(mocker, rag
def test_get_pipeline_returns_pipeline_when_found(mocker, rag_pipeline_service) -> None:
dataset = SimpleNamespace(pipeline_id="p1")
pipeline = SimpleNamespace(id="p1")
query = mocker.Mock()
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
result = rag_pipeline_service.get_pipeline("t1", "d1")

File diff suppressed because it is too large Load Diff

View File

@@ -1,59 +0,0 @@
from unittest.mock import MagicMock
class ServiceDbTestHelper:
"""
Helper class for service database query tests.
"""
@staticmethod
def setup_db_query_filter_by_mock(mock_db, query_results):
"""
Smart database query mock that responds based on model type and query parameters.
Args:
mock_db: Mock database session
query_results: Dict mapping (model_name, filter_key, filter_value) to return value
Example: {('Account', 'email', 'test@example.com'): mock_account}
"""
def query_side_effect(model):
mock_query = MagicMock()
def filter_by_side_effect(**kwargs):
mock_filter_result = MagicMock()
def first_side_effect():
# Find matching result based on model and filter parameters
for (model_name, filter_key, filter_value), result in query_results.items():
if model.__name__ == model_name and filter_key in kwargs and kwargs[filter_key] == filter_value:
return result
return None
mock_filter_result.first.side_effect = first_side_effect
# Handle order_by calls for complex queries
def order_by_side_effect(*args, **kwargs):
mock_order_result = MagicMock()
def order_first_side_effect():
# Look for order_by results in the same query_results dict
for (model_name, filter_key, filter_value), result in query_results.items():
if (
model.__name__ == model_name
and filter_key == "order_by"
and filter_value == "first_available"
):
return result
return None
mock_order_result.first.side_effect = order_first_side_effect
return mock_order_result
mock_filter_result.order_by.side_effect = order_by_side_effect
return mock_filter_result
mock_query.filter_by.side_effect = filter_by_side_effect
return mock_query
mock_db.session.query.side_effect = query_side_effect

View File

@@ -14,7 +14,6 @@ from services.errors.account import (
AccountRegisterError,
CurrentPasswordIncorrectError,
)
from tests.unit_tests.services.services_test_help import ServiceDbTestHelper
class TestAccountAssociatedDataFactory:
@@ -149,7 +148,6 @@ class TestAccountService:
# Setup basic session methods
mock_session.add = MagicMock()
mock_session.commit = MagicMock()
mock_session.query = MagicMock()
yield mock_db
@@ -1572,15 +1570,9 @@ class TestRegisterService:
account_id="existing-user-456", email="existing@example.com", status="active"
)
# Mock database queries
query_results = {
(
"TenantAccountJoin",
"tenant_id",
"tenant-456",
): TestAccountAssociatedDataFactory.create_tenant_join_mock(),
}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies[
"db"
].session.scalar.return_value = TestAccountAssociatedDataFactory.create_tenant_join_mock()
# Mock TenantService methods
with (

View File

@@ -714,7 +714,6 @@ class TestSegmentServiceMutations:
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.delete_segment_from_index_task") as delete_task,
):
segments_query = MagicMock()
# execute().all() for segments_info (multi-column)
execute_result = MagicMock()
execute_result.all.return_value = [

View File

@@ -36,9 +36,7 @@ class TestDatasourceProviderService:
@pytest.fixture
def mock_db_session(self):
"""
Robust, chainable query mock.
q returns itself for .filter_by(), .order_by(), .where() so any
SQLAlchemy chaining pattern works without multiple brittle sub-mocks.
Mock session with scalar/scalars defaults for current SQLAlchemy access paths.
"""
with (
patch("services.datasource_provider_service.Session") as mock_cls,
@@ -46,20 +44,6 @@ class TestDatasourceProviderService:
):
sess = MagicMock(spec=Session)
q = MagicMock()
sess.query.return_value = q
# Self-returning chain — any method called on q returns q
q.filter_by.return_value = q
q.order_by.return_value = q
q.where.return_value = q
# Default terminal values (tests override per-case)
q.first.return_value = None
q.all.return_value = []
q.count.return_value = 0
q.delete.return_value = 1
# Default values for select()-style calls (tests override per-case)
sess.scalar.return_value = None
sess.scalars.return_value.all.return_value = []

View File

@@ -17,23 +17,6 @@ from services.trigger import webhook_service as service_module
from services.trigger.webhook_service import WebhookService
class _FakeQuery:
def __init__(self, result: Any) -> None:
self._result = result
def where(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
return self
def filter(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
return self
def order_by(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
return self
def first(self) -> Any:
return self._result
@pytest.fixture
def flask_app() -> Flask:
return Flask(__name__)

View File

@@ -1649,8 +1649,6 @@ class TestWorkflowServiceCredentialValidation:
"""Missing BuiltinToolProvider → plugin requires no credentials → no error."""
# Arrange
with patch("services.workflow_service.db") as mock_db:
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = None
# Act + Assert (should NOT raise)
service._check_default_tool_credential("tenant-1", "some-provider")
@@ -1662,10 +1660,6 @@ class TestWorkflowServiceCredentialValidation:
patch("services.workflow_service.db") as mock_db,
patch("core.helper.credential_utils.check_credential_policy_compliance", side_effect=Exception("denied")),
):
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = (
mock_provider
)
# Act + Assert
with pytest.raises(ValueError, match="Failed to validate default credential"):
service._check_default_tool_credential("tenant-1", "some-provider")

File diff suppressed because it is too large Load Diff

View File

@@ -89,9 +89,6 @@ def mock_db_session():
session = MagicMock()
session._shared_data = {"dataset": None, "documents": []}
# Keep a pointer so repeated Document.first() calls iterate across provided docs
session._doc_first_idx = 0
def _get_entity(stmt) -> type | None:
"""Extract the mapped entity class from a SQLAlchemy select statement."""
try:
@@ -1591,18 +1588,7 @@ class TestDocumentIndexingTaskSummaryFlow:
need_summary=True,
)
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.return_value = dataset
phase1_docs = [SimpleNamespace(id="doc-1"), SimpleNamespace(id="doc-2"), SimpleNamespace(id="doc-3")]
phase1_document_query = MagicMock()
phase1_document_query.where.return_value = phase1_document_query
phase1_document_query.all.return_value = phase1_docs
summary_document_query = MagicMock()
summary_document_query.where.return_value = summary_document_query
summary_document_query.all.return_value = [doc_eligible, doc_skip_form, doc_skip_status]
session1 = MagicMock()
session2 = MagicMock()
@@ -1657,18 +1643,6 @@ class TestDocumentIndexingTaskSummaryFlow:
need_summary=True,
)
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.return_value = dataset
phase1_query = MagicMock()
phase1_query.where.return_value = phase1_query
phase1_query.all.return_value = [SimpleNamespace(id="doc-1")]
summary_query = MagicMock()
summary_query.where.return_value = summary_query
summary_query.all.return_value = [doc_eligible]
session1 = MagicMock()
session2 = MagicMock()
session2.begin.return_value = nullcontext()