mirror of
https://github.com/langgenius/dify.git
synced 2026-05-01 01:00:51 -04:00
chore: fix use select style api in orm (#35531)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: WH-2099 <wh2099@pm.me>
This commit is contained in:
@@ -225,8 +225,10 @@ class TestSpanBuilder:
|
||||
span = builder.build_span(span_data)
|
||||
assert isinstance(span, ReadableSpan)
|
||||
assert span.name == "test-span"
|
||||
assert span.context is not None
|
||||
assert span.context.trace_id == 123
|
||||
assert span.context.span_id == 456
|
||||
assert span.parent is not None
|
||||
assert span.parent.span_id == 789
|
||||
assert span.resource == resource
|
||||
assert span.attributes == {"attr1": "val1"}
|
||||
|
||||
@@ -64,12 +64,13 @@ class TestSpanData:
|
||||
|
||||
def test_span_data_missing_required_fields(self):
|
||||
with pytest.raises(ValidationError):
|
||||
SpanData(
|
||||
trace_id=123,
|
||||
# span_id missing
|
||||
name="test_span",
|
||||
start_time=1000,
|
||||
end_time=2000,
|
||||
SpanData.model_validate(
|
||||
{
|
||||
"trace_id": 123,
|
||||
"name": "test_span",
|
||||
"start_time": 1000,
|
||||
"end_time": 2000,
|
||||
}
|
||||
)
|
||||
|
||||
def test_span_data_arbitrary_types_allowed(self):
|
||||
|
||||
@@ -2,12 +2,14 @@ from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import dify_trace_aliyun.aliyun_trace as aliyun_trace_module
|
||||
import pytest
|
||||
from dify_trace_aliyun.aliyun_trace import AliyunDataTrace
|
||||
from dify_trace_aliyun.config import AliyunConfig
|
||||
from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata
|
||||
from dify_trace_aliyun.entities.semconv import (
|
||||
GEN_AI_COMPLETION,
|
||||
GEN_AI_INPUT_MESSAGE,
|
||||
@@ -44,7 +46,7 @@ class RecordingTraceClient:
|
||||
self.endpoint = endpoint
|
||||
self.added_spans: list[object] = []
|
||||
|
||||
def add_span(self, span) -> None:
|
||||
def add_span(self, span: object) -> None:
|
||||
self.added_spans.append(span)
|
||||
|
||||
def api_check(self) -> bool:
|
||||
@@ -63,11 +65,35 @@ def _make_link(trace_id: int = 1, span_id: int = 2) -> Link:
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags.SAMPLED,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
)
|
||||
return Link(context)
|
||||
|
||||
|
||||
def _make_trace_metadata(
|
||||
trace_id: int = 1,
|
||||
workflow_span_id: int = 2,
|
||||
session_id: str = "s",
|
||||
user_id: str = "u",
|
||||
links: list[Link] | None = None,
|
||||
) -> TraceMetadata:
|
||||
return TraceMetadata(
|
||||
trace_id=trace_id,
|
||||
workflow_span_id=workflow_span_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
links=[] if links is None else links,
|
||||
)
|
||||
|
||||
|
||||
def _recording_trace_client(trace_instance: AliyunDataTrace) -> RecordingTraceClient:
|
||||
return cast(RecordingTraceClient, trace_instance.trace_client)
|
||||
|
||||
|
||||
def _recorded_span_data(trace_instance: AliyunDataTrace) -> list[SpanData]:
|
||||
return cast(list[SpanData], _recording_trace_client(trace_instance).added_spans)
|
||||
|
||||
|
||||
def _make_workflow_trace_info(**overrides) -> WorkflowTraceInfo:
|
||||
defaults = {
|
||||
"workflow_id": "workflow-id",
|
||||
@@ -263,20 +289,20 @@ def test_workflow_trace_adds_workflow_and_node_spans(trace_instance: AliyunDataT
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
add_workflow_span.assert_called_once()
|
||||
passed_trace_metadata = add_workflow_span.call_args.args[1]
|
||||
passed_trace_metadata = cast(TraceMetadata, add_workflow_span.call_args.args[1])
|
||||
assert passed_trace_metadata.trace_id == 111
|
||||
assert passed_trace_metadata.workflow_span_id == 222
|
||||
assert passed_trace_metadata.session_id == "c"
|
||||
assert passed_trace_metadata.user_id == "u"
|
||||
assert passed_trace_metadata.links == []
|
||||
|
||||
assert trace_instance.trace_client.added_spans == ["span-1", "span-2"]
|
||||
assert _recording_trace_client(trace_instance).added_spans == ["span-1", "span-2"]
|
||||
|
||||
|
||||
def test_message_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
|
||||
trace_info = _make_message_trace_info(message_data=None)
|
||||
trace_instance.message_trace(trace_info)
|
||||
assert trace_instance.trace_client.added_spans == []
|
||||
assert _recording_trace_client(trace_instance).added_spans == []
|
||||
|
||||
|
||||
def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
@@ -302,8 +328,9 @@ def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataT
|
||||
)
|
||||
trace_instance.message_trace(trace_info)
|
||||
|
||||
assert len(trace_instance.trace_client.added_spans) == 2
|
||||
message_span, llm_span = trace_instance.trace_client.added_spans
|
||||
spans = _recorded_span_data(trace_instance)
|
||||
assert len(spans) == 2
|
||||
message_span, llm_span = spans
|
||||
|
||||
assert message_span.name == "message"
|
||||
assert message_span.trace_id == 10
|
||||
@@ -324,7 +351,7 @@ def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataT
|
||||
def test_dataset_retrieval_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
|
||||
trace_info = _make_dataset_retrieval_trace_info(message_data=None)
|
||||
trace_instance.dataset_retrieval_trace(trace_info)
|
||||
assert trace_instance.trace_client.added_spans == []
|
||||
assert _recording_trace_client(trace_instance).added_spans == []
|
||||
|
||||
|
||||
def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
@@ -338,8 +365,9 @@ def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, m
|
||||
monkeypatch.setattr(aliyun_trace_module, "extract_retrieval_documents", lambda _: [{"doc": "d"}])
|
||||
|
||||
trace_instance.dataset_retrieval_trace(_make_dataset_retrieval_trace_info(inputs="query"))
|
||||
assert len(trace_instance.trace_client.added_spans) == 1
|
||||
span = trace_instance.trace_client.added_spans[0]
|
||||
spans = _recorded_span_data(trace_instance)
|
||||
assert len(spans) == 1
|
||||
span = spans[0]
|
||||
assert span.name == "dataset_retrieval"
|
||||
assert span.attributes[RETRIEVAL_QUERY] == "query"
|
||||
assert span.attributes[RETRIEVAL_DOCUMENT] == '[{"doc": "d"}]'
|
||||
@@ -348,7 +376,7 @@ def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, m
|
||||
def test_tool_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
|
||||
trace_info = _make_tool_trace_info(message_data=None)
|
||||
trace_instance.tool_trace(trace_info)
|
||||
assert trace_instance.trace_client.added_spans == []
|
||||
assert _recording_trace_client(trace_instance).added_spans == []
|
||||
|
||||
|
||||
def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
@@ -371,8 +399,9 @@ def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: p
|
||||
)
|
||||
)
|
||||
|
||||
assert len(trace_instance.trace_client.added_spans) == 1
|
||||
span = trace_instance.trace_client.added_spans[0]
|
||||
spans = _recorded_span_data(trace_instance)
|
||||
assert len(spans) == 1
|
||||
span = spans[0]
|
||||
assert span.name == "my-tool"
|
||||
assert span.status == status
|
||||
assert span.attributes[TOOL_NAME] == "my-tool"
|
||||
@@ -409,7 +438,7 @@ def test_get_workflow_node_executions_builds_repo_and_fetches(
|
||||
def test_build_workflow_node_span_routes_llm_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
trace_info = _make_workflow_trace_info()
|
||||
trace_metadata = MagicMock()
|
||||
trace_metadata = _make_trace_metadata()
|
||||
|
||||
monkeypatch.setattr(trace_instance, "build_workflow_llm_span", MagicMock(return_value="llm"))
|
||||
|
||||
@@ -422,7 +451,7 @@ def test_build_workflow_node_span_routes_knowledge_retrieval_type(
|
||||
):
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
trace_info = _make_workflow_trace_info()
|
||||
trace_metadata = MagicMock()
|
||||
trace_metadata = _make_trace_metadata()
|
||||
|
||||
monkeypatch.setattr(trace_instance, "build_workflow_retrieval_span", MagicMock(return_value="retrieval"))
|
||||
|
||||
@@ -433,7 +462,7 @@ def test_build_workflow_node_span_routes_knowledge_retrieval_type(
|
||||
def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
trace_info = _make_workflow_trace_info()
|
||||
trace_metadata = MagicMock()
|
||||
trace_metadata = _make_trace_metadata()
|
||||
|
||||
monkeypatch.setattr(trace_instance, "build_workflow_tool_span", MagicMock(return_value="tool"))
|
||||
|
||||
@@ -444,7 +473,7 @@ def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTra
|
||||
def test_build_workflow_node_span_routes_code_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
trace_info = _make_workflow_trace_info()
|
||||
trace_metadata = MagicMock()
|
||||
trace_metadata = _make_trace_metadata()
|
||||
|
||||
monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(return_value="task"))
|
||||
|
||||
@@ -457,7 +486,7 @@ def test_build_workflow_node_span_handles_errors(
|
||||
):
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
trace_info = _make_workflow_trace_info()
|
||||
trace_metadata = MagicMock()
|
||||
trace_metadata = _make_trace_metadata()
|
||||
|
||||
monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(side_effect=RuntimeError("boom")))
|
||||
node_execution.node_type = BuiltinNodeTypes.CODE
|
||||
@@ -472,7 +501,7 @@ def test_build_workflow_task_span(trace_instance: AliyunDataTrace, monkeypatch:
|
||||
status = Status(StatusCode.OK)
|
||||
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
|
||||
|
||||
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
|
||||
trace_metadata = _make_trace_metadata()
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node-id"
|
||||
node_execution.title = "title"
|
||||
@@ -494,7 +523,7 @@ def test_build_workflow_tool_span(trace_instance: AliyunDataTrace, monkeypatch:
|
||||
status = Status(StatusCode.OK)
|
||||
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
|
||||
|
||||
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[_make_link()])
|
||||
trace_metadata = _make_trace_metadata(links=[_make_link()])
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node-id"
|
||||
node_execution.title = "my-tool"
|
||||
@@ -527,7 +556,7 @@ def test_build_workflow_retrieval_span(trace_instance: AliyunDataTrace, monkeypa
|
||||
aliyun_trace_module, "format_retrieval_documents", lambda docs: [{"formatted": True}] if docs else []
|
||||
)
|
||||
|
||||
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
|
||||
trace_metadata = _make_trace_metadata()
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node-id"
|
||||
node_execution.title = "retrieval"
|
||||
@@ -556,7 +585,7 @@ def test_build_workflow_llm_span(trace_instance: AliyunDataTrace, monkeypatch: p
|
||||
monkeypatch.setattr(aliyun_trace_module, "format_input_messages", lambda _: "in")
|
||||
monkeypatch.setattr(aliyun_trace_module, "format_output_messages", lambda _: "out")
|
||||
|
||||
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
|
||||
trace_metadata = _make_trace_metadata()
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node-id"
|
||||
node_execution.title = "llm"
|
||||
@@ -594,7 +623,7 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
|
||||
status = Status(StatusCode.OK)
|
||||
monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status)
|
||||
|
||||
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
|
||||
trace_metadata = _make_trace_metadata()
|
||||
|
||||
# CASE 1: With message_id
|
||||
trace_info = _make_workflow_trace_info(
|
||||
@@ -602,9 +631,11 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
|
||||
)
|
||||
trace_instance.add_workflow_span(trace_info, trace_metadata)
|
||||
|
||||
assert len(trace_instance.trace_client.added_spans) == 2
|
||||
message_span = trace_instance.trace_client.added_spans[0]
|
||||
workflow_span = trace_instance.trace_client.added_spans[1]
|
||||
client = _recording_trace_client(trace_instance)
|
||||
spans = _recorded_span_data(trace_instance)
|
||||
assert len(spans) == 2
|
||||
message_span = spans[0]
|
||||
workflow_span = spans[1]
|
||||
|
||||
assert message_span.name == "message"
|
||||
assert message_span.span_kind == SpanKind.SERVER
|
||||
@@ -614,13 +645,14 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
|
||||
assert workflow_span.span_kind == SpanKind.INTERNAL
|
||||
assert workflow_span.parent_span_id == 20
|
||||
|
||||
trace_instance.trace_client.added_spans.clear()
|
||||
client.added_spans.clear()
|
||||
|
||||
# CASE 2: Without message_id
|
||||
trace_info_no_msg = _make_workflow_trace_info(message_id=None)
|
||||
trace_instance.add_workflow_span(trace_info_no_msg, trace_metadata)
|
||||
assert len(trace_instance.trace_client.added_spans) == 1
|
||||
span = trace_instance.trace_client.added_spans[0]
|
||||
spans = _recorded_span_data(trace_instance)
|
||||
assert len(spans) == 1
|
||||
span = spans[0]
|
||||
assert span.name == "workflow"
|
||||
assert span.span_kind == SpanKind.SERVER
|
||||
assert span.parent_span_id is None
|
||||
@@ -641,7 +673,8 @@ def test_suggested_question_trace(trace_instance: AliyunDataTrace, monkeypatch:
|
||||
trace_info = _make_suggested_question_trace_info(suggested_question=["how?"])
|
||||
trace_instance.suggested_question_trace(trace_info)
|
||||
|
||||
assert len(trace_instance.trace_client.added_spans) == 1
|
||||
span = trace_instance.trace_client.added_spans[0]
|
||||
spans = _recorded_span_data(trace_instance)
|
||||
assert len(spans) == 1
|
||||
span = spans[0]
|
||||
assert span.name == "suggested_question"
|
||||
assert span.attributes[GEN_AI_COMPLETION] == '["how?"]'
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from dify_trace_aliyun.entities.semconv import (
|
||||
@@ -170,7 +172,7 @@ def test_create_common_span_attributes():
|
||||
|
||||
def test_format_retrieval_documents():
|
||||
# Not a list
|
||||
assert format_retrieval_documents("not a list") == []
|
||||
assert format_retrieval_documents(cast(list[object], "not a list")) == []
|
||||
|
||||
# Valid list
|
||||
docs = [
|
||||
@@ -211,7 +213,7 @@ def test_format_retrieval_documents():
|
||||
|
||||
def test_format_input_messages():
|
||||
# Not a dict
|
||||
assert format_input_messages(None) == serialize_json_data([])
|
||||
assert format_input_messages(cast(Mapping[str, Any], None)) == serialize_json_data([])
|
||||
|
||||
# No prompts
|
||||
assert format_input_messages({}) == serialize_json_data([])
|
||||
@@ -244,7 +246,7 @@ def test_format_input_messages():
|
||||
|
||||
def test_format_output_messages():
|
||||
# Not a dict
|
||||
assert format_output_messages(None) == serialize_json_data([])
|
||||
assert format_output_messages(cast(Mapping[str, Any], None)) == serialize_json_data([])
|
||||
|
||||
# No text
|
||||
assert format_output_messages({"finish_reason": "stop"}) == serialize_json_data([])
|
||||
|
||||
@@ -25,13 +25,13 @@ class TestAliyunConfig:
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig()
|
||||
AliyunConfig.model_validate({})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig(license_key="test_license")
|
||||
AliyunConfig.model_validate({"license_key": "test_license"})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||
AliyunConfig.model_validate({"endpoint": "https://tracing-analysis-dc-hz.aliyuncs.com"})
|
||||
|
||||
def test_app_name_validation_empty(self):
|
||||
"""Test app_name validation with empty value"""
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -129,7 +130,7 @@ def test_set_span_status():
|
||||
return "SilentErrorRepr"
|
||||
|
||||
span.reset_mock()
|
||||
set_span_status(span, SilentError())
|
||||
set_span_status(span, cast(Exception | str | None, SilentError()))
|
||||
assert span.add_event.call_args[1]["attributes"][OTELSpanAttributes.EXCEPTION_MESSAGE] == "SilentErrorRepr"
|
||||
|
||||
|
||||
|
||||
@@ -28,13 +28,13 @@ class TestLangfuseConfig:
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
LangfuseConfig()
|
||||
LangfuseConfig.model_validate({})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangfuseConfig(public_key="public")
|
||||
LangfuseConfig.model_validate({"public_key": "public"})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangfuseConfig(secret_key="secret")
|
||||
LangfuseConfig.model_validate({"secret_key": "secret"})
|
||||
|
||||
def test_host_validation_empty(self):
|
||||
"""Test host validation with empty value"""
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from dify_trace_langfuse.config import LangfuseConfig
|
||||
@@ -134,4 +135,4 @@ class TestLangFuseDataTraceCompletionStartTime:
|
||||
|
||||
assert trace._get_completion_start_time(start_time, None) is None
|
||||
assert trace._get_completion_start_time(start_time, -1) is None
|
||||
assert trace._get_completion_start_time(start_time, "invalid") is None
|
||||
assert trace._get_completion_start_time(start_time, cast(float | int | None, "invalid")) is None
|
||||
|
||||
@@ -21,13 +21,13 @@ class TestLangSmithConfig:
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
LangSmithConfig()
|
||||
LangSmithConfig.model_validate({})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangSmithConfig(api_key="key")
|
||||
LangSmithConfig.model_validate({"api_key": "key"})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangSmithConfig(project="project")
|
||||
LangSmithConfig.model_validate({"project": "project"})
|
||||
|
||||
def test_endpoint_validation_https_only(self):
|
||||
"""Test endpoint validation only allows HTTPS"""
|
||||
|
||||
@@ -599,7 +599,6 @@ class TestMessageTrace:
|
||||
span = MagicMock()
|
||||
mock_tracing["start"].return_value = span
|
||||
mock_tracing["set"].return_value = "token"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
trace_instance.message_trace(_make_message_trace_info())
|
||||
mock_tracing["start"].assert_called_once()
|
||||
@@ -609,7 +608,6 @@ class TestMessageTrace:
|
||||
span = MagicMock()
|
||||
mock_tracing["start"].return_value = span
|
||||
mock_tracing["set"].return_value = "token"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
trace_info = _make_message_trace_info(error="something broke")
|
||||
trace_instance.message_trace(trace_info)
|
||||
@@ -620,7 +618,6 @@ class TestMessageTrace:
|
||||
span = MagicMock()
|
||||
mock_tracing["start"].return_value = span
|
||||
mock_tracing["set"].return_value = "token"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
monkeypatch.setenv("FILES_URL", "http://files.test")
|
||||
|
||||
file_data = SimpleNamespace(url="path/to/file.png")
|
||||
@@ -638,7 +635,6 @@ class TestMessageTrace:
|
||||
span = MagicMock()
|
||||
mock_tracing["start"].return_value = span
|
||||
mock_tracing["set"].return_value = "token"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
trace_info = _make_message_trace_info(file_list=None, message_file_data=None)
|
||||
trace_instance.message_trace(trace_info)
|
||||
@@ -651,7 +647,6 @@ class TestMessageTrace:
|
||||
|
||||
end_user = MagicMock()
|
||||
end_user.session_id = "session-xyz"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = end_user
|
||||
|
||||
trace_info = _make_message_trace_info(
|
||||
metadata={"from_end_user_id": "eu-1", "conversation_id": "c1"},
|
||||
@@ -664,7 +659,6 @@ class TestMessageTrace:
|
||||
span = MagicMock()
|
||||
mock_tracing["start"].return_value = span
|
||||
mock_tracing["set"].return_value = "token"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
trace_info = _make_message_trace_info(
|
||||
metadata={"from_account_id": "acc-1"},
|
||||
|
||||
@@ -12,6 +12,7 @@ from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from dify_trace_opik.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid
|
||||
@@ -69,6 +70,14 @@ def _make_opik_trace_instance() -> OpikDataTrace:
|
||||
return instance
|
||||
|
||||
|
||||
def _add_trace_mock(instance: OpikDataTrace) -> MagicMock:
|
||||
return cast(MagicMock, instance.add_trace)
|
||||
|
||||
|
||||
def _add_span_mock(instance: OpikDataTrace) -> MagicMock:
|
||||
return cast(MagicMock, instance.add_span)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _seed_to_uuid4
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -155,21 +164,21 @@ class TestWorkflowTraceWithoutMessageId:
|
||||
def test_root_span_is_created(self):
|
||||
trace_info = _make_workflow_trace_info(message_id=None)
|
||||
instance = self._run(trace_info)
|
||||
assert instance.add_span.called
|
||||
assert _add_span_mock(instance).called
|
||||
|
||||
def test_root_span_id_matches_expected(self):
|
||||
trace_info = _make_workflow_trace_info(message_id=None)
|
||||
instance = self._run(trace_info)
|
||||
|
||||
expected = self._expected_root_span_id(trace_info)
|
||||
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
|
||||
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
|
||||
assert root_span_kwargs["id"] == expected
|
||||
|
||||
def test_root_span_has_no_parent(self):
|
||||
trace_info = _make_workflow_trace_info(message_id=None)
|
||||
instance = self._run(trace_info)
|
||||
|
||||
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
|
||||
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
|
||||
assert root_span_kwargs["parent_span_id"] is None
|
||||
|
||||
def test_trace_name_is_workflow_trace(self):
|
||||
@@ -177,21 +186,21 @@ class TestWorkflowTraceWithoutMessageId:
|
||||
trace_info = _make_workflow_trace_info(message_id=None)
|
||||
instance = self._run(trace_info)
|
||||
|
||||
trace_kwargs = instance.add_trace.call_args_list[0][0][0]
|
||||
trace_kwargs = _add_trace_mock(instance).call_args_list[0][0][0]
|
||||
assert trace_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE
|
||||
|
||||
def test_root_span_name_is_workflow_trace(self):
|
||||
trace_info = _make_workflow_trace_info(message_id=None)
|
||||
instance = self._run(trace_info)
|
||||
|
||||
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
|
||||
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
|
||||
assert root_span_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE
|
||||
|
||||
def test_root_span_has_workflow_tag(self):
|
||||
trace_info = _make_workflow_trace_info(message_id=None)
|
||||
instance = self._run(trace_info)
|
||||
|
||||
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
|
||||
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
|
||||
assert "workflow" in root_span_kwargs["tags"]
|
||||
|
||||
def test_node_execution_spans_are_parented_to_root(self):
|
||||
@@ -214,8 +223,9 @@ class TestWorkflowTraceWithoutMessageId:
|
||||
instance = self._run(trace_info, node_executions=[node_exec])
|
||||
|
||||
# call_args_list[0] = root span, [1] = node execution span
|
||||
assert instance.add_span.call_count == 2
|
||||
node_span_kwargs = instance.add_span.call_args_list[1][0][0]
|
||||
add_span = _add_span_mock(instance)
|
||||
assert add_span.call_count == 2
|
||||
node_span_kwargs = add_span.call_args_list[1][0][0]
|
||||
assert node_span_kwargs["parent_span_id"] == expected_root_span_id
|
||||
|
||||
def test_node_span_not_parented_to_workflow_app_log_id(self):
|
||||
@@ -240,7 +250,7 @@ class TestWorkflowTraceWithoutMessageId:
|
||||
instance = self._run(trace_info, node_executions=[node_exec])
|
||||
|
||||
old_parent_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_app_log_id)
|
||||
node_span_kwargs = instance.add_span.call_args_list[1][0][0]
|
||||
node_span_kwargs = _add_span_mock(instance).call_args_list[1][0][0]
|
||||
assert node_span_kwargs["parent_span_id"] != old_parent_id
|
||||
|
||||
def test_root_span_id_differs_from_trace_id(self):
|
||||
@@ -283,7 +293,7 @@ class TestWorkflowTraceWithMessageId:
|
||||
trace_info = _make_workflow_trace_info(message_id=self._MESSAGE_ID)
|
||||
instance = self._run(trace_info)
|
||||
|
||||
trace_kwargs = instance.add_trace.call_args_list[0][0][0]
|
||||
trace_kwargs = _add_trace_mock(instance).call_args_list[0][0][0]
|
||||
assert trace_kwargs["name"] == TraceTaskName.MESSAGE_TRACE
|
||||
|
||||
def test_root_span_uses_workflow_run_id_directly(self):
|
||||
@@ -292,7 +302,7 @@ class TestWorkflowTraceWithMessageId:
|
||||
instance = self._run(trace_info)
|
||||
|
||||
expected_root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id)
|
||||
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
|
||||
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
|
||||
assert root_span_kwargs["id"] == expected_root_span_id
|
||||
|
||||
def test_root_span_id_differs_from_no_message_id_case(self):
|
||||
@@ -326,5 +336,5 @@ class TestWorkflowTraceWithMessageId:
|
||||
|
||||
instance = self._run(trace_info, node_executions=[node_exec])
|
||||
|
||||
node_span_kwargs = instance.add_span.call_args_list[1][0][0]
|
||||
node_span_kwargs = _add_span_mock(instance).call_args_list[1][0][0]
|
||||
assert node_span_kwargs["parent_span_id"] == expected_root_span_id
|
||||
|
||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, TypedDict, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@@ -12,7 +13,7 @@ from dify_trace_tencent import client as client_module
|
||||
from dify_trace_tencent.client import TencentTraceClient, _get_opentelemetry_sdk_version
|
||||
from dify_trace_tencent.entities.tencent_trace_entity import SpanData
|
||||
from opentelemetry.sdk.trace import Event
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
from opentelemetry.trace import SpanContext, Status, StatusCode, TraceFlags
|
||||
|
||||
metric_reader_instances: list[DummyMetricReader] = []
|
||||
meter_provider_instances: list[DummyMeterProvider] = []
|
||||
@@ -80,6 +81,16 @@ class DummyJsonMetricExporterNoTemporality:
|
||||
self.kwargs = kwargs
|
||||
|
||||
|
||||
class PatchedCoreComponents(TypedDict):
|
||||
span_exporter: MagicMock
|
||||
span_processor: MagicMock
|
||||
tracer: MagicMock
|
||||
span: MagicMock
|
||||
tracer_provider: MagicMock
|
||||
logger: MagicMock
|
||||
trace_api: Any
|
||||
|
||||
|
||||
def _add_stub_modules(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Drop fake metric modules into sys.modules so the client imports resolve."""
|
||||
|
||||
@@ -118,7 +129,7 @@ def stub_metric_modules(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]:
|
||||
def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> PatchedCoreComponents:
|
||||
span_exporter = MagicMock(name="span_exporter")
|
||||
monkeypatch.setattr(client_module, "OTLPSpanExporter", MagicMock(return_value=span_exporter))
|
||||
|
||||
@@ -168,6 +179,15 @@ def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]:
|
||||
}
|
||||
|
||||
|
||||
def _make_span_context(trace_id: int = 1, span_id: int = 2) -> SpanContext:
|
||||
return SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
)
|
||||
|
||||
|
||||
def _build_client() -> TencentTraceClient:
|
||||
return TencentTraceClient(
|
||||
service_name="service",
|
||||
@@ -208,7 +228,7 @@ def test_resolve_grpc_target_parsable_variants(endpoint: str, expected: tuple[st
|
||||
|
||||
|
||||
def test_resolve_grpc_target_handles_errors() -> None:
|
||||
assert TencentTraceClient._resolve_grpc_target(123) == ("localhost:4317", True, "localhost", 4317)
|
||||
assert TencentTraceClient._resolve_grpc_target(cast(str, 123)) == ("localhost:4317", True, "localhost", 4317)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -248,7 +268,7 @@ def test_record_methods_skip_when_histogram_missing() -> None:
|
||||
client.record_trace_duration(0.5)
|
||||
|
||||
|
||||
def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str, object]) -> None:
|
||||
def test_record_llm_duration_handles_exceptions(patch_core_components: PatchedCoreComponents) -> None:
|
||||
client = _build_client()
|
||||
client.hist_llm_duration = MagicMock(name="hist_llm_duration")
|
||||
client.hist_llm_duration.record.side_effect = RuntimeError("boom")
|
||||
@@ -258,10 +278,11 @@ def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str,
|
||||
logger.debug.assert_called()
|
||||
|
||||
|
||||
def test_create_and_export_span_sets_attributes(patch_core_components: dict[str, object]) -> None:
|
||||
def test_create_and_export_span_sets_attributes(patch_core_components: PatchedCoreComponents) -> None:
|
||||
client = _build_client()
|
||||
span = patch_core_components["span"]
|
||||
span.get_span_context.return_value = "ctx"
|
||||
ctx = _make_span_context(span_id=2)
|
||||
span.get_span_context.return_value = ctx
|
||||
|
||||
data = SpanData(
|
||||
trace_id=1,
|
||||
@@ -280,14 +301,15 @@ def test_create_and_export_span_sets_attributes(patch_core_components: dict[str,
|
||||
span.add_event.assert_called_once()
|
||||
span.set_status.assert_called_once()
|
||||
span.end.assert_called_once_with(end_time=20)
|
||||
assert client.span_contexts[2] == "ctx"
|
||||
assert client.span_contexts[2] == ctx
|
||||
|
||||
|
||||
def test_create_and_export_span_uses_parent_context(patch_core_components: dict[str, object]) -> None:
|
||||
def test_create_and_export_span_uses_parent_context(patch_core_components: PatchedCoreComponents) -> None:
|
||||
client = _build_client()
|
||||
client.span_contexts[10] = "existing"
|
||||
existing_context = _make_span_context(span_id=10)
|
||||
client.span_contexts[10] = existing_context
|
||||
span = patch_core_components["span"]
|
||||
span.get_span_context.return_value = "child"
|
||||
span.get_span_context.return_value = _make_span_context(span_id=11)
|
||||
|
||||
data = SpanData(
|
||||
trace_id=1,
|
||||
@@ -302,14 +324,14 @@ def test_create_and_export_span_uses_parent_context(patch_core_components: dict[
|
||||
|
||||
client._create_and_export_span(data)
|
||||
trace_api = patch_core_components["trace_api"]
|
||||
trace_api.NonRecordingSpan.assert_called_once_with("existing")
|
||||
trace_api.NonRecordingSpan.assert_called_once_with(existing_context)
|
||||
trace_api.set_span_in_context.assert_called_once()
|
||||
|
||||
|
||||
def test_create_and_export_span_exception_logs_error(patch_core_components: dict[str, object]) -> None:
|
||||
def test_create_and_export_span_exception_logs_error(patch_core_components: PatchedCoreComponents) -> None:
|
||||
client = _build_client()
|
||||
span = patch_core_components["span"]
|
||||
span.get_span_context.return_value = "ctx"
|
||||
span.get_span_context.return_value = _make_span_context(span_id=2)
|
||||
client.tracer.start_span.side_effect = RuntimeError("boom")
|
||||
|
||||
client._create_and_export_span(
|
||||
@@ -385,7 +407,7 @@ def test_get_project_url() -> None:
|
||||
assert client.get_project_url() == "https://console.cloud.tencent.com/apm"
|
||||
|
||||
|
||||
def test_shutdown_flushes_all_components(patch_core_components: dict[str, object]) -> None:
|
||||
def test_shutdown_flushes_all_components(patch_core_components: PatchedCoreComponents) -> None:
|
||||
client = _build_client()
|
||||
span_processor = patch_core_components["span_processor"]
|
||||
tracer_provider = patch_core_components["tracer_provider"]
|
||||
@@ -401,10 +423,11 @@ def test_shutdown_flushes_all_components(patch_core_components: dict[str, object
|
||||
metric_reader.shutdown.assert_called_once()
|
||||
|
||||
|
||||
def test_shutdown_logs_when_meter_provider_fails(patch_core_components: dict[str, object]) -> None:
|
||||
def test_shutdown_logs_when_meter_provider_fails(patch_core_components: PatchedCoreComponents) -> None:
|
||||
client = _build_client()
|
||||
meter_provider = meter_provider_instances[-1]
|
||||
meter_provider.shutdown.side_effect = RuntimeError("boom")
|
||||
assert client.metric_reader is not None
|
||||
client.metric_reader.shutdown.side_effect = RuntimeError("boom")
|
||||
|
||||
client.shutdown()
|
||||
@@ -433,7 +456,7 @@ def test_metrics_initialization_failure_sets_histogram_attributes(monkeypatch: p
|
||||
assert client.metric_reader is None
|
||||
|
||||
|
||||
def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: dict[str, object]) -> None:
|
||||
def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: PatchedCoreComponents) -> None:
|
||||
client = _build_client()
|
||||
monkeypatch.setattr(client, "_create_and_export_span", MagicMock(side_effect=RuntimeError("boom")))
|
||||
|
||||
@@ -454,10 +477,10 @@ def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_com
|
||||
logger.exception.assert_called_once()
|
||||
|
||||
|
||||
def test_create_and_export_span_converts_attribute_types(patch_core_components: dict[str, object]) -> None:
|
||||
def test_create_and_export_span_converts_attribute_types(patch_core_components: PatchedCoreComponents) -> None:
|
||||
client = _build_client()
|
||||
span = patch_core_components["span"]
|
||||
span.get_span_context.return_value = "ctx"
|
||||
span.get_span_context.return_value = _make_span_context(span_id=2)
|
||||
|
||||
data = SpanData.model_construct(
|
||||
trace_id=1,
|
||||
@@ -485,7 +508,7 @@ def test_record_llm_duration_converts_attributes() -> None:
|
||||
hist_mock = MagicMock(name="hist_llm_duration")
|
||||
client.hist_llm_duration = hist_mock
|
||||
|
||||
client.record_llm_duration(0.3, {"foo": object(), "bar": 2})
|
||||
client.record_llm_duration(0.3, cast(dict[str, str], {"foo": object(), "bar": 2}))
|
||||
_, attrs = hist_mock.record.call_args.args
|
||||
assert isinstance(attrs["foo"], str)
|
||||
assert attrs["bar"] == 2
|
||||
@@ -496,7 +519,7 @@ def test_record_trace_duration_converts_attributes() -> None:
|
||||
hist_mock = MagicMock(name="hist_trace_duration")
|
||||
client.hist_trace_duration = hist_mock
|
||||
|
||||
client.record_trace_duration(1.0, {"meta": object(), "ok": True})
|
||||
client.record_trace_duration(1.0, cast(dict[str, str], {"meta": object(), "ok": True}))
|
||||
_, attrs = hist_mock.record.call_args.args
|
||||
assert isinstance(attrs["meta"], str)
|
||||
assert attrs["ok"] is True
|
||||
@@ -512,7 +535,7 @@ def test_record_trace_duration_converts_attributes() -> None:
|
||||
],
|
||||
)
|
||||
def test_record_methods_handle_exceptions(
|
||||
method: str, attr_name: str, args: tuple[object, ...], patch_core_components: dict[str, object]
|
||||
method: str, attr_name: str, args: tuple[object, ...], patch_core_components: PatchedCoreComponents
|
||||
) -> None:
|
||||
client = _build_client()
|
||||
hist_mock = MagicMock(name=attr_name)
|
||||
@@ -527,35 +550,38 @@ def test_record_methods_handle_exceptions(
|
||||
def test_metrics_initializes_grpc_metric_exporter() -> None:
|
||||
client = _build_client()
|
||||
metric_reader = metric_reader_instances[-1]
|
||||
exporter = cast(DummyGrpcMetricExporter, metric_reader.exporter)
|
||||
|
||||
assert isinstance(metric_reader.exporter, DummyGrpcMetricExporter)
|
||||
assert isinstance(exporter, DummyGrpcMetricExporter)
|
||||
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
|
||||
assert metric_reader.exporter.kwargs["endpoint"] == "trace.example.com:4317"
|
||||
assert metric_reader.exporter.kwargs["insecure"] is False
|
||||
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
|
||||
assert exporter.kwargs["endpoint"] == "trace.example.com:4317"
|
||||
assert exporter.kwargs["insecure"] is False
|
||||
assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token"
|
||||
|
||||
|
||||
def test_metrics_initializes_http_protobuf_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/protobuf")
|
||||
client = _build_client()
|
||||
metric_reader = metric_reader_instances[-1]
|
||||
exporter = cast(DummyHttpMetricExporter, metric_reader.exporter)
|
||||
|
||||
assert isinstance(metric_reader.exporter, DummyHttpMetricExporter)
|
||||
assert isinstance(exporter, DummyHttpMetricExporter)
|
||||
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
|
||||
assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint
|
||||
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
|
||||
assert exporter.kwargs["endpoint"] == client.endpoint
|
||||
assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token"
|
||||
|
||||
|
||||
def test_metrics_initializes_http_json_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json")
|
||||
client = _build_client()
|
||||
metric_reader = metric_reader_instances[-1]
|
||||
exporter = cast(DummyJsonMetricExporter, metric_reader.exporter)
|
||||
|
||||
assert isinstance(metric_reader.exporter, DummyJsonMetricExporter)
|
||||
assert isinstance(exporter, DummyJsonMetricExporter)
|
||||
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
|
||||
assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint
|
||||
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
|
||||
assert "preferred_temporality" in metric_reader.exporter.kwargs
|
||||
assert exporter.kwargs["endpoint"] == client.endpoint
|
||||
assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token"
|
||||
assert "preferred_temporality" in exporter.kwargs
|
||||
|
||||
|
||||
def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@@ -564,9 +590,10 @@ def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkey
|
||||
monkeypatch.setattr(exporter_module, "OTLPMetricExporter", DummyJsonMetricExporterNoTemporality)
|
||||
_ = _build_client()
|
||||
metric_reader = metric_reader_instances[-1]
|
||||
exporter = cast(DummyJsonMetricExporterNoTemporality, metric_reader.exporter)
|
||||
|
||||
assert isinstance(metric_reader.exporter, DummyJsonMetricExporterNoTemporality)
|
||||
assert "preferred_temporality" not in metric_reader.exporter.kwargs
|
||||
assert isinstance(exporter, DummyJsonMetricExporterNoTemporality)
|
||||
assert "preferred_temporality" not in exporter.kwargs
|
||||
|
||||
|
||||
def test_metrics_http_json_uses_http_fallback_when_no_json_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
@@ -31,13 +31,13 @@ class TestWeaveConfig:
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
WeaveConfig()
|
||||
WeaveConfig.model_validate({})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
WeaveConfig(api_key="key")
|
||||
WeaveConfig.model_validate({"api_key": "key"})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
WeaveConfig(project="project")
|
||||
WeaveConfig.model_validate({"project": "project"})
|
||||
|
||||
def test_endpoint_validation_https_only(self):
|
||||
"""Test endpoint validation only allows HTTPS"""
|
||||
|
||||
@@ -171,35 +171,13 @@ class TestChatMessageApiPermissions:
|
||||
parent_message_id=None,
|
||||
)
|
||||
|
||||
class MockQuery:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
|
||||
def where(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
if getattr(self.model, "__name__", "") == "Conversation":
|
||||
return mock_conversation
|
||||
return None
|
||||
|
||||
def order_by(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def limit(self, *_):
|
||||
return self
|
||||
|
||||
def all(self):
|
||||
if getattr(self.model, "__name__", "") == "Message":
|
||||
return [mock_message]
|
||||
return []
|
||||
|
||||
mock_session = mock.Mock()
|
||||
mock_session.query.side_effect = MockQuery
|
||||
mock_session.scalar.return_value = False
|
||||
mock_session.scalar.return_value = mock_conversation
|
||||
mock_session.scalars.return_value.all.return_value = [mock_message]
|
||||
|
||||
monkeypatch.setattr(message_api, "db", SimpleNamespace(session=mock_session))
|
||||
monkeypatch.setattr(message_api, "current_user", mock_account)
|
||||
monkeypatch.setattr(message_api, "attach_message_extra_contents", mock.Mock())
|
||||
|
||||
class DummyPagination:
|
||||
def __init__(self, data, limit, has_more):
|
||||
|
||||
@@ -24,7 +24,6 @@ def _patch_wraps():
|
||||
patch("controllers.console.wraps.dify_config", dify_settings),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
yield
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,12 @@ from models.model import App, Conversation, Message
|
||||
from services.feedback_service import FeedbackService
|
||||
|
||||
|
||||
def _execute_result(rows):
|
||||
result = mock.Mock()
|
||||
result.all.return_value = rows
|
||||
return result
|
||||
|
||||
|
||||
class TestFeedbackService:
|
||||
"""Test FeedbackService methods."""
|
||||
|
||||
@@ -81,25 +87,17 @@ class TestFeedbackService:
|
||||
|
||||
def test_export_feedbacks_csv_format(self, mock_db_session, sample_data):
|
||||
"""Test exporting feedback data in CSV format."""
|
||||
|
||||
# Setup mock query result
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [
|
||||
(
|
||||
sample_data["user_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["user_feedback"].from_account,
|
||||
)
|
||||
]
|
||||
|
||||
mock_db_session.execute.return_value = mock_query
|
||||
mock_db_session.execute.return_value = _execute_result(
|
||||
[
|
||||
(
|
||||
sample_data["user_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["user_feedback"].from_account,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Test CSV export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
|
||||
@@ -120,25 +118,17 @@ class TestFeedbackService:
|
||||
|
||||
def test_export_feedbacks_json_format(self, mock_db_session, sample_data):
|
||||
"""Test exporting feedback data in JSON format."""
|
||||
|
||||
# Setup mock query result
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [
|
||||
(
|
||||
sample_data["admin_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["admin_feedback"].from_account,
|
||||
)
|
||||
]
|
||||
|
||||
mock_db_session.execute.return_value = mock_query
|
||||
mock_db_session.execute.return_value = _execute_result(
|
||||
[
|
||||
(
|
||||
sample_data["admin_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["admin_feedback"].from_account,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Test JSON export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
|
||||
@@ -157,25 +147,17 @@ class TestFeedbackService:
|
||||
|
||||
def test_export_feedbacks_with_filters(self, mock_db_session, sample_data):
|
||||
"""Test exporting feedback with various filters."""
|
||||
|
||||
# Setup mock query result
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [
|
||||
(
|
||||
sample_data["admin_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["admin_feedback"].from_account,
|
||||
)
|
||||
]
|
||||
|
||||
mock_db_session.execute.return_value = mock_query
|
||||
mock_db_session.execute.return_value = _execute_result(
|
||||
[
|
||||
(
|
||||
sample_data["admin_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["admin_feedback"].from_account,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Test with filters
|
||||
result = FeedbackService.export_feedbacks(
|
||||
@@ -193,17 +175,7 @@ class TestFeedbackService:
|
||||
|
||||
def test_export_feedbacks_no_data(self, mock_db_session, sample_data):
|
||||
"""Test exporting feedback when no data exists."""
|
||||
|
||||
# Setup mock query result with no data
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = []
|
||||
|
||||
mock_db_session.execute.return_value = mock_query
|
||||
mock_db_session.execute.return_value = _execute_result([])
|
||||
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
|
||||
|
||||
@@ -251,24 +223,17 @@ class TestFeedbackService:
|
||||
created_at=datetime(2024, 1, 1, 10, 0, 0),
|
||||
)
|
||||
|
||||
# Setup mock query result
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [
|
||||
(
|
||||
sample_data["user_feedback"],
|
||||
long_message,
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["user_feedback"].from_account,
|
||||
)
|
||||
]
|
||||
|
||||
mock_db_session.execute.return_value = mock_query
|
||||
mock_db_session.execute.return_value = _execute_result(
|
||||
[
|
||||
(
|
||||
sample_data["user_feedback"],
|
||||
long_message,
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["user_feedback"].from_account,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Test export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
|
||||
@@ -309,24 +274,17 @@ class TestFeedbackService:
|
||||
created_at=datetime(2024, 1, 1, 10, 0, 0),
|
||||
)
|
||||
|
||||
# Setup mock query result
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [
|
||||
(
|
||||
chinese_feedback,
|
||||
chinese_message,
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
None, # No account for user feedback
|
||||
)
|
||||
]
|
||||
|
||||
mock_db_session.execute.return_value = mock_query
|
||||
mock_db_session.execute.return_value = _execute_result(
|
||||
[
|
||||
(
|
||||
chinese_feedback,
|
||||
chinese_message,
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
None,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Test export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
|
||||
@@ -339,32 +297,24 @@ class TestFeedbackService:
|
||||
|
||||
def test_export_feedbacks_emoji_ratings(self, mock_db_session, sample_data):
|
||||
"""Test that rating emojis are properly formatted in export."""
|
||||
|
||||
# Setup mock query result with both like and dislike feedback
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [
|
||||
(
|
||||
sample_data["user_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["user_feedback"].from_account,
|
||||
),
|
||||
(
|
||||
sample_data["admin_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["admin_feedback"].from_account,
|
||||
),
|
||||
]
|
||||
|
||||
mock_db_session.execute.return_value = mock_query
|
||||
mock_db_session.execute.return_value = _execute_result(
|
||||
[
|
||||
(
|
||||
sample_data["user_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["user_feedback"].from_account,
|
||||
),
|
||||
(
|
||||
sample_data["admin_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["admin_feedback"].from_account,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# Test export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
|
||||
|
||||
@@ -121,33 +121,32 @@ def _configure_session_factory(_unit_test_engine):
|
||||
configure_session_factory(_unit_test_engine, expire_on_commit=False)
|
||||
|
||||
|
||||
def setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account):
|
||||
def setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_owner):
|
||||
"""
|
||||
Helper to set up the mock DB execute chain for tenant/account authentication.
|
||||
Helper to stub the tenant-owner execute result for service API app authentication.
|
||||
|
||||
This configures the mock to return (tenant, account) for the
|
||||
db.session.execute(select(...).join().join().where()).one_or_none()
|
||||
query used by validate_app_token decorator.
|
||||
The validate_app_token decorator currently resolves the active tenant owner
|
||||
via db.session.execute(select(Tenant, Account)...).one_or_none().
|
||||
|
||||
Args:
|
||||
mock_db: The mocked db object
|
||||
mock_tenant: Mock tenant object to return
|
||||
mock_account: Mock account object to return
|
||||
mock_owner: Mock owner object to return from the execute result
|
||||
"""
|
||||
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_account)
|
||||
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_owner)
|
||||
|
||||
|
||||
def setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta):
|
||||
def setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_tenant_account_join):
|
||||
"""
|
||||
Helper to set up the mock DB execute chain for dataset tenant authentication.
|
||||
Helper to stub the tenant-owner execute result for dataset token authentication.
|
||||
|
||||
This configures the mock to return (tenant, tenant_account) for the
|
||||
db.session.execute(select(...).where().where().where().where()).one_or_none()
|
||||
query used by validate_dataset_token decorator.
|
||||
The validate_dataset_token decorator currently resolves the owner mapping via
|
||||
db.session.execute(select(Tenant, TenantAccountJoin)...).one_or_none(), and
|
||||
then loads the Account separately via db.session.get(...).
|
||||
|
||||
Args:
|
||||
mock_db: The mocked db object
|
||||
mock_tenant: Mock tenant object to return
|
||||
mock_ta: Mock tenant account object to return
|
||||
mock_tenant_account_join: Mock tenant-account join object to return
|
||||
"""
|
||||
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta)
|
||||
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_tenant_account_join)
|
||||
|
||||
@@ -208,8 +208,6 @@ class TestAnnotationImportServiceValidation:
|
||||
|
||||
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
|
||||
|
||||
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
|
||||
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
|
||||
|
||||
@@ -230,8 +228,6 @@ class TestAnnotationImportServiceValidation:
|
||||
|
||||
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
|
||||
|
||||
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
|
||||
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
|
||||
|
||||
@@ -248,8 +244,6 @@ class TestAnnotationImportServiceValidation:
|
||||
csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff'
|
||||
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
|
||||
|
||||
with (
|
||||
patch("services.annotation_service.current_account_with_tenant") as mock_auth,
|
||||
patch("services.annotation_service.pd.read_csv", side_effect=ParserError("malformed CSV")),
|
||||
@@ -269,8 +263,6 @@ class TestAnnotationImportServiceValidation:
|
||||
|
||||
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
|
||||
|
||||
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
|
||||
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
|
||||
|
||||
|
||||
@@ -43,7 +43,6 @@ class TestAuthenticationSecurity:
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.")
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
mock_features.return_value.is_allow_register = True
|
||||
|
||||
# Act
|
||||
@@ -76,7 +75,6 @@ class TestAuthenticationSecurity:
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Wrong password")
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
|
||||
# Act
|
||||
with self.app.test_request_context(
|
||||
@@ -109,7 +107,6 @@ class TestAuthenticationSecurity:
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.")
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
mock_features.return_value.is_allow_register = False
|
||||
|
||||
# Act
|
||||
@@ -135,7 +132,6 @@ class TestAuthenticationSecurity:
|
||||
def test_reset_password_with_existing_account(self, mock_send_email, mock_get_user, mock_features, mock_db):
|
||||
"""Test that reset password returns success with token for existing accounts."""
|
||||
# Mock the setup check
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
|
||||
# Test with existing account
|
||||
mock_get_user.return_value = MagicMock(email="existing@example.com")
|
||||
|
||||
@@ -65,7 +65,6 @@ class TestEmailCodeLoginSendEmailApi:
|
||||
- IP rate limiting is checked
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_send_email.return_value = "email_token_123"
|
||||
@@ -98,7 +97,6 @@ class TestEmailCodeLoginSendEmailApi:
|
||||
- Registration is allowed by system features
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_user.return_value = None
|
||||
mock_get_features.return_value.is_allow_register = True
|
||||
@@ -130,7 +128,6 @@ class TestEmailCodeLoginSendEmailApi:
|
||||
- Registration is blocked by system features
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_user.return_value = None
|
||||
mock_get_features.return_value.is_allow_register = False
|
||||
@@ -152,7 +149,6 @@ class TestEmailCodeLoginSendEmailApi:
|
||||
- Prevents spam and abuse
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = True
|
||||
|
||||
# Act & Assert
|
||||
@@ -172,7 +168,6 @@ class TestEmailCodeLoginSendEmailApi:
|
||||
- AccountInFreezeError is raised for frozen accounts
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_user.side_effect = AccountRegisterError("Account frozen")
|
||||
|
||||
@@ -213,7 +208,6 @@ class TestEmailCodeLoginSendEmailApi:
|
||||
- Defaults to en-US when not specified
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_send_email.return_value = "token"
|
||||
@@ -286,7 +280,6 @@ class TestEmailCodeLoginApi:
|
||||
- User is logged in with token pair
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_get_tenants.return_value = [MagicMock()]
|
||||
@@ -335,7 +328,6 @@ class TestEmailCodeLoginApi:
|
||||
- User is logged in after account creation
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "newuser@example.com", "code": "123456"}
|
||||
mock_get_user.return_value = None
|
||||
mock_create_account.return_value = mock_account
|
||||
@@ -369,7 +361,6 @@ class TestEmailCodeLoginApi:
|
||||
- InvalidTokenError is raised for invalid/expired tokens
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
@@ -392,7 +383,6 @@ class TestEmailCodeLoginApi:
|
||||
- InvalidEmailError is raised when email doesn't match token
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "original@example.com", "code": "123456"}
|
||||
|
||||
# Act & Assert
|
||||
@@ -415,7 +405,6 @@ class TestEmailCodeLoginApi:
|
||||
- EmailCodeError is raised for wrong verification code
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
|
||||
# Act & Assert
|
||||
@@ -453,7 +442,6 @@ class TestEmailCodeLoginApi:
|
||||
- User is added as owner of new workspace
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_get_tenants.return_value = []
|
||||
@@ -496,7 +484,6 @@ class TestEmailCodeLoginApi:
|
||||
- WorkspacesLimitExceeded is raised when limit reached
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_get_tenants.return_value = []
|
||||
@@ -538,7 +525,6 @@ class TestEmailCodeLoginApi:
|
||||
- NotAllowedCreateWorkspace is raised when creation disabled
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_get_tenants.return_value = []
|
||||
|
||||
@@ -110,7 +110,6 @@ class TestLoginApi:
|
||||
- Rate limit is reset after successful login
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.return_value = mock_account
|
||||
@@ -162,7 +161,6 @@ class TestLoginApi:
|
||||
- Authentication proceeds with invitation token
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = {"data": {"email": "test@example.com"}}
|
||||
mock_authenticate.return_value = mock_account
|
||||
@@ -199,7 +197,6 @@ class TestLoginApi:
|
||||
- EmailPasswordLoginLimitError is raised when limit exceeded
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = True
|
||||
mock_get_invitation.return_value = None
|
||||
|
||||
@@ -228,7 +225,6 @@ class TestLoginApi:
|
||||
- AccountInFreezeError is raised for frozen accounts
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_frozen.return_value = True
|
||||
|
||||
# Act & Assert
|
||||
@@ -268,7 +264,6 @@ class TestLoginApi:
|
||||
- Generic error message prevents user enumeration
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = AccountPasswordError("Invalid password")
|
||||
@@ -305,7 +300,6 @@ class TestLoginApi:
|
||||
- Login is prevented even with valid credentials
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = AccountLoginError("Account is banned")
|
||||
@@ -351,7 +345,6 @@ class TestLoginApi:
|
||||
- User cannot login without an assigned workspace
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.return_value = mock_account
|
||||
@@ -383,7 +376,6 @@ class TestLoginApi:
|
||||
- Security check prevents invitation token abuse
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = {"data": {"email": "invited@example.com"}}
|
||||
|
||||
@@ -425,7 +417,6 @@ class TestLoginApi:
|
||||
mock_token_pair,
|
||||
):
|
||||
"""Test that login retries with lowercase email when uppercase lookup fails."""
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = [AccountPasswordError("Invalid"), mock_account]
|
||||
@@ -459,7 +450,6 @@ class TestLoginApi:
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"}
|
||||
mock_get_account.side_effect = Unauthorized("Account is banned.")
|
||||
|
||||
@@ -513,7 +503,6 @@ class TestLogoutApi:
|
||||
- Success response is returned
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_current_account.return_value = (mock_account, MagicMock())
|
||||
|
||||
# Act
|
||||
@@ -539,7 +528,6 @@ class TestLogoutApi:
|
||||
- Success response is returned
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
# Create a mock anonymous user that will pass isinstance check
|
||||
anonymous_user = MagicMock()
|
||||
mock_flask_login.AnonymousUserMixin = type("AnonymousUserMixin", (), {})
|
||||
|
||||
@@ -46,7 +46,6 @@ class TestPartnerTenants:
|
||||
patch("libs.login.dify_config.LOGIN_DISABLED", False),
|
||||
patch("libs.login.check_csrf_token") as mock_csrf,
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
mock_csrf.return_value = None
|
||||
yield {"db": mock_db, "csrf": mock_csrf}
|
||||
|
||||
|
||||
@@ -24,10 +24,6 @@ def app():
|
||||
return app
|
||||
|
||||
|
||||
def _mock_wraps_db(mock_db):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
|
||||
|
||||
def _build_account(email: str, account_id: str = "acc", tenant: object | None = None) -> Account:
|
||||
tenant_obj = tenant if tenant is not None else SimpleNamespace(id="tenant-id")
|
||||
account = Account(name=account_id, email=email)
|
||||
@@ -64,7 +60,6 @@ class TestChangeEmailSend:
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("current@example.com", "acc1")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
@@ -117,7 +112,6 @@ class TestChangeEmailSend:
|
||||
"""GHSA-4q3w-q5mc-45rq: a phase-1 token must not unlock the new-email send step."""
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("current@example.com", "acc1")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
@@ -163,7 +157,6 @@ class TestChangeEmailValidity:
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("user@example.com", "acc2")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
@@ -223,7 +216,6 @@ class TestChangeEmailValidity:
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
|
||||
mock_is_rate_limit.return_value = False
|
||||
@@ -280,7 +272,6 @@ class TestChangeEmailValidity:
|
||||
"""A token whose phase marker is a string but not a known transition must be rejected."""
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
|
||||
mock_is_rate_limit.return_value = False
|
||||
@@ -330,7 +321,6 @@ class TestChangeEmailValidity:
|
||||
"""A token minted without a phase marker (e.g. a hand-crafted token) must not validate."""
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
|
||||
mock_is_rate_limit.return_value = False
|
||||
@@ -378,7 +368,6 @@ class TestChangeEmailReset:
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
current_user = _build_account("old@example.com", "acc3")
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
@@ -434,7 +423,6 @@ class TestChangeEmailReset:
|
||||
"""GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset."""
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
current_user = _build_account("old@example.com", "acc3")
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
@@ -488,7 +476,6 @@ class TestChangeEmailReset:
|
||||
"""A verified token for address A must not be replayed to change to address B."""
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
current_user = _build_account("old@example.com", "acc3")
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
@@ -561,7 +548,6 @@ class TestAccountDeletionFeedback:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.BillingService.update_account_deletion_feedback")
|
||||
def test_should_normalize_feedback_email(self, mock_update, mock_db, app):
|
||||
_mock_wraps_db(mock_db)
|
||||
with app.test_request_context(
|
||||
"/account/delete/feedback",
|
||||
method="POST",
|
||||
@@ -578,7 +564,6 @@ class TestCheckEmailUnique:
|
||||
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
|
||||
def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
@@ -16,10 +16,6 @@ def app():
|
||||
return flask_app
|
||||
|
||||
|
||||
def _mock_wraps_db(mock_db):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
|
||||
|
||||
def _build_feature_flags():
|
||||
placeholder_quota = SimpleNamespace(limit=0, size=0)
|
||||
workspace_members = SimpleNamespace(is_available=lambda count: True)
|
||||
@@ -49,7 +45,6 @@ class TestMemberInviteEmailApi:
|
||||
mock_get_features,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_get_features.return_value = _build_feature_flags()
|
||||
mock_invite_member.return_value = "token-abc"
|
||||
|
||||
|
||||
@@ -310,7 +310,6 @@ class TestSystemSetup:
|
||||
def test_should_allow_when_setup_complete(self, mock_db):
|
||||
"""Test that requests are allowed when setup is complete"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Setup exists
|
||||
|
||||
@setup_required
|
||||
def admin_view():
|
||||
|
||||
@@ -22,7 +22,7 @@ _WRAPS_MODULE: ModuleType | None = None
|
||||
|
||||
@contextmanager
|
||||
def _mock_db():
|
||||
mock_session = SimpleNamespace(query=lambda *args, **kwargs: SimpleNamespace(first=lambda: True))
|
||||
mock_session = SimpleNamespace(scalar=lambda *args, **kwargs: True)
|
||||
with patch("extensions.ext_database.db.session", mock_session):
|
||||
yield
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from controllers.service_api.app.app import AppInfoApi, AppMetaApi, AppParameter
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from models.account import TenantStatus
|
||||
from models.model import App, AppMode
|
||||
from tests.unit_tests.conftest import setup_mock_tenant_account_query
|
||||
from tests.unit_tests.conftest import setup_mock_tenant_owner_execute_result
|
||||
|
||||
|
||||
class TestAppParameterApi:
|
||||
@@ -74,7 +74,7 @@ class TestAppParameterApi:
|
||||
# Mock tenant owner info for login
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@@ -120,7 +120,7 @@ class TestAppParameterApi:
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@@ -161,7 +161,7 @@ class TestAppParameterApi:
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@@ -200,7 +200,7 @@ class TestAppParameterApi:
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@@ -263,7 +263,7 @@ class TestAppMetaApi:
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/meta", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@@ -331,7 +331,7 @@ class TestAppInfoApi:
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@@ -388,7 +388,7 @@ class TestAppInfoApi:
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@@ -434,7 +434,7 @@ class TestAppInfoApi:
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@@ -486,7 +486,7 @@ class TestAppInfoApi:
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
|
||||
@@ -15,7 +15,10 @@ from flask import Flask
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models.account import TenantStatus
|
||||
from models.model import App, AppMode, EndUser
|
||||
from tests.unit_tests.conftest import setup_mock_tenant_account_query
|
||||
from tests.unit_tests.conftest import (
|
||||
setup_mock_dataset_owner_execute_result,
|
||||
setup_mock_tenant_owner_execute_result,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -123,9 +126,7 @@ class AuthenticationMocker:
|
||||
mock_db.session.get.side_effect = [mock_app, mock_tenant]
|
||||
|
||||
if mock_account:
|
||||
mock_ta = Mock()
|
||||
mock_ta.account_id = mock_account.id
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
@staticmethod
|
||||
def setup_dataset_auth(mock_db, mock_tenant, mock_account):
|
||||
@@ -133,8 +134,7 @@ class AuthenticationMocker:
|
||||
mock_ta = Mock()
|
||||
mock_ta.account_id = mock_account.id
|
||||
|
||||
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta)
|
||||
|
||||
setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_ta)
|
||||
mock_db.session.get.return_value = mock_account
|
||||
|
||||
|
||||
|
||||
@@ -701,8 +701,8 @@ class TestDocumentApiDelete:
|
||||
``delete`` is wrapped by ``@cloud_edition_billing_rate_limit_check`` which
|
||||
internally calls ``validate_and_get_api_token``. To bypass the decorator
|
||||
we call the original function via ``__wrapped__`` (preserved by
|
||||
``functools.wraps``). ``delete`` queries the dataset via
|
||||
``db.session.query(Dataset)`` directly, so we patch ``db`` at the
|
||||
``functools.wraps``). ``delete`` loads the dataset via
|
||||
``db.session.scalar(select(Dataset)...)``, so we patch ``db`` at the
|
||||
controller module.
|
||||
"""
|
||||
|
||||
|
||||
@@ -24,8 +24,8 @@ from enums.cloud_plan import CloudPlan
|
||||
from models.account import TenantStatus
|
||||
from models.model import ApiToken
|
||||
from tests.unit_tests.conftest import (
|
||||
setup_mock_dataset_tenant_query,
|
||||
setup_mock_tenant_account_query,
|
||||
setup_mock_dataset_owner_execute_result,
|
||||
setup_mock_tenant_owner_execute_result,
|
||||
)
|
||||
|
||||
|
||||
@@ -141,14 +141,11 @@ class TestValidateAppToken:
|
||||
mock_account = Mock()
|
||||
mock_account.id = str(uuid.uuid4())
|
||||
|
||||
mock_ta = Mock()
|
||||
mock_ta.account_id = mock_account.id
|
||||
|
||||
# Use side_effect to return app first, then tenant via session.get()
|
||||
mock_db.session.get.side_effect = [mock_app, mock_tenant]
|
||||
|
||||
# Mock the tenant owner query (execute(select(...)).one_or_none())
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta)
|
||||
# Mock the tenant owner execute result (execute(select(...)).one_or_none())
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
@validate_app_token
|
||||
def protected_view(app_model):
|
||||
@@ -471,7 +468,7 @@ class TestValidateDatasetToken:
|
||||
mock_account.current_tenant = mock_tenant
|
||||
|
||||
# Mock the tenant account join query (execute(select(...)).one_or_none())
|
||||
setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta)
|
||||
setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_ta)
|
||||
|
||||
# Mock the account lookup via session.get()
|
||||
mock_db.session.get.return_value = mock_account
|
||||
|
||||
@@ -22,18 +22,16 @@ class FakeSession:
|
||||
|
||||
def __init__(self, mapping: dict[str, Any] | None = None):
|
||||
self._mapping: dict[str, Any] = mapping or {}
|
||||
self._model_name: str | None = None
|
||||
|
||||
def query(self, model: type) -> FakeSession:
|
||||
self._model_name = model.__name__
|
||||
return self
|
||||
def get(self, model: type, _ident: object) -> Any:
|
||||
return self._mapping.get(model.__name__)
|
||||
|
||||
def where(self, *_args: object, **_kwargs: object) -> FakeSession:
|
||||
return self
|
||||
|
||||
def first(self) -> Any:
|
||||
assert self._model_name is not None
|
||||
return self._mapping.get(self._model_name)
|
||||
def scalar(self, stmt: Any) -> Any:
|
||||
try:
|
||||
model = stmt.column_descriptions[0]["entity"]
|
||||
except (AttributeError, IndexError, KeyError, TypeError):
|
||||
return None
|
||||
return self._mapping.get(model.__name__)
|
||||
|
||||
|
||||
class FakeDB:
|
||||
|
||||
@@ -36,18 +36,6 @@ class _FakeSession:
|
||||
|
||||
def __init__(self, mapping: dict[str, Any]):
|
||||
self._mapping = mapping
|
||||
self._model_name: str | None = None
|
||||
|
||||
def query(self, model):
|
||||
self._model_name = model.__name__
|
||||
return self
|
||||
|
||||
def where(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
assert self._model_name is not None
|
||||
return self._mapping.get(self._model_name)
|
||||
|
||||
def get(self, model, ident):
|
||||
return self._mapping.get(model.__name__)
|
||||
|
||||
@@ -34,7 +34,6 @@ def _patch_wraps():
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
|
||||
patch("controllers.web.login.dify_config", web_dify),
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
yield
|
||||
|
||||
|
||||
|
||||
@@ -154,7 +154,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_session_class.return_value.__enter__.return_value = MagicMock()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
# Mock GraphRuntimeState to accept the variable pool
|
||||
@@ -301,7 +300,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_session_class.return_value.__enter__.return_value = MagicMock()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
# Mock ConversationVariable.from_variable to return mock objects
|
||||
@@ -453,7 +451,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_session_class.return_value.__enter__.return_value = MagicMock()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
# Mock GraphRuntimeState to accept the variable pool
|
||||
|
||||
@@ -375,7 +375,7 @@ def test_generate_success_returns_converted(generator, mocker):
|
||||
|
||||
workflow = MagicMock(id="wf", tenant_id="tenant", app_id="pipe", graph_dict={})
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = workflow
|
||||
session.get.return_value = workflow
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
queue_manager = MagicMock()
|
||||
|
||||
@@ -132,11 +132,8 @@ def test_run_pipeline_not_found(mocker):
|
||||
app_generate_entity.single_iteration_run = None
|
||||
app_generate_entity.single_loop_run = None
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value.first.return_value = None
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
session.get.side_effect = [None, None]
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
runner = PipelineRunner(
|
||||
@@ -157,11 +154,9 @@ def test_run_workflow_not_initialized(mocker):
|
||||
app_generate_entity = _build_app_generate_entity()
|
||||
|
||||
pipeline = MagicMock(id="pipe")
|
||||
query_pipeline = MagicMock()
|
||||
query_pipeline.where.return_value.first.return_value = pipeline
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query_pipeline
|
||||
session.get.side_effect = [None, pipeline]
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
runner = PipelineRunner(
|
||||
|
||||
@@ -775,9 +775,6 @@ class TestNotionExtractorLastEditedTime:
|
||||
"last_edited_time": "2024-11-27T18:00:00.000Z",
|
||||
}
|
||||
mock_request.return_value = mock_response
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
extractor_page.update_last_edited_time(mock_document_model)
|
||||
@@ -863,9 +860,6 @@ class TestNotionExtractorIntegration:
|
||||
}
|
||||
|
||||
mock_request.side_effect = [last_edited_response, block_response]
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
documents = extractor.extract()
|
||||
@@ -919,10 +913,6 @@ class TestNotionExtractorIntegration:
|
||||
}
|
||||
mock_post.return_value = database_response
|
||||
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
documents = extractor.extract()
|
||||
|
||||
|
||||
@@ -40,11 +40,11 @@ class TestObfuscatedToken:
|
||||
class TestEncryptToken:
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_successful_encryption(self, mock_encrypt, mock_query):
|
||||
def test_successful_encryption(self, mock_encrypt, mock_get):
|
||||
"""Test successful token encryption"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_get.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_data"
|
||||
|
||||
result = encrypt_token("tenant-123", "test_token")
|
||||
@@ -53,9 +53,9 @@ class TestEncryptToken:
|
||||
mock_encrypt.assert_called_with("test_token", "mock_public_key")
|
||||
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
def test_tenant_not_found(self, mock_query):
|
||||
def test_tenant_not_found(self, mock_get):
|
||||
"""Test error when tenant doesn't exist"""
|
||||
mock_query.return_value = None
|
||||
mock_get.return_value = None
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypt_token("invalid-tenant", "test_token")
|
||||
@@ -122,12 +122,12 @@ class TestEncryptDecryptIntegration:
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
@patch("libs.rsa.decrypt")
|
||||
def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_query):
|
||||
def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_get):
|
||||
"""Test that encryption and decryption are consistent"""
|
||||
# Setup mock tenant
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_get.return_value = mock_tenant
|
||||
|
||||
# Setup mock encryption/decryption
|
||||
original_token = "test_token_123"
|
||||
@@ -148,12 +148,12 @@ class TestSecurity:
|
||||
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_cross_tenant_isolation(self, mock_encrypt, mock_query):
|
||||
def test_cross_tenant_isolation(self, mock_encrypt, mock_get):
|
||||
"""Ensure tokens encrypted for one tenant cannot be used by another"""
|
||||
# Setup mock tenant
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "tenant1_public_key"
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_get.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_for_tenant1"
|
||||
|
||||
# Encrypt token for tenant1
|
||||
@@ -183,10 +183,10 @@ class TestSecurity:
|
||||
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_encryption_randomness(self, mock_encrypt, mock_query):
|
||||
def test_encryption_randomness(self, mock_encrypt, mock_get):
|
||||
"""Ensure same plaintext produces different ciphertext"""
|
||||
mock_tenant = MagicMock(encrypt_public_key="key")
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_get.return_value = mock_tenant
|
||||
|
||||
# Different outputs for same input
|
||||
mock_encrypt.side_effect = [b"enc1", b"enc2", b"enc3"]
|
||||
@@ -207,11 +207,11 @@ class TestEdgeCases:
|
||||
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_query):
|
||||
def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_get):
|
||||
"""Test encryption of empty token"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_get.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_empty"
|
||||
|
||||
result = encrypt_token("tenant-123", "")
|
||||
@@ -221,11 +221,11 @@ class TestEdgeCases:
|
||||
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_query):
|
||||
def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_get):
|
||||
"""Test tokens containing special/unicode characters"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_get.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_special"
|
||||
|
||||
# Test various special characters
|
||||
@@ -244,11 +244,11 @@ class TestEdgeCases:
|
||||
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_query):
|
||||
def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_get):
|
||||
"""Test behavior when token exceeds RSA encryption limits"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_get.return_value = mock_tenant
|
||||
|
||||
# RSA 2048-bit can only encrypt ~245 bytes
|
||||
# The actual limit depends on padding scheme
|
||||
|
||||
@@ -491,7 +491,7 @@ class TestLLMGenerator:
|
||||
|
||||
def test_instruction_modify_workflow_no_last_run_fallback(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session") as mock_session:
|
||||
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
|
||||
mock_session.return_value.scalar.return_value = MagicMock()
|
||||
workflow = MagicMock()
|
||||
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "code"}}]}}
|
||||
|
||||
@@ -517,7 +517,7 @@ class TestLLMGenerator:
|
||||
|
||||
def test_instruction_modify_workflow_node_type_fallback(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session") as mock_session:
|
||||
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
|
||||
mock_session.return_value.scalar.return_value = MagicMock()
|
||||
workflow = MagicMock()
|
||||
# Cause exception in node_type logic
|
||||
workflow.graph_dict = {"graph": {"nodes": []}}
|
||||
@@ -544,7 +544,7 @@ class TestLLMGenerator:
|
||||
|
||||
def test_instruction_modify_workflow_empty_agent_log(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session") as mock_session:
|
||||
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
|
||||
mock_session.return_value.scalar.return_value = MagicMock()
|
||||
workflow = MagicMock()
|
||||
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "llm"}}]}}
|
||||
|
||||
@@ -632,7 +632,7 @@ class TestLLMGenerator:
|
||||
instance.invoke_llm.return_value = mock_response
|
||||
|
||||
with patch("extensions.ext_database.db.session") as mock_session:
|
||||
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
|
||||
mock_session.return_value.scalar.return_value = MagicMock()
|
||||
workflow = MagicMock()
|
||||
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "other"}}]}}
|
||||
|
||||
|
||||
@@ -29,15 +29,6 @@ class _Field:
|
||||
return ("in", self._name, tuple(values))
|
||||
|
||||
|
||||
class _FakeQuery:
|
||||
def __init__(self):
|
||||
self.where_calls: list[tuple] = []
|
||||
|
||||
def where(self, *conditions):
|
||||
self.where_calls.append(conditions)
|
||||
return self
|
||||
|
||||
|
||||
class _FakeExecuteResult:
|
||||
def __init__(self, segments: list[SimpleNamespace]):
|
||||
self._segments = segments
|
||||
|
||||
@@ -109,17 +109,6 @@ class _FakeExecuteResult:
|
||||
return _FakeExecuteScalarResult(self._data)
|
||||
|
||||
|
||||
class _FakeSummaryQuery:
|
||||
def __init__(self, summaries: list) -> None:
|
||||
self._summaries = summaries
|
||||
|
||||
def filter(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def all(self) -> list:
|
||||
return self._summaries
|
||||
|
||||
|
||||
class _FakeScalarsResult:
|
||||
def __init__(self, data: list) -> None:
|
||||
self._data = data
|
||||
|
||||
@@ -372,19 +372,11 @@ def test_vector_delegation_methods(vector_factory_module):
|
||||
|
||||
|
||||
def test_search_by_file_handles_missing_and_existing_upload(vector_factory_module, monkeypatch):
|
||||
class _Field:
|
||||
def __eq__(self, value):
|
||||
return value
|
||||
|
||||
upload_query = MagicMock()
|
||||
upload_query.where.return_value = upload_query
|
||||
|
||||
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
|
||||
vector._embeddings = MagicMock()
|
||||
vector._vector_processor = MagicMock()
|
||||
|
||||
mock_session = SimpleNamespace(get=lambda _model, _id: None)
|
||||
monkeypatch.setattr(vector_factory_module, "UploadFile", SimpleNamespace(id=_Field()))
|
||||
monkeypatch.setattr(vector_factory_module, "db", SimpleNamespace(session=mock_session))
|
||||
|
||||
assert vector.search_by_file("file-1") == []
|
||||
|
||||
@@ -1484,11 +1484,8 @@ class TestIndexingRunnerProcessChunk:
|
||||
|
||||
mock_dependencies["redis"].get.return_value = None
|
||||
|
||||
# Mock database query for segment updates
|
||||
mock_query = MagicMock()
|
||||
mock_dependencies["db"].session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.update.return_value = None
|
||||
# Mock database update for segment status
|
||||
mock_dependencies["db"].session.execute.return_value = None
|
||||
|
||||
# Create a proper context manager mock
|
||||
mock_context = MagicMock()
|
||||
|
||||
@@ -2417,12 +2417,11 @@ class TestDatasetRetrievalKnowledgeRetrieval:
|
||||
mock_document.data_source_type = "upload_file"
|
||||
mock_document.doc_metadata = {}
|
||||
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = [
|
||||
mock_dataset_from_db
|
||||
]
|
||||
mock_session.query.return_value.filter.return_value.all.__iter__ = lambda self: iter(
|
||||
[mock_dataset_from_db, mock_document]
|
||||
)
|
||||
mock_datasets = MagicMock()
|
||||
mock_datasets.all.return_value = [mock_dataset_from_db]
|
||||
mock_documents = MagicMock()
|
||||
mock_documents.all.return_value = [mock_document]
|
||||
mock_session.scalars.side_effect = [mock_datasets, mock_documents]
|
||||
|
||||
# Act
|
||||
result = dataset_retrieval.knowledge_retrieval(request)
|
||||
|
||||
@@ -451,12 +451,11 @@ class TestDatasetRetrievalKnowledgeRetrieval:
|
||||
mock_document.data_source_type = "upload_file"
|
||||
mock_document.doc_metadata = {}
|
||||
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = [
|
||||
mock_dataset_from_db
|
||||
]
|
||||
mock_session.query.return_value.filter.return_value.all.__iter__ = lambda self: iter(
|
||||
[mock_dataset_from_db, mock_document]
|
||||
)
|
||||
mock_datasets = MagicMock()
|
||||
mock_datasets.all.return_value = [mock_dataset_from_db]
|
||||
mock_documents = MagicMock()
|
||||
mock_documents.all.return_value = [mock_document]
|
||||
mock_session.scalars.side_effect = [mock_datasets, mock_documents]
|
||||
|
||||
# Act
|
||||
result = dataset_retrieval.knowledge_retrieval(request)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,925 +0,0 @@
|
||||
"""
|
||||
Extensive unit tests for ``ExternalDatasetService``.
|
||||
|
||||
This module focuses on the *external dataset service* surface area, which is responsible
|
||||
for integrating with **external knowledge APIs** and wiring them into Dify datasets.
|
||||
|
||||
The goal of this test suite is twofold:
|
||||
|
||||
- Provide **high‑confidence regression coverage** for all public helpers on
|
||||
``ExternalDatasetService``.
|
||||
- Serve as **executable documentation** for how external API integration is expected
|
||||
to behave in different scenarios (happy paths, validation failures, and error codes).
|
||||
|
||||
The file intentionally contains **rich comments and generous spacing** in order to make
|
||||
each scenario easy to scan during reviews.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from models.dataset import Dataset, ExternalKnowledgeApis, ExternalKnowledgeBindings
|
||||
from services.entities.external_knowledge_entities.external_knowledge_entities import (
|
||||
Authorization,
|
||||
AuthorizationConfig,
|
||||
ExternalKnowledgeApiSetting,
|
||||
)
|
||||
from services.errors.dataset import DatasetNameDuplicateError
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
|
||||
class ExternalDatasetTestDataFactory:
|
||||
"""
|
||||
Factory helpers for building *lightweight* mocks for external knowledge tests.
|
||||
|
||||
These helpers are intentionally small and explicit:
|
||||
|
||||
- They avoid pulling in unnecessary fixtures.
|
||||
- They reflect the minimal contract that the service under test cares about.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_external_api(
|
||||
api_id: str = "api-123",
|
||||
tenant_id: str = "tenant-1",
|
||||
name: str = "Test API",
|
||||
description: str = "Description",
|
||||
settings: dict[str, Any] | None = None,
|
||||
) -> ExternalKnowledgeApis:
|
||||
"""
|
||||
Create a concrete ``ExternalKnowledgeApis`` instance with minimal fields.
|
||||
|
||||
Using the real SQLAlchemy model (instead of a pure Mock) makes it easier to
|
||||
exercise ``settings_dict`` and other convenience properties if needed.
|
||||
"""
|
||||
|
||||
instance = ExternalKnowledgeApis(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=description,
|
||||
settings=None if settings is None else cast(str, pytest.approx), # type: ignore[assignment]
|
||||
)
|
||||
|
||||
# Overwrite generated id for determinism in assertions.
|
||||
instance.id = api_id
|
||||
return instance
|
||||
|
||||
@staticmethod
|
||||
def create_dataset(
|
||||
dataset_id: str = "ds-1",
|
||||
tenant_id: str = "tenant-1",
|
||||
name: str = "External Dataset",
|
||||
provider: str = "external",
|
||||
) -> Dataset:
|
||||
"""
|
||||
Build a small ``Dataset`` instance representing an external dataset.
|
||||
"""
|
||||
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description="",
|
||||
provider=provider,
|
||||
created_by="user-1",
|
||||
)
|
||||
dataset.id = dataset_id
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_external_binding(
|
||||
tenant_id: str = "tenant-1",
|
||||
dataset_id: str = "ds-1",
|
||||
api_id: str = "api-1",
|
||||
external_knowledge_id: str = "knowledge-1",
|
||||
) -> ExternalKnowledgeBindings:
|
||||
"""
|
||||
Small helper for a binding between dataset and external knowledge API.
|
||||
"""
|
||||
|
||||
binding = ExternalKnowledgeBindings(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
external_knowledge_api_id=api_id,
|
||||
external_knowledge_id=external_knowledge_id,
|
||||
created_by="user-1",
|
||||
)
|
||||
return binding
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_external_knowledge_apis
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExternalDatasetServiceGetExternalKnowledgeApis:
|
||||
"""
|
||||
Tests for ``ExternalDatasetService.get_external_knowledge_apis``.
|
||||
|
||||
These tests focus on:
|
||||
|
||||
- Basic pagination wiring via ``db.paginate``.
|
||||
- Optional search keyword behaviour.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_paginate(self):
|
||||
"""
|
||||
Patch ``db.paginate`` so we do not touch the real database layer.
|
||||
"""
|
||||
|
||||
with (
|
||||
patch("services.external_knowledge_service.db.paginate", autospec=True) as mock_paginate,
|
||||
patch("services.external_knowledge_service.select", autospec=True),
|
||||
):
|
||||
yield mock_paginate
|
||||
|
||||
def test_get_external_knowledge_apis_basic_pagination(self, mock_db_paginate: MagicMock):
|
||||
"""
|
||||
It should return ``items`` and ``total`` coming from the paginate object.
|
||||
"""
|
||||
|
||||
# Arrange
|
||||
tenant_id = "tenant-1"
|
||||
page = 1
|
||||
per_page = 20
|
||||
|
||||
mock_items = [Mock(spec=ExternalKnowledgeApis), Mock(spec=ExternalKnowledgeApis)]
|
||||
mock_pagination = SimpleNamespace(items=mock_items, total=42)
|
||||
mock_db_paginate.return_value = mock_pagination
|
||||
|
||||
# Act
|
||||
items, total = ExternalDatasetService.get_external_knowledge_apis(page, per_page, tenant_id)
|
||||
|
||||
# Assert
|
||||
assert items is mock_items
|
||||
assert total == 42
|
||||
|
||||
mock_db_paginate.assert_called_once()
|
||||
call_kwargs = mock_db_paginate.call_args.kwargs
|
||||
assert call_kwargs["page"] == page
|
||||
assert call_kwargs["per_page"] == per_page
|
||||
assert call_kwargs["max_per_page"] == 100
|
||||
assert call_kwargs["error_out"] is False
|
||||
|
||||
def test_get_external_knowledge_apis_with_search_keyword(self, mock_db_paginate: MagicMock):
|
||||
"""
|
||||
When a search keyword is provided, the query should be adjusted
|
||||
(we simply assert that paginate is still called and does not explode).
|
||||
"""
|
||||
|
||||
# Arrange
|
||||
tenant_id = "tenant-1"
|
||||
page = 2
|
||||
per_page = 10
|
||||
search = "foo"
|
||||
|
||||
mock_pagination = SimpleNamespace(items=[], total=0)
|
||||
mock_db_paginate.return_value = mock_pagination
|
||||
|
||||
# Act
|
||||
items, total = ExternalDatasetService.get_external_knowledge_apis(page, per_page, tenant_id, search=search)
|
||||
|
||||
# Assert
|
||||
assert items == []
|
||||
assert total == 0
|
||||
mock_db_paginate.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# validate_api_list
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExternalDatasetServiceValidateApiList:
|
||||
"""
|
||||
Lightweight validation tests for ``validate_api_list``.
|
||||
"""
|
||||
|
||||
def test_validate_api_list_success(self):
|
||||
"""
|
||||
A minimal valid configuration (endpoint + api_key) should pass.
|
||||
"""
|
||||
|
||||
config = {"endpoint": "https://example.com", "api_key": "secret"}
|
||||
|
||||
# Act & Assert – no exception expected
|
||||
ExternalDatasetService.validate_api_list(config)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("config", "expected_message"),
|
||||
[
|
||||
({}, "api list is empty"),
|
||||
({"api_key": "k"}, "endpoint is required"),
|
||||
({"endpoint": "https://example.com"}, "api_key is required"),
|
||||
],
|
||||
)
|
||||
def test_validate_api_list_failures(self, config: dict[str, Any], expected_message: str):
|
||||
"""
|
||||
Invalid configs should raise ``ValueError`` with a clear message.
|
||||
"""
|
||||
|
||||
with pytest.raises(ValueError, match=expected_message):
|
||||
ExternalDatasetService.validate_api_list(config)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# create_external_knowledge_api & get/update/delete
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExternalDatasetServiceCrudExternalKnowledgeApi:
|
||||
"""
|
||||
CRUD tests for external knowledge API templates.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""
|
||||
Patch ``db.session`` for all CRUD tests in this class.
|
||||
"""
|
||||
|
||||
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
|
||||
yield mock_session
|
||||
|
||||
def test_create_external_knowledge_api_success(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
``create_external_knowledge_api`` should persist a new record
|
||||
when settings are present and valid.
|
||||
"""
|
||||
|
||||
tenant_id = "tenant-1"
|
||||
user_id = "user-1"
|
||||
args = {
|
||||
"name": "API",
|
||||
"description": "desc",
|
||||
"settings": {"endpoint": "https://api.example.com", "api_key": "secret"},
|
||||
}
|
||||
|
||||
# We do not want to actually call the remote endpoint here, so we patch the validator.
|
||||
with patch.object(ExternalDatasetService, "check_endpoint_and_api_key", autospec=True) as mock_check:
|
||||
result = ExternalDatasetService.create_external_knowledge_api(tenant_id, user_id, args)
|
||||
|
||||
assert isinstance(result, ExternalKnowledgeApis)
|
||||
mock_check.assert_called_once_with(args["settings"])
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_create_external_knowledge_api_missing_settings_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
Missing ``settings`` should result in a ``ValueError``.
|
||||
"""
|
||||
|
||||
tenant_id = "tenant-1"
|
||||
user_id = "user-1"
|
||||
args = {"name": "API", "description": "desc"}
|
||||
|
||||
with pytest.raises(ValueError, match="settings is required"):
|
||||
ExternalDatasetService.create_external_knowledge_api(tenant_id, user_id, args)
|
||||
|
||||
mock_db_session.add.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
def test_get_external_knowledge_api_found(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
``get_external_knowledge_api`` should return the first matching record.
|
||||
"""
|
||||
|
||||
api = Mock(spec=ExternalKnowledgeApis)
|
||||
mock_db_session.scalar.return_value = api
|
||||
|
||||
result = ExternalDatasetService.get_external_knowledge_api("api-id", "tenant-id")
|
||||
assert result is api
|
||||
|
||||
def test_get_external_knowledge_api_not_found_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
When the record is absent, a ``ValueError`` is raised.
|
||||
"""
|
||||
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="api template not found"):
|
||||
ExternalDatasetService.get_external_knowledge_api("missing-id", "tenant-id")
|
||||
|
||||
def test_update_external_knowledge_api_success_with_hidden_api_key(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
Updating an API should keep the existing API key when the special hidden
|
||||
value placeholder is sent from the UI.
|
||||
"""
|
||||
|
||||
tenant_id = "tenant-1"
|
||||
user_id = "user-1"
|
||||
api_id = "api-1"
|
||||
|
||||
existing_api = Mock(spec=ExternalKnowledgeApis)
|
||||
existing_api.settings_dict = {"api_key": "stored-key"}
|
||||
existing_api.settings = '{"api_key":"stored-key"}'
|
||||
mock_db_session.scalar.return_value = existing_api
|
||||
|
||||
args = {
|
||||
"name": "New Name",
|
||||
"description": "New Desc",
|
||||
"settings": {"endpoint": "https://api.example.com", "api_key": HIDDEN_VALUE},
|
||||
}
|
||||
|
||||
result = ExternalDatasetService.update_external_knowledge_api(tenant_id, user_id, api_id, args)
|
||||
|
||||
assert result is existing_api
|
||||
# The placeholder should be replaced with stored key.
|
||||
assert args["settings"]["api_key"] == "stored-key"
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_update_external_knowledge_api_not_found_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
Updating a non‑existent API template should raise ``ValueError``.
|
||||
"""
|
||||
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="api template not found"):
|
||||
ExternalDatasetService.update_external_knowledge_api(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
external_knowledge_api_id="missing-id",
|
||||
args={"name": "n", "description": "d", "settings": {}},
|
||||
)
|
||||
|
||||
def test_delete_external_knowledge_api_success(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
``delete_external_knowledge_api`` should delete and commit when found.
|
||||
"""
|
||||
|
||||
api = Mock(spec=ExternalKnowledgeApis)
|
||||
mock_db_session.scalar.return_value = api
|
||||
|
||||
ExternalDatasetService.delete_external_knowledge_api("tenant-1", "api-1")
|
||||
|
||||
mock_db_session.delete.assert_called_once_with(api)
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_delete_external_knowledge_api_not_found_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
Deletion of a missing template should raise ``ValueError``.
|
||||
"""
|
||||
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="api template not found"):
|
||||
ExternalDatasetService.delete_external_knowledge_api("tenant-1", "missing")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# external_knowledge_api_use_check & binding lookups
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExternalDatasetServiceUsageAndBindings:
|
||||
"""
|
||||
Tests for usage checks and dataset binding retrieval.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
|
||||
yield mock_session
|
||||
|
||||
def test_external_knowledge_api_use_check_in_use(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
When there are bindings, ``external_knowledge_api_use_check`` returns True and count.
|
||||
"""
|
||||
|
||||
mock_db_session.scalar.return_value = 3
|
||||
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1", "tenant-1")
|
||||
|
||||
assert in_use is True
|
||||
assert count == 3
|
||||
assert "tenant_id" in str(mock_db_session.scalar.call_args.args[0])
|
||||
|
||||
def test_external_knowledge_api_use_check_not_in_use(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
Zero bindings should return ``(False, 0)``.
|
||||
"""
|
||||
|
||||
mock_db_session.scalar.return_value = 0
|
||||
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1", "tenant-1")
|
||||
|
||||
assert in_use is False
|
||||
assert count == 0
|
||||
|
||||
def test_get_external_knowledge_binding_with_dataset_id_found(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
Binding lookup should return the first record when present.
|
||||
"""
|
||||
|
||||
binding = Mock(spec=ExternalKnowledgeBindings)
|
||||
mock_db_session.scalar.return_value = binding
|
||||
|
||||
result = ExternalDatasetService.get_external_knowledge_binding_with_dataset_id("tenant-1", "ds-1")
|
||||
assert result is binding
|
||||
|
||||
def test_get_external_knowledge_binding_with_dataset_id_not_found_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
Missing binding should result in a ``ValueError``.
|
||||
"""
|
||||
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="external knowledge binding not found"):
|
||||
ExternalDatasetService.get_external_knowledge_binding_with_dataset_id("tenant-1", "ds-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# document_create_args_validate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExternalDatasetServiceDocumentCreateArgsValidate:
|
||||
"""
|
||||
Tests for ``document_create_args_validate``.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
|
||||
yield mock_session
|
||||
|
||||
def test_document_create_args_validate_success(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
All required custom parameters present – validation should pass.
|
||||
"""
|
||||
|
||||
external_api = Mock(spec=ExternalKnowledgeApis)
|
||||
external_api.settings = json_settings = (
|
||||
'[{"document_process_setting":[{"name":"foo","required":true},{"name":"bar","required":false}]}]'
|
||||
)
|
||||
# Raw string; the service itself calls json.loads on it
|
||||
mock_db_session.scalar.return_value = external_api
|
||||
|
||||
process_parameter = {"foo": "value", "bar": "optional"}
|
||||
|
||||
# Act & Assert – no exception
|
||||
ExternalDatasetService.document_create_args_validate("tenant-1", "api-1", process_parameter)
|
||||
|
||||
assert json_settings in external_api.settings # simple sanity check on our test data
|
||||
|
||||
def test_document_create_args_validate_missing_template_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
When the referenced API template is missing, a ``ValueError`` is raised.
|
||||
"""
|
||||
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="api template not found"):
|
||||
ExternalDatasetService.document_create_args_validate("tenant-1", "missing", {})
|
||||
|
||||
def test_document_create_args_validate_missing_required_parameter_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
Required document process parameters must be supplied.
|
||||
"""
|
||||
|
||||
external_api = Mock(spec=ExternalKnowledgeApis)
|
||||
external_api.settings = (
|
||||
'[{"document_process_setting":[{"name":"foo","required":true},{"name":"bar","required":false}]}]'
|
||||
)
|
||||
mock_db_session.scalar.return_value = external_api
|
||||
|
||||
process_parameter = {"bar": "present"} # missing "foo"
|
||||
|
||||
with pytest.raises(ValueError, match="foo is required"):
|
||||
ExternalDatasetService.document_create_args_validate("tenant-1", "api-1", process_parameter)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# process_external_api
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExternalDatasetServiceProcessExternalApi:
|
||||
"""
|
||||
Tests focused on the HTTP request assembly and method mapping behaviour.
|
||||
"""
|
||||
|
||||
def test_process_external_api_valid_method_post(self):
|
||||
"""
|
||||
For a supported HTTP verb we should delegate to the correct ``ssrf_proxy`` function.
|
||||
"""
|
||||
|
||||
settings = ExternalKnowledgeApiSetting(
|
||||
url="https://example.com/path",
|
||||
request_method="POST",
|
||||
headers={"X-Test": "1"},
|
||||
params={"foo": "bar"},
|
||||
)
|
||||
|
||||
fake_response = httpx.Response(200)
|
||||
|
||||
with patch("services.external_knowledge_service.ssrf_proxy.post", autospec=True) as mock_post:
|
||||
mock_post.return_value = fake_response
|
||||
|
||||
result = ExternalDatasetService.process_external_api(settings, files=None)
|
||||
|
||||
assert result is fake_response
|
||||
mock_post.assert_called_once()
|
||||
kwargs = mock_post.call_args.kwargs
|
||||
assert kwargs["url"] == settings.url
|
||||
assert kwargs["headers"] == settings.headers
|
||||
assert kwargs["follow_redirects"] is True
|
||||
assert "data" in kwargs
|
||||
|
||||
def test_process_external_api_invalid_method_raises(self):
|
||||
"""
|
||||
An unsupported HTTP verb should raise ``InvalidHttpMethodError``.
|
||||
"""
|
||||
|
||||
settings = ExternalKnowledgeApiSetting(
|
||||
url="https://example.com",
|
||||
request_method="INVALID",
|
||||
headers=None,
|
||||
params={},
|
||||
)
|
||||
|
||||
from graphon.nodes.http_request.exc import InvalidHttpMethodError
|
||||
|
||||
with pytest.raises(InvalidHttpMethodError):
|
||||
ExternalDatasetService.process_external_api(settings, files=None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# assembling_headers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExternalDatasetServiceAssemblingHeaders:
|
||||
"""
|
||||
Tests for header assembly based on different authentication flavours.
|
||||
"""
|
||||
|
||||
def test_assembling_headers_bearer_token(self):
|
||||
"""
|
||||
For bearer auth we expect ``Authorization: Bearer <key>`` by default.
|
||||
"""
|
||||
|
||||
auth = Authorization(
|
||||
type="api-key",
|
||||
config=AuthorizationConfig(type="bearer", api_key="secret", header=None),
|
||||
)
|
||||
|
||||
headers = ExternalDatasetService.assembling_headers(auth)
|
||||
|
||||
assert headers["Authorization"] == "Bearer secret"
|
||||
|
||||
def test_assembling_headers_basic_token_with_custom_header(self):
|
||||
"""
|
||||
For basic auth we honour the configured header name.
|
||||
"""
|
||||
|
||||
auth = Authorization(
|
||||
type="api-key",
|
||||
config=AuthorizationConfig(type="basic", api_key="abc123", header="X-Auth"),
|
||||
)
|
||||
|
||||
headers = ExternalDatasetService.assembling_headers(auth, headers={"Existing": "1"})
|
||||
|
||||
assert headers["Existing"] == "1"
|
||||
assert headers["X-Auth"] == "Basic abc123"
|
||||
|
||||
def test_assembling_headers_custom_type(self):
|
||||
"""
|
||||
Custom auth type should inject the raw API key.
|
||||
"""
|
||||
|
||||
auth = Authorization(
|
||||
type="api-key",
|
||||
config=AuthorizationConfig(type="custom", api_key="raw-key", header="X-API-KEY"),
|
||||
)
|
||||
|
||||
headers = ExternalDatasetService.assembling_headers(auth, headers=None)
|
||||
|
||||
assert headers["X-API-KEY"] == "raw-key"
|
||||
|
||||
def test_assembling_headers_missing_config_raises(self):
|
||||
"""
|
||||
Missing config object should be rejected.
|
||||
"""
|
||||
|
||||
auth = Authorization(type="api-key", config=None)
|
||||
|
||||
with pytest.raises(ValueError, match="authorization config is required"):
|
||||
ExternalDatasetService.assembling_headers(auth)
|
||||
|
||||
def test_assembling_headers_missing_api_key_raises(self):
|
||||
"""
|
||||
``api_key`` is required when type is ``api-key``.
|
||||
"""
|
||||
|
||||
auth = Authorization(
|
||||
type="api-key",
|
||||
config=AuthorizationConfig(type="bearer", api_key=None, header="Authorization"),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="api_key is required"):
|
||||
ExternalDatasetService.assembling_headers(auth)
|
||||
|
||||
def test_assembling_headers_no_auth_type_leaves_headers_unchanged(self):
|
||||
"""
|
||||
For ``no-auth`` we should not modify the headers mapping.
|
||||
"""
|
||||
|
||||
auth = Authorization(type="no-auth", config=None)
|
||||
|
||||
base_headers = {"X": "1"}
|
||||
result = ExternalDatasetService.assembling_headers(auth, headers=base_headers)
|
||||
|
||||
# A copy is returned, original is not mutated.
|
||||
assert result == base_headers
|
||||
assert result is not base_headers
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_external_knowledge_api_settings
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExternalDatasetServiceGetExternalKnowledgeApiSettings:
|
||||
"""
|
||||
Simple shape test for ``get_external_knowledge_api_settings``.
|
||||
"""
|
||||
|
||||
def test_get_external_knowledge_api_settings(self):
|
||||
settings_dict: dict[str, Any] = {
|
||||
"url": "https://example.com/retrieval",
|
||||
"request_method": "post",
|
||||
"headers": {"Content-Type": "application/json"},
|
||||
"params": {"foo": "bar"},
|
||||
}
|
||||
|
||||
result = ExternalDatasetService.get_external_knowledge_api_settings(settings_dict)
|
||||
|
||||
assert isinstance(result, ExternalKnowledgeApiSetting)
|
||||
assert result.url == settings_dict["url"]
|
||||
assert result.request_method == settings_dict["request_method"]
|
||||
assert result.headers == settings_dict["headers"]
|
||||
assert result.params == settings_dict["params"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# create_external_dataset
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExternalDatasetServiceCreateExternalDataset:
|
||||
"""
|
||||
Tests around creating the external dataset and its binding row.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
|
||||
yield mock_session
|
||||
|
||||
def test_create_external_dataset_success(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
A brand new dataset name with valid external knowledge references
|
||||
should create both the dataset and its binding.
|
||||
"""
|
||||
|
||||
tenant_id = "tenant-1"
|
||||
user_id = "user-1"
|
||||
|
||||
args = {
|
||||
"name": "My Dataset",
|
||||
"description": "desc",
|
||||
"external_knowledge_api_id": "api-1",
|
||||
"external_knowledge_id": "knowledge-1",
|
||||
"external_retrieval_model": {"top_k": 3},
|
||||
}
|
||||
|
||||
# No existing dataset with same name.
|
||||
mock_db_session.scalar.side_effect = [
|
||||
None, # duplicate‑name check
|
||||
Mock(spec=ExternalKnowledgeApis), # external knowledge api
|
||||
]
|
||||
|
||||
dataset = ExternalDatasetService.create_external_dataset(tenant_id, user_id, args)
|
||||
|
||||
assert isinstance(dataset, Dataset)
|
||||
assert dataset.provider == "external"
|
||||
assert dataset.retrieval_model == args["external_retrieval_model"]
|
||||
|
||||
assert mock_db_session.add.call_count >= 2 # dataset + binding
|
||||
mock_db_session.flush.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_create_external_dataset_duplicate_name_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
When a dataset with the same name already exists,
|
||||
``DatasetNameDuplicateError`` is raised.
|
||||
"""
|
||||
|
||||
existing_dataset = Mock(spec=Dataset)
|
||||
mock_db_session.scalar.return_value = existing_dataset
|
||||
|
||||
args = {
|
||||
"name": "Existing",
|
||||
"external_knowledge_api_id": "api-1",
|
||||
"external_knowledge_id": "knowledge-1",
|
||||
}
|
||||
|
||||
with pytest.raises(DatasetNameDuplicateError):
|
||||
ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args)
|
||||
|
||||
mock_db_session.add.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
def test_create_external_dataset_missing_api_template_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
If the referenced external knowledge API does not exist, a ``ValueError`` is raised.
|
||||
"""
|
||||
|
||||
# First call: duplicate name check – not found.
|
||||
mock_db_session.scalar.side_effect = [
|
||||
None,
|
||||
None, # external knowledge api lookup
|
||||
]
|
||||
|
||||
args = {
|
||||
"name": "Dataset",
|
||||
"external_knowledge_api_id": "missing",
|
||||
"external_knowledge_id": "knowledge-1",
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="api template not found"):
|
||||
ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args)
|
||||
|
||||
def test_create_external_dataset_missing_required_ids_raise(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
``external_knowledge_id`` and ``external_knowledge_api_id`` are mandatory.
|
||||
"""
|
||||
|
||||
# duplicate name check — two calls to create_external_dataset, each does 2 scalar calls
|
||||
mock_db_session.scalar.side_effect = [
|
||||
None,
|
||||
Mock(spec=ExternalKnowledgeApis),
|
||||
None,
|
||||
Mock(spec=ExternalKnowledgeApis),
|
||||
]
|
||||
|
||||
args_missing_knowledge_id = {
|
||||
"name": "Dataset",
|
||||
"external_knowledge_api_id": "api-1",
|
||||
"external_knowledge_id": None,
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="external_knowledge_id is required"):
|
||||
ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args_missing_knowledge_id)
|
||||
|
||||
args_missing_api_id = {
|
||||
"name": "Dataset",
|
||||
"external_knowledge_api_id": None,
|
||||
"external_knowledge_id": "k-1",
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="external_knowledge_api_id is required"):
|
||||
ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args_missing_api_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# fetch_external_knowledge_retrieval
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
|
||||
"""
|
||||
Tests for ``fetch_external_knowledge_retrieval`` which orchestrates
|
||||
external retrieval requests and normalises the response payload.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
|
||||
yield mock_session
|
||||
|
||||
def test_fetch_external_knowledge_retrieval_success(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
With a valid binding and API template, records from the external
|
||||
service should be returned when the HTTP response is 200.
|
||||
"""
|
||||
|
||||
tenant_id = "tenant-1"
|
||||
dataset_id = "ds-1"
|
||||
query = "test query"
|
||||
external_retrieval_parameters = {"top_k": 3, "score_threshold_enabled": True, "score_threshold": 0.5}
|
||||
|
||||
binding = ExternalDatasetTestDataFactory.create_external_binding(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
api_id="api-1",
|
||||
external_knowledge_id="knowledge-1",
|
||||
)
|
||||
|
||||
api = Mock(spec=ExternalKnowledgeApis)
|
||||
api.settings = '{"endpoint":"https://example.com","api_key":"secret"}'
|
||||
|
||||
# First query: binding; second query: api.
|
||||
mock_db_session.scalar.side_effect = [
|
||||
binding,
|
||||
api,
|
||||
]
|
||||
|
||||
fake_records = [{"content": "doc", "score": 0.9}]
|
||||
fake_response = Mock(spec=httpx.Response)
|
||||
fake_response.status_code = 200
|
||||
fake_response.json.return_value = {"records": fake_records}
|
||||
|
||||
metadata_condition = SimpleNamespace(model_dump=lambda: {"field": "value"})
|
||||
|
||||
with patch.object(
|
||||
ExternalDatasetService, "process_external_api", return_value=fake_response, autospec=True
|
||||
) as mock_process:
|
||||
result = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
query=query,
|
||||
external_retrieval_parameters=external_retrieval_parameters,
|
||||
metadata_condition=metadata_condition,
|
||||
)
|
||||
|
||||
assert result == fake_records
|
||||
|
||||
mock_process.assert_called_once()
|
||||
setting_arg = mock_process.call_args.args[0]
|
||||
assert isinstance(setting_arg, ExternalKnowledgeApiSetting)
|
||||
assert setting_arg.url.endswith("/retrieval")
|
||||
|
||||
def test_fetch_external_knowledge_retrieval_binding_not_found_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
Missing binding should raise ``ValueError``.
|
||||
"""
|
||||
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="external knowledge binding not found"):
|
||||
ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
tenant_id="tenant-1",
|
||||
dataset_id="missing",
|
||||
query="q",
|
||||
external_retrieval_parameters={},
|
||||
metadata_condition=None,
|
||||
)
|
||||
|
||||
def test_fetch_external_knowledge_retrieval_missing_api_template_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
When the API template is missing or has no settings, a ``ValueError`` is raised.
|
||||
"""
|
||||
|
||||
binding = ExternalDatasetTestDataFactory.create_external_binding()
|
||||
mock_db_session.scalar.side_effect = [
|
||||
binding,
|
||||
None,
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="external api template not found"):
|
||||
ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
tenant_id="tenant-1",
|
||||
dataset_id="ds-1",
|
||||
query="q",
|
||||
external_retrieval_parameters={},
|
||||
metadata_condition=None,
|
||||
)
|
||||
|
||||
def test_fetch_external_knowledge_retrieval_non_200_status_returns_empty_list(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
Non‑200 responses should be treated as an empty result set.
|
||||
"""
|
||||
|
||||
binding = ExternalDatasetTestDataFactory.create_external_binding()
|
||||
api = Mock(spec=ExternalKnowledgeApis)
|
||||
api.settings = '{"endpoint":"https://example.com","api_key":"secret"}'
|
||||
|
||||
mock_db_session.scalar.side_effect = [
|
||||
binding,
|
||||
api,
|
||||
]
|
||||
|
||||
fake_response = Mock(spec=httpx.Response)
|
||||
fake_response.status_code = 500
|
||||
fake_response.json.return_value = {}
|
||||
|
||||
with patch.object(ExternalDatasetService, "process_external_api", return_value=fake_response, autospec=True):
|
||||
result = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
tenant_id="tenant-1",
|
||||
dataset_id="ds-1",
|
||||
query="q",
|
||||
external_retrieval_parameters={},
|
||||
metadata_condition=None,
|
||||
)
|
||||
|
||||
assert result == []
|
||||
@@ -374,24 +374,14 @@ def test_publish_workflow_success(mocker, rag_pipeline_service) -> None:
|
||||
mock_db = mocker.Mock()
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db", mock_db)
|
||||
mock_dataset_service_class = mocker.patch("services.dataset_service.DatasetService")
|
||||
mock_dataset_service = mock_dataset_service_class.return_value
|
||||
|
||||
# 6. Mock session and its scalar/query methods
|
||||
# 6. Mock session and dataset lookup
|
||||
mock_session = mocker.Mock()
|
||||
mock_session.scalar.return_value = draft_wf
|
||||
|
||||
# Mock dataset update query (needed even if service is mocked, as rag_pipeline fetches it first)
|
||||
dataset = mocker.Mock()
|
||||
dataset.retrieval_model_dict = {}
|
||||
dataset_query = mocker.Mock()
|
||||
dataset_query.where.return_value.first.return_value = dataset
|
||||
|
||||
# Mock node execution copy
|
||||
node_exec_query = mocker.Mock()
|
||||
node_exec_query.where.return_value.all.return_value = []
|
||||
|
||||
# Mocked session query side effects
|
||||
mock_session.query.side_effect = [node_exec_query, dataset_query]
|
||||
pipeline.retrieve_dataset.return_value = dataset
|
||||
|
||||
# 7. Run test
|
||||
result = rag_pipeline_service.publish_workflow(session=mock_session, pipeline=pipeline, account=account)
|
||||
@@ -1524,7 +1514,6 @@ def test_handle_node_run_result_marks_document_error_for_published_invoke(mocker
|
||||
)
|
||||
|
||||
document = SimpleNamespace(indexing_status="waiting", error=None)
|
||||
query = mocker.Mock()
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=document)
|
||||
add_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.add")
|
||||
commit_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
|
||||
@@ -1595,7 +1584,6 @@ def test_publish_customized_pipeline_template_raises_for_missing_workflow_id(moc
|
||||
|
||||
|
||||
def test_get_pipeline_raises_when_dataset_missing(mocker, rag_pipeline_service) -> None:
|
||||
query = mocker.Mock()
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Dataset not found"):
|
||||
@@ -1604,7 +1592,6 @@ def test_get_pipeline_raises_when_dataset_missing(mocker, rag_pipeline_service)
|
||||
|
||||
def test_get_pipeline_raises_when_pipeline_missing(mocker, rag_pipeline_service) -> None:
|
||||
dataset = SimpleNamespace(pipeline_id="p1")
|
||||
query = mocker.Mock()
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, None])
|
||||
|
||||
with pytest.raises(ValueError, match="Pipeline not found"):
|
||||
@@ -1644,7 +1631,6 @@ def test_get_pipeline_templates_builtin_en_us_no_fallback(mocker) -> None:
|
||||
|
||||
def test_update_customized_pipeline_template_commits_when_name_empty(mocker) -> None:
|
||||
template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None)
|
||||
query = mocker.Mock()
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=template)
|
||||
commit = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
|
||||
@@ -1871,7 +1857,6 @@ def test_run_free_workflow_node_delegates_to_handle_result(mocker, rag_pipeline_
|
||||
|
||||
def test_publish_customized_pipeline_template_raises_when_workflow_missing(mocker, rag_pipeline_service) -> None:
|
||||
pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id="wf-1")
|
||||
query = mocker.Mock()
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", side_effect=[pipeline, None])
|
||||
|
||||
with pytest.raises(ValueError, match="Workflow not found"):
|
||||
@@ -1910,7 +1895,6 @@ def test_get_recommended_plugins_skips_manifest_when_missing(mocker, rag_pipelin
|
||||
|
||||
def test_retry_error_document_raises_when_pipeline_missing(mocker, rag_pipeline_service) -> None:
|
||||
exec_log = SimpleNamespace(pipeline_id="p1")
|
||||
query = mocker.Mock()
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=exec_log)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=None)
|
||||
|
||||
@@ -1923,7 +1907,6 @@ def test_retry_error_document_raises_when_pipeline_missing(mocker, rag_pipeline_
|
||||
def test_retry_error_document_raises_when_workflow_missing(mocker, rag_pipeline_service) -> None:
|
||||
exec_log = SimpleNamespace(pipeline_id="p1")
|
||||
pipeline = SimpleNamespace(id="p1")
|
||||
query = mocker.Mock()
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=exec_log)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline)
|
||||
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=None)
|
||||
@@ -1940,7 +1923,6 @@ def test_get_datasource_plugins_returns_empty_for_non_datasource_nodes(mocker, r
|
||||
workflow = SimpleNamespace(
|
||||
graph_dict={"nodes": [{"id": "n1", "data": {"type": "start"}}]}, rag_pipeline_variables=[]
|
||||
)
|
||||
query = mocker.Mock()
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
|
||||
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
|
||||
|
||||
@@ -2103,7 +2085,6 @@ def test_get_datasource_plugins_handles_empty_datasource_data_and_non_published(
|
||||
graph_dict={"nodes": [{"id": "n1", "data": {"type": "datasource", "datasource_parameters": {}}}]},
|
||||
rag_pipeline_variables=[{"variable": "v1", "belong_to_node_id": "shared"}],
|
||||
)
|
||||
query = mocker.Mock()
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
|
||||
mocker.patch.object(rag_pipeline_service, "get_draft_workflow", return_value=workflow)
|
||||
mocker.patch(
|
||||
@@ -2143,7 +2124,6 @@ def test_get_datasource_plugins_extracts_user_inputs_and_credentials(mocker, rag
|
||||
{"variable": "v3", "belong_to_node_id": "shared"},
|
||||
],
|
||||
)
|
||||
query = mocker.Mock()
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
|
||||
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
|
||||
mocker.patch(
|
||||
@@ -2161,7 +2141,6 @@ def test_get_datasource_plugins_extracts_user_inputs_and_credentials(mocker, rag
|
||||
def test_get_pipeline_returns_pipeline_when_found(mocker, rag_pipeline_service) -> None:
|
||||
dataset = SimpleNamespace(pipeline_id="p1")
|
||||
pipeline = SimpleNamespace(id="p1")
|
||||
query = mocker.Mock()
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
|
||||
|
||||
result = rag_pipeline_service.get_pipeline("t1", "d1")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,59 +0,0 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
class ServiceDbTestHelper:
|
||||
"""
|
||||
Helper class for service database query tests.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def setup_db_query_filter_by_mock(mock_db, query_results):
|
||||
"""
|
||||
Smart database query mock that responds based on model type and query parameters.
|
||||
|
||||
Args:
|
||||
mock_db: Mock database session
|
||||
query_results: Dict mapping (model_name, filter_key, filter_value) to return value
|
||||
Example: {('Account', 'email', 'test@example.com'): mock_account}
|
||||
"""
|
||||
|
||||
def query_side_effect(model):
|
||||
mock_query = MagicMock()
|
||||
|
||||
def filter_by_side_effect(**kwargs):
|
||||
mock_filter_result = MagicMock()
|
||||
|
||||
def first_side_effect():
|
||||
# Find matching result based on model and filter parameters
|
||||
for (model_name, filter_key, filter_value), result in query_results.items():
|
||||
if model.__name__ == model_name and filter_key in kwargs and kwargs[filter_key] == filter_value:
|
||||
return result
|
||||
return None
|
||||
|
||||
mock_filter_result.first.side_effect = first_side_effect
|
||||
|
||||
# Handle order_by calls for complex queries
|
||||
def order_by_side_effect(*args, **kwargs):
|
||||
mock_order_result = MagicMock()
|
||||
|
||||
def order_first_side_effect():
|
||||
# Look for order_by results in the same query_results dict
|
||||
for (model_name, filter_key, filter_value), result in query_results.items():
|
||||
if (
|
||||
model.__name__ == model_name
|
||||
and filter_key == "order_by"
|
||||
and filter_value == "first_available"
|
||||
):
|
||||
return result
|
||||
return None
|
||||
|
||||
mock_order_result.first.side_effect = order_first_side_effect
|
||||
return mock_order_result
|
||||
|
||||
mock_filter_result.order_by.side_effect = order_by_side_effect
|
||||
return mock_filter_result
|
||||
|
||||
mock_query.filter_by.side_effect = filter_by_side_effect
|
||||
return mock_query
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
@@ -14,7 +14,6 @@ from services.errors.account import (
|
||||
AccountRegisterError,
|
||||
CurrentPasswordIncorrectError,
|
||||
)
|
||||
from tests.unit_tests.services.services_test_help import ServiceDbTestHelper
|
||||
|
||||
|
||||
class TestAccountAssociatedDataFactory:
|
||||
@@ -149,7 +148,6 @@ class TestAccountService:
|
||||
# Setup basic session methods
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = MagicMock()
|
||||
mock_session.query = MagicMock()
|
||||
|
||||
yield mock_db
|
||||
|
||||
@@ -1572,15 +1570,9 @@ class TestRegisterService:
|
||||
account_id="existing-user-456", email="existing@example.com", status="active"
|
||||
)
|
||||
|
||||
# Mock database queries
|
||||
query_results = {
|
||||
(
|
||||
"TenantAccountJoin",
|
||||
"tenant_id",
|
||||
"tenant-456",
|
||||
): TestAccountAssociatedDataFactory.create_tenant_join_mock(),
|
||||
}
|
||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
||||
mock_db_dependencies[
|
||||
"db"
|
||||
].session.scalar.return_value = TestAccountAssociatedDataFactory.create_tenant_join_mock()
|
||||
|
||||
# Mock TenantService methods
|
||||
with (
|
||||
|
||||
@@ -714,7 +714,6 @@ class TestSegmentServiceMutations:
|
||||
patch("services.dataset_service.db") as mock_db,
|
||||
patch("services.dataset_service.delete_segment_from_index_task") as delete_task,
|
||||
):
|
||||
segments_query = MagicMock()
|
||||
# execute().all() for segments_info (multi-column)
|
||||
execute_result = MagicMock()
|
||||
execute_result.all.return_value = [
|
||||
|
||||
@@ -36,9 +36,7 @@ class TestDatasourceProviderService:
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""
|
||||
Robust, chainable query mock.
|
||||
q returns itself for .filter_by(), .order_by(), .where() so any
|
||||
SQLAlchemy chaining pattern works without multiple brittle sub-mocks.
|
||||
Mock session with scalar/scalars defaults for current SQLAlchemy access paths.
|
||||
"""
|
||||
with (
|
||||
patch("services.datasource_provider_service.Session") as mock_cls,
|
||||
@@ -46,20 +44,6 @@ class TestDatasourceProviderService:
|
||||
):
|
||||
sess = MagicMock(spec=Session)
|
||||
|
||||
q = MagicMock()
|
||||
sess.query.return_value = q
|
||||
|
||||
# Self-returning chain — any method called on q returns q
|
||||
q.filter_by.return_value = q
|
||||
q.order_by.return_value = q
|
||||
q.where.return_value = q
|
||||
|
||||
# Default terminal values (tests override per-case)
|
||||
q.first.return_value = None
|
||||
q.all.return_value = []
|
||||
q.count.return_value = 0
|
||||
q.delete.return_value = 1
|
||||
|
||||
# Default values for select()-style calls (tests override per-case)
|
||||
sess.scalar.return_value = None
|
||||
sess.scalars.return_value.all.return_value = []
|
||||
|
||||
@@ -17,23 +17,6 @@ from services.trigger import webhook_service as service_module
|
||||
from services.trigger.webhook_service import WebhookService
|
||||
|
||||
|
||||
class _FakeQuery:
|
||||
def __init__(self, result: Any) -> None:
|
||||
self._result = result
|
||||
|
||||
def where(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
|
||||
return self
|
||||
|
||||
def filter(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
|
||||
return self
|
||||
|
||||
def order_by(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
|
||||
return self
|
||||
|
||||
def first(self) -> Any:
|
||||
return self._result
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_app() -> Flask:
|
||||
return Flask(__name__)
|
||||
|
||||
@@ -1649,8 +1649,6 @@ class TestWorkflowServiceCredentialValidation:
|
||||
"""Missing BuiltinToolProvider → plugin requires no credentials → no error."""
|
||||
# Arrange
|
||||
with patch("services.workflow_service.db") as mock_db:
|
||||
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
|
||||
# Act + Assert (should NOT raise)
|
||||
service._check_default_tool_credential("tenant-1", "some-provider")
|
||||
|
||||
@@ -1662,10 +1660,6 @@ class TestWorkflowServiceCredentialValidation:
|
||||
patch("services.workflow_service.db") as mock_db,
|
||||
patch("core.helper.credential_utils.check_credential_policy_compliance", side_effect=Exception("denied")),
|
||||
):
|
||||
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = (
|
||||
mock_provider
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Failed to validate default credential"):
|
||||
service._check_default_tool_credential("tenant-1", "some-provider")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -89,9 +89,6 @@ def mock_db_session():
|
||||
session = MagicMock()
|
||||
session._shared_data = {"dataset": None, "documents": []}
|
||||
|
||||
# Keep a pointer so repeated Document.first() calls iterate across provided docs
|
||||
session._doc_first_idx = 0
|
||||
|
||||
def _get_entity(stmt) -> type | None:
|
||||
"""Extract the mapped entity class from a SQLAlchemy select statement."""
|
||||
try:
|
||||
@@ -1591,18 +1588,7 @@ class TestDocumentIndexingTaskSummaryFlow:
|
||||
need_summary=True,
|
||||
)
|
||||
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.return_value = dataset
|
||||
|
||||
phase1_docs = [SimpleNamespace(id="doc-1"), SimpleNamespace(id="doc-2"), SimpleNamespace(id="doc-3")]
|
||||
phase1_document_query = MagicMock()
|
||||
phase1_document_query.where.return_value = phase1_document_query
|
||||
phase1_document_query.all.return_value = phase1_docs
|
||||
|
||||
summary_document_query = MagicMock()
|
||||
summary_document_query.where.return_value = summary_document_query
|
||||
summary_document_query.all.return_value = [doc_eligible, doc_skip_form, doc_skip_status]
|
||||
|
||||
session1 = MagicMock()
|
||||
session2 = MagicMock()
|
||||
@@ -1657,18 +1643,6 @@ class TestDocumentIndexingTaskSummaryFlow:
|
||||
need_summary=True,
|
||||
)
|
||||
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.return_value = dataset
|
||||
|
||||
phase1_query = MagicMock()
|
||||
phase1_query.where.return_value = phase1_query
|
||||
phase1_query.all.return_value = [SimpleNamespace(id="doc-1")]
|
||||
|
||||
summary_query = MagicMock()
|
||||
summary_query.where.return_value = summary_query
|
||||
summary_query.all.return_value = [doc_eligible]
|
||||
|
||||
session1 = MagicMock()
|
||||
session2 = MagicMock()
|
||||
session2.begin.return_value = nullcontext()
|
||||
|
||||
Reference in New Issue
Block a user