diff --git a/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py index 286dda419c..ac09060e9d 100644 --- a/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py @@ -225,8 +225,10 @@ class TestSpanBuilder: span = builder.build_span(span_data) assert isinstance(span, ReadableSpan) assert span.name == "test-span" + assert span.context is not None assert span.context.trace_id == 123 assert span.context.span_id == 456 + assert span.parent is not None assert span.parent.span_id == 789 assert span.resource == resource assert span.attributes == {"attr1": "val1"} diff --git a/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_aliyun_trace_entity.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_aliyun_trace_entity.py index 38d33dd21b..a6808fec0a 100644 --- a/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_aliyun_trace_entity.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_aliyun_trace_entity.py @@ -64,12 +64,13 @@ class TestSpanData: def test_span_data_missing_required_fields(self): with pytest.raises(ValidationError): - SpanData( - trace_id=123, - # span_id missing - name="test_span", - start_time=1000, - end_time=2000, + SpanData.model_validate( + { + "trace_id": 123, + "name": "test_span", + "start_time": 1000, + "end_time": 2000, + } ) def test_span_data_arbitrary_types_allowed(self): diff --git a/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace.py index c1b11c9186..fa00829653 100644 --- a/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace.py @@ -2,12 +2,14 @@ from __future__ import annotations from datetime import UTC, datetime from types import SimpleNamespace +from typing import cast from unittest.mock import MagicMock import dify_trace_aliyun.aliyun_trace as aliyun_trace_module import pytest from dify_trace_aliyun.aliyun_trace import AliyunDataTrace from dify_trace_aliyun.config import AliyunConfig +from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata from dify_trace_aliyun.entities.semconv import ( GEN_AI_COMPLETION, GEN_AI_INPUT_MESSAGE, @@ -44,7 +46,7 @@ class RecordingTraceClient: self.endpoint = endpoint self.added_spans: list[object] = [] - def add_span(self, span) -> None: + def add_span(self, span: object) -> None: self.added_spans.append(span) def api_check(self) -> bool: @@ -63,11 +65,35 @@ def _make_link(trace_id: int = 1, span_id: int = 2) -> Link: trace_id=trace_id, span_id=span_id, is_remote=False, - trace_flags=TraceFlags.SAMPLED, + trace_flags=TraceFlags(TraceFlags.SAMPLED), ) return Link(context) +def _make_trace_metadata( + trace_id: int = 1, + workflow_span_id: int = 2, + session_id: str = "s", + user_id: str = "u", + links: list[Link] | None = None, +) -> TraceMetadata: + return TraceMetadata( + trace_id=trace_id, + workflow_span_id=workflow_span_id, + session_id=session_id, + user_id=user_id, + links=[] if links is None else links, + ) + + +def _recording_trace_client(trace_instance: AliyunDataTrace) -> RecordingTraceClient: + return cast(RecordingTraceClient, trace_instance.trace_client) + + +def _recorded_span_data(trace_instance: AliyunDataTrace) -> list[SpanData]: + return cast(list[SpanData], _recording_trace_client(trace_instance).added_spans) + + def _make_workflow_trace_info(**overrides) -> WorkflowTraceInfo: defaults = { "workflow_id": "workflow-id", @@ -263,20 +289,20 @@ def test_workflow_trace_adds_workflow_and_node_spans(trace_instance: AliyunDataT trace_instance.workflow_trace(trace_info) add_workflow_span.assert_called_once() - passed_trace_metadata = add_workflow_span.call_args.args[1] + passed_trace_metadata = cast(TraceMetadata, add_workflow_span.call_args.args[1]) assert passed_trace_metadata.trace_id == 111 assert passed_trace_metadata.workflow_span_id == 222 assert passed_trace_metadata.session_id == "c" assert passed_trace_metadata.user_id == "u" assert passed_trace_metadata.links == [] - assert trace_instance.trace_client.added_spans == ["span-1", "span-2"] + assert _recording_trace_client(trace_instance).added_spans == ["span-1", "span-2"] def test_message_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace): trace_info = _make_message_trace_info(message_data=None) trace_instance.message_trace(trace_info) - assert trace_instance.trace_client.added_spans == [] + assert _recording_trace_client(trace_instance).added_spans == [] def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): @@ -302,8 +328,9 @@ def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataT ) trace_instance.message_trace(trace_info) - assert len(trace_instance.trace_client.added_spans) == 2 - message_span, llm_span = trace_instance.trace_client.added_spans + spans = _recorded_span_data(trace_instance) + assert len(spans) == 2 + message_span, llm_span = spans assert message_span.name == "message" assert message_span.trace_id == 10 @@ -324,7 +351,7 @@ def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataT def test_dataset_retrieval_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace): trace_info = _make_dataset_retrieval_trace_info(message_data=None) trace_instance.dataset_retrieval_trace(trace_info) - assert trace_instance.trace_client.added_spans == [] + assert _recording_trace_client(trace_instance).added_spans == [] def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): @@ -338,8 +365,9 @@ def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, m monkeypatch.setattr(aliyun_trace_module, "extract_retrieval_documents", lambda _: [{"doc": "d"}]) trace_instance.dataset_retrieval_trace(_make_dataset_retrieval_trace_info(inputs="query")) - assert len(trace_instance.trace_client.added_spans) == 1 - span = trace_instance.trace_client.added_spans[0] + spans = _recorded_span_data(trace_instance) + assert len(spans) == 1 + span = spans[0] assert span.name == "dataset_retrieval" assert span.attributes[RETRIEVAL_QUERY] == "query" assert span.attributes[RETRIEVAL_DOCUMENT] == '[{"doc": "d"}]' @@ -348,7 +376,7 @@ def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, m def test_tool_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace): trace_info = _make_tool_trace_info(message_data=None) trace_instance.tool_trace(trace_info) - assert trace_instance.trace_client.added_spans == [] + assert _recording_trace_client(trace_instance).added_spans == [] def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): @@ -371,8 +399,9 @@ def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: p ) ) - assert len(trace_instance.trace_client.added_spans) == 1 - span = trace_instance.trace_client.added_spans[0] + spans = _recorded_span_data(trace_instance) + assert len(spans) == 1 + span = spans[0] assert span.name == "my-tool" assert span.status == status assert span.attributes[TOOL_NAME] == "my-tool" @@ -409,7 +438,7 @@ def test_get_workflow_node_executions_builds_repo_and_fetches( def test_build_workflow_node_span_routes_llm_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): node_execution = MagicMock(spec=WorkflowNodeExecution) trace_info = _make_workflow_trace_info() - trace_metadata = MagicMock() + trace_metadata = _make_trace_metadata() monkeypatch.setattr(trace_instance, "build_workflow_llm_span", MagicMock(return_value="llm")) @@ -422,7 +451,7 @@ def test_build_workflow_node_span_routes_knowledge_retrieval_type( ): node_execution = MagicMock(spec=WorkflowNodeExecution) trace_info = _make_workflow_trace_info() - trace_metadata = MagicMock() + trace_metadata = _make_trace_metadata() monkeypatch.setattr(trace_instance, "build_workflow_retrieval_span", MagicMock(return_value="retrieval")) @@ -433,7 +462,7 @@ def test_build_workflow_node_span_routes_knowledge_retrieval_type( def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): node_execution = MagicMock(spec=WorkflowNodeExecution) trace_info = _make_workflow_trace_info() - trace_metadata = MagicMock() + trace_metadata = _make_trace_metadata() monkeypatch.setattr(trace_instance, "build_workflow_tool_span", MagicMock(return_value="tool")) @@ -444,7 +473,7 @@ def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTra def test_build_workflow_node_span_routes_code_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): node_execution = MagicMock(spec=WorkflowNodeExecution) trace_info = _make_workflow_trace_info() - trace_metadata = MagicMock() + trace_metadata = _make_trace_metadata() monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(return_value="task")) @@ -457,7 +486,7 @@ def test_build_workflow_node_span_handles_errors( ): node_execution = MagicMock(spec=WorkflowNodeExecution) trace_info = _make_workflow_trace_info() - trace_metadata = MagicMock() + trace_metadata = _make_trace_metadata() monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(side_effect=RuntimeError("boom"))) node_execution.node_type = BuiltinNodeTypes.CODE @@ -472,7 +501,7 @@ def test_build_workflow_task_span(trace_instance: AliyunDataTrace, monkeypatch: status = Status(StatusCode.OK) monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status) - trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[]) + trace_metadata = _make_trace_metadata() node_execution = MagicMock(spec=WorkflowNodeExecution) node_execution.id = "node-id" node_execution.title = "title" @@ -494,7 +523,7 @@ def test_build_workflow_tool_span(trace_instance: AliyunDataTrace, monkeypatch: status = Status(StatusCode.OK) monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status) - trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[_make_link()]) + trace_metadata = _make_trace_metadata(links=[_make_link()]) node_execution = MagicMock(spec=WorkflowNodeExecution) node_execution.id = "node-id" node_execution.title = "my-tool" @@ -527,7 +556,7 @@ def test_build_workflow_retrieval_span(trace_instance: AliyunDataTrace, monkeypa aliyun_trace_module, "format_retrieval_documents", lambda docs: [{"formatted": True}] if docs else [] ) - trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[]) + trace_metadata = _make_trace_metadata() node_execution = MagicMock(spec=WorkflowNodeExecution) node_execution.id = "node-id" node_execution.title = "retrieval" @@ -556,7 +585,7 @@ def test_build_workflow_llm_span(trace_instance: AliyunDataTrace, monkeypatch: p monkeypatch.setattr(aliyun_trace_module, "format_input_messages", lambda _: "in") monkeypatch.setattr(aliyun_trace_module, "format_output_messages", lambda _: "out") - trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[]) + trace_metadata = _make_trace_metadata() node_execution = MagicMock(spec=WorkflowNodeExecution) node_execution.id = "node-id" node_execution.title = "llm" @@ -594,7 +623,7 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest. status = Status(StatusCode.OK) monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status) - trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[]) + trace_metadata = _make_trace_metadata() # CASE 1: With message_id trace_info = _make_workflow_trace_info( @@ -602,9 +631,11 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest. ) trace_instance.add_workflow_span(trace_info, trace_metadata) - assert len(trace_instance.trace_client.added_spans) == 2 - message_span = trace_instance.trace_client.added_spans[0] - workflow_span = trace_instance.trace_client.added_spans[1] + client = _recording_trace_client(trace_instance) + spans = _recorded_span_data(trace_instance) + assert len(spans) == 2 + message_span = spans[0] + workflow_span = spans[1] assert message_span.name == "message" assert message_span.span_kind == SpanKind.SERVER @@ -614,13 +645,14 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest. assert workflow_span.span_kind == SpanKind.INTERNAL assert workflow_span.parent_span_id == 20 - trace_instance.trace_client.added_spans.clear() + client.added_spans.clear() # CASE 2: Without message_id trace_info_no_msg = _make_workflow_trace_info(message_id=None) trace_instance.add_workflow_span(trace_info_no_msg, trace_metadata) - assert len(trace_instance.trace_client.added_spans) == 1 - span = trace_instance.trace_client.added_spans[0] + spans = _recorded_span_data(trace_instance) + assert len(spans) == 1 + span = spans[0] assert span.name == "workflow" assert span.span_kind == SpanKind.SERVER assert span.parent_span_id is None @@ -641,7 +673,8 @@ def test_suggested_question_trace(trace_instance: AliyunDataTrace, monkeypatch: trace_info = _make_suggested_question_trace_info(suggested_question=["how?"]) trace_instance.suggested_question_trace(trace_info) - assert len(trace_instance.trace_client.added_spans) == 1 - span = trace_instance.trace_client.added_spans[0] + spans = _recorded_span_data(trace_instance) + assert len(spans) == 1 + span = spans[0] assert span.name == "suggested_question" assert span.attributes[GEN_AI_COMPLETION] == '["how?"]' diff --git a/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace_utils.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace_utils.py index a9e7b80c2a..1b97746dea 100644 --- a/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace_utils.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace_utils.py @@ -1,4 +1,6 @@ import json +from collections.abc import Mapping +from typing import Any, cast from unittest.mock import MagicMock from dify_trace_aliyun.entities.semconv import ( @@ -170,7 +172,7 @@ def test_create_common_span_attributes(): def test_format_retrieval_documents(): # Not a list - assert format_retrieval_documents("not a list") == [] + assert format_retrieval_documents(cast(list[object], "not a list")) == [] # Valid list docs = [ @@ -211,7 +213,7 @@ def test_format_retrieval_documents(): def test_format_input_messages(): # Not a dict - assert format_input_messages(None) == serialize_json_data([]) + assert format_input_messages(cast(Mapping[str, Any], None)) == serialize_json_data([]) # No prompts assert format_input_messages({}) == serialize_json_data([]) @@ -244,7 +246,7 @@ def test_format_input_messages(): def test_format_output_messages(): # Not a dict - assert format_output_messages(None) == serialize_json_data([]) + assert format_output_messages(cast(Mapping[str, Any], None)) == serialize_json_data([]) # No text assert format_output_messages({"finish_reason": "stop"}) == serialize_json_data([]) diff --git a/api/providers/trace/trace-aliyun/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-aliyun/tests/unit_tests/test_config_entity.py index 1b24ee7421..8068ee1328 100644 --- a/api/providers/trace/trace-aliyun/tests/unit_tests/test_config_entity.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/test_config_entity.py @@ -25,13 +25,13 @@ class TestAliyunConfig: def test_missing_required_fields(self): """Test that required fields are enforced""" with pytest.raises(ValidationError): - AliyunConfig() + AliyunConfig.model_validate({}) with pytest.raises(ValidationError): - AliyunConfig(license_key="test_license") + AliyunConfig.model_validate({"license_key": "test_license"}) with pytest.raises(ValidationError): - AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") + AliyunConfig.model_validate({"endpoint": "https://tracing-analysis-dc-hz.aliyuncs.com"}) def test_app_name_validation_empty(self): """Test app_name validation with empty value""" diff --git a/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py index b0691a87ea..e9ecc2e083 100644 --- a/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py +++ b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py @@ -1,4 +1,5 @@ from datetime import UTC, datetime, timedelta +from typing import cast from unittest.mock import MagicMock, patch import pytest @@ -129,7 +130,7 @@ def test_set_span_status(): return "SilentErrorRepr" span.reset_mock() - set_span_status(span, SilentError()) + set_span_status(span, cast(Exception | str | None, SilentError())) assert span.add_event.call_args[1]["attributes"][OTELSpanAttributes.EXCEPTION_MESSAGE] == "SilentErrorRepr" diff --git a/api/providers/trace/trace-langfuse/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-langfuse/tests/unit_tests/test_config_entity.py index 103d888eef..0c3c3fc81e 100644 --- a/api/providers/trace/trace-langfuse/tests/unit_tests/test_config_entity.py +++ b/api/providers/trace/trace-langfuse/tests/unit_tests/test_config_entity.py @@ -28,13 +28,13 @@ class TestLangfuseConfig: def test_missing_required_fields(self): """Test that required fields are enforced""" with pytest.raises(ValidationError): - LangfuseConfig() + LangfuseConfig.model_validate({}) with pytest.raises(ValidationError): - LangfuseConfig(public_key="public") + LangfuseConfig.model_validate({"public_key": "public"}) with pytest.raises(ValidationError): - LangfuseConfig(secret_key="secret") + LangfuseConfig.model_validate({"secret_key": "secret"}) def test_host_validation_empty(self): """Test host validation with empty value""" diff --git a/api/providers/trace/trace-langfuse/tests/unit_tests/test_langfuse_trace.py b/api/providers/trace/trace-langfuse/tests/unit_tests/test_langfuse_trace.py index 0340ffb669..82d69b6180 100644 --- a/api/providers/trace/trace-langfuse/tests/unit_tests/test_langfuse_trace.py +++ b/api/providers/trace/trace-langfuse/tests/unit_tests/test_langfuse_trace.py @@ -2,6 +2,7 @@ from datetime import datetime, timedelta from types import SimpleNamespace +from typing import cast from unittest.mock import MagicMock, patch from dify_trace_langfuse.config import LangfuseConfig @@ -134,4 +135,4 @@ class TestLangFuseDataTraceCompletionStartTime: assert trace._get_completion_start_time(start_time, None) is None assert trace._get_completion_start_time(start_time, -1) is None - assert trace._get_completion_start_time(start_time, "invalid") is None + assert trace._get_completion_start_time(start_time, cast(float | int | None, "invalid")) is None diff --git a/api/providers/trace/trace-langsmith/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-langsmith/tests/unit_tests/test_config_entity.py index 37efaf69cf..bd226c9f1a 100644 --- a/api/providers/trace/trace-langsmith/tests/unit_tests/test_config_entity.py +++ b/api/providers/trace/trace-langsmith/tests/unit_tests/test_config_entity.py @@ -21,13 +21,13 @@ class TestLangSmithConfig: def test_missing_required_fields(self): """Test that required fields are enforced""" with pytest.raises(ValidationError): - LangSmithConfig() + LangSmithConfig.model_validate({}) with pytest.raises(ValidationError): - LangSmithConfig(api_key="key") + LangSmithConfig.model_validate({"api_key": "key"}) with pytest.raises(ValidationError): - LangSmithConfig(project="project") + LangSmithConfig.model_validate({"project": "project"}) def test_endpoint_validation_https_only(self): """Test endpoint validation only allows HTTPS""" diff --git a/api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py b/api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py index 20211456e3..46c9750a5d 100644 --- a/api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py +++ b/api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py @@ -599,7 +599,6 @@ class TestMessageTrace: span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" - mock_db.session.query.return_value.where.return_value.first.return_value = None trace_instance.message_trace(_make_message_trace_info()) mock_tracing["start"].assert_called_once() @@ -609,7 +608,6 @@ class TestMessageTrace: span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" - mock_db.session.query.return_value.where.return_value.first.return_value = None trace_info = _make_message_trace_info(error="something broke") trace_instance.message_trace(trace_info) @@ -620,7 +618,6 @@ class TestMessageTrace: span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" - mock_db.session.query.return_value.where.return_value.first.return_value = None monkeypatch.setenv("FILES_URL", "http://files.test") file_data = SimpleNamespace(url="path/to/file.png") @@ -638,7 +635,6 @@ class TestMessageTrace: span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" - mock_db.session.query.return_value.where.return_value.first.return_value = None trace_info = _make_message_trace_info(file_list=None, message_file_data=None) trace_instance.message_trace(trace_info) @@ -651,7 +647,6 @@ class TestMessageTrace: end_user = MagicMock() end_user.session_id = "session-xyz" - mock_db.session.query.return_value.where.return_value.first.return_value = end_user trace_info = _make_message_trace_info( metadata={"from_end_user_id": "eu-1", "conversation_id": "c1"}, @@ -664,7 +659,6 @@ class TestMessageTrace: span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" - mock_db.session.query.return_value.where.return_value.first.return_value = None trace_info = _make_message_trace_info( metadata={"from_account_id": "acc-1"}, diff --git a/api/providers/trace/trace-opik/tests/unit_tests/test_opik_trace.py b/api/providers/trace/trace-opik/tests/unit_tests/test_opik_trace.py index fba290f5b8..2e0796c291 100644 --- a/api/providers/trace/trace-opik/tests/unit_tests/test_opik_trace.py +++ b/api/providers/trace/trace-opik/tests/unit_tests/test_opik_trace.py @@ -12,6 +12,7 @@ from __future__ import annotations import uuid from datetime import datetime +from typing import cast from unittest.mock import MagicMock, patch from dify_trace_opik.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid @@ -69,6 +70,14 @@ def _make_opik_trace_instance() -> OpikDataTrace: return instance +def _add_trace_mock(instance: OpikDataTrace) -> MagicMock: + return cast(MagicMock, instance.add_trace) + + +def _add_span_mock(instance: OpikDataTrace) -> MagicMock: + return cast(MagicMock, instance.add_span) + + # --------------------------------------------------------------------------- # _seed_to_uuid4 # --------------------------------------------------------------------------- @@ -155,21 +164,21 @@ class TestWorkflowTraceWithoutMessageId: def test_root_span_is_created(self): trace_info = _make_workflow_trace_info(message_id=None) instance = self._run(trace_info) - assert instance.add_span.called + assert _add_span_mock(instance).called def test_root_span_id_matches_expected(self): trace_info = _make_workflow_trace_info(message_id=None) instance = self._run(trace_info) expected = self._expected_root_span_id(trace_info) - root_span_kwargs = instance.add_span.call_args_list[0][0][0] + root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0] assert root_span_kwargs["id"] == expected def test_root_span_has_no_parent(self): trace_info = _make_workflow_trace_info(message_id=None) instance = self._run(trace_info) - root_span_kwargs = instance.add_span.call_args_list[0][0][0] + root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0] assert root_span_kwargs["parent_span_id"] is None def test_trace_name_is_workflow_trace(self): @@ -177,21 +186,21 @@ class TestWorkflowTraceWithoutMessageId: trace_info = _make_workflow_trace_info(message_id=None) instance = self._run(trace_info) - trace_kwargs = instance.add_trace.call_args_list[0][0][0] + trace_kwargs = _add_trace_mock(instance).call_args_list[0][0][0] assert trace_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE def test_root_span_name_is_workflow_trace(self): trace_info = _make_workflow_trace_info(message_id=None) instance = self._run(trace_info) - root_span_kwargs = instance.add_span.call_args_list[0][0][0] + root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0] assert root_span_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE def test_root_span_has_workflow_tag(self): trace_info = _make_workflow_trace_info(message_id=None) instance = self._run(trace_info) - root_span_kwargs = instance.add_span.call_args_list[0][0][0] + root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0] assert "workflow" in root_span_kwargs["tags"] def test_node_execution_spans_are_parented_to_root(self): @@ -214,8 +223,9 @@ class TestWorkflowTraceWithoutMessageId: instance = self._run(trace_info, node_executions=[node_exec]) # call_args_list[0] = root span, [1] = node execution span - assert instance.add_span.call_count == 2 - node_span_kwargs = instance.add_span.call_args_list[1][0][0] + add_span = _add_span_mock(instance) + assert add_span.call_count == 2 + node_span_kwargs = add_span.call_args_list[1][0][0] assert node_span_kwargs["parent_span_id"] == expected_root_span_id def test_node_span_not_parented_to_workflow_app_log_id(self): @@ -240,7 +250,7 @@ class TestWorkflowTraceWithoutMessageId: instance = self._run(trace_info, node_executions=[node_exec]) old_parent_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_app_log_id) - node_span_kwargs = instance.add_span.call_args_list[1][0][0] + node_span_kwargs = _add_span_mock(instance).call_args_list[1][0][0] assert node_span_kwargs["parent_span_id"] != old_parent_id def test_root_span_id_differs_from_trace_id(self): @@ -283,7 +293,7 @@ class TestWorkflowTraceWithMessageId: trace_info = _make_workflow_trace_info(message_id=self._MESSAGE_ID) instance = self._run(trace_info) - trace_kwargs = instance.add_trace.call_args_list[0][0][0] + trace_kwargs = _add_trace_mock(instance).call_args_list[0][0][0] assert trace_kwargs["name"] == TraceTaskName.MESSAGE_TRACE def test_root_span_uses_workflow_run_id_directly(self): @@ -292,7 +302,7 @@ class TestWorkflowTraceWithMessageId: instance = self._run(trace_info) expected_root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id) - root_span_kwargs = instance.add_span.call_args_list[0][0][0] + root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0] assert root_span_kwargs["id"] == expected_root_span_id def test_root_span_id_differs_from_no_message_id_case(self): @@ -326,5 +336,5 @@ class TestWorkflowTraceWithMessageId: instance = self._run(trace_info, node_executions=[node_exec]) - node_span_kwargs = instance.add_span.call_args_list[1][0][0] + node_span_kwargs = _add_span_mock(instance).call_args_list[1][0][0] assert node_span_kwargs["parent_span_id"] == expected_root_span_id diff --git a/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_client.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_client.py index 1e656e2462..3cd918f408 100644 --- a/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_client.py +++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_client.py @@ -5,6 +5,7 @@ from __future__ import annotations import sys import types from types import SimpleNamespace +from typing import Any, TypedDict, cast from unittest.mock import MagicMock import pytest @@ -12,7 +13,7 @@ from dify_trace_tencent import client as client_module from dify_trace_tencent.client import TencentTraceClient, _get_opentelemetry_sdk_version from dify_trace_tencent.entities.tencent_trace_entity import SpanData from opentelemetry.sdk.trace import Event -from opentelemetry.trace import Status, StatusCode +from opentelemetry.trace import SpanContext, Status, StatusCode, TraceFlags metric_reader_instances: list[DummyMetricReader] = [] meter_provider_instances: list[DummyMeterProvider] = [] @@ -80,6 +81,16 @@ class DummyJsonMetricExporterNoTemporality: self.kwargs = kwargs +class PatchedCoreComponents(TypedDict): + span_exporter: MagicMock + span_processor: MagicMock + tracer: MagicMock + span: MagicMock + tracer_provider: MagicMock + logger: MagicMock + trace_api: Any + + def _add_stub_modules(monkeypatch: pytest.MonkeyPatch) -> None: """Drop fake metric modules into sys.modules so the client imports resolve.""" @@ -118,7 +129,7 @@ def stub_metric_modules(monkeypatch: pytest.MonkeyPatch) -> None: @pytest.fixture(autouse=True) -def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]: +def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> PatchedCoreComponents: span_exporter = MagicMock(name="span_exporter") monkeypatch.setattr(client_module, "OTLPSpanExporter", MagicMock(return_value=span_exporter)) @@ -168,6 +179,15 @@ def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]: } +def _make_span_context(trace_id: int = 1, span_id: int = 2) -> SpanContext: + return SpanContext( + trace_id=trace_id, + span_id=span_id, + is_remote=False, + trace_flags=TraceFlags(TraceFlags.SAMPLED), + ) + + def _build_client() -> TencentTraceClient: return TencentTraceClient( service_name="service", @@ -208,7 +228,7 @@ def test_resolve_grpc_target_parsable_variants(endpoint: str, expected: tuple[st def test_resolve_grpc_target_handles_errors() -> None: - assert TencentTraceClient._resolve_grpc_target(123) == ("localhost:4317", True, "localhost", 4317) + assert TencentTraceClient._resolve_grpc_target(cast(str, 123)) == ("localhost:4317", True, "localhost", 4317) @pytest.mark.parametrize( @@ -248,7 +268,7 @@ def test_record_methods_skip_when_histogram_missing() -> None: client.record_trace_duration(0.5) -def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str, object]) -> None: +def test_record_llm_duration_handles_exceptions(patch_core_components: PatchedCoreComponents) -> None: client = _build_client() client.hist_llm_duration = MagicMock(name="hist_llm_duration") client.hist_llm_duration.record.side_effect = RuntimeError("boom") @@ -258,10 +278,11 @@ def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str, logger.debug.assert_called() -def test_create_and_export_span_sets_attributes(patch_core_components: dict[str, object]) -> None: +def test_create_and_export_span_sets_attributes(patch_core_components: PatchedCoreComponents) -> None: client = _build_client() span = patch_core_components["span"] - span.get_span_context.return_value = "ctx" + ctx = _make_span_context(span_id=2) + span.get_span_context.return_value = ctx data = SpanData( trace_id=1, @@ -280,14 +301,15 @@ def test_create_and_export_span_sets_attributes(patch_core_components: dict[str, span.add_event.assert_called_once() span.set_status.assert_called_once() span.end.assert_called_once_with(end_time=20) - assert client.span_contexts[2] == "ctx" + assert client.span_contexts[2] == ctx -def test_create_and_export_span_uses_parent_context(patch_core_components: dict[str, object]) -> None: +def test_create_and_export_span_uses_parent_context(patch_core_components: PatchedCoreComponents) -> None: client = _build_client() - client.span_contexts[10] = "existing" + existing_context = _make_span_context(span_id=10) + client.span_contexts[10] = existing_context span = patch_core_components["span"] - span.get_span_context.return_value = "child" + span.get_span_context.return_value = _make_span_context(span_id=11) data = SpanData( trace_id=1, @@ -302,14 +324,14 @@ def test_create_and_export_span_uses_parent_context(patch_core_components: dict[ client._create_and_export_span(data) trace_api = patch_core_components["trace_api"] - trace_api.NonRecordingSpan.assert_called_once_with("existing") + trace_api.NonRecordingSpan.assert_called_once_with(existing_context) trace_api.set_span_in_context.assert_called_once() -def test_create_and_export_span_exception_logs_error(patch_core_components: dict[str, object]) -> None: +def test_create_and_export_span_exception_logs_error(patch_core_components: PatchedCoreComponents) -> None: client = _build_client() span = patch_core_components["span"] - span.get_span_context.return_value = "ctx" + span.get_span_context.return_value = _make_span_context(span_id=2) client.tracer.start_span.side_effect = RuntimeError("boom") client._create_and_export_span( @@ -385,7 +407,7 @@ def test_get_project_url() -> None: assert client.get_project_url() == "https://console.cloud.tencent.com/apm" -def test_shutdown_flushes_all_components(patch_core_components: dict[str, object]) -> None: +def test_shutdown_flushes_all_components(patch_core_components: PatchedCoreComponents) -> None: client = _build_client() span_processor = patch_core_components["span_processor"] tracer_provider = patch_core_components["tracer_provider"] @@ -401,10 +423,11 @@ def test_shutdown_flushes_all_components(patch_core_components: dict[str, object metric_reader.shutdown.assert_called_once() -def test_shutdown_logs_when_meter_provider_fails(patch_core_components: dict[str, object]) -> None: +def test_shutdown_logs_when_meter_provider_fails(patch_core_components: PatchedCoreComponents) -> None: client = _build_client() meter_provider = meter_provider_instances[-1] meter_provider.shutdown.side_effect = RuntimeError("boom") + assert client.metric_reader is not None client.metric_reader.shutdown.side_effect = RuntimeError("boom") client.shutdown() @@ -433,7 +456,7 @@ def test_metrics_initialization_failure_sets_histogram_attributes(monkeypatch: p assert client.metric_reader is None -def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: dict[str, object]) -> None: +def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: PatchedCoreComponents) -> None: client = _build_client() monkeypatch.setattr(client, "_create_and_export_span", MagicMock(side_effect=RuntimeError("boom"))) @@ -454,10 +477,10 @@ def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_com logger.exception.assert_called_once() -def test_create_and_export_span_converts_attribute_types(patch_core_components: dict[str, object]) -> None: +def test_create_and_export_span_converts_attribute_types(patch_core_components: PatchedCoreComponents) -> None: client = _build_client() span = patch_core_components["span"] - span.get_span_context.return_value = "ctx" + span.get_span_context.return_value = _make_span_context(span_id=2) data = SpanData.model_construct( trace_id=1, @@ -485,7 +508,7 @@ def test_record_llm_duration_converts_attributes() -> None: hist_mock = MagicMock(name="hist_llm_duration") client.hist_llm_duration = hist_mock - client.record_llm_duration(0.3, {"foo": object(), "bar": 2}) + client.record_llm_duration(0.3, cast(dict[str, str], {"foo": object(), "bar": 2})) _, attrs = hist_mock.record.call_args.args assert isinstance(attrs["foo"], str) assert attrs["bar"] == 2 @@ -496,7 +519,7 @@ def test_record_trace_duration_converts_attributes() -> None: hist_mock = MagicMock(name="hist_trace_duration") client.hist_trace_duration = hist_mock - client.record_trace_duration(1.0, {"meta": object(), "ok": True}) + client.record_trace_duration(1.0, cast(dict[str, str], {"meta": object(), "ok": True})) _, attrs = hist_mock.record.call_args.args assert isinstance(attrs["meta"], str) assert attrs["ok"] is True @@ -512,7 +535,7 @@ def test_record_trace_duration_converts_attributes() -> None: ], ) def test_record_methods_handle_exceptions( - method: str, attr_name: str, args: tuple[object, ...], patch_core_components: dict[str, object] + method: str, attr_name: str, args: tuple[object, ...], patch_core_components: PatchedCoreComponents ) -> None: client = _build_client() hist_mock = MagicMock(name=attr_name) @@ -527,35 +550,38 @@ def test_record_methods_handle_exceptions( def test_metrics_initializes_grpc_metric_exporter() -> None: client = _build_client() metric_reader = metric_reader_instances[-1] + exporter = cast(DummyGrpcMetricExporter, metric_reader.exporter) - assert isinstance(metric_reader.exporter, DummyGrpcMetricExporter) + assert isinstance(exporter, DummyGrpcMetricExporter) assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000 - assert metric_reader.exporter.kwargs["endpoint"] == "trace.example.com:4317" - assert metric_reader.exporter.kwargs["insecure"] is False - assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token" + assert exporter.kwargs["endpoint"] == "trace.example.com:4317" + assert exporter.kwargs["insecure"] is False + assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token" def test_metrics_initializes_http_protobuf_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/protobuf") client = _build_client() metric_reader = metric_reader_instances[-1] + exporter = cast(DummyHttpMetricExporter, metric_reader.exporter) - assert isinstance(metric_reader.exporter, DummyHttpMetricExporter) + assert isinstance(exporter, DummyHttpMetricExporter) assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000 - assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint - assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token" + assert exporter.kwargs["endpoint"] == client.endpoint + assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token" def test_metrics_initializes_http_json_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json") client = _build_client() metric_reader = metric_reader_instances[-1] + exporter = cast(DummyJsonMetricExporter, metric_reader.exporter) - assert isinstance(metric_reader.exporter, DummyJsonMetricExporter) + assert isinstance(exporter, DummyJsonMetricExporter) assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000 - assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint - assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token" - assert "preferred_temporality" in metric_reader.exporter.kwargs + assert exporter.kwargs["endpoint"] == client.endpoint + assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token" + assert "preferred_temporality" in exporter.kwargs def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkeypatch: pytest.MonkeyPatch) -> None: @@ -564,9 +590,10 @@ def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkey monkeypatch.setattr(exporter_module, "OTLPMetricExporter", DummyJsonMetricExporterNoTemporality) _ = _build_client() metric_reader = metric_reader_instances[-1] + exporter = cast(DummyJsonMetricExporterNoTemporality, metric_reader.exporter) - assert isinstance(metric_reader.exporter, DummyJsonMetricExporterNoTemporality) - assert "preferred_temporality" not in metric_reader.exporter.kwargs + assert isinstance(exporter, DummyJsonMetricExporterNoTemporality) + assert "preferred_temporality" not in exporter.kwargs def test_metrics_http_json_uses_http_fallback_when_no_json_exporter(monkeypatch: pytest.MonkeyPatch) -> None: diff --git a/api/providers/trace/trace-weave/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-weave/tests/unit_tests/test_config_entity.py index eeb1fe1d87..377c768198 100644 --- a/api/providers/trace/trace-weave/tests/unit_tests/test_config_entity.py +++ b/api/providers/trace/trace-weave/tests/unit_tests/test_config_entity.py @@ -31,13 +31,13 @@ class TestWeaveConfig: def test_missing_required_fields(self): """Test that required fields are enforced""" with pytest.raises(ValidationError): - WeaveConfig() + WeaveConfig.model_validate({}) with pytest.raises(ValidationError): - WeaveConfig(api_key="key") + WeaveConfig.model_validate({"api_key": "key"}) with pytest.raises(ValidationError): - WeaveConfig(project="project") + WeaveConfig.model_validate({"project": "project"}) def test_endpoint_validation_https_only(self): """Test endpoint validation only allows HTTPS""" diff --git a/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py index d10e5ed13c..3b5e822b90 100644 --- a/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py +++ b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py @@ -171,35 +171,13 @@ class TestChatMessageApiPermissions: parent_message_id=None, ) - class MockQuery: - def __init__(self, model): - self.model = model - - def where(self, *args, **kwargs): - return self - - def first(self): - if getattr(self.model, "__name__", "") == "Conversation": - return mock_conversation - return None - - def order_by(self, *args, **kwargs): - return self - - def limit(self, *_): - return self - - def all(self): - if getattr(self.model, "__name__", "") == "Message": - return [mock_message] - return [] - mock_session = mock.Mock() - mock_session.query.side_effect = MockQuery - mock_session.scalar.return_value = False + mock_session.scalar.return_value = mock_conversation + mock_session.scalars.return_value.all.return_value = [mock_message] monkeypatch.setattr(message_api, "db", SimpleNamespace(session=mock_session)) monkeypatch.setattr(message_api, "current_user", mock_account) + monkeypatch.setattr(message_api, "attach_message_extra_contents", mock.Mock()) class DummyPagination: def __init__(self, data, limit, has_more): diff --git a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py index f14b2c0ae5..635cfee2da 100644 --- a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py @@ -24,7 +24,6 @@ def _patch_wraps(): patch("controllers.console.wraps.dify_config", dify_settings), patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features), ): - mock_db.session.query.return_value.first.return_value = MagicMock() yield diff --git a/api/tests/test_containers_integration_tests/services/test_feedback_service.py b/api/tests/test_containers_integration_tests/services/test_feedback_service.py index d82933ccb9..3dcd6586e2 100644 --- a/api/tests/test_containers_integration_tests/services/test_feedback_service.py +++ b/api/tests/test_containers_integration_tests/services/test_feedback_service.py @@ -13,6 +13,12 @@ from models.model import App, Conversation, Message from services.feedback_service import FeedbackService +def _execute_result(rows): + result = mock.Mock() + result.all.return_value = rows + return result + + class TestFeedbackService: """Test FeedbackService methods.""" @@ -81,25 +87,17 @@ class TestFeedbackService: def test_export_feedbacks_csv_format(self, mock_db_session, sample_data): """Test exporting feedback data in CSV format.""" - - # Setup mock query result - mock_query = mock.Mock() - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [ - ( - sample_data["user_feedback"], - sample_data["message"], - sample_data["conversation"], - sample_data["app"], - sample_data["user_feedback"].from_account, - ) - ] - - mock_db_session.execute.return_value = mock_query + mock_db_session.execute.return_value = _execute_result( + [ + ( + sample_data["user_feedback"], + sample_data["message"], + sample_data["conversation"], + sample_data["app"], + sample_data["user_feedback"].from_account, + ) + ] + ) # Test CSV export result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv") @@ -120,25 +118,17 @@ class TestFeedbackService: def test_export_feedbacks_json_format(self, mock_db_session, sample_data): """Test exporting feedback data in JSON format.""" - - # Setup mock query result - mock_query = mock.Mock() - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [ - ( - sample_data["admin_feedback"], - sample_data["message"], - sample_data["conversation"], - sample_data["app"], - sample_data["admin_feedback"].from_account, - ) - ] - - mock_db_session.execute.return_value = mock_query + mock_db_session.execute.return_value = _execute_result( + [ + ( + sample_data["admin_feedback"], + sample_data["message"], + sample_data["conversation"], + sample_data["app"], + sample_data["admin_feedback"].from_account, + ) + ] + ) # Test JSON export result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json") @@ -157,25 +147,17 @@ class TestFeedbackService: def test_export_feedbacks_with_filters(self, mock_db_session, sample_data): """Test exporting feedback with various filters.""" - - # Setup mock query result - mock_query = mock.Mock() - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [ - ( - sample_data["admin_feedback"], - sample_data["message"], - sample_data["conversation"], - sample_data["app"], - sample_data["admin_feedback"].from_account, - ) - ] - - mock_db_session.execute.return_value = mock_query + mock_db_session.execute.return_value = _execute_result( + [ + ( + sample_data["admin_feedback"], + sample_data["message"], + sample_data["conversation"], + sample_data["app"], + sample_data["admin_feedback"].from_account, + ) + ] + ) # Test with filters result = FeedbackService.export_feedbacks( @@ -193,17 +175,7 @@ class TestFeedbackService: def test_export_feedbacks_no_data(self, mock_db_session, sample_data): """Test exporting feedback when no data exists.""" - - # Setup mock query result with no data - mock_query = mock.Mock() - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - mock_db_session.execute.return_value = mock_query + mock_db_session.execute.return_value = _execute_result([]) result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv") @@ -251,24 +223,17 @@ class TestFeedbackService: created_at=datetime(2024, 1, 1, 10, 0, 0), ) - # Setup mock query result - mock_query = mock.Mock() - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [ - ( - sample_data["user_feedback"], - long_message, - sample_data["conversation"], - sample_data["app"], - sample_data["user_feedback"].from_account, - ) - ] - - mock_db_session.execute.return_value = mock_query + mock_db_session.execute.return_value = _execute_result( + [ + ( + sample_data["user_feedback"], + long_message, + sample_data["conversation"], + sample_data["app"], + sample_data["user_feedback"].from_account, + ) + ] + ) # Test export result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json") @@ -309,24 +274,17 @@ class TestFeedbackService: created_at=datetime(2024, 1, 1, 10, 0, 0), ) - # Setup mock query result - mock_query = mock.Mock() - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [ - ( - chinese_feedback, - chinese_message, - sample_data["conversation"], - sample_data["app"], - None, # No account for user feedback - ) - ] - - mock_db_session.execute.return_value = mock_query + mock_db_session.execute.return_value = _execute_result( + [ + ( + chinese_feedback, + chinese_message, + sample_data["conversation"], + sample_data["app"], + None, + ) + ] + ) # Test export result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv") @@ -339,32 +297,24 @@ class TestFeedbackService: def test_export_feedbacks_emoji_ratings(self, mock_db_session, sample_data): """Test that rating emojis are properly formatted in export.""" - - # Setup mock query result with both like and dislike feedback - mock_query = mock.Mock() - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [ - ( - sample_data["user_feedback"], - sample_data["message"], - sample_data["conversation"], - sample_data["app"], - sample_data["user_feedback"].from_account, - ), - ( - sample_data["admin_feedback"], - sample_data["message"], - sample_data["conversation"], - sample_data["app"], - sample_data["admin_feedback"].from_account, - ), - ] - - mock_db_session.execute.return_value = mock_query + mock_db_session.execute.return_value = _execute_result( + [ + ( + sample_data["user_feedback"], + sample_data["message"], + sample_data["conversation"], + sample_data["app"], + sample_data["user_feedback"].from_account, + ), + ( + sample_data["admin_feedback"], + sample_data["message"], + sample_data["conversation"], + sample_data["app"], + sample_data["admin_feedback"].from_account, + ), + ] + ) # Test export result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json") diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py index 55873b06a8..7174530e97 100644 --- a/api/tests/unit_tests/conftest.py +++ b/api/tests/unit_tests/conftest.py @@ -121,33 +121,32 @@ def _configure_session_factory(_unit_test_engine): configure_session_factory(_unit_test_engine, expire_on_commit=False) -def setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account): +def setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_owner): """ - Helper to set up the mock DB execute chain for tenant/account authentication. + Helper to stub the tenant-owner execute result for service API app authentication. - This configures the mock to return (tenant, account) for the - db.session.execute(select(...).join().join().where()).one_or_none() - query used by validate_app_token decorator. + The validate_app_token decorator currently resolves the active tenant owner + via db.session.execute(select(Tenant, Account)...).one_or_none(). Args: mock_db: The mocked db object mock_tenant: Mock tenant object to return - mock_account: Mock account object to return + mock_owner: Mock owner object to return from the execute result """ - mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_account) + mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_owner) -def setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta): +def setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_tenant_account_join): """ - Helper to set up the mock DB execute chain for dataset tenant authentication. + Helper to stub the tenant-owner execute result for dataset token authentication. - This configures the mock to return (tenant, tenant_account) for the - db.session.execute(select(...).where().where().where().where()).one_or_none() - query used by validate_dataset_token decorator. + The validate_dataset_token decorator currently resolves the owner mapping via + db.session.execute(select(Tenant, TenantAccountJoin)...).one_or_none(), and + then loads the Account separately via db.session.get(...). Args: mock_db: The mocked db object mock_tenant: Mock tenant object to return - mock_ta: Mock tenant account object to return + mock_tenant_account_join: Mock tenant-account join object to return """ - mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta) + mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_tenant_account_join) diff --git a/api/tests/unit_tests/controllers/console/app/test_annotation_security.py b/api/tests/unit_tests/controllers/console/app/test_annotation_security.py index 9f1ff9b40f..bfa4048191 100644 --- a/api/tests/unit_tests/controllers/console/app/test_annotation_security.py +++ b/api/tests/unit_tests/controllers/console/app/test_annotation_security.py @@ -208,8 +208,6 @@ class TestAnnotationImportServiceValidation: file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") - mock_db_session.query.return_value.where.return_value.first.return_value = mock_app - with patch("services.annotation_service.current_account_with_tenant") as mock_auth: mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id") @@ -230,8 +228,6 @@ class TestAnnotationImportServiceValidation: file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") - mock_db_session.query.return_value.where.return_value.first.return_value = mock_app - with patch("services.annotation_service.current_account_with_tenant") as mock_auth: mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id") @@ -248,8 +244,6 @@ class TestAnnotationImportServiceValidation: csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff' file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") - mock_db_session.query.return_value.where.return_value.first.return_value = mock_app - with ( patch("services.annotation_service.current_account_with_tenant") as mock_auth, patch("services.annotation_service.pd.read_csv", side_effect=ParserError("malformed CSV")), @@ -269,8 +263,6 @@ class TestAnnotationImportServiceValidation: file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") - mock_db_session.query.return_value.where.return_value.first.return_value = mock_app - with patch("services.annotation_service.current_account_with_tenant") as mock_auth: mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id") diff --git a/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py b/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py index cb4fe40944..17bee94c52 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py +++ b/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py @@ -43,7 +43,6 @@ class TestAuthenticationSecurity: mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.") - mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists mock_features.return_value.is_allow_register = True # Act @@ -76,7 +75,6 @@ class TestAuthenticationSecurity: mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Wrong password") - mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists # Act with self.app.test_request_context( @@ -109,7 +107,6 @@ class TestAuthenticationSecurity: mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.") - mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists mock_features.return_value.is_allow_register = False # Act @@ -135,7 +132,6 @@ class TestAuthenticationSecurity: def test_reset_password_with_existing_account(self, mock_send_email, mock_get_user, mock_features, mock_db): """Test that reset password returns success with token for existing accounts.""" # Mock the setup check - mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists # Test with existing account mock_get_user.return_value = MagicMock(email="existing@example.com") diff --git a/api/tests/unit_tests/controllers/console/auth/test_email_verification.py b/api/tests/unit_tests/controllers/console/auth/test_email_verification.py index 9929a71120..b7bc73da5f 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_email_verification.py +++ b/api/tests/unit_tests/controllers/console/auth/test_email_verification.py @@ -65,7 +65,6 @@ class TestEmailCodeLoginSendEmailApi: - IP rate limiting is checked """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_user.return_value = mock_account mock_send_email.return_value = "email_token_123" @@ -98,7 +97,6 @@ class TestEmailCodeLoginSendEmailApi: - Registration is allowed by system features """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_user.return_value = None mock_get_features.return_value.is_allow_register = True @@ -130,7 +128,6 @@ class TestEmailCodeLoginSendEmailApi: - Registration is blocked by system features """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_user.return_value = None mock_get_features.return_value.is_allow_register = False @@ -152,7 +149,6 @@ class TestEmailCodeLoginSendEmailApi: - Prevents spam and abuse """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = True # Act & Assert @@ -172,7 +168,6 @@ class TestEmailCodeLoginSendEmailApi: - AccountInFreezeError is raised for frozen accounts """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_user.side_effect = AccountRegisterError("Account frozen") @@ -213,7 +208,6 @@ class TestEmailCodeLoginSendEmailApi: - Defaults to en-US when not specified """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_user.return_value = mock_account mock_send_email.return_value = "token" @@ -286,7 +280,6 @@ class TestEmailCodeLoginApi: - User is logged in with token pair """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} mock_get_user.return_value = mock_account mock_get_tenants.return_value = [MagicMock()] @@ -335,7 +328,6 @@ class TestEmailCodeLoginApi: - User is logged in after account creation """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "newuser@example.com", "code": "123456"} mock_get_user.return_value = None mock_create_account.return_value = mock_account @@ -369,7 +361,6 @@ class TestEmailCodeLoginApi: - InvalidTokenError is raised for invalid/expired tokens """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = None # Act & Assert @@ -392,7 +383,6 @@ class TestEmailCodeLoginApi: - InvalidEmailError is raised when email doesn't match token """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "original@example.com", "code": "123456"} # Act & Assert @@ -415,7 +405,6 @@ class TestEmailCodeLoginApi: - EmailCodeError is raised for wrong verification code """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} # Act & Assert @@ -453,7 +442,6 @@ class TestEmailCodeLoginApi: - User is added as owner of new workspace """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} mock_get_user.return_value = mock_account mock_get_tenants.return_value = [] @@ -496,7 +484,6 @@ class TestEmailCodeLoginApi: - WorkspacesLimitExceeded is raised when limit reached """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} mock_get_user.return_value = mock_account mock_get_tenants.return_value = [] @@ -538,7 +525,6 @@ class TestEmailCodeLoginApi: - NotAllowedCreateWorkspace is raised when creation disabled """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} mock_get_user.return_value = mock_account mock_get_tenants.return_value = [] diff --git a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py index 0cf97da878..d089be8905 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py +++ b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py @@ -110,7 +110,6 @@ class TestLoginApi: - Rate limit is reset after successful login """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None mock_authenticate.return_value = mock_account @@ -162,7 +161,6 @@ class TestLoginApi: - Authentication proceeds with invitation token """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_invitation.return_value = {"data": {"email": "test@example.com"}} mock_authenticate.return_value = mock_account @@ -199,7 +197,6 @@ class TestLoginApi: - EmailPasswordLoginLimitError is raised when limit exceeded """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = True mock_get_invitation.return_value = None @@ -228,7 +225,6 @@ class TestLoginApi: - AccountInFreezeError is raised for frozen accounts """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_frozen.return_value = True # Act & Assert @@ -268,7 +264,6 @@ class TestLoginApi: - Generic error message prevents user enumeration """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None mock_authenticate.side_effect = AccountPasswordError("Invalid password") @@ -305,7 +300,6 @@ class TestLoginApi: - Login is prevented even with valid credentials """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None mock_authenticate.side_effect = AccountLoginError("Account is banned") @@ -351,7 +345,6 @@ class TestLoginApi: - User cannot login without an assigned workspace """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None mock_authenticate.return_value = mock_account @@ -383,7 +376,6 @@ class TestLoginApi: - Security check prevents invitation token abuse """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_invitation.return_value = {"data": {"email": "invited@example.com"}} @@ -425,7 +417,6 @@ class TestLoginApi: mock_token_pair, ): """Test that login retries with lowercase email when uppercase lookup fails.""" - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None mock_authenticate.side_effect = [AccountPasswordError("Invalid"), mock_account] @@ -459,7 +450,6 @@ class TestLoginApi: mock_db, app, ): - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"} mock_get_account.side_effect = Unauthorized("Account is banned.") @@ -513,7 +503,6 @@ class TestLogoutApi: - Success response is returned """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_current_account.return_value = (mock_account, MagicMock()) # Act @@ -539,7 +528,6 @@ class TestLogoutApi: - Success response is returned """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() # Create a mock anonymous user that will pass isinstance check anonymous_user = MagicMock() mock_flask_login.AnonymousUserMixin = type("AnonymousUserMixin", (), {}) diff --git a/api/tests/unit_tests/controllers/console/billing/test_billing.py b/api/tests/unit_tests/controllers/console/billing/test_billing.py index c80758c857..810f1b94fc 100644 --- a/api/tests/unit_tests/controllers/console/billing/test_billing.py +++ b/api/tests/unit_tests/controllers/console/billing/test_billing.py @@ -46,7 +46,6 @@ class TestPartnerTenants: patch("libs.login.dify_config.LOGIN_DISABLED", False), patch("libs.login.check_csrf_token") as mock_csrf, ): - mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists mock_csrf.return_value = None yield {"db": mock_db, "csrf": mock_csrf} diff --git a/api/tests/unit_tests/controllers/console/test_workspace_account.py b/api/tests/unit_tests/controllers/console/test_workspace_account.py index 26ff264f18..0b1a32581a 100644 --- a/api/tests/unit_tests/controllers/console/test_workspace_account.py +++ b/api/tests/unit_tests/controllers/console/test_workspace_account.py @@ -24,10 +24,6 @@ def app(): return app -def _mock_wraps_db(mock_db): - mock_db.session.query.return_value.first.return_value = MagicMock() - - def _build_account(email: str, account_id: str = "acc", tenant: object | None = None) -> Account: tenant_obj = tenant if tenant is not None else SimpleNamespace(id="tenant-id") account = Account(name=account_id, email=email) @@ -64,7 +60,6 @@ class TestChangeEmailSend: mock_db, app, ): - _mock_wraps_db(mock_db) mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_account = _build_account("current@example.com", "acc1") mock_current_account.return_value = (mock_account, None) @@ -117,7 +112,6 @@ class TestChangeEmailSend: """GHSA-4q3w-q5mc-45rq: a phase-1 token must not unlock the new-email send step.""" from controllers.console.auth.error import InvalidTokenError - _mock_wraps_db(mock_db) mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_account = _build_account("current@example.com", "acc1") mock_current_account.return_value = (mock_account, None) @@ -163,7 +157,6 @@ class TestChangeEmailValidity: mock_db, app, ): - _mock_wraps_db(mock_db) mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_account = _build_account("user@example.com", "acc2") mock_current_account.return_value = (mock_account, None) @@ -223,7 +216,6 @@ class TestChangeEmailValidity: mock_db, app, ): - _mock_wraps_db(mock_db) mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_current_account.return_value = (_build_account("old@example.com", "acc"), None) mock_is_rate_limit.return_value = False @@ -280,7 +272,6 @@ class TestChangeEmailValidity: """A token whose phase marker is a string but not a known transition must be rejected.""" from controllers.console.auth.error import InvalidTokenError - _mock_wraps_db(mock_db) mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_current_account.return_value = (_build_account("old@example.com", "acc"), None) mock_is_rate_limit.return_value = False @@ -330,7 +321,6 @@ class TestChangeEmailValidity: """A token minted without a phase marker (e.g. a hand-crafted token) must not validate.""" from controllers.console.auth.error import InvalidTokenError - _mock_wraps_db(mock_db) mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_current_account.return_value = (_build_account("old@example.com", "acc"), None) mock_is_rate_limit.return_value = False @@ -378,7 +368,6 @@ class TestChangeEmailReset: mock_db, app, ): - _mock_wraps_db(mock_db) mock_features.return_value = SimpleNamespace(enable_change_email=True) current_user = _build_account("old@example.com", "acc3") mock_current_account.return_value = (current_user, None) @@ -434,7 +423,6 @@ class TestChangeEmailReset: """GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset.""" from controllers.console.auth.error import InvalidTokenError - _mock_wraps_db(mock_db) mock_features.return_value = SimpleNamespace(enable_change_email=True) current_user = _build_account("old@example.com", "acc3") mock_current_account.return_value = (current_user, None) @@ -488,7 +476,6 @@ class TestChangeEmailReset: """A verified token for address A must not be replayed to change to address B.""" from controllers.console.auth.error import InvalidTokenError - _mock_wraps_db(mock_db) mock_features.return_value = SimpleNamespace(enable_change_email=True) current_user = _build_account("old@example.com", "acc3") mock_current_account.return_value = (current_user, None) @@ -561,7 +548,6 @@ class TestAccountDeletionFeedback: @patch("controllers.console.wraps.db") @patch("controllers.console.workspace.account.BillingService.update_account_deletion_feedback") def test_should_normalize_feedback_email(self, mock_update, mock_db, app): - _mock_wraps_db(mock_db) with app.test_request_context( "/account/delete/feedback", method="POST", @@ -578,7 +564,6 @@ class TestCheckEmailUnique: @patch("controllers.console.workspace.account.AccountService.check_email_unique") @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app): - _mock_wraps_db(mock_db) mock_is_freeze.return_value = False mock_check_unique.return_value = True diff --git a/api/tests/unit_tests/controllers/console/test_workspace_members.py b/api/tests/unit_tests/controllers/console/test_workspace_members.py index 239fec8430..811bf5b1e7 100644 --- a/api/tests/unit_tests/controllers/console/test_workspace_members.py +++ b/api/tests/unit_tests/controllers/console/test_workspace_members.py @@ -1,5 +1,5 @@ from types import SimpleNamespace -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest from flask import Flask, g @@ -16,10 +16,6 @@ def app(): return flask_app -def _mock_wraps_db(mock_db): - mock_db.session.query.return_value.first.return_value = MagicMock() - - def _build_feature_flags(): placeholder_quota = SimpleNamespace(limit=0, size=0) workspace_members = SimpleNamespace(is_available=lambda count: True) @@ -49,7 +45,6 @@ class TestMemberInviteEmailApi: mock_get_features, app, ): - _mock_wraps_db(mock_db) mock_get_features.return_value = _build_feature_flags() mock_invite_member.return_value = "token-abc" diff --git a/api/tests/unit_tests/controllers/console/test_wraps.py b/api/tests/unit_tests/controllers/console/test_wraps.py index f6e096a97b..aa4973851a 100644 --- a/api/tests/unit_tests/controllers/console/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/test_wraps.py @@ -310,7 +310,6 @@ class TestSystemSetup: def test_should_allow_when_setup_complete(self, mock_db): """Test that requests are allowed when setup is complete""" # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() # Setup exists @setup_required def admin_view(): diff --git a/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py index 44feacf2ad..1422f29849 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py @@ -22,7 +22,7 @@ _WRAPS_MODULE: ModuleType | None = None @contextmanager def _mock_db(): - mock_session = SimpleNamespace(query=lambda *args, **kwargs: SimpleNamespace(first=lambda: True)) + mock_session = SimpleNamespace(scalar=lambda *args, **kwargs: True) with patch("extensions.ext_database.db.session", mock_session): yield diff --git a/api/tests/unit_tests/controllers/service_api/app/test_app.py b/api/tests/unit_tests/controllers/service_api/app/test_app.py index f48ace427d..f5d93b5ac3 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_app.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_app.py @@ -12,7 +12,7 @@ from controllers.service_api.app.app import AppInfoApi, AppMetaApi, AppParameter from controllers.service_api.app.error import AppUnavailableError from models.account import TenantStatus from models.model import App, AppMode -from tests.unit_tests.conftest import setup_mock_tenant_account_query +from tests.unit_tests.conftest import setup_mock_tenant_owner_execute_result class TestAppParameterApi: @@ -74,7 +74,7 @@ class TestAppParameterApi: # Mock tenant owner info for login mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -120,7 +120,7 @@ class TestAppParameterApi: mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -161,7 +161,7 @@ class TestAppParameterApi: mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act & Assert with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -200,7 +200,7 @@ class TestAppParameterApi: mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act & Assert with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -263,7 +263,7 @@ class TestAppMetaApi: mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act with app.test_request_context("/meta", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -331,7 +331,7 @@ class TestAppInfoApi: mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -388,7 +388,7 @@ class TestAppInfoApi: mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -434,7 +434,7 @@ class TestAppInfoApi: mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -486,7 +486,7 @@ class TestAppInfoApi: mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}): diff --git a/api/tests/unit_tests/controllers/service_api/conftest.py b/api/tests/unit_tests/controllers/service_api/conftest.py index eddba5a517..8c89812cb4 100644 --- a/api/tests/unit_tests/controllers/service_api/conftest.py +++ b/api/tests/unit_tests/controllers/service_api/conftest.py @@ -15,7 +15,10 @@ from flask import Flask from core.rag.index_processor.constant.index_type import IndexStructureType from models.account import TenantStatus from models.model import App, AppMode, EndUser -from tests.unit_tests.conftest import setup_mock_tenant_account_query +from tests.unit_tests.conftest import ( + setup_mock_dataset_owner_execute_result, + setup_mock_tenant_owner_execute_result, +) @pytest.fixture @@ -123,9 +126,7 @@ class AuthenticationMocker: mock_db.session.get.side_effect = [mock_app, mock_tenant] if mock_account: - mock_ta = Mock() - mock_ta.account_id = mock_account.id - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) @staticmethod def setup_dataset_auth(mock_db, mock_tenant, mock_account): @@ -133,8 +134,7 @@ class AuthenticationMocker: mock_ta = Mock() mock_ta.account_id = mock_account.id - mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta) - + setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_ta) mock_db.session.get.return_value = mock_account diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py index 288659b192..1b391e67ec 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py @@ -701,8 +701,8 @@ class TestDocumentApiDelete: ``delete`` is wrapped by ``@cloud_edition_billing_rate_limit_check`` which internally calls ``validate_and_get_api_token``. To bypass the decorator we call the original function via ``__wrapped__`` (preserved by - ``functools.wraps``). ``delete`` queries the dataset via - ``db.session.query(Dataset)`` directly, so we patch ``db`` at the + ``functools.wraps``). ``delete`` loads the dataset via + ``db.session.scalar(select(Dataset)...)``, so we patch ``db`` at the controller module. """ diff --git a/api/tests/unit_tests/controllers/service_api/test_wraps.py b/api/tests/unit_tests/controllers/service_api/test_wraps.py index a2008e024b..6dfbdcf98e 100644 --- a/api/tests/unit_tests/controllers/service_api/test_wraps.py +++ b/api/tests/unit_tests/controllers/service_api/test_wraps.py @@ -24,8 +24,8 @@ from enums.cloud_plan import CloudPlan from models.account import TenantStatus from models.model import ApiToken from tests.unit_tests.conftest import ( - setup_mock_dataset_tenant_query, - setup_mock_tenant_account_query, + setup_mock_dataset_owner_execute_result, + setup_mock_tenant_owner_execute_result, ) @@ -141,14 +141,11 @@ class TestValidateAppToken: mock_account = Mock() mock_account.id = str(uuid.uuid4()) - mock_ta = Mock() - mock_ta.account_id = mock_account.id - # Use side_effect to return app first, then tenant via session.get() mock_db.session.get.side_effect = [mock_app, mock_tenant] - # Mock the tenant owner query (execute(select(...)).one_or_none()) - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta) + # Mock the tenant owner execute result (execute(select(...)).one_or_none()) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) @validate_app_token def protected_view(app_model): @@ -471,7 +468,7 @@ class TestValidateDatasetToken: mock_account.current_tenant = mock_tenant # Mock the tenant account join query (execute(select(...)).one_or_none()) - setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta) + setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_ta) # Mock the account lookup via session.get() mock_db.session.get.return_value = mock_account diff --git a/api/tests/unit_tests/controllers/web/conftest.py b/api/tests/unit_tests/controllers/web/conftest.py index 274d78c9cf..b7f3244c6c 100644 --- a/api/tests/unit_tests/controllers/web/conftest.py +++ b/api/tests/unit_tests/controllers/web/conftest.py @@ -22,18 +22,16 @@ class FakeSession: def __init__(self, mapping: dict[str, Any] | None = None): self._mapping: dict[str, Any] = mapping or {} - self._model_name: str | None = None - def query(self, model: type) -> FakeSession: - self._model_name = model.__name__ - return self + def get(self, model: type, _ident: object) -> Any: + return self._mapping.get(model.__name__) - def where(self, *_args: object, **_kwargs: object) -> FakeSession: - return self - - def first(self) -> Any: - assert self._model_name is not None - return self._mapping.get(self._model_name) + def scalar(self, stmt: Any) -> Any: + try: + model = stmt.column_descriptions[0]["entity"] + except (AttributeError, IndexError, KeyError, TypeError): + return None + return self._mapping.get(model.__name__) class FakeDB: diff --git a/api/tests/unit_tests/controllers/web/test_human_input_form.py b/api/tests/unit_tests/controllers/web/test_human_input_form.py index a1dbc80b20..5f2dc19aab 100644 --- a/api/tests/unit_tests/controllers/web/test_human_input_form.py +++ b/api/tests/unit_tests/controllers/web/test_human_input_form.py @@ -36,18 +36,6 @@ class _FakeSession: def __init__(self, mapping: dict[str, Any]): self._mapping = mapping - self._model_name: str | None = None - - def query(self, model): - self._model_name = model.__name__ - return self - - def where(self, *args, **kwargs): - return self - - def first(self): - assert self._model_name is not None - return self._mapping.get(self._model_name) def get(self, model, ident): return self._mapping.get(model.__name__) diff --git a/api/tests/unit_tests/controllers/web/test_web_login.py b/api/tests/unit_tests/controllers/web/test_web_login.py index a01587d64a..13b953c04d 100644 --- a/api/tests/unit_tests/controllers/web/test_web_login.py +++ b/api/tests/unit_tests/controllers/web/test_web_login.py @@ -34,7 +34,6 @@ def _patch_wraps(): patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features), patch("controllers.web.login.dify_config", web_dify), ): - mock_db.session.query.return_value.first.return_value = MagicMock() yield diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py index 45d4b0e321..370f7abb8b 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py @@ -154,7 +154,6 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) mock_session_class.return_value.__enter__.return_value = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists mock_db.engine = MagicMock() # Mock GraphRuntimeState to accept the variable pool @@ -301,7 +300,6 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) mock_session_class.return_value.__enter__.return_value = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists mock_db.engine = MagicMock() # Mock ConversationVariable.from_variable to return mock objects @@ -453,7 +451,6 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) mock_session_class.return_value.__enter__.return_value = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists mock_db.engine = MagicMock() # Mock GraphRuntimeState to accept the variable pool diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py index 9a2dc38f74..c36edf48fc 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py @@ -375,7 +375,7 @@ def test_generate_success_returns_converted(generator, mocker): workflow = MagicMock(id="wf", tenant_id="tenant", app_id="pipe", graph_dict={}) session = MagicMock() - session.query.return_value.where.return_value.first.return_value = workflow + session.get.return_value = workflow mocker.patch.object(module.db, "session", session) queue_manager = MagicMock() diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py index 618c8fd76f..603062a51c 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py @@ -132,11 +132,8 @@ def test_run_pipeline_not_found(mocker): app_generate_entity.single_iteration_run = None app_generate_entity.single_loop_run = None - query = MagicMock() - query.where.return_value.first.return_value = None - session = MagicMock() - session.query.return_value = query + session.get.side_effect = [None, None] mocker.patch.object(module.db, "session", session) runner = PipelineRunner( @@ -157,11 +154,9 @@ def test_run_workflow_not_initialized(mocker): app_generate_entity = _build_app_generate_entity() pipeline = MagicMock(id="pipe") - query_pipeline = MagicMock() - query_pipeline.where.return_value.first.return_value = pipeline session = MagicMock() - session.query.return_value = query_pipeline + session.get.side_effect = [None, pipeline] mocker.patch.object(module.db, "session", session) runner = PipelineRunner( diff --git a/api/tests/unit_tests/core/datasource/test_notion_provider.py b/api/tests/unit_tests/core/datasource/test_notion_provider.py index e4bd7d3bdf..d21b9e471b 100644 --- a/api/tests/unit_tests/core/datasource/test_notion_provider.py +++ b/api/tests/unit_tests/core/datasource/test_notion_provider.py @@ -775,9 +775,6 @@ class TestNotionExtractorLastEditedTime: "last_edited_time": "2024-11-27T18:00:00.000Z", } mock_request.return_value = mock_response - mock_query = Mock() - mock_db.session.query.return_value = mock_query - mock_query.filter_by.return_value = mock_query # Act extractor_page.update_last_edited_time(mock_document_model) @@ -863,9 +860,6 @@ class TestNotionExtractorIntegration: } mock_request.side_effect = [last_edited_response, block_response] - mock_query = Mock() - mock_db.session.query.return_value = mock_query - mock_query.filter_by.return_value = mock_query # Act documents = extractor.extract() @@ -919,10 +913,6 @@ class TestNotionExtractorIntegration: } mock_post.return_value = database_response - mock_query = Mock() - mock_db.session.query.return_value = mock_query - mock_query.filter_by.return_value = mock_query - # Act documents = extractor.extract() diff --git a/api/tests/unit_tests/core/helper/test_encrypter.py b/api/tests/unit_tests/core/helper/test_encrypter.py index f3ef7fccd0..73e081a570 100644 --- a/api/tests/unit_tests/core/helper/test_encrypter.py +++ b/api/tests/unit_tests/core/helper/test_encrypter.py @@ -40,11 +40,11 @@ class TestObfuscatedToken: class TestEncryptToken: @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") - def test_successful_encryption(self, mock_encrypt, mock_query): + def test_successful_encryption(self, mock_encrypt, mock_get): """Test successful token encryption""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value = mock_tenant + mock_get.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_data" result = encrypt_token("tenant-123", "test_token") @@ -53,9 +53,9 @@ class TestEncryptToken: mock_encrypt.assert_called_with("test_token", "mock_public_key") @patch("extensions.ext_database.db.session.get") - def test_tenant_not_found(self, mock_query): + def test_tenant_not_found(self, mock_get): """Test error when tenant doesn't exist""" - mock_query.return_value = None + mock_get.return_value = None with pytest.raises(ValueError) as exc_info: encrypt_token("invalid-tenant", "test_token") @@ -122,12 +122,12 @@ class TestEncryptDecryptIntegration: @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") @patch("libs.rsa.decrypt") - def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_query): + def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_get): """Test that encryption and decryption are consistent""" # Setup mock tenant mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value = mock_tenant + mock_get.return_value = mock_tenant # Setup mock encryption/decryption original_token = "test_token_123" @@ -148,12 +148,12 @@ class TestSecurity: @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") - def test_cross_tenant_isolation(self, mock_encrypt, mock_query): + def test_cross_tenant_isolation(self, mock_encrypt, mock_get): """Ensure tokens encrypted for one tenant cannot be used by another""" # Setup mock tenant mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "tenant1_public_key" - mock_query.return_value = mock_tenant + mock_get.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_for_tenant1" # Encrypt token for tenant1 @@ -183,10 +183,10 @@ class TestSecurity: @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") - def test_encryption_randomness(self, mock_encrypt, mock_query): + def test_encryption_randomness(self, mock_encrypt, mock_get): """Ensure same plaintext produces different ciphertext""" mock_tenant = MagicMock(encrypt_public_key="key") - mock_query.return_value = mock_tenant + mock_get.return_value = mock_tenant # Different outputs for same input mock_encrypt.side_effect = [b"enc1", b"enc2", b"enc3"] @@ -207,11 +207,11 @@ class TestEdgeCases: @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") - def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_query): + def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_get): """Test encryption of empty token""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value = mock_tenant + mock_get.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_empty" result = encrypt_token("tenant-123", "") @@ -221,11 +221,11 @@ class TestEdgeCases: @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") - def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_query): + def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_get): """Test tokens containing special/unicode characters""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value = mock_tenant + mock_get.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_special" # Test various special characters @@ -244,11 +244,11 @@ class TestEdgeCases: @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") - def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_query): + def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_get): """Test behavior when token exceeds RSA encryption limits""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value = mock_tenant + mock_get.return_value = mock_tenant # RSA 2048-bit can only encrypt ~245 bytes # The actual limit depends on padding scheme diff --git a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py index 7d6c8f983f..c4e610d5b0 100644 --- a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py +++ b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py @@ -491,7 +491,7 @@ class TestLLMGenerator: def test_instruction_modify_workflow_no_last_run_fallback(self, mock_model_instance, model_config_entity): with patch("extensions.ext_database.db.session") as mock_session: - mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + mock_session.return_value.scalar.return_value = MagicMock() workflow = MagicMock() workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "code"}}]}} @@ -517,7 +517,7 @@ class TestLLMGenerator: def test_instruction_modify_workflow_node_type_fallback(self, mock_model_instance, model_config_entity): with patch("extensions.ext_database.db.session") as mock_session: - mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + mock_session.return_value.scalar.return_value = MagicMock() workflow = MagicMock() # Cause exception in node_type logic workflow.graph_dict = {"graph": {"nodes": []}} @@ -544,7 +544,7 @@ class TestLLMGenerator: def test_instruction_modify_workflow_empty_agent_log(self, mock_model_instance, model_config_entity): with patch("extensions.ext_database.db.session") as mock_session: - mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + mock_session.return_value.scalar.return_value = MagicMock() workflow = MagicMock() workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "llm"}}]}} @@ -632,7 +632,7 @@ class TestLLMGenerator: instance.invoke_llm.return_value = mock_response with patch("extensions.ext_database.db.session") as mock_session: - mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + mock_session.return_value.scalar.return_value = MagicMock() workflow = MagicMock() workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "other"}}]}} diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py index 136ac0c72a..1e91c2dd88 100644 --- a/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py +++ b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py @@ -29,15 +29,6 @@ class _Field: return ("in", self._name, tuple(values)) -class _FakeQuery: - def __init__(self): - self.where_calls: list[tuple] = [] - - def where(self, *conditions): - self.where_calls.append(conditions) - return self - - class _FakeExecuteResult: def __init__(self, segments: list[SimpleNamespace]): self._segments = segments diff --git a/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py index 0baf85c314..b0ecad4d0c 100644 --- a/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py +++ b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py @@ -109,17 +109,6 @@ class _FakeExecuteResult: return _FakeExecuteScalarResult(self._data) -class _FakeSummaryQuery: - def __init__(self, summaries: list) -> None: - self._summaries = summaries - - def filter(self, *args, **kwargs): - return self - - def all(self) -> list: - return self._summaries - - class _FakeScalarsResult: def __init__(self, data: list) -> None: self._data = data diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py index dc21d378a2..9de04c80ba 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py @@ -372,19 +372,11 @@ def test_vector_delegation_methods(vector_factory_module): def test_search_by_file_handles_missing_and_existing_upload(vector_factory_module, monkeypatch): - class _Field: - def __eq__(self, value): - return value - - upload_query = MagicMock() - upload_query.where.return_value = upload_query - vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) vector._embeddings = MagicMock() vector._vector_processor = MagicMock() mock_session = SimpleNamespace(get=lambda _model, _id: None) - monkeypatch.setattr(vector_factory_module, "UploadFile", SimpleNamespace(id=_Field())) monkeypatch.setattr(vector_factory_module, "db", SimpleNamespace(session=mock_session)) assert vector.search_by_file("file-1") == [] diff --git a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py index 7c4defc180..b4bb343533 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py +++ b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py @@ -1484,11 +1484,8 @@ class TestIndexingRunnerProcessChunk: mock_dependencies["redis"].get.return_value = None - # Mock database query for segment updates - mock_query = MagicMock() - mock_dependencies["db"].session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.update.return_value = None + # Mock database update for segment status + mock_dependencies["db"].session.execute.return_value = None # Create a proper context manager mock mock_context = MagicMock() diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index 89830f7517..fd607210f1 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -2417,12 +2417,11 @@ class TestDatasetRetrievalKnowledgeRetrieval: mock_document.data_source_type = "upload_file" mock_document.doc_metadata = {} - mock_session.query.return_value.filter.return_value.all.return_value = [ - mock_dataset_from_db - ] - mock_session.query.return_value.filter.return_value.all.__iter__ = lambda self: iter( - [mock_dataset_from_db, mock_document] - ) + mock_datasets = MagicMock() + mock_datasets.all.return_value = [mock_dataset_from_db] + mock_documents = MagicMock() + mock_documents.all.return_value = [mock_document] + mock_session.scalars.side_effect = [mock_datasets, mock_documents] # Act result = dataset_retrieval.knowledge_retrieval(request) diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py index 90feb4cf01..aace419d15 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py @@ -451,12 +451,11 @@ class TestDatasetRetrievalKnowledgeRetrieval: mock_document.data_source_type = "upload_file" mock_document.doc_metadata = {} - mock_session.query.return_value.filter.return_value.all.return_value = [ - mock_dataset_from_db - ] - mock_session.query.return_value.filter.return_value.all.__iter__ = lambda self: iter( - [mock_dataset_from_db, mock_document] - ) + mock_datasets = MagicMock() + mock_datasets.all.return_value = [mock_dataset_from_db] + mock_documents = MagicMock() + mock_documents.all.return_value = [mock_document] + mock_session.scalars.side_effect = [mock_datasets, mock_documents] # Act result = dataset_retrieval.knowledge_retrieval(request) diff --git a/api/tests/unit_tests/services/document_indexing_task_proxy.py b/api/tests/unit_tests/services/document_indexing_task_proxy.py deleted file mode 100644 index ff243b8dc3..0000000000 --- a/api/tests/unit_tests/services/document_indexing_task_proxy.py +++ /dev/null @@ -1,1291 +0,0 @@ -""" -Comprehensive unit tests for DocumentIndexingTaskProxy service. - -This module contains extensive unit tests for the DocumentIndexingTaskProxy class, -which is responsible for routing document indexing tasks to appropriate Celery queues -based on tenant billing configuration and managing tenant-isolated task queues. - -The DocumentIndexingTaskProxy handles: -- Task scheduling and queuing (direct vs tenant-isolated queues) -- Priority vs normal task routing based on billing plans -- Tenant isolation using TenantIsolatedTaskQueue -- Batch indexing operations with multiple document IDs -- Error handling and retry logic through queue management - -This test suite ensures: -- Correct task routing based on billing configuration -- Proper tenant isolation queue management -- Accurate batch operation handling -- Comprehensive error condition coverage -- Edge cases are properly handled - -================================================================================ -ARCHITECTURE OVERVIEW -================================================================================ - -The DocumentIndexingTaskProxy is a critical component in the document indexing -workflow. It acts as a proxy/router that determines which Celery queue to use -for document indexing tasks based on tenant billing configuration. - -1. Task Queue Routing: - - Direct Queue: Bypasses tenant isolation, used for self-hosted/enterprise - - Tenant Queue: Uses tenant isolation, queues tasks when another task is running - - Default Queue: Normal priority with tenant isolation (SANDBOX plan) - - Priority Queue: High priority with tenant isolation (TEAM/PRO plans) - - Priority Direct Queue: High priority without tenant isolation (billing disabled) - -2. Tenant Isolation: - - Uses TenantIsolatedTaskQueue to ensure only one indexing task runs per tenant - - When a task is running, new tasks are queued in Redis - - When a task completes, it pulls the next task from the queue - - Prevents resource contention and ensures fair task distribution - -3. Billing Configuration: - - SANDBOX plan: Uses default tenant queue (normal priority, tenant isolated) - - TEAM/PRO plans: Uses priority tenant queue (high priority, tenant isolated) - - Billing disabled: Uses priority direct queue (high priority, no isolation) - -4. Batch Operations: - - Supports indexing multiple documents in a single task - - DocumentTask entity serializes task information - - Tasks are queued with all document IDs for batch processing - -================================================================================ -TESTING STRATEGY -================================================================================ - -This test suite follows a comprehensive testing strategy that covers: - -1. Initialization and Configuration: - - Proxy initialization with various parameters - - TenantIsolatedTaskQueue initialization - - Features property caching - - Edge cases (empty document_ids, single document, large batches) - -2. Task Queue Routing: - - Direct queue routing (bypasses tenant isolation) - - Tenant queue routing with existing task key (pushes to waiting queue) - - Tenant queue routing without task key (sets flag and executes immediately) - - DocumentTask serialization and deserialization - - Task function delay() call with correct parameters - -3. Queue Type Selection: - - Default tenant queue routing (normal_document_indexing_task) - - Priority tenant queue routing (priority_document_indexing_task with isolation) - - Priority direct queue routing (priority_document_indexing_task without isolation) - -4. Dispatch Logic: - - Billing enabled + SANDBOX plan → default tenant queue - - Billing enabled + non-SANDBOX plan (TEAM, PRO, etc.) → priority tenant queue - - Billing disabled (self-hosted/enterprise) → priority direct queue - - All CloudPlan enum values handling - - Edge cases: None plan, empty plan string - -5. Tenant Isolation and Queue Management: - - Task key existence checking (get_task_key) - - Task waiting time setting (set_task_waiting_time) - - Task pushing to queue (push_tasks) - - Queue state transitions (idle → active → idle) - - Multiple concurrent task handling - -6. Batch Operations: - - Single document indexing - - Multiple document batch indexing - - Large batch handling - - Empty batch handling (edge case) - -7. Error Handling and Retry Logic: - - Task function delay() failure handling - - Queue operation failures (Redis errors) - - Feature service failures - - Invalid task data handling - - Retry mechanism through queue pull operations - -8. Integration Points: - - FeatureService integration (billing features, subscription plans) - - TenantIsolatedTaskQueue integration (Redis operations) - - Celery task integration (normal_document_indexing_task, priority_document_indexing_task) - - DocumentTask entity serialization - -================================================================================ -""" - -from unittest.mock import Mock, patch - -import pytest - -from core.entities.document_task import DocumentTask -from core.rag.pipeline.queue import TenantIsolatedTaskQueue -from enums.cloud_plan import CloudPlan -from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy - -# ============================================================================ -# Test Data Factory -# ============================================================================ - - -class DocumentIndexingTaskProxyTestDataFactory: - """ - Factory class for creating test data and mock objects for DocumentIndexingTaskProxy tests. - - This factory provides static methods to create mock objects for: - - FeatureService features with billing configuration - - TenantIsolatedTaskQueue mocks with various states - - DocumentIndexingTaskProxy instances with different configurations - - DocumentTask entities for testing serialization - - The factory methods help maintain consistency across tests and reduce - code duplication when setting up test scenarios. - """ - - @staticmethod - def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock: - """ - Create mock features with billing configuration. - - This method creates a mock FeatureService features object with - billing configuration that can be used to test different billing - scenarios in the DocumentIndexingTaskProxy. - - Args: - billing_enabled: Whether billing is enabled for the tenant - plan: The CloudPlan enum value for the subscription plan - - Returns: - Mock object configured as FeatureService features with billing info - """ - features = Mock() - - features.billing = Mock() - - features.billing.enabled = billing_enabled - - features.billing.subscription = Mock() - - features.billing.subscription.plan = plan - - return features - - @staticmethod - def create_mock_tenant_queue(has_task_key: bool = False) -> Mock: - """ - Create mock TenantIsolatedTaskQueue. - - This method creates a mock TenantIsolatedTaskQueue that can simulate - different queue states for testing tenant isolation logic. - - Args: - has_task_key: Whether the queue has an active task key (task running) - - Returns: - Mock object configured as TenantIsolatedTaskQueue - """ - queue = Mock(spec=TenantIsolatedTaskQueue) - - queue.get_task_key.return_value = "task_key" if has_task_key else None - - queue.push_tasks = Mock() - - queue.set_task_waiting_time = Mock() - - queue.delete_task_key = Mock() - - return queue - - @staticmethod - def create_document_task_proxy( - tenant_id: str = "tenant-123", dataset_id: str = "dataset-456", document_ids: list[str] | None = None - ) -> DocumentIndexingTaskProxy: - """ - Create DocumentIndexingTaskProxy instance for testing. - - This method creates a DocumentIndexingTaskProxy instance with default - or specified parameters for use in test cases. - - Args: - tenant_id: Tenant identifier for the proxy - dataset_id: Dataset identifier for the proxy - document_ids: List of document IDs to index (defaults to 3 documents) - - Returns: - DocumentIndexingTaskProxy instance configured for testing - """ - if document_ids is None: - document_ids = ["doc-1", "doc-2", "doc-3"] - - return DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) - - @staticmethod - def create_document_task( - tenant_id: str = "tenant-123", dataset_id: str = "dataset-456", document_ids: list[str] | None = None - ) -> DocumentTask: - """ - Create DocumentTask entity for testing. - - This method creates a DocumentTask entity that can be used to test - task serialization and deserialization logic. - - Args: - tenant_id: Tenant identifier for the task - dataset_id: Dataset identifier for the task - document_ids: List of document IDs to index (defaults to 3 documents) - - Returns: - DocumentTask entity configured for testing - """ - if document_ids is None: - document_ids = ["doc-1", "doc-2", "doc-3"] - - return DocumentTask(tenant_id=tenant_id, dataset_id=dataset_id, document_ids=document_ids) - - -# ============================================================================ -# Test Classes -# ============================================================================ - - -class TestDocumentIndexingTaskProxy: - """ - Comprehensive unit tests for DocumentIndexingTaskProxy class. - - This test class covers all methods and scenarios of the DocumentIndexingTaskProxy, - including initialization, task routing, queue management, dispatch logic, and - error handling. - """ - - # ======================================================================== - # Initialization Tests - # ======================================================================== - - def test_initialization(self): - """ - Test DocumentIndexingTaskProxy initialization. - - This test verifies that the proxy is correctly initialized with - the provided tenant_id, dataset_id, and document_ids, and that - the TenantIsolatedTaskQueue is properly configured. - """ - # Arrange - tenant_id = "tenant-123" - - dataset_id = "dataset-456" - - document_ids = ["doc-1", "doc-2", "doc-3"] - - # Act - proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) - - # Assert - assert proxy._tenant_id == tenant_id - - assert proxy._dataset_id == dataset_id - - assert proxy._document_ids == document_ids - - assert isinstance(proxy._tenant_isolated_task_queue, TenantIsolatedTaskQueue) - - assert proxy._tenant_isolated_task_queue._tenant_id == tenant_id - - assert proxy._tenant_isolated_task_queue._unique_key == "document_indexing" - - def test_initialization_with_empty_document_ids(self): - """ - Test initialization with empty document_ids list. - - This test verifies that the proxy can be initialized with an empty - document_ids list, which may occur in edge cases or error scenarios. - """ - # Arrange - tenant_id = "tenant-123" - - dataset_id = "dataset-456" - - document_ids = [] - - # Act - proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) - - # Assert - assert proxy._tenant_id == tenant_id - - assert proxy._dataset_id == dataset_id - - assert proxy._document_ids == document_ids - - assert len(proxy._document_ids) == 0 - - def test_initialization_with_single_document_id(self): - """ - Test initialization with single document_id. - - This test verifies that the proxy can be initialized with a single - document ID, which is a common use case for single document indexing. - """ - # Arrange - tenant_id = "tenant-123" - - dataset_id = "dataset-456" - - document_ids = ["doc-1"] - - # Act - proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) - - # Assert - assert proxy._tenant_id == tenant_id - - assert proxy._dataset_id == dataset_id - - assert proxy._document_ids == document_ids - - assert len(proxy._document_ids) == 1 - - def test_initialization_with_large_batch(self): - """ - Test initialization with large batch of document IDs. - - This test verifies that the proxy can handle large batches of - document IDs, which may occur in bulk indexing scenarios. - """ - # Arrange - tenant_id = "tenant-123" - - dataset_id = "dataset-456" - - document_ids = [f"doc-{i}" for i in range(100)] - - # Act - proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) - - # Assert - assert proxy._tenant_id == tenant_id - - assert proxy._dataset_id == dataset_id - - assert proxy._document_ids == document_ids - - assert len(proxy._document_ids) == 100 - - # ======================================================================== - # Features Property Tests - # ======================================================================== - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") - def test_features_property(self, mock_feature_service): - """ - Test cached_property features. - - This test verifies that the features property is correctly cached - and that FeatureService.get_features is called only once, even when - the property is accessed multiple times. - """ - # Arrange - mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features() - - mock_feature_service.get_features.return_value = mock_features - - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - # Act - features1 = proxy.features - - features2 = proxy.features # Second call should use cached property - - # Assert - assert features1 == mock_features - - assert features2 == mock_features - - assert features1 is features2 # Should be the same instance due to caching - - mock_feature_service.get_features.assert_called_once_with("tenant-123") - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") - def test_features_property_with_different_tenants(self, mock_feature_service): - """ - Test features property with different tenant IDs. - - This test verifies that the features property correctly calls - FeatureService.get_features with the correct tenant_id for each - proxy instance. - """ - # Arrange - mock_features1 = DocumentIndexingTaskProxyTestDataFactory.create_mock_features() - - mock_features2 = DocumentIndexingTaskProxyTestDataFactory.create_mock_features() - - mock_feature_service.get_features.side_effect = [mock_features1, mock_features2] - - proxy1 = DocumentIndexingTaskProxy("tenant-1", "dataset-1", ["doc-1"]) - - proxy2 = DocumentIndexingTaskProxy("tenant-2", "dataset-2", ["doc-2"]) - - # Act - features1 = proxy1.features - - features2 = proxy2.features - - # Assert - assert features1 == mock_features1 - - assert features2 == mock_features2 - - mock_feature_service.get_features.assert_any_call("tenant-1") - - mock_feature_service.get_features.assert_any_call("tenant-2") - - # ======================================================================== - # Direct Queue Routing Tests - # ======================================================================== - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") - def test_send_to_direct_queue(self, mock_task): - """ - Test _send_to_direct_queue method. - - This test verifies that _send_to_direct_queue correctly calls - task_func.delay() with the correct parameters, bypassing tenant - isolation queue management. - """ - # Arrange - tenant_id = "tenant-direct-queue" - dataset_id = "dataset-direct-queue" - document_ids = ["doc-direct-1", "doc-direct-2"] - proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) - mock_task.delay = Mock() - - # Act - proxy._send_to_direct_queue(mock_task) - - # Assert - mock_task.delay.assert_called_once_with(tenant_id=tenant_id, dataset_id=dataset_id, document_ids=document_ids) - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") - def test_send_to_direct_queue_with_priority_task(self, mock_task): - """ - Test _send_to_direct_queue with priority task function. - - This test verifies that _send_to_direct_queue works correctly - with priority_document_indexing_task as the task function. - """ - # Arrange - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - mock_task.delay = Mock() - - # Act - proxy._send_to_direct_queue(mock_task) - - # Assert - mock_task.delay.assert_called_once_with( - tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] - ) - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") - def test_send_to_direct_queue_with_single_document(self, mock_task): - """ - Test _send_to_direct_queue with single document ID. - - This test verifies that _send_to_direct_queue correctly handles - a single document ID in the document_ids list. - """ - # Arrange - proxy = DocumentIndexingTaskProxy("tenant-123", "dataset-456", ["doc-1"]) - - mock_task.delay = Mock() - - # Act - proxy._send_to_direct_queue(mock_task) - - # Assert - mock_task.delay.assert_called_once_with( - tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1"] - ) - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") - def test_send_to_direct_queue_with_empty_documents(self, mock_task): - """ - Test _send_to_direct_queue with empty document_ids list. - - This test verifies that _send_to_direct_queue correctly handles - an empty document_ids list, which may occur in edge cases. - """ - # Arrange - proxy = DocumentIndexingTaskProxy("tenant-123", "dataset-456", []) - - mock_task.delay = Mock() - - # Act - proxy._send_to_direct_queue(mock_task) - - # Assert - mock_task.delay.assert_called_once_with(tenant_id="tenant-123", dataset_id="dataset-456", document_ids=[]) - - # ======================================================================== - # Tenant Queue Routing Tests - # ======================================================================== - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") - def test_send_to_tenant_queue_with_existing_task_key(self, mock_task): - """ - Test _send_to_tenant_queue when task key exists. - - This test verifies that when a task key exists (indicating another - task is running), the new task is pushed to the waiting queue instead - of being executed immediately. - """ - # Arrange - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue( - has_task_key=True - ) - - mock_task.delay = Mock() - - # Act - proxy._send_to_tenant_queue(mock_task) - - # Assert - proxy._tenant_isolated_task_queue.push_tasks.assert_called_once() - - pushed_tasks = proxy._tenant_isolated_task_queue.push_tasks.call_args[0][0] - - assert len(pushed_tasks) == 1 - - expected_task_data = { - "tenant_id": "tenant-123", - "dataset_id": "dataset-456", - "document_ids": ["doc-1", "doc-2", "doc-3"], - } - assert pushed_tasks[0] == expected_task_data - - assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"] - - mock_task.delay.assert_not_called() - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") - def test_send_to_tenant_queue_without_task_key(self, mock_task): - """ - Test _send_to_tenant_queue when no task key exists. - - This test verifies that when no task key exists (indicating no task - is currently running), the task is executed immediately and the - task waiting time flag is set. - """ - # Arrange - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue( - has_task_key=False - ) - - mock_task.delay = Mock() - - # Act - proxy._send_to_tenant_queue(mock_task) - - # Assert - proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once() - - mock_task.delay.assert_called_once_with( - tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] - ) - - proxy._tenant_isolated_task_queue.push_tasks.assert_not_called() - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") - def test_send_to_tenant_queue_with_priority_task(self, mock_task): - """ - Test _send_to_tenant_queue with priority task function. - - This test verifies that _send_to_tenant_queue works correctly - with priority_document_indexing_task as the task function. - """ - # Arrange - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue( - has_task_key=False - ) - - mock_task.delay = Mock() - - # Act - proxy._send_to_tenant_queue(mock_task) - - # Assert - proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once() - - mock_task.delay.assert_called_once_with( - tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] - ) - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") - def test_send_to_tenant_queue_document_task_serialization(self, mock_task): - """ - Test DocumentTask serialization in _send_to_tenant_queue. - - This test verifies that DocumentTask entities are correctly - serialized to dictionaries when pushing to the waiting queue. - """ - # Arrange - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue( - has_task_key=True - ) - - mock_task.delay = Mock() - - # Act - proxy._send_to_tenant_queue(mock_task) - - # Assert - pushed_tasks = proxy._tenant_isolated_task_queue.push_tasks.call_args[0][0] - - task_dict = pushed_tasks[0] - - # Verify the task can be deserialized back to DocumentTask - document_task = DocumentTask(**task_dict) - - assert document_task.tenant_id == "tenant-123" - - assert document_task.dataset_id == "dataset-456" - - assert document_task.document_ids == ["doc-1", "doc-2", "doc-3"] - - # ======================================================================== - # Queue Type Selection Tests - # ======================================================================== - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") - def test_send_to_default_tenant_queue(self, mock_task): - """ - Test _send_to_default_tenant_queue method. - - This test verifies that _send_to_default_tenant_queue correctly - calls _send_to_tenant_queue with normal_document_indexing_task. - """ - # Arrange - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._send_to_tenant_queue = Mock() - - # Act - proxy._send_to_default_tenant_queue() - - # Assert - proxy._send_to_tenant_queue.assert_called_once_with(mock_task) - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") - def test_send_to_priority_tenant_queue(self, mock_task): - """ - Test _send_to_priority_tenant_queue method. - - This test verifies that _send_to_priority_tenant_queue correctly - calls _send_to_tenant_queue with priority_document_indexing_task. - """ - # Arrange - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._send_to_tenant_queue = Mock() - - # Act - proxy._send_to_priority_tenant_queue() - - # Assert - proxy._send_to_tenant_queue.assert_called_once_with(mock_task) - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") - def test_send_to_priority_direct_queue(self, mock_task): - """ - Test _send_to_priority_direct_queue method. - - This test verifies that _send_to_priority_direct_queue correctly - calls _send_to_direct_queue with priority_document_indexing_task. - """ - # Arrange - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._send_to_direct_queue = Mock() - - # Act - proxy._send_to_priority_direct_queue() - - # Assert - proxy._send_to_direct_queue.assert_called_once_with(mock_task) - - # ======================================================================== - # Dispatch Logic Tests - # ======================================================================== - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") - def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service): - """ - Test _dispatch method when billing is enabled with SANDBOX plan. - - This test verifies that when billing is enabled and the subscription - plan is SANDBOX, the dispatch method routes to the default tenant queue. - """ - # Arrange - mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features( - billing_enabled=True, plan=CloudPlan.SANDBOX - ) - - mock_feature_service.get_features.return_value = mock_features - - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._send_to_default_tenant_queue = Mock() - - # Act - proxy._dispatch() - - # Assert - proxy._send_to_default_tenant_queue.assert_called_once() - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") - def test_dispatch_with_billing_enabled_team_plan(self, mock_feature_service): - """ - Test _dispatch method when billing is enabled with TEAM plan. - - This test verifies that when billing is enabled and the subscription - plan is TEAM, the dispatch method routes to the priority tenant queue. - """ - # Arrange - mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features( - billing_enabled=True, plan=CloudPlan.TEAM - ) - - mock_feature_service.get_features.return_value = mock_features - - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._send_to_priority_tenant_queue = Mock() - - # Act - proxy._dispatch() - - # Assert - proxy._send_to_priority_tenant_queue.assert_called_once() - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") - def test_dispatch_with_billing_enabled_professional_plan(self, mock_feature_service): - """ - Test _dispatch method when billing is enabled with PROFESSIONAL plan. - - This test verifies that when billing is enabled and the subscription - plan is PROFESSIONAL, the dispatch method routes to the priority tenant queue. - """ - # Arrange - mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features( - billing_enabled=True, plan=CloudPlan.PROFESSIONAL - ) - - mock_feature_service.get_features.return_value = mock_features - - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._send_to_priority_tenant_queue = Mock() - - # Act - proxy._dispatch() - - # Assert - proxy._send_to_priority_tenant_queue.assert_called_once() - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") - def test_dispatch_with_billing_disabled(self, mock_feature_service): - """ - Test _dispatch method when billing is disabled. - - This test verifies that when billing is disabled (e.g., self-hosted - or enterprise), the dispatch method routes to the priority direct queue. - """ - # Arrange - mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=False) - - mock_feature_service.get_features.return_value = mock_features - - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._send_to_priority_direct_queue = Mock() - - # Act - proxy._dispatch() - - # Assert - proxy._send_to_priority_direct_queue.assert_called_once() - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") - def test_dispatch_edge_case_empty_plan(self, mock_feature_service): - """ - Test _dispatch method with empty plan string. - - This test verifies that when billing is enabled but the plan is an - empty string, the dispatch method routes to the priority tenant queue - (treats it as a non-SANDBOX plan). - """ - # Arrange - mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan="") - - mock_feature_service.get_features.return_value = mock_features - - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._send_to_priority_tenant_queue = Mock() - - # Act - proxy._dispatch() - - # Assert - proxy._send_to_priority_tenant_queue.assert_called_once() - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") - def test_dispatch_edge_case_none_plan(self, mock_feature_service): - """ - Test _dispatch method with None plan. - - This test verifies that when billing is enabled but the plan is None, - the dispatch method routes to the priority tenant queue (treats it as - a non-SANDBOX plan). - """ - # Arrange - mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan=None) - - mock_feature_service.get_features.return_value = mock_features - - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._send_to_priority_tenant_queue = Mock() - - # Act - proxy._dispatch() - - # Assert - proxy._send_to_priority_tenant_queue.assert_called_once() - - # ======================================================================== - # Delay Method Tests - # ======================================================================== - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") - def test_delay_method(self, mock_feature_service): - """ - Test delay method integration. - - This test verifies that the delay method correctly calls _dispatch, - which is the public interface for scheduling document indexing tasks. - """ - # Arrange - mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features( - billing_enabled=True, plan=CloudPlan.SANDBOX - ) - - mock_feature_service.get_features.return_value = mock_features - - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._send_to_default_tenant_queue = Mock() - - # Act - proxy.delay() - - # Assert - proxy._send_to_default_tenant_queue.assert_called_once() - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") - def test_delay_method_with_team_plan(self, mock_feature_service): - """ - Test delay method with TEAM plan. - - This test verifies that the delay method correctly routes to the - priority tenant queue when the subscription plan is TEAM. - """ - # Arrange - mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features( - billing_enabled=True, plan=CloudPlan.TEAM - ) - - mock_feature_service.get_features.return_value = mock_features - - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._send_to_priority_tenant_queue = Mock() - - # Act - proxy.delay() - - # Assert - proxy._send_to_priority_tenant_queue.assert_called_once() - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") - def test_delay_method_with_billing_disabled(self, mock_feature_service): - """ - Test delay method with billing disabled. - - This test verifies that the delay method correctly routes to the - priority direct queue when billing is disabled. - """ - # Arrange - mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=False) - - mock_feature_service.get_features.return_value = mock_features - - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._send_to_priority_direct_queue = Mock() - - # Act - proxy.delay() - - # Assert - proxy._send_to_priority_direct_queue.assert_called_once() - - # ======================================================================== - # DocumentTask Entity Tests - # ======================================================================== - - def test_document_task_dataclass(self): - """ - Test DocumentTask dataclass. - - This test verifies that DocumentTask entities can be created and - accessed correctly, which is important for task serialization. - """ - # Arrange - tenant_id = "tenant-123" - - dataset_id = "dataset-456" - - document_ids = ["doc-1", "doc-2"] - - # Act - task = DocumentTask(tenant_id=tenant_id, dataset_id=dataset_id, document_ids=document_ids) - - # Assert - assert task.tenant_id == tenant_id - - assert task.dataset_id == dataset_id - - assert task.document_ids == document_ids - - def test_document_task_serialization(self): - """ - Test DocumentTask serialization to dictionary. - - This test verifies that DocumentTask entities can be correctly - serialized to dictionaries using asdict() for queue storage. - """ - # Arrange - from dataclasses import asdict - - task = DocumentIndexingTaskProxyTestDataFactory.create_document_task() - - # Act - task_dict = asdict(task) - - # Assert - assert task_dict["tenant_id"] == "tenant-123" - - assert task_dict["dataset_id"] == "dataset-456" - - assert task_dict["document_ids"] == ["doc-1", "doc-2", "doc-3"] - - def test_document_task_deserialization(self): - """ - Test DocumentTask deserialization from dictionary. - - This test verifies that DocumentTask entities can be correctly - deserialized from dictionaries when pulled from the queue. - """ - # Arrange - task_dict = { - "tenant_id": "tenant-123", - "dataset_id": "dataset-456", - "document_ids": ["doc-1", "doc-2", "doc-3"], - } - - # Act - task = DocumentTask(**task_dict) - - # Assert - assert task.tenant_id == "tenant-123" - - assert task.dataset_id == "dataset-456" - - assert task.document_ids == ["doc-1", "doc-2", "doc-3"] - - # ======================================================================== - # Batch Operations Tests - # ======================================================================== - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") - def test_batch_operation_with_multiple_documents(self, mock_task): - """ - Test batch operation with multiple documents. - - This test verifies that the proxy correctly handles batch operations - with multiple document IDs in a single task. - """ - # Arrange - document_ids = [f"doc-{i}" for i in range(10)] - - proxy = DocumentIndexingTaskProxy("tenant-123", "dataset-456", document_ids) - - mock_task.delay = Mock() - - # Act - proxy._send_to_direct_queue(mock_task) - - # Assert - mock_task.delay.assert_called_once_with( - tenant_id="tenant-123", dataset_id="dataset-456", document_ids=document_ids - ) - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") - def test_batch_operation_with_large_batch(self, mock_task): - """ - Test batch operation with large batch of documents. - - This test verifies that the proxy correctly handles large batches - of document IDs, which may occur in bulk indexing scenarios. - """ - # Arrange - document_ids = [f"doc-{i}" for i in range(100)] - - proxy = DocumentIndexingTaskProxy("tenant-123", "dataset-456", document_ids) - - mock_task.delay = Mock() - - # Act - proxy._send_to_direct_queue(mock_task) - - # Assert - mock_task.delay.assert_called_once_with( - tenant_id="tenant-123", dataset_id="dataset-456", document_ids=document_ids - ) - - assert len(mock_task.delay.call_args[1]["document_ids"]) == 100 - - # ======================================================================== - # Error Handling Tests - # ======================================================================== - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") - def test_send_to_direct_queue_task_delay_failure(self, mock_task): - """ - Test _send_to_direct_queue when task.delay() raises an exception. - - This test verifies that exceptions raised by task.delay() are - propagated correctly and not swallowed. - """ - # Arrange - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - mock_task.delay.side_effect = Exception("Task delay failed") - - # Act & Assert - with pytest.raises(Exception, match="Task delay failed"): - proxy._send_to_direct_queue(mock_task) - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") - def test_send_to_tenant_queue_push_tasks_failure(self, mock_task): - """ - Test _send_to_tenant_queue when push_tasks raises an exception. - - This test verifies that exceptions raised by push_tasks are - propagated correctly when a task key exists. - """ - # Arrange - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - mock_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(has_task_key=True) - - mock_queue.push_tasks.side_effect = Exception("Push tasks failed") - - proxy._tenant_isolated_task_queue = mock_queue - - # Act & Assert - with pytest.raises(Exception, match="Push tasks failed"): - proxy._send_to_tenant_queue(mock_task) - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") - def test_send_to_tenant_queue_set_waiting_time_failure(self, mock_task): - """ - Test _send_to_tenant_queue when set_task_waiting_time raises an exception. - - This test verifies that exceptions raised by set_task_waiting_time are - propagated correctly when no task key exists. - """ - # Arrange - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - mock_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(has_task_key=False) - - mock_queue.set_task_waiting_time.side_effect = Exception("Set waiting time failed") - - proxy._tenant_isolated_task_queue = mock_queue - - # Act & Assert - with pytest.raises(Exception, match="Set waiting time failed"): - proxy._send_to_tenant_queue(mock_task) - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") - def test_dispatch_feature_service_failure(self, mock_feature_service): - """ - Test _dispatch when FeatureService.get_features raises an exception. - - This test verifies that exceptions raised by FeatureService.get_features - are propagated correctly during dispatch. - """ - # Arrange - mock_feature_service.get_features.side_effect = Exception("Feature service failed") - - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - # Act & Assert - with pytest.raises(Exception, match="Feature service failed"): - proxy._dispatch() - - # ======================================================================== - # Integration Tests - # ======================================================================== - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") - @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") - def test_full_flow_sandbox_plan(self, mock_task, mock_feature_service): - """ - Test full flow for SANDBOX plan with tenant queue. - - This test verifies the complete flow from delay() call to task - scheduling for a SANDBOX plan tenant, including tenant isolation. - """ - # Arrange - mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features( - billing_enabled=True, plan=CloudPlan.SANDBOX - ) - - mock_feature_service.get_features.return_value = mock_features - - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue( - has_task_key=False - ) - - mock_task.delay = Mock() - - # Act - proxy.delay() - - # Assert - proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once() - - mock_task.delay.assert_called_once_with( - tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] - ) - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") - @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") - def test_full_flow_team_plan(self, mock_task, mock_feature_service): - """ - Test full flow for TEAM plan with priority tenant queue. - - This test verifies the complete flow from delay() call to task - scheduling for a TEAM plan tenant, including priority routing. - """ - # Arrange - mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features( - billing_enabled=True, plan=CloudPlan.TEAM - ) - - mock_feature_service.get_features.return_value = mock_features - - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue( - has_task_key=False - ) - - mock_task.delay = Mock() - - # Act - proxy.delay() - - # Assert - proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once() - - mock_task.delay.assert_called_once_with( - tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] - ) - - @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") - @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") - def test_full_flow_billing_disabled(self, mock_task, mock_feature_service): - """ - Test full flow for billing disabled (self-hosted/enterprise). - - This test verifies the complete flow from delay() call to task - scheduling when billing is disabled, using priority direct queue. - """ - # Arrange - mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=False) - - mock_feature_service.get_features.return_value = mock_features - - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - mock_task.delay = Mock() - - # Act - proxy.delay() - - # Assert - mock_task.delay.assert_called_once_with( - tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] - ) - - @patch("services.document_indexing_task_proxy.FeatureService") - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") - def test_full_flow_with_existing_task_key(self, mock_task, mock_feature_service): - """ - Test full flow when task key exists (task queuing). - - This test verifies the complete flow when another task is already - running, ensuring the new task is queued correctly. - """ - # Arrange - mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features( - billing_enabled=True, plan=CloudPlan.SANDBOX - ) - - mock_feature_service.get_features.return_value = mock_features - - proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() - - proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue( - has_task_key=True - ) - - mock_task.delay = Mock() - - # Act - proxy.delay() - - # Assert - proxy._tenant_isolated_task_queue.push_tasks.assert_called_once() - - pushed_tasks = proxy._tenant_isolated_task_queue.push_tasks.call_args[0][0] - - expected_task_data = { - "tenant_id": "tenant-123", - "dataset_id": "dataset-456", - "document_ids": ["doc-1", "doc-2", "doc-3"], - } - assert pushed_tasks[0] == expected_task_data - - assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"] - - mock_task.delay.assert_not_called() diff --git a/api/tests/unit_tests/services/external_dataset_service.py b/api/tests/unit_tests/services/external_dataset_service.py deleted file mode 100644 index 83bae370eb..0000000000 --- a/api/tests/unit_tests/services/external_dataset_service.py +++ /dev/null @@ -1,925 +0,0 @@ -""" -Extensive unit tests for ``ExternalDatasetService``. - -This module focuses on the *external dataset service* surface area, which is responsible -for integrating with **external knowledge APIs** and wiring them into Dify datasets. - -The goal of this test suite is twofold: - -- Provide **high‑confidence regression coverage** for all public helpers on - ``ExternalDatasetService``. -- Serve as **executable documentation** for how external API integration is expected - to behave in different scenarios (happy paths, validation failures, and error codes). - -The file intentionally contains **rich comments and generous spacing** in order to make -each scenario easy to scan during reviews. -""" - -from __future__ import annotations - -from types import SimpleNamespace -from typing import Any, cast -from unittest.mock import MagicMock, Mock, patch - -import httpx -import pytest - -from constants import HIDDEN_VALUE -from models.dataset import Dataset, ExternalKnowledgeApis, ExternalKnowledgeBindings -from services.entities.external_knowledge_entities.external_knowledge_entities import ( - Authorization, - AuthorizationConfig, - ExternalKnowledgeApiSetting, -) -from services.errors.dataset import DatasetNameDuplicateError -from services.external_knowledge_service import ExternalDatasetService - - -class ExternalDatasetTestDataFactory: - """ - Factory helpers for building *lightweight* mocks for external knowledge tests. - - These helpers are intentionally small and explicit: - - - They avoid pulling in unnecessary fixtures. - - They reflect the minimal contract that the service under test cares about. - """ - - @staticmethod - def create_external_api( - api_id: str = "api-123", - tenant_id: str = "tenant-1", - name: str = "Test API", - description: str = "Description", - settings: dict[str, Any] | None = None, - ) -> ExternalKnowledgeApis: - """ - Create a concrete ``ExternalKnowledgeApis`` instance with minimal fields. - - Using the real SQLAlchemy model (instead of a pure Mock) makes it easier to - exercise ``settings_dict`` and other convenience properties if needed. - """ - - instance = ExternalKnowledgeApis( - tenant_id=tenant_id, - name=name, - description=description, - settings=None if settings is None else cast(str, pytest.approx), # type: ignore[assignment] - ) - - # Overwrite generated id for determinism in assertions. - instance.id = api_id - return instance - - @staticmethod - def create_dataset( - dataset_id: str = "ds-1", - tenant_id: str = "tenant-1", - name: str = "External Dataset", - provider: str = "external", - ) -> Dataset: - """ - Build a small ``Dataset`` instance representing an external dataset. - """ - - dataset = Dataset( - tenant_id=tenant_id, - name=name, - description="", - provider=provider, - created_by="user-1", - ) - dataset.id = dataset_id - return dataset - - @staticmethod - def create_external_binding( - tenant_id: str = "tenant-1", - dataset_id: str = "ds-1", - api_id: str = "api-1", - external_knowledge_id: str = "knowledge-1", - ) -> ExternalKnowledgeBindings: - """ - Small helper for a binding between dataset and external knowledge API. - """ - - binding = ExternalKnowledgeBindings( - tenant_id=tenant_id, - dataset_id=dataset_id, - external_knowledge_api_id=api_id, - external_knowledge_id=external_knowledge_id, - created_by="user-1", - ) - return binding - - -# --------------------------------------------------------------------------- -# get_external_knowledge_apis -# --------------------------------------------------------------------------- - - -class TestExternalDatasetServiceGetExternalKnowledgeApis: - """ - Tests for ``ExternalDatasetService.get_external_knowledge_apis``. - - These tests focus on: - - - Basic pagination wiring via ``db.paginate``. - - Optional search keyword behaviour. - """ - - @pytest.fixture - def mock_db_paginate(self): - """ - Patch ``db.paginate`` so we do not touch the real database layer. - """ - - with ( - patch("services.external_knowledge_service.db.paginate", autospec=True) as mock_paginate, - patch("services.external_knowledge_service.select", autospec=True), - ): - yield mock_paginate - - def test_get_external_knowledge_apis_basic_pagination(self, mock_db_paginate: MagicMock): - """ - It should return ``items`` and ``total`` coming from the paginate object. - """ - - # Arrange - tenant_id = "tenant-1" - page = 1 - per_page = 20 - - mock_items = [Mock(spec=ExternalKnowledgeApis), Mock(spec=ExternalKnowledgeApis)] - mock_pagination = SimpleNamespace(items=mock_items, total=42) - mock_db_paginate.return_value = mock_pagination - - # Act - items, total = ExternalDatasetService.get_external_knowledge_apis(page, per_page, tenant_id) - - # Assert - assert items is mock_items - assert total == 42 - - mock_db_paginate.assert_called_once() - call_kwargs = mock_db_paginate.call_args.kwargs - assert call_kwargs["page"] == page - assert call_kwargs["per_page"] == per_page - assert call_kwargs["max_per_page"] == 100 - assert call_kwargs["error_out"] is False - - def test_get_external_knowledge_apis_with_search_keyword(self, mock_db_paginate: MagicMock): - """ - When a search keyword is provided, the query should be adjusted - (we simply assert that paginate is still called and does not explode). - """ - - # Arrange - tenant_id = "tenant-1" - page = 2 - per_page = 10 - search = "foo" - - mock_pagination = SimpleNamespace(items=[], total=0) - mock_db_paginate.return_value = mock_pagination - - # Act - items, total = ExternalDatasetService.get_external_knowledge_apis(page, per_page, tenant_id, search=search) - - # Assert - assert items == [] - assert total == 0 - mock_db_paginate.assert_called_once() - - -# --------------------------------------------------------------------------- -# validate_api_list -# --------------------------------------------------------------------------- - - -class TestExternalDatasetServiceValidateApiList: - """ - Lightweight validation tests for ``validate_api_list``. - """ - - def test_validate_api_list_success(self): - """ - A minimal valid configuration (endpoint + api_key) should pass. - """ - - config = {"endpoint": "https://example.com", "api_key": "secret"} - - # Act & Assert – no exception expected - ExternalDatasetService.validate_api_list(config) - - @pytest.mark.parametrize( - ("config", "expected_message"), - [ - ({}, "api list is empty"), - ({"api_key": "k"}, "endpoint is required"), - ({"endpoint": "https://example.com"}, "api_key is required"), - ], - ) - def test_validate_api_list_failures(self, config: dict[str, Any], expected_message: str): - """ - Invalid configs should raise ``ValueError`` with a clear message. - """ - - with pytest.raises(ValueError, match=expected_message): - ExternalDatasetService.validate_api_list(config) - - -# --------------------------------------------------------------------------- -# create_external_knowledge_api & get/update/delete -# --------------------------------------------------------------------------- - - -class TestExternalDatasetServiceCrudExternalKnowledgeApi: - """ - CRUD tests for external knowledge API templates. - """ - - @pytest.fixture - def mock_db_session(self): - """ - Patch ``db.session`` for all CRUD tests in this class. - """ - - with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session: - yield mock_session - - def test_create_external_knowledge_api_success(self, mock_db_session: MagicMock): - """ - ``create_external_knowledge_api`` should persist a new record - when settings are present and valid. - """ - - tenant_id = "tenant-1" - user_id = "user-1" - args = { - "name": "API", - "description": "desc", - "settings": {"endpoint": "https://api.example.com", "api_key": "secret"}, - } - - # We do not want to actually call the remote endpoint here, so we patch the validator. - with patch.object(ExternalDatasetService, "check_endpoint_and_api_key", autospec=True) as mock_check: - result = ExternalDatasetService.create_external_knowledge_api(tenant_id, user_id, args) - - assert isinstance(result, ExternalKnowledgeApis) - mock_check.assert_called_once_with(args["settings"]) - mock_db_session.add.assert_called_once() - mock_db_session.commit.assert_called_once() - - def test_create_external_knowledge_api_missing_settings_raises(self, mock_db_session: MagicMock): - """ - Missing ``settings`` should result in a ``ValueError``. - """ - - tenant_id = "tenant-1" - user_id = "user-1" - args = {"name": "API", "description": "desc"} - - with pytest.raises(ValueError, match="settings is required"): - ExternalDatasetService.create_external_knowledge_api(tenant_id, user_id, args) - - mock_db_session.add.assert_not_called() - mock_db_session.commit.assert_not_called() - - def test_get_external_knowledge_api_found(self, mock_db_session: MagicMock): - """ - ``get_external_knowledge_api`` should return the first matching record. - """ - - api = Mock(spec=ExternalKnowledgeApis) - mock_db_session.scalar.return_value = api - - result = ExternalDatasetService.get_external_knowledge_api("api-id", "tenant-id") - assert result is api - - def test_get_external_knowledge_api_not_found_raises(self, mock_db_session: MagicMock): - """ - When the record is absent, a ``ValueError`` is raised. - """ - - mock_db_session.scalar.return_value = None - - with pytest.raises(ValueError, match="api template not found"): - ExternalDatasetService.get_external_knowledge_api("missing-id", "tenant-id") - - def test_update_external_knowledge_api_success_with_hidden_api_key(self, mock_db_session: MagicMock): - """ - Updating an API should keep the existing API key when the special hidden - value placeholder is sent from the UI. - """ - - tenant_id = "tenant-1" - user_id = "user-1" - api_id = "api-1" - - existing_api = Mock(spec=ExternalKnowledgeApis) - existing_api.settings_dict = {"api_key": "stored-key"} - existing_api.settings = '{"api_key":"stored-key"}' - mock_db_session.scalar.return_value = existing_api - - args = { - "name": "New Name", - "description": "New Desc", - "settings": {"endpoint": "https://api.example.com", "api_key": HIDDEN_VALUE}, - } - - result = ExternalDatasetService.update_external_knowledge_api(tenant_id, user_id, api_id, args) - - assert result is existing_api - # The placeholder should be replaced with stored key. - assert args["settings"]["api_key"] == "stored-key" - mock_db_session.commit.assert_called_once() - - def test_update_external_knowledge_api_not_found_raises(self, mock_db_session: MagicMock): - """ - Updating a non‑existent API template should raise ``ValueError``. - """ - - mock_db_session.scalar.return_value = None - - with pytest.raises(ValueError, match="api template not found"): - ExternalDatasetService.update_external_knowledge_api( - tenant_id="tenant-1", - user_id="user-1", - external_knowledge_api_id="missing-id", - args={"name": "n", "description": "d", "settings": {}}, - ) - - def test_delete_external_knowledge_api_success(self, mock_db_session: MagicMock): - """ - ``delete_external_knowledge_api`` should delete and commit when found. - """ - - api = Mock(spec=ExternalKnowledgeApis) - mock_db_session.scalar.return_value = api - - ExternalDatasetService.delete_external_knowledge_api("tenant-1", "api-1") - - mock_db_session.delete.assert_called_once_with(api) - mock_db_session.commit.assert_called_once() - - def test_delete_external_knowledge_api_not_found_raises(self, mock_db_session: MagicMock): - """ - Deletion of a missing template should raise ``ValueError``. - """ - - mock_db_session.scalar.return_value = None - - with pytest.raises(ValueError, match="api template not found"): - ExternalDatasetService.delete_external_knowledge_api("tenant-1", "missing") - - -# --------------------------------------------------------------------------- -# external_knowledge_api_use_check & binding lookups -# --------------------------------------------------------------------------- - - -class TestExternalDatasetServiceUsageAndBindings: - """ - Tests for usage checks and dataset binding retrieval. - """ - - @pytest.fixture - def mock_db_session(self): - with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session: - yield mock_session - - def test_external_knowledge_api_use_check_in_use(self, mock_db_session: MagicMock): - """ - When there are bindings, ``external_knowledge_api_use_check`` returns True and count. - """ - - mock_db_session.scalar.return_value = 3 - - in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1", "tenant-1") - - assert in_use is True - assert count == 3 - assert "tenant_id" in str(mock_db_session.scalar.call_args.args[0]) - - def test_external_knowledge_api_use_check_not_in_use(self, mock_db_session: MagicMock): - """ - Zero bindings should return ``(False, 0)``. - """ - - mock_db_session.scalar.return_value = 0 - - in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1", "tenant-1") - - assert in_use is False - assert count == 0 - - def test_get_external_knowledge_binding_with_dataset_id_found(self, mock_db_session: MagicMock): - """ - Binding lookup should return the first record when present. - """ - - binding = Mock(spec=ExternalKnowledgeBindings) - mock_db_session.scalar.return_value = binding - - result = ExternalDatasetService.get_external_knowledge_binding_with_dataset_id("tenant-1", "ds-1") - assert result is binding - - def test_get_external_knowledge_binding_with_dataset_id_not_found_raises(self, mock_db_session: MagicMock): - """ - Missing binding should result in a ``ValueError``. - """ - - mock_db_session.scalar.return_value = None - - with pytest.raises(ValueError, match="external knowledge binding not found"): - ExternalDatasetService.get_external_knowledge_binding_with_dataset_id("tenant-1", "ds-1") - - -# --------------------------------------------------------------------------- -# document_create_args_validate -# --------------------------------------------------------------------------- - - -class TestExternalDatasetServiceDocumentCreateArgsValidate: - """ - Tests for ``document_create_args_validate``. - """ - - @pytest.fixture - def mock_db_session(self): - with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session: - yield mock_session - - def test_document_create_args_validate_success(self, mock_db_session: MagicMock): - """ - All required custom parameters present – validation should pass. - """ - - external_api = Mock(spec=ExternalKnowledgeApis) - external_api.settings = json_settings = ( - '[{"document_process_setting":[{"name":"foo","required":true},{"name":"bar","required":false}]}]' - ) - # Raw string; the service itself calls json.loads on it - mock_db_session.scalar.return_value = external_api - - process_parameter = {"foo": "value", "bar": "optional"} - - # Act & Assert – no exception - ExternalDatasetService.document_create_args_validate("tenant-1", "api-1", process_parameter) - - assert json_settings in external_api.settings # simple sanity check on our test data - - def test_document_create_args_validate_missing_template_raises(self, mock_db_session: MagicMock): - """ - When the referenced API template is missing, a ``ValueError`` is raised. - """ - - mock_db_session.scalar.return_value = None - - with pytest.raises(ValueError, match="api template not found"): - ExternalDatasetService.document_create_args_validate("tenant-1", "missing", {}) - - def test_document_create_args_validate_missing_required_parameter_raises(self, mock_db_session: MagicMock): - """ - Required document process parameters must be supplied. - """ - - external_api = Mock(spec=ExternalKnowledgeApis) - external_api.settings = ( - '[{"document_process_setting":[{"name":"foo","required":true},{"name":"bar","required":false}]}]' - ) - mock_db_session.scalar.return_value = external_api - - process_parameter = {"bar": "present"} # missing "foo" - - with pytest.raises(ValueError, match="foo is required"): - ExternalDatasetService.document_create_args_validate("tenant-1", "api-1", process_parameter) - - -# --------------------------------------------------------------------------- -# process_external_api -# --------------------------------------------------------------------------- - - -class TestExternalDatasetServiceProcessExternalApi: - """ - Tests focused on the HTTP request assembly and method mapping behaviour. - """ - - def test_process_external_api_valid_method_post(self): - """ - For a supported HTTP verb we should delegate to the correct ``ssrf_proxy`` function. - """ - - settings = ExternalKnowledgeApiSetting( - url="https://example.com/path", - request_method="POST", - headers={"X-Test": "1"}, - params={"foo": "bar"}, - ) - - fake_response = httpx.Response(200) - - with patch("services.external_knowledge_service.ssrf_proxy.post", autospec=True) as mock_post: - mock_post.return_value = fake_response - - result = ExternalDatasetService.process_external_api(settings, files=None) - - assert result is fake_response - mock_post.assert_called_once() - kwargs = mock_post.call_args.kwargs - assert kwargs["url"] == settings.url - assert kwargs["headers"] == settings.headers - assert kwargs["follow_redirects"] is True - assert "data" in kwargs - - def test_process_external_api_invalid_method_raises(self): - """ - An unsupported HTTP verb should raise ``InvalidHttpMethodError``. - """ - - settings = ExternalKnowledgeApiSetting( - url="https://example.com", - request_method="INVALID", - headers=None, - params={}, - ) - - from graphon.nodes.http_request.exc import InvalidHttpMethodError - - with pytest.raises(InvalidHttpMethodError): - ExternalDatasetService.process_external_api(settings, files=None) - - -# --------------------------------------------------------------------------- -# assembling_headers -# --------------------------------------------------------------------------- - - -class TestExternalDatasetServiceAssemblingHeaders: - """ - Tests for header assembly based on different authentication flavours. - """ - - def test_assembling_headers_bearer_token(self): - """ - For bearer auth we expect ``Authorization: Bearer `` by default. - """ - - auth = Authorization( - type="api-key", - config=AuthorizationConfig(type="bearer", api_key="secret", header=None), - ) - - headers = ExternalDatasetService.assembling_headers(auth) - - assert headers["Authorization"] == "Bearer secret" - - def test_assembling_headers_basic_token_with_custom_header(self): - """ - For basic auth we honour the configured header name. - """ - - auth = Authorization( - type="api-key", - config=AuthorizationConfig(type="basic", api_key="abc123", header="X-Auth"), - ) - - headers = ExternalDatasetService.assembling_headers(auth, headers={"Existing": "1"}) - - assert headers["Existing"] == "1" - assert headers["X-Auth"] == "Basic abc123" - - def test_assembling_headers_custom_type(self): - """ - Custom auth type should inject the raw API key. - """ - - auth = Authorization( - type="api-key", - config=AuthorizationConfig(type="custom", api_key="raw-key", header="X-API-KEY"), - ) - - headers = ExternalDatasetService.assembling_headers(auth, headers=None) - - assert headers["X-API-KEY"] == "raw-key" - - def test_assembling_headers_missing_config_raises(self): - """ - Missing config object should be rejected. - """ - - auth = Authorization(type="api-key", config=None) - - with pytest.raises(ValueError, match="authorization config is required"): - ExternalDatasetService.assembling_headers(auth) - - def test_assembling_headers_missing_api_key_raises(self): - """ - ``api_key`` is required when type is ``api-key``. - """ - - auth = Authorization( - type="api-key", - config=AuthorizationConfig(type="bearer", api_key=None, header="Authorization"), - ) - - with pytest.raises(ValueError, match="api_key is required"): - ExternalDatasetService.assembling_headers(auth) - - def test_assembling_headers_no_auth_type_leaves_headers_unchanged(self): - """ - For ``no-auth`` we should not modify the headers mapping. - """ - - auth = Authorization(type="no-auth", config=None) - - base_headers = {"X": "1"} - result = ExternalDatasetService.assembling_headers(auth, headers=base_headers) - - # A copy is returned, original is not mutated. - assert result == base_headers - assert result is not base_headers - - -# --------------------------------------------------------------------------- -# get_external_knowledge_api_settings -# --------------------------------------------------------------------------- - - -class TestExternalDatasetServiceGetExternalKnowledgeApiSettings: - """ - Simple shape test for ``get_external_knowledge_api_settings``. - """ - - def test_get_external_knowledge_api_settings(self): - settings_dict: dict[str, Any] = { - "url": "https://example.com/retrieval", - "request_method": "post", - "headers": {"Content-Type": "application/json"}, - "params": {"foo": "bar"}, - } - - result = ExternalDatasetService.get_external_knowledge_api_settings(settings_dict) - - assert isinstance(result, ExternalKnowledgeApiSetting) - assert result.url == settings_dict["url"] - assert result.request_method == settings_dict["request_method"] - assert result.headers == settings_dict["headers"] - assert result.params == settings_dict["params"] - - -# --------------------------------------------------------------------------- -# create_external_dataset -# --------------------------------------------------------------------------- - - -class TestExternalDatasetServiceCreateExternalDataset: - """ - Tests around creating the external dataset and its binding row. - """ - - @pytest.fixture - def mock_db_session(self): - with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session: - yield mock_session - - def test_create_external_dataset_success(self, mock_db_session: MagicMock): - """ - A brand new dataset name with valid external knowledge references - should create both the dataset and its binding. - """ - - tenant_id = "tenant-1" - user_id = "user-1" - - args = { - "name": "My Dataset", - "description": "desc", - "external_knowledge_api_id": "api-1", - "external_knowledge_id": "knowledge-1", - "external_retrieval_model": {"top_k": 3}, - } - - # No existing dataset with same name. - mock_db_session.scalar.side_effect = [ - None, # duplicate‑name check - Mock(spec=ExternalKnowledgeApis), # external knowledge api - ] - - dataset = ExternalDatasetService.create_external_dataset(tenant_id, user_id, args) - - assert isinstance(dataset, Dataset) - assert dataset.provider == "external" - assert dataset.retrieval_model == args["external_retrieval_model"] - - assert mock_db_session.add.call_count >= 2 # dataset + binding - mock_db_session.flush.assert_called_once() - mock_db_session.commit.assert_called_once() - - def test_create_external_dataset_duplicate_name_raises(self, mock_db_session: MagicMock): - """ - When a dataset with the same name already exists, - ``DatasetNameDuplicateError`` is raised. - """ - - existing_dataset = Mock(spec=Dataset) - mock_db_session.scalar.return_value = existing_dataset - - args = { - "name": "Existing", - "external_knowledge_api_id": "api-1", - "external_knowledge_id": "knowledge-1", - } - - with pytest.raises(DatasetNameDuplicateError): - ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args) - - mock_db_session.add.assert_not_called() - mock_db_session.commit.assert_not_called() - - def test_create_external_dataset_missing_api_template_raises(self, mock_db_session: MagicMock): - """ - If the referenced external knowledge API does not exist, a ``ValueError`` is raised. - """ - - # First call: duplicate name check – not found. - mock_db_session.scalar.side_effect = [ - None, - None, # external knowledge api lookup - ] - - args = { - "name": "Dataset", - "external_knowledge_api_id": "missing", - "external_knowledge_id": "knowledge-1", - } - - with pytest.raises(ValueError, match="api template not found"): - ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args) - - def test_create_external_dataset_missing_required_ids_raise(self, mock_db_session: MagicMock): - """ - ``external_knowledge_id`` and ``external_knowledge_api_id`` are mandatory. - """ - - # duplicate name check — two calls to create_external_dataset, each does 2 scalar calls - mock_db_session.scalar.side_effect = [ - None, - Mock(spec=ExternalKnowledgeApis), - None, - Mock(spec=ExternalKnowledgeApis), - ] - - args_missing_knowledge_id = { - "name": "Dataset", - "external_knowledge_api_id": "api-1", - "external_knowledge_id": None, - } - - with pytest.raises(ValueError, match="external_knowledge_id is required"): - ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args_missing_knowledge_id) - - args_missing_api_id = { - "name": "Dataset", - "external_knowledge_api_id": None, - "external_knowledge_id": "k-1", - } - - with pytest.raises(ValueError, match="external_knowledge_api_id is required"): - ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args_missing_api_id) - - -# --------------------------------------------------------------------------- -# fetch_external_knowledge_retrieval -# --------------------------------------------------------------------------- - - -class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval: - """ - Tests for ``fetch_external_knowledge_retrieval`` which orchestrates - external retrieval requests and normalises the response payload. - """ - - @pytest.fixture - def mock_db_session(self): - with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session: - yield mock_session - - def test_fetch_external_knowledge_retrieval_success(self, mock_db_session: MagicMock): - """ - With a valid binding and API template, records from the external - service should be returned when the HTTP response is 200. - """ - - tenant_id = "tenant-1" - dataset_id = "ds-1" - query = "test query" - external_retrieval_parameters = {"top_k": 3, "score_threshold_enabled": True, "score_threshold": 0.5} - - binding = ExternalDatasetTestDataFactory.create_external_binding( - tenant_id=tenant_id, - dataset_id=dataset_id, - api_id="api-1", - external_knowledge_id="knowledge-1", - ) - - api = Mock(spec=ExternalKnowledgeApis) - api.settings = '{"endpoint":"https://example.com","api_key":"secret"}' - - # First query: binding; second query: api. - mock_db_session.scalar.side_effect = [ - binding, - api, - ] - - fake_records = [{"content": "doc", "score": 0.9}] - fake_response = Mock(spec=httpx.Response) - fake_response.status_code = 200 - fake_response.json.return_value = {"records": fake_records} - - metadata_condition = SimpleNamespace(model_dump=lambda: {"field": "value"}) - - with patch.object( - ExternalDatasetService, "process_external_api", return_value=fake_response, autospec=True - ) as mock_process: - result = ExternalDatasetService.fetch_external_knowledge_retrieval( - tenant_id=tenant_id, - dataset_id=dataset_id, - query=query, - external_retrieval_parameters=external_retrieval_parameters, - metadata_condition=metadata_condition, - ) - - assert result == fake_records - - mock_process.assert_called_once() - setting_arg = mock_process.call_args.args[0] - assert isinstance(setting_arg, ExternalKnowledgeApiSetting) - assert setting_arg.url.endswith("/retrieval") - - def test_fetch_external_knowledge_retrieval_binding_not_found_raises(self, mock_db_session: MagicMock): - """ - Missing binding should raise ``ValueError``. - """ - - mock_db_session.scalar.return_value = None - - with pytest.raises(ValueError, match="external knowledge binding not found"): - ExternalDatasetService.fetch_external_knowledge_retrieval( - tenant_id="tenant-1", - dataset_id="missing", - query="q", - external_retrieval_parameters={}, - metadata_condition=None, - ) - - def test_fetch_external_knowledge_retrieval_missing_api_template_raises(self, mock_db_session: MagicMock): - """ - When the API template is missing or has no settings, a ``ValueError`` is raised. - """ - - binding = ExternalDatasetTestDataFactory.create_external_binding() - mock_db_session.scalar.side_effect = [ - binding, - None, - ] - - with pytest.raises(ValueError, match="external api template not found"): - ExternalDatasetService.fetch_external_knowledge_retrieval( - tenant_id="tenant-1", - dataset_id="ds-1", - query="q", - external_retrieval_parameters={}, - metadata_condition=None, - ) - - def test_fetch_external_knowledge_retrieval_non_200_status_returns_empty_list(self, mock_db_session: MagicMock): - """ - Non‑200 responses should be treated as an empty result set. - """ - - binding = ExternalDatasetTestDataFactory.create_external_binding() - api = Mock(spec=ExternalKnowledgeApis) - api.settings = '{"endpoint":"https://example.com","api_key":"secret"}' - - mock_db_session.scalar.side_effect = [ - binding, - api, - ] - - fake_response = Mock(spec=httpx.Response) - fake_response.status_code = 500 - fake_response.json.return_value = {} - - with patch.object(ExternalDatasetService, "process_external_api", return_value=fake_response, autospec=True): - result = ExternalDatasetService.fetch_external_knowledge_retrieval( - tenant_id="tenant-1", - dataset_id="ds-1", - query="q", - external_retrieval_parameters={}, - metadata_condition=None, - ) - - assert result == [] diff --git a/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_service.py b/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_service.py index 327281d07f..efb79aadde 100644 --- a/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_service.py +++ b/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_service.py @@ -374,24 +374,14 @@ def test_publish_workflow_success(mocker, rag_pipeline_service) -> None: mock_db = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db", mock_db) mock_dataset_service_class = mocker.patch("services.dataset_service.DatasetService") - mock_dataset_service = mock_dataset_service_class.return_value - # 6. Mock session and its scalar/query methods + # 6. Mock session and dataset lookup mock_session = mocker.Mock() mock_session.scalar.return_value = draft_wf - # Mock dataset update query (needed even if service is mocked, as rag_pipeline fetches it first) dataset = mocker.Mock() dataset.retrieval_model_dict = {} - dataset_query = mocker.Mock() - dataset_query.where.return_value.first.return_value = dataset - - # Mock node execution copy - node_exec_query = mocker.Mock() - node_exec_query.where.return_value.all.return_value = [] - - # Mocked session query side effects - mock_session.query.side_effect = [node_exec_query, dataset_query] + pipeline.retrieve_dataset.return_value = dataset # 7. Run test result = rag_pipeline_service.publish_workflow(session=mock_session, pipeline=pipeline, account=account) @@ -1524,7 +1514,6 @@ def test_handle_node_run_result_marks_document_error_for_published_invoke(mocker ) document = SimpleNamespace(indexing_status="waiting", error=None) - query = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=document) add_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.add") commit_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit") @@ -1595,7 +1584,6 @@ def test_publish_customized_pipeline_template_raises_for_missing_workflow_id(moc def test_get_pipeline_raises_when_dataset_missing(mocker, rag_pipeline_service) -> None: - query = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None) with pytest.raises(ValueError, match="Dataset not found"): @@ -1604,7 +1592,6 @@ def test_get_pipeline_raises_when_dataset_missing(mocker, rag_pipeline_service) def test_get_pipeline_raises_when_pipeline_missing(mocker, rag_pipeline_service) -> None: dataset = SimpleNamespace(pipeline_id="p1") - query = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, None]) with pytest.raises(ValueError, match="Pipeline not found"): @@ -1644,7 +1631,6 @@ def test_get_pipeline_templates_builtin_en_us_no_fallback(mocker) -> None: def test_update_customized_pipeline_template_commits_when_name_empty(mocker) -> None: template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None) - query = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=template) commit = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit") mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1")) @@ -1871,7 +1857,6 @@ def test_run_free_workflow_node_delegates_to_handle_result(mocker, rag_pipeline_ def test_publish_customized_pipeline_template_raises_when_workflow_missing(mocker, rag_pipeline_service) -> None: pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id="wf-1") - query = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", side_effect=[pipeline, None]) with pytest.raises(ValueError, match="Workflow not found"): @@ -1910,7 +1895,6 @@ def test_get_recommended_plugins_skips_manifest_when_missing(mocker, rag_pipelin def test_retry_error_document_raises_when_pipeline_missing(mocker, rag_pipeline_service) -> None: exec_log = SimpleNamespace(pipeline_id="p1") - query = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=exec_log) mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=None) @@ -1923,7 +1907,6 @@ def test_retry_error_document_raises_when_pipeline_missing(mocker, rag_pipeline_ def test_retry_error_document_raises_when_workflow_missing(mocker, rag_pipeline_service) -> None: exec_log = SimpleNamespace(pipeline_id="p1") pipeline = SimpleNamespace(id="p1") - query = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=exec_log) mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline) mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=None) @@ -1940,7 +1923,6 @@ def test_get_datasource_plugins_returns_empty_for_non_datasource_nodes(mocker, r workflow = SimpleNamespace( graph_dict={"nodes": [{"id": "n1", "data": {"type": "start"}}]}, rag_pipeline_variables=[] ) - query = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline]) mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow) @@ -2103,7 +2085,6 @@ def test_get_datasource_plugins_handles_empty_datasource_data_and_non_published( graph_dict={"nodes": [{"id": "n1", "data": {"type": "datasource", "datasource_parameters": {}}}]}, rag_pipeline_variables=[{"variable": "v1", "belong_to_node_id": "shared"}], ) - query = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline]) mocker.patch.object(rag_pipeline_service, "get_draft_workflow", return_value=workflow) mocker.patch( @@ -2143,7 +2124,6 @@ def test_get_datasource_plugins_extracts_user_inputs_and_credentials(mocker, rag {"variable": "v3", "belong_to_node_id": "shared"}, ], ) - query = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline]) mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow) mocker.patch( @@ -2161,7 +2141,6 @@ def test_get_datasource_plugins_extracts_user_inputs_and_credentials(mocker, rag def test_get_pipeline_returns_pipeline_when_found(mocker, rag_pipeline_service) -> None: dataset = SimpleNamespace(pipeline_id="p1") pipeline = SimpleNamespace(id="p1") - query = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline]) result = rag_pipeline_service.get_pipeline("t1", "d1") diff --git a/api/tests/unit_tests/services/segment_service.py b/api/tests/unit_tests/services/segment_service.py deleted file mode 100644 index f0a66a00d4..0000000000 --- a/api/tests/unit_tests/services/segment_service.py +++ /dev/null @@ -1,1115 +0,0 @@ -from unittest.mock import MagicMock, Mock, patch - -import pytest - -from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType -from models.account import Account -from models.dataset import ChildChunk, Dataset, Document, DocumentSegment -from models.enums import SegmentType -from services.dataset_service import SegmentService -from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs -from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError - - -class SegmentTestDataFactory: - """Factory class for creating test data and mock objects for segment service tests.""" - - @staticmethod - def create_segment_mock( - segment_id: str = "segment-123", - document_id: str = "doc-123", - dataset_id: str = "dataset-123", - tenant_id: str = "tenant-123", - content: str = "Test segment content", - position: int = 1, - enabled: bool = True, - status: str = "completed", - word_count: int = 3, - tokens: int = 5, - **kwargs, - ) -> Mock: - """Create a mock segment with specified attributes.""" - segment = Mock(spec=DocumentSegment) - segment.id = segment_id - segment.document_id = document_id - segment.dataset_id = dataset_id - segment.tenant_id = tenant_id - segment.content = content - segment.position = position - segment.enabled = enabled - segment.status = status - segment.word_count = word_count - segment.tokens = tokens - segment.index_node_id = f"node-{segment_id}" - segment.index_node_hash = "hash-123" - segment.keywords = [] - segment.answer = None - segment.disabled_at = None - segment.disabled_by = None - segment.updated_by = None - segment.updated_at = None - segment.indexing_at = None - segment.completed_at = None - segment.error = None - for key, value in kwargs.items(): - setattr(segment, key, value) - return segment - - @staticmethod - def create_child_chunk_mock( - chunk_id: str = "chunk-123", - segment_id: str = "segment-123", - document_id: str = "doc-123", - dataset_id: str = "dataset-123", - tenant_id: str = "tenant-123", - content: str = "Test child chunk content", - position: int = 1, - word_count: int = 3, - **kwargs, - ) -> Mock: - """Create a mock child chunk with specified attributes.""" - chunk = Mock(spec=ChildChunk) - chunk.id = chunk_id - chunk.segment_id = segment_id - chunk.document_id = document_id - chunk.dataset_id = dataset_id - chunk.tenant_id = tenant_id - chunk.content = content - chunk.position = position - chunk.word_count = word_count - chunk.index_node_id = f"node-{chunk_id}" - chunk.index_node_hash = "hash-123" - chunk.type = SegmentType.AUTOMATIC - chunk.created_by = "user-123" - chunk.updated_by = None - chunk.updated_at = None - for key, value in kwargs.items(): - setattr(chunk, key, value) - return chunk - - @staticmethod - def create_document_mock( - document_id: str = "doc-123", - dataset_id: str = "dataset-123", - tenant_id: str = "tenant-123", - doc_form: str = IndexStructureType.PARAGRAPH_INDEX, - word_count: int = 100, - **kwargs, - ) -> Mock: - """Create a mock document with specified attributes.""" - document = Mock(spec=Document) - document.id = document_id - document.dataset_id = dataset_id - document.tenant_id = tenant_id - document.doc_form = doc_form - document.word_count = word_count - for key, value in kwargs.items(): - setattr(document, key, value) - return document - - @staticmethod - def create_dataset_mock( - dataset_id: str = "dataset-123", - tenant_id: str = "tenant-123", - indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, - embedding_model: str = "text-embedding-ada-002", - embedding_model_provider: str = "openai", - **kwargs, - ) -> Mock: - """Create a mock dataset with specified attributes.""" - dataset = Mock(spec=Dataset) - dataset.id = dataset_id - dataset.tenant_id = tenant_id - dataset.indexing_technique = indexing_technique - dataset.embedding_model = embedding_model - dataset.embedding_model_provider = embedding_model_provider - for key, value in kwargs.items(): - setattr(dataset, key, value) - return dataset - - @staticmethod - def create_user_mock( - user_id: str = "user-789", - tenant_id: str = "tenant-123", - **kwargs, - ) -> Mock: - """Create a mock user with specified attributes.""" - user = Mock(spec=Account) - user.id = user_id - user.current_tenant_id = tenant_id - user.name = "Test User" - for key, value in kwargs.items(): - setattr(user, key, value) - return user - - -class TestSegmentServiceCreateSegment: - """Tests for SegmentService.create_segment method.""" - - @pytest.fixture - def mock_db_session(self): - """Mock database session.""" - with patch("services.dataset_service.db.session", autospec=True) as mock_db: - yield mock_db - - @pytest.fixture - def mock_current_user(self): - """Mock current_user.""" - user = SegmentTestDataFactory.create_user_mock() - with patch("services.dataset_service.current_user", user): - yield user - - def test_create_segment_success(self, mock_db_session, mock_current_user): - """Test successful creation of a segment.""" - # Arrange - document = SegmentTestDataFactory.create_document_mock(word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) - args = {"content": "New segment content", "keywords": ["test", "segment"]} - - mock_query = MagicMock() - mock_query.where.return_value.scalar.return_value = None # No existing segments - mock_db_session.query.return_value = mock_query - - mock_segment = SegmentTestDataFactory.create_segment_mock() - mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment - - with ( - patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock, - patch( - "services.dataset_service.VectorService.create_segments_vector", autospec=True - ) as mock_vector_service, - patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash, - patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, - ): - mock_lock.return_value.__enter__ = Mock() - mock_lock.return_value.__exit__ = Mock(return_value=None) - mock_hash.return_value = "hash-123" - mock_now.return_value = "2024-01-01T00:00:00" - - # Act - result = SegmentService.create_segment(args, document, dataset) - - # Assert - assert mock_db_session.add.call_count == 2 - - created_segment = mock_db_session.add.call_args_list[0].args[0] - assert isinstance(created_segment, DocumentSegment) - assert created_segment.content == args["content"] - assert created_segment.word_count == len(args["content"]) - - mock_db_session.commit.assert_called_once() - - mock_vector_service.assert_called_once() - vector_call_args = mock_vector_service.call_args[0] - assert vector_call_args[0] == [args["keywords"]] - assert vector_call_args[1][0] == created_segment - assert vector_call_args[2] == dataset - assert vector_call_args[3] == document.doc_form - - assert result == mock_segment - - def test_create_segment_with_qa_model(self, mock_db_session, mock_current_user): - """Test creation of segment with QA model (requires answer).""" - # Arrange - document = SegmentTestDataFactory.create_document_mock(doc_form=IndexStructureType.QA_INDEX, word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) - args = {"content": "What is AI?", "answer": "AI is Artificial Intelligence", "keywords": ["ai"]} - - mock_query = MagicMock() - mock_query.where.return_value.scalar.return_value = None - mock_db_session.query.return_value = mock_query - - mock_segment = SegmentTestDataFactory.create_segment_mock() - mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment - - with ( - patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock, - patch( - "services.dataset_service.VectorService.create_segments_vector", autospec=True - ) as mock_vector_service, - patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash, - patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, - ): - mock_lock.return_value.__enter__ = Mock() - mock_lock.return_value.__exit__ = Mock(return_value=None) - mock_hash.return_value = "hash-123" - mock_now.return_value = "2024-01-01T00:00:00" - - # Act - result = SegmentService.create_segment(args, document, dataset) - - # Assert - assert result == mock_segment - mock_db_session.add.assert_called() - mock_db_session.commit.assert_called() - - def test_create_segment_with_high_quality_indexing(self, mock_db_session, mock_current_user): - """Test creation of segment with high quality indexing technique.""" - # Arrange - document = SegmentTestDataFactory.create_document_mock(word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) - args = {"content": "New segment content", "keywords": ["test"]} - - mock_query = MagicMock() - mock_query.where.return_value.scalar.return_value = None - mock_db_session.query.return_value = mock_query - - mock_embedding_model = MagicMock() - mock_embedding_model.get_text_embedding_num_tokens.return_value = [10] - mock_model_manager = MagicMock() - mock_model_manager.get_model_instance.return_value = mock_embedding_model - - mock_segment = SegmentTestDataFactory.create_segment_mock() - mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment - - with ( - patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock, - patch( - "services.dataset_service.VectorService.create_segments_vector", autospec=True - ) as mock_vector_service, - patch("services.dataset_service.ModelManager.for_tenant", autospec=True) as mock_model_manager_class, - patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash, - patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, - ): - mock_lock.return_value.__enter__ = Mock() - mock_lock.return_value.__exit__ = Mock(return_value=None) - mock_model_manager_class.return_value = mock_model_manager - mock_hash.return_value = "hash-123" - mock_now.return_value = "2024-01-01T00:00:00" - - # Act - result = SegmentService.create_segment(args, document, dataset) - - # Assert - assert result == mock_segment - mock_model_manager.get_model_instance.assert_called_once() - mock_embedding_model.get_text_embedding_num_tokens.assert_called_once() - - def test_create_segment_vector_index_failure(self, mock_db_session, mock_current_user): - """Test segment creation when vector indexing fails.""" - # Arrange - document = SegmentTestDataFactory.create_document_mock(word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) - args = {"content": "New segment content", "keywords": ["test"]} - - mock_query = MagicMock() - mock_query.where.return_value.scalar.return_value = None - mock_db_session.query.return_value = mock_query - - mock_segment = SegmentTestDataFactory.create_segment_mock(enabled=False, status="error") - mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment - - with ( - patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock, - patch( - "services.dataset_service.VectorService.create_segments_vector", autospec=True - ) as mock_vector_service, - patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash, - patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, - ): - mock_lock.return_value.__enter__ = Mock() - mock_lock.return_value.__exit__ = Mock(return_value=None) - mock_vector_service.side_effect = Exception("Vector indexing failed") - mock_hash.return_value = "hash-123" - mock_now.return_value = "2024-01-01T00:00:00" - - # Act - result = SegmentService.create_segment(args, document, dataset) - - # Assert - assert result == mock_segment - assert mock_db_session.commit.call_count == 2 # Once for creation, once for error update - - -class TestSegmentServiceUpdateSegment: - """Tests for SegmentService.update_segment method.""" - - @pytest.fixture - def mock_db_session(self): - """Mock database session.""" - with patch("services.dataset_service.db.session", autospec=True) as mock_db: - yield mock_db - - @pytest.fixture - def mock_current_user(self): - """Mock current_user.""" - user = SegmentTestDataFactory.create_user_mock() - with patch("services.dataset_service.current_user", user): - yield user - - def test_update_segment_content_success(self, mock_db_session, mock_current_user): - """Test successful update of segment content.""" - # Arrange - segment = SegmentTestDataFactory.create_segment_mock(enabled=True, word_count=10) - document = SegmentTestDataFactory.create_document_mock(word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) - args = SegmentUpdateArgs(content="Updated content", keywords=["updated"]) - - mock_db_session.query.return_value.where.return_value.first.return_value = segment - - with ( - patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get, - patch("services.dataset_service.VectorService.update_segment_vector", autospec=True) as mock_vector_service, - patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash, - patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, - ): - mock_redis_get.return_value = None # Not indexing - mock_hash.return_value = "new-hash" - mock_now.return_value = "2024-01-01T00:00:00" - - # Act - result = SegmentService.update_segment(args, segment, document, dataset) - - # Assert - assert result == segment - assert segment.content == "Updated content" - assert segment.keywords == ["updated"] - assert segment.word_count == len("Updated content") - assert document.word_count == 100 + (len("Updated content") - 10) - mock_db_session.add.assert_called() - mock_db_session.commit.assert_called() - - def test_update_segment_disable(self, mock_db_session, mock_current_user): - """Test disabling a segment.""" - # Arrange - segment = SegmentTestDataFactory.create_segment_mock(enabled=True) - document = SegmentTestDataFactory.create_document_mock() - dataset = SegmentTestDataFactory.create_dataset_mock() - args = SegmentUpdateArgs(enabled=False) - - with ( - patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get, - patch("services.dataset_service.redis_client.setex", autospec=True) as mock_redis_setex, - patch("services.dataset_service.disable_segment_from_index_task", autospec=True) as mock_task, - patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, - ): - mock_redis_get.return_value = None - mock_now.return_value = "2024-01-01T00:00:00" - - # Act - result = SegmentService.update_segment(args, segment, document, dataset) - - # Assert - assert result == segment - assert segment.enabled is False - mock_db_session.add.assert_called() - mock_db_session.commit.assert_called() - mock_task.delay.assert_called_once() - - def test_update_segment_indexing_in_progress(self, mock_db_session, mock_current_user): - """Test update fails when segment is currently indexing.""" - # Arrange - segment = SegmentTestDataFactory.create_segment_mock(enabled=True) - document = SegmentTestDataFactory.create_document_mock() - dataset = SegmentTestDataFactory.create_dataset_mock() - args = SegmentUpdateArgs(content="Updated content") - - with patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get: - mock_redis_get.return_value = "1" # Indexing in progress - - # Act & Assert - with pytest.raises(ValueError, match="Segment is indexing"): - SegmentService.update_segment(args, segment, document, dataset) - - def test_update_segment_disabled_segment(self, mock_db_session, mock_current_user): - """Test update fails when segment is disabled.""" - # Arrange - segment = SegmentTestDataFactory.create_segment_mock(enabled=False) - document = SegmentTestDataFactory.create_document_mock() - dataset = SegmentTestDataFactory.create_dataset_mock() - args = SegmentUpdateArgs(content="Updated content") - - with patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get: - mock_redis_get.return_value = None - - # Act & Assert - with pytest.raises(ValueError, match="Can't update disabled segment"): - SegmentService.update_segment(args, segment, document, dataset) - - def test_update_segment_with_qa_model(self, mock_db_session, mock_current_user): - """Test update segment with QA model (includes answer).""" - # Arrange - segment = SegmentTestDataFactory.create_segment_mock(enabled=True, word_count=10) - document = SegmentTestDataFactory.create_document_mock(doc_form=IndexStructureType.QA_INDEX, word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) - args = SegmentUpdateArgs(content="Updated question", answer="Updated answer", keywords=["qa"]) - - mock_db_session.query.return_value.where.return_value.first.return_value = segment - - with ( - patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get, - patch("services.dataset_service.VectorService.update_segment_vector", autospec=True) as mock_vector_service, - patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash, - patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, - ): - mock_redis_get.return_value = None - mock_hash.return_value = "new-hash" - mock_now.return_value = "2024-01-01T00:00:00" - - # Act - result = SegmentService.update_segment(args, segment, document, dataset) - - # Assert - assert result == segment - assert segment.content == "Updated question" - assert segment.answer == "Updated answer" - assert segment.keywords == ["qa"] - new_word_count = len("Updated question") + len("Updated answer") - assert segment.word_count == new_word_count - assert document.word_count == 100 + (new_word_count - 10) - mock_db_session.commit.assert_called() - - -class TestSegmentServiceDeleteSegment: - """Tests for SegmentService.delete_segment method.""" - - @pytest.fixture - def mock_db_session(self): - """Mock database session.""" - with patch("services.dataset_service.db.session", autospec=True) as mock_db: - yield mock_db - - def test_delete_segment_success(self, mock_db_session): - """Test successful deletion of a segment.""" - # Arrange - segment = SegmentTestDataFactory.create_segment_mock(enabled=True, word_count=50) - document = SegmentTestDataFactory.create_document_mock(word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock() - - mock_scalars = MagicMock() - mock_scalars.all.return_value = [] - mock_db_session.scalars.return_value = mock_scalars - - with ( - patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get, - patch("services.dataset_service.redis_client.setex", autospec=True) as mock_redis_setex, - patch("services.dataset_service.delete_segment_from_index_task", autospec=True) as mock_task, - patch("services.dataset_service.select", autospec=True) as mock_select, - ): - mock_redis_get.return_value = None - mock_select.return_value.where.return_value = mock_select - - # Act - SegmentService.delete_segment(segment, document, dataset) - - # Assert - mock_db_session.delete.assert_called_once_with(segment) - mock_db_session.commit.assert_called_once() - mock_task.delay.assert_called_once() - - def test_delete_segment_disabled(self, mock_db_session): - """Test deletion of disabled segment (no index deletion).""" - # Arrange - segment = SegmentTestDataFactory.create_segment_mock(enabled=False, word_count=50) - document = SegmentTestDataFactory.create_document_mock(word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock() - - with ( - patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get, - patch("services.dataset_service.delete_segment_from_index_task", autospec=True) as mock_task, - ): - mock_redis_get.return_value = None - - # Act - SegmentService.delete_segment(segment, document, dataset) - - # Assert - mock_db_session.delete.assert_called_once_with(segment) - mock_db_session.commit.assert_called_once() - mock_task.delay.assert_not_called() - - def test_delete_segment_indexing_in_progress(self, mock_db_session): - """Test deletion fails when segment is currently being deleted.""" - # Arrange - segment = SegmentTestDataFactory.create_segment_mock(enabled=True) - document = SegmentTestDataFactory.create_document_mock() - dataset = SegmentTestDataFactory.create_dataset_mock() - - with patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get: - mock_redis_get.return_value = "1" # Deletion in progress - - # Act & Assert - with pytest.raises(ValueError, match="Segment is deleting"): - SegmentService.delete_segment(segment, document, dataset) - - -class TestSegmentServiceDeleteSegments: - """Tests for SegmentService.delete_segments method.""" - - @pytest.fixture - def mock_db_session(self): - """Mock database session.""" - with patch("services.dataset_service.db.session", autospec=True) as mock_db: - yield mock_db - - @pytest.fixture - def mock_current_user(self): - """Mock current_user.""" - user = SegmentTestDataFactory.create_user_mock() - with patch("services.dataset_service.current_user", user): - yield user - - def test_delete_segments_success(self, mock_db_session, mock_current_user): - """Test successful deletion of multiple segments.""" - # Arrange - segment_ids = ["segment-1", "segment-2"] - document = SegmentTestDataFactory.create_document_mock(word_count=200) - dataset = SegmentTestDataFactory.create_dataset_mock() - - segments_info = [ - ("node-1", "segment-1", 50), - ("node-2", "segment-2", 30), - ] - - mock_query = MagicMock() - mock_query.with_entities.return_value.where.return_value.all.return_value = segments_info - mock_db_session.query.return_value = mock_query - - mock_scalars = MagicMock() - mock_scalars.all.return_value = [] - mock_select = MagicMock() - mock_select.where.return_value = mock_select - mock_db_session.scalars.return_value = mock_scalars - - with ( - patch("services.dataset_service.delete_segment_from_index_task", autospec=True) as mock_task, - patch("services.dataset_service.select", autospec=True) as mock_select_func, - ): - mock_select_func.return_value = mock_select - - # Act - SegmentService.delete_segments(segment_ids, document, dataset) - - # Assert - mock_db_session.query.return_value.where.return_value.delete.assert_called_once() - mock_db_session.commit.assert_called_once() - mock_task.delay.assert_called_once() - - def test_delete_segments_empty_list(self, mock_db_session, mock_current_user): - """Test deletion with empty list (should return early).""" - # Arrange - document = SegmentTestDataFactory.create_document_mock() - dataset = SegmentTestDataFactory.create_dataset_mock() - - # Act - SegmentService.delete_segments([], document, dataset) - - # Assert - mock_db_session.query.assert_not_called() - - -class TestSegmentServiceUpdateSegmentsStatus: - """Tests for SegmentService.update_segments_status method.""" - - @pytest.fixture - def mock_db_session(self): - """Mock database session.""" - with patch("services.dataset_service.db.session", autospec=True) as mock_db: - yield mock_db - - @pytest.fixture - def mock_current_user(self): - """Mock current_user.""" - user = SegmentTestDataFactory.create_user_mock() - with patch("services.dataset_service.current_user", user): - yield user - - def test_update_segments_status_enable(self, mock_db_session, mock_current_user): - """Test enabling multiple segments.""" - # Arrange - segment_ids = ["segment-1", "segment-2"] - document = SegmentTestDataFactory.create_document_mock() - dataset = SegmentTestDataFactory.create_dataset_mock() - - segments = [ - SegmentTestDataFactory.create_segment_mock(segment_id="segment-1", enabled=False), - SegmentTestDataFactory.create_segment_mock(segment_id="segment-2", enabled=False), - ] - - mock_scalars = MagicMock() - mock_scalars.all.return_value = segments - mock_select = MagicMock() - mock_select.where.return_value = mock_select - mock_db_session.scalars.return_value = mock_scalars - - with ( - patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get, - patch("services.dataset_service.enable_segments_to_index_task", autospec=True) as mock_task, - patch("services.dataset_service.select", autospec=True) as mock_select_func, - ): - mock_redis_get.return_value = None - mock_select_func.return_value = mock_select - - # Act - SegmentService.update_segments_status(segment_ids, "enable", dataset, document) - - # Assert - assert all(seg.enabled is True for seg in segments) - mock_db_session.commit.assert_called_once() - mock_task.delay.assert_called_once() - - def test_update_segments_status_disable(self, mock_db_session, mock_current_user): - """Test disabling multiple segments.""" - # Arrange - segment_ids = ["segment-1", "segment-2"] - document = SegmentTestDataFactory.create_document_mock() - dataset = SegmentTestDataFactory.create_dataset_mock() - - segments = [ - SegmentTestDataFactory.create_segment_mock(segment_id="segment-1", enabled=True), - SegmentTestDataFactory.create_segment_mock(segment_id="segment-2", enabled=True), - ] - - mock_scalars = MagicMock() - mock_scalars.all.return_value = segments - mock_select = MagicMock() - mock_select.where.return_value = mock_select - mock_db_session.scalars.return_value = mock_scalars - - with ( - patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get, - patch("services.dataset_service.disable_segments_from_index_task", autospec=True) as mock_task, - patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, - patch("services.dataset_service.select", autospec=True) as mock_select_func, - ): - mock_redis_get.return_value = None - mock_now.return_value = "2024-01-01T00:00:00" - mock_select_func.return_value = mock_select - - # Act - SegmentService.update_segments_status(segment_ids, "disable", dataset, document) - - # Assert - assert all(seg.enabled is False for seg in segments) - mock_db_session.commit.assert_called_once() - mock_task.delay.assert_called_once() - - def test_update_segments_status_empty_list(self, mock_db_session, mock_current_user): - """Test update with empty list (should return early).""" - # Arrange - document = SegmentTestDataFactory.create_document_mock() - dataset = SegmentTestDataFactory.create_dataset_mock() - - # Act - SegmentService.update_segments_status([], "enable", dataset, document) - - # Assert - mock_db_session.scalars.assert_not_called() - - -class TestSegmentServiceGetSegments: - """Tests for SegmentService.get_segments method.""" - - @pytest.fixture - def mock_db_session(self): - """Mock database session.""" - with patch("services.dataset_service.db.session", autospec=True) as mock_db: - yield mock_db - - @pytest.fixture - def mock_current_user(self): - """Mock current_user.""" - user = SegmentTestDataFactory.create_user_mock() - with patch("services.dataset_service.current_user", user): - yield user - - def test_get_segments_success(self, mock_db_session, mock_current_user): - """Test successful retrieval of segments.""" - # Arrange - document_id = "doc-123" - tenant_id = "tenant-123" - segments = [ - SegmentTestDataFactory.create_segment_mock(segment_id="segment-1"), - SegmentTestDataFactory.create_segment_mock(segment_id="segment-2"), - ] - - mock_paginate = MagicMock() - mock_paginate.items = segments - mock_paginate.total = 2 - mock_db_session.paginate.return_value = mock_paginate - - # Act - items, total = SegmentService.get_segments(document_id, tenant_id) - - # Assert - assert len(items) == 2 - assert total == 2 - mock_db_session.paginate.assert_called_once() - - def test_get_segments_with_status_filter(self, mock_db_session, mock_current_user): - """Test retrieval with status filter.""" - # Arrange - document_id = "doc-123" - tenant_id = "tenant-123" - status_list = ["completed", "error"] - - mock_paginate = MagicMock() - mock_paginate.items = [] - mock_paginate.total = 0 - mock_db_session.paginate.return_value = mock_paginate - - # Act - items, total = SegmentService.get_segments(document_id, tenant_id, status_list=status_list) - - # Assert - assert len(items) == 0 - assert total == 0 - - def test_get_segments_with_keyword(self, mock_db_session, mock_current_user): - """Test retrieval with keyword search.""" - # Arrange - document_id = "doc-123" - tenant_id = "tenant-123" - keyword = "test" - - mock_paginate = MagicMock() - mock_paginate.items = [SegmentTestDataFactory.create_segment_mock()] - mock_paginate.total = 1 - mock_db_session.paginate.return_value = mock_paginate - - # Act - items, total = SegmentService.get_segments(document_id, tenant_id, keyword=keyword) - - # Assert - assert len(items) == 1 - assert total == 1 - - -class TestSegmentServiceGetSegmentById: - """Tests for SegmentService.get_segment_by_id method.""" - - @pytest.fixture - def mock_db_session(self): - """Mock database session.""" - with patch("services.dataset_service.db.session", autospec=True) as mock_db: - yield mock_db - - def test_get_segment_by_id_success(self, mock_db_session): - """Test successful retrieval of segment by ID.""" - # Arrange - segment_id = "segment-123" - tenant_id = "tenant-123" - segment = SegmentTestDataFactory.create_segment_mock(segment_id=segment_id) - - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = segment - mock_db_session.query.return_value = mock_query - - # Act - result = SegmentService.get_segment_by_id(segment_id, tenant_id) - - # Assert - assert result == segment - - def test_get_segment_by_id_not_found(self, mock_db_session): - """Test retrieval when segment is not found.""" - # Arrange - segment_id = "non-existent" - tenant_id = "tenant-123" - - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = None - mock_db_session.query.return_value = mock_query - - # Act - result = SegmentService.get_segment_by_id(segment_id, tenant_id) - - # Assert - assert result is None - - -class TestSegmentServiceGetChildChunks: - """Tests for SegmentService.get_child_chunks method.""" - - @pytest.fixture - def mock_db_session(self): - """Mock database session.""" - with patch("services.dataset_service.db.session", autospec=True) as mock_db: - yield mock_db - - @pytest.fixture - def mock_current_user(self): - """Mock current_user.""" - user = SegmentTestDataFactory.create_user_mock() - with patch("services.dataset_service.current_user", user): - yield user - - def test_get_child_chunks_success(self, mock_db_session, mock_current_user): - """Test successful retrieval of child chunks.""" - # Arrange - segment_id = "segment-123" - document_id = "doc-123" - dataset_id = "dataset-123" - page = 1 - limit = 20 - - mock_paginate = MagicMock() - mock_paginate.items = [ - SegmentTestDataFactory.create_child_chunk_mock(chunk_id="chunk-1"), - SegmentTestDataFactory.create_child_chunk_mock(chunk_id="chunk-2"), - ] - mock_paginate.total = 2 - mock_db_session.paginate.return_value = mock_paginate - - # Act - result = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit) - - # Assert - assert result == mock_paginate - mock_db_session.paginate.assert_called_once() - - def test_get_child_chunks_with_keyword(self, mock_db_session, mock_current_user): - """Test retrieval with keyword search.""" - # Arrange - segment_id = "segment-123" - document_id = "doc-123" - dataset_id = "dataset-123" - page = 1 - limit = 20 - keyword = "test" - - mock_paginate = MagicMock() - mock_paginate.items = [] - mock_paginate.total = 0 - mock_db_session.paginate.return_value = mock_paginate - - # Act - result = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword=keyword) - - # Assert - assert result == mock_paginate - - -class TestSegmentServiceGetChildChunkById: - """Tests for SegmentService.get_child_chunk_by_id method.""" - - @pytest.fixture - def mock_db_session(self): - """Mock database session.""" - with patch("services.dataset_service.db.session", autospec=True) as mock_db: - yield mock_db - - def test_get_child_chunk_by_id_success(self, mock_db_session): - """Test successful retrieval of child chunk by ID.""" - # Arrange - chunk_id = "chunk-123" - tenant_id = "tenant-123" - chunk = SegmentTestDataFactory.create_child_chunk_mock(chunk_id=chunk_id) - - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = chunk - mock_db_session.query.return_value = mock_query - - # Act - result = SegmentService.get_child_chunk_by_id(chunk_id, tenant_id) - - # Assert - assert result == chunk - - def test_get_child_chunk_by_id_not_found(self, mock_db_session): - """Test retrieval when child chunk is not found.""" - # Arrange - chunk_id = "non-existent" - tenant_id = "tenant-123" - - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = None - mock_db_session.query.return_value = mock_query - - # Act - result = SegmentService.get_child_chunk_by_id(chunk_id, tenant_id) - - # Assert - assert result is None - - -class TestSegmentServiceCreateChildChunk: - """Tests for SegmentService.create_child_chunk method.""" - - @pytest.fixture - def mock_db_session(self): - """Mock database session.""" - with patch("services.dataset_service.db.session", autospec=True) as mock_db: - yield mock_db - - @pytest.fixture - def mock_current_user(self): - """Mock current_user.""" - user = SegmentTestDataFactory.create_user_mock() - with patch("services.dataset_service.current_user", user): - yield user - - def test_create_child_chunk_success(self, mock_db_session, mock_current_user): - """Test successful creation of a child chunk.""" - # Arrange - content = "New child chunk content" - segment = SegmentTestDataFactory.create_segment_mock() - document = SegmentTestDataFactory.create_document_mock() - dataset = SegmentTestDataFactory.create_dataset_mock() - - mock_query = MagicMock() - mock_query.where.return_value.scalar.return_value = None - mock_db_session.query.return_value = mock_query - - with ( - patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock, - patch( - "services.dataset_service.VectorService.create_child_chunk_vector", autospec=True - ) as mock_vector_service, - patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash, - ): - mock_lock.return_value.__enter__ = Mock() - mock_lock.return_value.__exit__ = Mock(return_value=None) - mock_hash.return_value = "hash-123" - - # Act - result = SegmentService.create_child_chunk(content, segment, document, dataset) - - # Assert - assert result is not None - mock_db_session.add.assert_called_once() - mock_db_session.commit.assert_called_once() - mock_vector_service.assert_called_once() - - def test_create_child_chunk_vector_index_failure(self, mock_db_session, mock_current_user): - """Test child chunk creation when vector indexing fails.""" - # Arrange - content = "New child chunk content" - segment = SegmentTestDataFactory.create_segment_mock() - document = SegmentTestDataFactory.create_document_mock() - dataset = SegmentTestDataFactory.create_dataset_mock() - - mock_query = MagicMock() - mock_query.where.return_value.scalar.return_value = None - mock_db_session.query.return_value = mock_query - - with ( - patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock, - patch( - "services.dataset_service.VectorService.create_child_chunk_vector", autospec=True - ) as mock_vector_service, - patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash, - ): - mock_lock.return_value.__enter__ = Mock() - mock_lock.return_value.__exit__ = Mock(return_value=None) - mock_vector_service.side_effect = Exception("Vector indexing failed") - mock_hash.return_value = "hash-123" - - # Act & Assert - with pytest.raises(ChildChunkIndexingError): - SegmentService.create_child_chunk(content, segment, document, dataset) - - mock_db_session.rollback.assert_called_once() - - -class TestSegmentServiceUpdateChildChunk: - """Tests for SegmentService.update_child_chunk method.""" - - @pytest.fixture - def mock_db_session(self): - """Mock database session.""" - with patch("services.dataset_service.db.session", autospec=True) as mock_db: - yield mock_db - - @pytest.fixture - def mock_current_user(self): - """Mock current_user.""" - user = SegmentTestDataFactory.create_user_mock() - with patch("services.dataset_service.current_user", user): - yield user - - def test_update_child_chunk_success(self, mock_db_session, mock_current_user): - """Test successful update of a child chunk.""" - # Arrange - content = "Updated child chunk content" - chunk = SegmentTestDataFactory.create_child_chunk_mock() - segment = SegmentTestDataFactory.create_segment_mock() - document = SegmentTestDataFactory.create_document_mock() - dataset = SegmentTestDataFactory.create_dataset_mock() - - with ( - patch( - "services.dataset_service.VectorService.update_child_chunk_vector", autospec=True - ) as mock_vector_service, - patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, - ): - mock_now.return_value = "2024-01-01T00:00:00" - - # Act - result = SegmentService.update_child_chunk(content, chunk, segment, document, dataset) - - # Assert - assert result == chunk - assert chunk.content == content - assert chunk.word_count == len(content) - mock_db_session.add.assert_called_once_with(chunk) - mock_db_session.commit.assert_called_once() - mock_vector_service.assert_called_once() - - def test_update_child_chunk_vector_index_failure(self, mock_db_session, mock_current_user): - """Test child chunk update when vector indexing fails.""" - # Arrange - content = "Updated content" - chunk = SegmentTestDataFactory.create_child_chunk_mock() - segment = SegmentTestDataFactory.create_segment_mock() - document = SegmentTestDataFactory.create_document_mock() - dataset = SegmentTestDataFactory.create_dataset_mock() - - with ( - patch( - "services.dataset_service.VectorService.update_child_chunk_vector", autospec=True - ) as mock_vector_service, - patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, - ): - mock_vector_service.side_effect = Exception("Vector indexing failed") - mock_now.return_value = "2024-01-01T00:00:00" - - # Act & Assert - with pytest.raises(ChildChunkIndexingError): - SegmentService.update_child_chunk(content, chunk, segment, document, dataset) - - mock_db_session.rollback.assert_called_once() - - -class TestSegmentServiceDeleteChildChunk: - """Tests for SegmentService.delete_child_chunk method.""" - - @pytest.fixture - def mock_db_session(self): - """Mock database session.""" - with patch("services.dataset_service.db.session", autospec=True) as mock_db: - yield mock_db - - def test_delete_child_chunk_success(self, mock_db_session): - """Test successful deletion of a child chunk.""" - # Arrange - chunk = SegmentTestDataFactory.create_child_chunk_mock() - dataset = SegmentTestDataFactory.create_dataset_mock() - - with patch( - "services.dataset_service.VectorService.delete_child_chunk_vector", autospec=True - ) as mock_vector_service: - # Act - SegmentService.delete_child_chunk(chunk, dataset) - - # Assert - mock_db_session.delete.assert_called_once_with(chunk) - mock_db_session.commit.assert_called_once() - mock_vector_service.assert_called_once_with(chunk, dataset) - - def test_delete_child_chunk_vector_index_failure(self, mock_db_session): - """Test child chunk deletion when vector indexing fails.""" - # Arrange - chunk = SegmentTestDataFactory.create_child_chunk_mock() - dataset = SegmentTestDataFactory.create_dataset_mock() - - with patch( - "services.dataset_service.VectorService.delete_child_chunk_vector", autospec=True - ) as mock_vector_service: - mock_vector_service.side_effect = Exception("Vector deletion failed") - - # Act & Assert - with pytest.raises(ChildChunkDeleteIndexError): - SegmentService.delete_child_chunk(chunk, dataset) - - mock_db_session.rollback.assert_called_once() diff --git a/api/tests/unit_tests/services/services_test_help.py b/api/tests/unit_tests/services/services_test_help.py deleted file mode 100644 index c6b962f7fc..0000000000 --- a/api/tests/unit_tests/services/services_test_help.py +++ /dev/null @@ -1,59 +0,0 @@ -from unittest.mock import MagicMock - - -class ServiceDbTestHelper: - """ - Helper class for service database query tests. - """ - - @staticmethod - def setup_db_query_filter_by_mock(mock_db, query_results): - """ - Smart database query mock that responds based on model type and query parameters. - - Args: - mock_db: Mock database session - query_results: Dict mapping (model_name, filter_key, filter_value) to return value - Example: {('Account', 'email', 'test@example.com'): mock_account} - """ - - def query_side_effect(model): - mock_query = MagicMock() - - def filter_by_side_effect(**kwargs): - mock_filter_result = MagicMock() - - def first_side_effect(): - # Find matching result based on model and filter parameters - for (model_name, filter_key, filter_value), result in query_results.items(): - if model.__name__ == model_name and filter_key in kwargs and kwargs[filter_key] == filter_value: - return result - return None - - mock_filter_result.first.side_effect = first_side_effect - - # Handle order_by calls for complex queries - def order_by_side_effect(*args, **kwargs): - mock_order_result = MagicMock() - - def order_first_side_effect(): - # Look for order_by results in the same query_results dict - for (model_name, filter_key, filter_value), result in query_results.items(): - if ( - model.__name__ == model_name - and filter_key == "order_by" - and filter_value == "first_available" - ): - return result - return None - - mock_order_result.first.side_effect = order_first_side_effect - return mock_order_result - - mock_filter_result.order_by.side_effect = order_by_side_effect - return mock_filter_result - - mock_query.filter_by.side_effect = filter_by_side_effect - return mock_query - - mock_db.session.query.side_effect = query_side_effect diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index c4f5f57153..e9d2f1481e 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -14,7 +14,6 @@ from services.errors.account import ( AccountRegisterError, CurrentPasswordIncorrectError, ) -from tests.unit_tests.services.services_test_help import ServiceDbTestHelper class TestAccountAssociatedDataFactory: @@ -149,7 +148,6 @@ class TestAccountService: # Setup basic session methods mock_session.add = MagicMock() mock_session.commit = MagicMock() - mock_session.query = MagicMock() yield mock_db @@ -1572,15 +1570,9 @@ class TestRegisterService: account_id="existing-user-456", email="existing@example.com", status="active" ) - # Mock database queries - query_results = { - ( - "TenantAccountJoin", - "tenant_id", - "tenant-456", - ): TestAccountAssociatedDataFactory.create_tenant_join_mock(), - } - ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) + mock_db_dependencies[ + "db" + ].session.scalar.return_value = TestAccountAssociatedDataFactory.create_tenant_join_mock() # Mock TenantService methods with ( diff --git a/api/tests/unit_tests/services/test_dataset_service_segment.py b/api/tests/unit_tests/services/test_dataset_service_segment.py index d6c104708c..5cfef76719 100644 --- a/api/tests/unit_tests/services/test_dataset_service_segment.py +++ b/api/tests/unit_tests/services/test_dataset_service_segment.py @@ -714,7 +714,6 @@ class TestSegmentServiceMutations: patch("services.dataset_service.db") as mock_db, patch("services.dataset_service.delete_segment_from_index_task") as delete_task, ): - segments_query = MagicMock() # execute().all() for segments_info (multi-column) execute_result = MagicMock() execute_result.all.return_value = [ diff --git a/api/tests/unit_tests/services/test_datasource_provider_service.py b/api/tests/unit_tests/services/test_datasource_provider_service.py index d304e0ec44..c389c4a635 100644 --- a/api/tests/unit_tests/services/test_datasource_provider_service.py +++ b/api/tests/unit_tests/services/test_datasource_provider_service.py @@ -36,9 +36,7 @@ class TestDatasourceProviderService: @pytest.fixture def mock_db_session(self): """ - Robust, chainable query mock. - q returns itself for .filter_by(), .order_by(), .where() so any - SQLAlchemy chaining pattern works without multiple brittle sub-mocks. + Mock session with scalar/scalars defaults for current SQLAlchemy access paths. """ with ( patch("services.datasource_provider_service.Session") as mock_cls, @@ -46,20 +44,6 @@ class TestDatasourceProviderService: ): sess = MagicMock(spec=Session) - q = MagicMock() - sess.query.return_value = q - - # Self-returning chain — any method called on q returns q - q.filter_by.return_value = q - q.order_by.return_value = q - q.where.return_value = q - - # Default terminal values (tests override per-case) - q.first.return_value = None - q.all.return_value = [] - q.count.return_value = 0 - q.delete.return_value = 1 - # Default values for select()-style calls (tests override per-case) sess.scalar.return_value = None sess.scalars.return_value.all.return_value = [] diff --git a/api/tests/unit_tests/services/test_webhook_service_additional.py b/api/tests/unit_tests/services/test_webhook_service_additional.py index 776cb5dc3f..491dd94842 100644 --- a/api/tests/unit_tests/services/test_webhook_service_additional.py +++ b/api/tests/unit_tests/services/test_webhook_service_additional.py @@ -17,23 +17,6 @@ from services.trigger import webhook_service as service_module from services.trigger.webhook_service import WebhookService -class _FakeQuery: - def __init__(self, result: Any) -> None: - self._result = result - - def where(self, *args: Any, **kwargs: Any) -> "_FakeQuery": - return self - - def filter(self, *args: Any, **kwargs: Any) -> "_FakeQuery": - return self - - def order_by(self, *args: Any, **kwargs: Any) -> "_FakeQuery": - return self - - def first(self) -> Any: - return self._result - - @pytest.fixture def flask_app() -> Flask: return Flask(__name__) diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py index 0015e8b908..feafada59a 100644 --- a/api/tests/unit_tests/services/test_workflow_service.py +++ b/api/tests/unit_tests/services/test_workflow_service.py @@ -1649,8 +1649,6 @@ class TestWorkflowServiceCredentialValidation: """Missing BuiltinToolProvider → plugin requires no credentials → no error.""" # Arrange with patch("services.workflow_service.db") as mock_db: - mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = None - # Act + Assert (should NOT raise) service._check_default_tool_credential("tenant-1", "some-provider") @@ -1662,10 +1660,6 @@ class TestWorkflowServiceCredentialValidation: patch("services.workflow_service.db") as mock_db, patch("core.helper.credential_utils.check_credential_policy_compliance", side_effect=Exception("denied")), ): - mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = ( - mock_provider - ) - # Act + Assert with pytest.raises(ValueError, match="Failed to validate default credential"): service._check_default_tool_credential("tenant-1", "some-provider") diff --git a/api/tests/unit_tests/services/vector_service.py b/api/tests/unit_tests/services/vector_service.py deleted file mode 100644 index ad80beb4e3..0000000000 --- a/api/tests/unit_tests/services/vector_service.py +++ /dev/null @@ -1,1793 +0,0 @@ -""" -Comprehensive unit tests for VectorService and Vector classes. - -This module contains extensive unit tests for the VectorService and Vector -classes, which are critical components in the RAG (Retrieval-Augmented Generation) -pipeline that handle vector database operations, collection management, embedding -storage and retrieval, and metadata filtering. - -The VectorService provides methods for: -- Creating vector embeddings for document segments -- Updating segment vector embeddings -- Generating child chunks for hierarchical indexing -- Managing child chunk vectors (create, update, delete) - -The Vector class provides methods for: -- Vector database operations (create, add, delete, search) -- Collection creation and management with Redis locking -- Embedding storage and retrieval -- Vector index operations (HNSW, L2 distance, etc.) -- Metadata filtering in vector space -- Support for multiple vector database backends - -This test suite ensures: -- Correct vector database operations -- Proper collection creation and management -- Accurate embedding storage and retrieval -- Comprehensive vector search functionality -- Metadata filtering and querying -- Error conditions are handled correctly -- Edge cases are properly validated - -================================================================================ -ARCHITECTURE OVERVIEW -================================================================================ - -The Vector service system is a critical component that bridges document -segments and vector databases, enabling semantic search and retrieval. - -1. VectorService: - - High-level service for managing vector operations on document segments - - Handles both regular segments and hierarchical (parent-child) indexing - - Integrates with IndexProcessor for document transformation - - Manages embedding model instances via ModelManager - -2. Vector Class: - - Wrapper around BaseVector implementations - - Handles embedding generation via ModelManager - - Supports multiple vector database backends (Chroma, Milvus, Qdrant, etc.) - - Manages collection creation with Redis locking for concurrency control - - Provides batch processing for large document sets - -3. BaseVector Abstract Class: - - Defines interface for vector database operations - - Implemented by various vector database backends - - Provides methods for CRUD operations on vectors - - Supports both vector similarity search and full-text search - -4. Collection Management: - - Uses Redis locks to prevent concurrent collection creation - - Caches collection existence status in Redis - - Supports collection deletion with cache invalidation - -5. Embedding Generation: - - Uses ModelManager to get embedding model instances - - Supports cached embeddings for performance - - Handles batch processing for large document sets - - Generates embeddings for both documents and queries - -================================================================================ -TESTING STRATEGY -================================================================================ - -This test suite follows a comprehensive testing strategy that covers: - -1. VectorService Methods: - - create_segments_vector: Regular and hierarchical indexing - - update_segment_vector: Vector and keyword index updates - - generate_child_chunks: Child chunk generation with full doc mode - - create_child_chunk_vector: Child chunk vector creation - - update_child_chunk_vector: Batch child chunk updates - - delete_child_chunk_vector: Child chunk deletion - -2. Vector Class Methods: - - Initialization with dataset and attributes - - Collection creation with Redis locking - - Embedding generation and batch processing - - Vector operations (create, add_texts, delete_by_ids, etc.) - - Search operations (by vector, by full text) - - Metadata filtering and querying - - Duplicate checking logic - - Vector factory selection - -3. Integration Points: - - ModelManager integration for embedding models - - IndexProcessor integration for document transformation - - Redis integration for locking and caching - - Database session management - - Vector database backend abstraction - -4. Error Handling: - - Invalid vector store configuration - - Missing embedding models - - Collection creation failures - - Search operation errors - - Metadata filtering errors - -5. Edge Cases: - - Empty document lists - - Missing metadata fields - - Duplicate document IDs - - Large batch processing - - Concurrent collection creation - -================================================================================ -""" - -from typing import Any -from unittest.mock import Mock, patch - -import pytest - -from core.rag.datasource.vdb.vector_base import BaseVector -from core.rag.datasource.vdb.vector_factory import Vector -from core.rag.datasource.vdb.vector_type import VectorType -from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType -from core.rag.models.document import Document -from models.dataset import ChildChunk, Dataset, DatasetDocument, DatasetProcessRule, DocumentSegment -from services.vector_service import VectorService - -# ============================================================================ -# Test Data Factory -# ============================================================================ - - -class VectorServiceTestDataFactory: - """ - Factory class for creating test data and mock objects for Vector service tests. - - This factory provides static methods to create mock objects for: - - Dataset instances with various configurations - - DocumentSegment instances - - ChildChunk instances - - Document instances (RAG documents) - - Embedding model instances - - Vector processor mocks - - Index processor mocks - - The factory methods help maintain consistency across tests and reduce - code duplication when setting up test scenarios. - """ - - @staticmethod - def create_dataset_mock( - dataset_id: str = "dataset-123", - tenant_id: str = "tenant-123", - doc_form: str = IndexStructureType.PARAGRAPH_INDEX, - indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, - embedding_model_provider: str = "openai", - embedding_model: str = "text-embedding-ada-002", - index_struct_dict: dict[str, Any] | None = None, - **kwargs, - ) -> Mock: - """ - Create a mock Dataset with specified attributes. - - Args: - dataset_id: Unique identifier for the dataset - tenant_id: Tenant identifier - doc_form: Document form type - indexing_technique: Indexing technique (high_quality or economy) - embedding_model_provider: Embedding model provider - embedding_model: Embedding model name - index_struct_dict: Index structure dictionary - **kwargs: Additional attributes to set on the mock - - Returns: - Mock object configured as a Dataset instance - """ - dataset = Mock(spec=Dataset) - - dataset.id = dataset_id - - dataset.tenant_id = tenant_id - - dataset.doc_form = doc_form - - dataset.indexing_technique = indexing_technique - - dataset.embedding_model_provider = embedding_model_provider - - dataset.embedding_model = embedding_model - - dataset.index_struct_dict = index_struct_dict - - for key, value in kwargs.items(): - setattr(dataset, key, value) - - return dataset - - @staticmethod - def create_document_segment_mock( - segment_id: str = "segment-123", - document_id: str = "doc-123", - dataset_id: str = "dataset-123", - content: str = "Test segment content", - index_node_id: str = "node-123", - index_node_hash: str = "hash-123", - **kwargs, - ) -> Mock: - """ - Create a mock DocumentSegment with specified attributes. - - Args: - segment_id: Unique identifier for the segment - document_id: Parent document identifier - dataset_id: Dataset identifier - content: Segment content text - index_node_id: Index node identifier - index_node_hash: Index node hash - **kwargs: Additional attributes to set on the mock - - Returns: - Mock object configured as a DocumentSegment instance - """ - segment = Mock(spec=DocumentSegment) - - segment.id = segment_id - - segment.document_id = document_id - - segment.dataset_id = dataset_id - - segment.content = content - - segment.index_node_id = index_node_id - - segment.index_node_hash = index_node_hash - - for key, value in kwargs.items(): - setattr(segment, key, value) - - return segment - - @staticmethod - def create_child_chunk_mock( - chunk_id: str = "chunk-123", - segment_id: str = "segment-123", - document_id: str = "doc-123", - dataset_id: str = "dataset-123", - tenant_id: str = "tenant-123", - content: str = "Test child chunk content", - index_node_id: str = "node-chunk-123", - index_node_hash: str = "hash-chunk-123", - position: int = 1, - **kwargs, - ) -> Mock: - """ - Create a mock ChildChunk with specified attributes. - - Args: - chunk_id: Unique identifier for the child chunk - segment_id: Parent segment identifier - document_id: Parent document identifier - dataset_id: Dataset identifier - tenant_id: Tenant identifier - content: Child chunk content text - index_node_id: Index node identifier - index_node_hash: Index node hash - position: Position in parent segment - **kwargs: Additional attributes to set on the mock - - Returns: - Mock object configured as a ChildChunk instance - """ - chunk = Mock(spec=ChildChunk) - - chunk.id = chunk_id - - chunk.segment_id = segment_id - - chunk.document_id = document_id - - chunk.dataset_id = dataset_id - - chunk.tenant_id = tenant_id - - chunk.content = content - - chunk.index_node_id = index_node_id - - chunk.index_node_hash = index_node_hash - - chunk.position = position - - for key, value in kwargs.items(): - setattr(chunk, key, value) - - return chunk - - @staticmethod - def create_dataset_document_mock( - document_id: str = "doc-123", - dataset_id: str = "dataset-123", - tenant_id: str = "tenant-123", - dataset_process_rule_id: str = "rule-123", - doc_language: str = "en", - created_by: str = "user-123", - **kwargs, - ) -> Mock: - """ - Create a mock DatasetDocument with specified attributes. - - Args: - document_id: Unique identifier for the document - dataset_id: Dataset identifier - tenant_id: Tenant identifier - dataset_process_rule_id: Process rule identifier - doc_language: Document language - created_by: Creator user ID - **kwargs: Additional attributes to set on the mock - - Returns: - Mock object configured as a DatasetDocument instance - """ - document = Mock(spec=DatasetDocument) - - document.id = document_id - - document.dataset_id = dataset_id - - document.tenant_id = tenant_id - - document.dataset_process_rule_id = dataset_process_rule_id - - document.doc_language = doc_language - - document.created_by = created_by - - for key, value in kwargs.items(): - setattr(document, key, value) - - return document - - @staticmethod - def create_dataset_process_rule_mock( - rule_id: str = "rule-123", - **kwargs, - ) -> Mock: - """ - Create a mock DatasetProcessRule with specified attributes. - - Args: - rule_id: Unique identifier for the process rule - **kwargs: Additional attributes to set on the mock - - Returns: - Mock object configured as a DatasetProcessRule instance - """ - rule = Mock(spec=DatasetProcessRule) - - rule.id = rule_id - - rule.to_dict = Mock(return_value={"rules": {"parent_mode": "chunk"}}) - - for key, value in kwargs.items(): - setattr(rule, key, value) - - return rule - - @staticmethod - def create_rag_document_mock( - page_content: str = "Test document content", - doc_id: str = "doc-123", - doc_hash: str = "hash-123", - document_id: str = "doc-123", - dataset_id: str = "dataset-123", - **kwargs, - ) -> Document: - """ - Create a RAG Document with specified attributes. - - Args: - page_content: Document content text - doc_id: Document identifier in metadata - doc_hash: Document hash in metadata - document_id: Parent document ID in metadata - dataset_id: Dataset ID in metadata - **kwargs: Additional metadata fields - - Returns: - Document instance configured for testing - """ - metadata = { - "doc_id": doc_id, - "doc_hash": doc_hash, - "document_id": document_id, - "dataset_id": dataset_id, - } - - metadata.update(kwargs) - - return Document(page_content=page_content, metadata=metadata) - - @staticmethod - def create_embedding_model_instance_mock() -> Mock: - """ - Create a mock embedding model instance. - - Returns: - Mock object configured as an embedding model instance - """ - model_instance = Mock() - - model_instance.embed_documents = Mock(return_value=[[0.1] * 1536]) - - model_instance.embed_query = Mock(return_value=[0.1] * 1536) - - return model_instance - - @staticmethod - def create_vector_processor_mock() -> Mock: - """ - Create a mock vector processor (BaseVector implementation). - - Returns: - Mock object configured as a BaseVector instance - """ - processor = Mock(spec=BaseVector) - - processor.collection_name = "test_collection" - - processor.create = Mock() - - processor.add_texts = Mock() - - processor.text_exists = Mock(return_value=False) - - processor.delete_by_ids = Mock() - - processor.delete_by_metadata_field = Mock() - - processor.search_by_vector = Mock(return_value=[]) - - processor.search_by_full_text = Mock(return_value=[]) - - processor.delete = Mock() - - return processor - - @staticmethod - def create_index_processor_mock() -> Mock: - """ - Create a mock index processor. - - Returns: - Mock object configured as an index processor instance - """ - processor = Mock() - - processor.load = Mock() - - processor.clean = Mock() - - processor.transform = Mock(return_value=[]) - - return processor - - -# ============================================================================ -# Tests for VectorService -# ============================================================================ - - -class TestVectorService: - """ - Comprehensive unit tests for VectorService class. - - This test class covers all methods of the VectorService class, including - segment vector operations, child chunk operations, and integration with - various components like IndexProcessor and ModelManager. - """ - - # ======================================================================== - # Tests for create_segments_vector - # ======================================================================== - - @patch("services.vector_service.IndexProcessorFactory") - @patch("services.vector_service.db") - def test_create_segments_vector_regular_indexing(self, mock_db, mock_index_processor_factory): - """ - Test create_segments_vector with regular indexing (non-hierarchical). - - This test verifies that segments are correctly converted to RAG documents - and loaded into the index processor for regular indexing scenarios. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_technique=IndexTechniqueType.HIGH_QUALITY - ) - - segment = VectorServiceTestDataFactory.create_document_segment_mock() - - keywords_list = [["keyword1", "keyword2"]] - - mock_index_processor = VectorServiceTestDataFactory.create_index_processor_mock() - - mock_index_processor_factory.return_value.init_index_processor.return_value = mock_index_processor - - # Act - VectorService.create_segments_vector(keywords_list, [segment], dataset, IndexStructureType.PARAGRAPH_INDEX) - - # Assert - mock_index_processor.load.assert_called_once() - - call_args = mock_index_processor.load.call_args - - assert call_args[0][0] == dataset - - assert len(call_args[0][1]) == 1 - - assert call_args[1]["with_keywords"] is True - - assert call_args[1]["keywords_list"] == keywords_list - - @patch("services.vector_service.VectorService.generate_child_chunks") - @patch("services.vector_service.ModelManager.for_tenant") - @patch("services.vector_service.db") - def test_create_segments_vector_parent_child_indexing( - self, mock_db, mock_model_manager, mock_generate_child_chunks - ): - """ - Test create_segments_vector with parent-child indexing. - - This test verifies that for hierarchical indexing, child chunks are - generated instead of regular segment indexing. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="parent_child_model", indexing_technique=IndexTechniqueType.HIGH_QUALITY - ) - - segment = VectorServiceTestDataFactory.create_document_segment_mock() - - dataset_document = VectorServiceTestDataFactory.create_dataset_document_mock() - - processing_rule = VectorServiceTestDataFactory.create_dataset_process_rule_mock() - - mock_db.session.query.return_value.filter_by.return_value.first.return_value = dataset_document - - mock_db.session.query.return_value.where.return_value.first.return_value = processing_rule - - mock_embedding_model = VectorServiceTestDataFactory.create_embedding_model_instance_mock() - - mock_model_manager.return_value.get_model_instance.return_value = mock_embedding_model - - # Act - VectorService.create_segments_vector(None, [segment], dataset, "parent_child_model") - - # Assert - mock_generate_child_chunks.assert_called_once() - - @patch("services.vector_service.db") - def test_create_segments_vector_missing_document(self, mock_db): - """ - Test create_segments_vector when document is missing. - - This test verifies that when a document is not found, the segment - is skipped with a warning log. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="parent_child_model", indexing_technique=IndexTechniqueType.HIGH_QUALITY - ) - - segment = VectorServiceTestDataFactory.create_document_segment_mock() - - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None - - # Act - VectorService.create_segments_vector(None, [segment], dataset, "parent_child_model") - - # Assert - # Should not raise an error, just skip the segment - - @patch("services.vector_service.db") - def test_create_segments_vector_missing_processing_rule(self, mock_db): - """ - Test create_segments_vector when processing rule is missing. - - This test verifies that when a processing rule is not found, a - ValueError is raised. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="parent_child_model", indexing_technique=IndexTechniqueType.HIGH_QUALITY - ) - - segment = VectorServiceTestDataFactory.create_document_segment_mock() - - dataset_document = VectorServiceTestDataFactory.create_dataset_document_mock() - - mock_db.session.query.return_value.filter_by.return_value.first.return_value = dataset_document - - mock_db.session.query.return_value.where.return_value.first.return_value = None - - # Act & Assert - with pytest.raises(ValueError, match="No processing rule found"): - VectorService.create_segments_vector(None, [segment], dataset, "parent_child_model") - - @patch("services.vector_service.db") - def test_create_segments_vector_economy_indexing_technique(self, mock_db): - """ - Test create_segments_vector with economy indexing technique. - - This test verifies that when indexing_technique is not high_quality, - a ValueError is raised for parent-child indexing. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="parent_child_model", indexing_technique=IndexTechniqueType.ECONOMY - ) - - segment = VectorServiceTestDataFactory.create_document_segment_mock() - - dataset_document = VectorServiceTestDataFactory.create_dataset_document_mock() - - processing_rule = VectorServiceTestDataFactory.create_dataset_process_rule_mock() - - mock_db.session.query.return_value.filter_by.return_value.first.return_value = dataset_document - - mock_db.session.query.return_value.where.return_value.first.return_value = processing_rule - - # Act & Assert - with pytest.raises(ValueError, match="The knowledge base index technique is not high quality"): - VectorService.create_segments_vector(None, [segment], dataset, "parent_child_model") - - @patch("services.vector_service.IndexProcessorFactory") - @patch("services.vector_service.db") - def test_create_segments_vector_empty_documents(self, mock_db, mock_index_processor_factory): - """ - Test create_segments_vector with empty documents list. - - This test verifies that when no documents are created, the index - processor is not called. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - mock_index_processor = VectorServiceTestDataFactory.create_index_processor_mock() - - mock_index_processor_factory.return_value.init_index_processor.return_value = mock_index_processor - - # Act - VectorService.create_segments_vector(None, [], dataset, IndexStructureType.PARAGRAPH_INDEX) - - # Assert - mock_index_processor.load.assert_not_called() - - # ======================================================================== - # Tests for update_segment_vector - # ======================================================================== - - @patch("services.vector_service.Vector") - @patch("services.vector_service.db") - def test_update_segment_vector_high_quality(self, mock_db, mock_vector_class): - """ - Test update_segment_vector with high_quality indexing technique. - - This test verifies that segments are correctly updated in the vector - store when using high_quality indexing. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) - - segment = VectorServiceTestDataFactory.create_document_segment_mock() - - mock_vector = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_vector_class.return_value = mock_vector - - # Act - VectorService.update_segment_vector(None, segment, dataset) - - # Assert - mock_vector.delete_by_ids.assert_called_once_with([segment.index_node_id]) - - mock_vector.add_texts.assert_called_once() - - @patch("services.vector_service.Keyword") - @patch("services.vector_service.db") - def test_update_segment_vector_economy_with_keywords(self, mock_db, mock_keyword_class): - """ - Test update_segment_vector with economy indexing and keywords. - - This test verifies that segments are correctly updated in the keyword - index when using economy indexing with keywords. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) - - segment = VectorServiceTestDataFactory.create_document_segment_mock() - - keywords = ["keyword1", "keyword2"] - - mock_keyword = Mock() - - mock_keyword.delete_by_ids = Mock() - - mock_keyword.add_texts = Mock() - - mock_keyword_class.return_value = mock_keyword - - # Act - VectorService.update_segment_vector(keywords, segment, dataset) - - # Assert - mock_keyword.delete_by_ids.assert_called_once_with([segment.index_node_id]) - - mock_keyword.add_texts.assert_called_once() - - call_args = mock_keyword.add_texts.call_args - - assert call_args[1]["keywords_list"] == [keywords] - - @patch("services.vector_service.Keyword") - @patch("services.vector_service.db") - def test_update_segment_vector_economy_without_keywords(self, mock_db, mock_keyword_class): - """ - Test update_segment_vector with economy indexing without keywords. - - This test verifies that segments are correctly updated in the keyword - index when using economy indexing without keywords. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) - - segment = VectorServiceTestDataFactory.create_document_segment_mock() - - mock_keyword = Mock() - - mock_keyword.delete_by_ids = Mock() - - mock_keyword.add_texts = Mock() - - mock_keyword_class.return_value = mock_keyword - - # Act - VectorService.update_segment_vector(None, segment, dataset) - - # Assert - mock_keyword.delete_by_ids.assert_called_once_with([segment.index_node_id]) - - mock_keyword.add_texts.assert_called_once() - - call_args = mock_keyword.add_texts.call_args - - assert "keywords_list" not in call_args[1] or call_args[1].get("keywords_list") is None - - # ======================================================================== - # Tests for generate_child_chunks - # ======================================================================== - - @patch("services.vector_service.IndexProcessorFactory") - @patch("services.vector_service.db") - def test_generate_child_chunks_with_children(self, mock_db, mock_index_processor_factory): - """ - Test generate_child_chunks when children are generated. - - This test verifies that child chunks are correctly generated and - saved to the database when the index processor returns children. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - segment = VectorServiceTestDataFactory.create_document_segment_mock() - - dataset_document = VectorServiceTestDataFactory.create_dataset_document_mock() - - processing_rule = VectorServiceTestDataFactory.create_dataset_process_rule_mock() - - embedding_model = VectorServiceTestDataFactory.create_embedding_model_instance_mock() - - child_document = VectorServiceTestDataFactory.create_rag_document_mock( - page_content="Child content", doc_id="child-node-123" - ) - - child_document.children = [child_document] - - mock_index_processor = VectorServiceTestDataFactory.create_index_processor_mock() - - mock_index_processor.transform.return_value = [child_document] - - mock_index_processor_factory.return_value.init_index_processor.return_value = mock_index_processor - - # Act - VectorService.generate_child_chunks(segment, dataset_document, dataset, embedding_model, processing_rule, False) - - # Assert - mock_index_processor.transform.assert_called_once() - - mock_index_processor.load.assert_called_once() - - mock_db.session.add.assert_called() - - mock_db.session.commit.assert_called_once() - - @patch("services.vector_service.IndexProcessorFactory") - @patch("services.vector_service.db") - def test_generate_child_chunks_regenerate(self, mock_db, mock_index_processor_factory): - """ - Test generate_child_chunks with regenerate=True. - - This test verifies that when regenerate is True, existing child chunks - are cleaned before generating new ones. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - segment = VectorServiceTestDataFactory.create_document_segment_mock() - - dataset_document = VectorServiceTestDataFactory.create_dataset_document_mock() - - processing_rule = VectorServiceTestDataFactory.create_dataset_process_rule_mock() - - embedding_model = VectorServiceTestDataFactory.create_embedding_model_instance_mock() - - mock_index_processor = VectorServiceTestDataFactory.create_index_processor_mock() - - mock_index_processor.transform.return_value = [] - - mock_index_processor_factory.return_value.init_index_processor.return_value = mock_index_processor - - # Act - VectorService.generate_child_chunks(segment, dataset_document, dataset, embedding_model, processing_rule, True) - - # Assert - mock_index_processor.clean.assert_called_once() - - call_args = mock_index_processor.clean.call_args - - assert call_args[0][0] == dataset - - assert call_args[0][1] == [segment.index_node_id] - - assert call_args[1]["with_keywords"] is True - - assert call_args[1]["delete_child_chunks"] is True - - @patch("services.vector_service.IndexProcessorFactory") - @patch("services.vector_service.db") - def test_generate_child_chunks_no_children(self, mock_db, mock_index_processor_factory): - """ - Test generate_child_chunks when no children are generated. - - This test verifies that when the index processor returns no children, - no child chunks are saved to the database. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - segment = VectorServiceTestDataFactory.create_document_segment_mock() - - dataset_document = VectorServiceTestDataFactory.create_dataset_document_mock() - - processing_rule = VectorServiceTestDataFactory.create_dataset_process_rule_mock() - - embedding_model = VectorServiceTestDataFactory.create_embedding_model_instance_mock() - - mock_index_processor = VectorServiceTestDataFactory.create_index_processor_mock() - - mock_index_processor.transform.return_value = [] - - mock_index_processor_factory.return_value.init_index_processor.return_value = mock_index_processor - - # Act - VectorService.generate_child_chunks(segment, dataset_document, dataset, embedding_model, processing_rule, False) - - # Assert - mock_index_processor.transform.assert_called_once() - - mock_index_processor.load.assert_not_called() - - mock_db.session.add.assert_not_called() - - # ======================================================================== - # Tests for create_child_chunk_vector - # ======================================================================== - - @patch("services.vector_service.Vector") - @patch("services.vector_service.db") - def test_create_child_chunk_vector_high_quality(self, mock_db, mock_vector_class): - """ - Test create_child_chunk_vector with high_quality indexing. - - This test verifies that child chunk vectors are correctly created - when using high_quality indexing. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) - - child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() - - mock_vector = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_vector_class.return_value = mock_vector - - # Act - VectorService.create_child_chunk_vector(child_chunk, dataset) - - # Assert - mock_vector.add_texts.assert_called_once() - - call_args = mock_vector.add_texts.call_args - - assert call_args[1]["duplicate_check"] is True - - @patch("services.vector_service.Vector") - @patch("services.vector_service.db") - def test_create_child_chunk_vector_economy(self, mock_db, mock_vector_class): - """ - Test create_child_chunk_vector with economy indexing. - - This test verifies that child chunk vectors are not created when - using economy indexing. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) - - child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() - - mock_vector = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_vector_class.return_value = mock_vector - - # Act - VectorService.create_child_chunk_vector(child_chunk, dataset) - - # Assert - mock_vector.add_texts.assert_not_called() - - # ======================================================================== - # Tests for update_child_chunk_vector - # ======================================================================== - - @patch("services.vector_service.Vector") - @patch("services.vector_service.db") - def test_update_child_chunk_vector_with_all_operations(self, mock_db, mock_vector_class): - """ - Test update_child_chunk_vector with new, update, and delete operations. - - This test verifies that child chunk vectors are correctly updated - when there are new chunks, updated chunks, and deleted chunks. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) - - new_chunk = VectorServiceTestDataFactory.create_child_chunk_mock(chunk_id="new-chunk-1") - - update_chunk = VectorServiceTestDataFactory.create_child_chunk_mock(chunk_id="update-chunk-1") - - delete_chunk = VectorServiceTestDataFactory.create_child_chunk_mock(chunk_id="delete-chunk-1") - - mock_vector = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_vector_class.return_value = mock_vector - - # Act - VectorService.update_child_chunk_vector([new_chunk], [update_chunk], [delete_chunk], dataset) - - # Assert - mock_vector.delete_by_ids.assert_called_once() - - delete_ids = mock_vector.delete_by_ids.call_args[0][0] - - assert update_chunk.index_node_id in delete_ids - - assert delete_chunk.index_node_id in delete_ids - - mock_vector.add_texts.assert_called_once() - - call_args = mock_vector.add_texts.call_args - - assert len(call_args[0][0]) == 2 # new_chunk + update_chunk - - assert call_args[1]["duplicate_check"] is True - - @patch("services.vector_service.Vector") - @patch("services.vector_service.db") - def test_update_child_chunk_vector_only_new(self, mock_db, mock_vector_class): - """ - Test update_child_chunk_vector with only new chunks. - - This test verifies that when only new chunks are provided, only - add_texts is called, not delete_by_ids. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) - - new_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() - - mock_vector = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_vector_class.return_value = mock_vector - - # Act - VectorService.update_child_chunk_vector([new_chunk], [], [], dataset) - - # Assert - mock_vector.delete_by_ids.assert_not_called() - - mock_vector.add_texts.assert_called_once() - - @patch("services.vector_service.Vector") - @patch("services.vector_service.db") - def test_update_child_chunk_vector_only_delete(self, mock_db, mock_vector_class): - """ - Test update_child_chunk_vector with only deleted chunks. - - This test verifies that when only deleted chunks are provided, only - delete_by_ids is called, not add_texts. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) - - delete_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() - - mock_vector = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_vector_class.return_value = mock_vector - - # Act - VectorService.update_child_chunk_vector([], [], [delete_chunk], dataset) - - # Assert - mock_vector.delete_by_ids.assert_called_once_with([delete_chunk.index_node_id]) - - mock_vector.add_texts.assert_not_called() - - @patch("services.vector_service.Vector") - @patch("services.vector_service.db") - def test_update_child_chunk_vector_economy(self, mock_db, mock_vector_class): - """ - Test update_child_chunk_vector with economy indexing. - - This test verifies that child chunk vectors are not updated when - using economy indexing. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) - - new_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() - - mock_vector = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_vector_class.return_value = mock_vector - - # Act - VectorService.update_child_chunk_vector([new_chunk], [], [], dataset) - - # Assert - mock_vector.delete_by_ids.assert_not_called() - - mock_vector.add_texts.assert_not_called() - - # ======================================================================== - # Tests for delete_child_chunk_vector - # ======================================================================== - - @patch("services.vector_service.Vector") - @patch("services.vector_service.db") - def test_delete_child_chunk_vector_high_quality(self, mock_db, mock_vector_class): - """ - Test delete_child_chunk_vector with high_quality indexing. - - This test verifies that child chunk vectors are correctly deleted - when using high_quality indexing. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) - - child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() - - mock_vector = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_vector_class.return_value = mock_vector - - # Act - VectorService.delete_child_chunk_vector(child_chunk, dataset) - - # Assert - mock_vector.delete_by_ids.assert_called_once_with([child_chunk.index_node_id]) - - @patch("services.vector_service.Vector") - @patch("services.vector_service.db") - def test_delete_child_chunk_vector_economy(self, mock_db, mock_vector_class): - """ - Test delete_child_chunk_vector with economy indexing. - - This test verifies that child chunk vectors are not deleted when - using economy indexing. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) - - child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() - - mock_vector = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_vector_class.return_value = mock_vector - - # Act - VectorService.delete_child_chunk_vector(child_chunk, dataset) - - # Assert - mock_vector.delete_by_ids.assert_not_called() - - -# ============================================================================ -# Tests for Vector Class -# ============================================================================ - - -class TestVector: - """ - Comprehensive unit tests for Vector class. - - This test class covers all methods of the Vector class, including - initialization, collection management, embedding operations, vector - database operations, and search functionality. - """ - - # ======================================================================== - # Tests for Vector Initialization - # ======================================================================== - - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") - def test_vector_initialization_default_attributes(self, mock_get_embeddings, mock_init_vector): - """ - Test Vector initialization with default attributes. - - This test verifies that Vector is correctly initialized with default - attributes when none are provided. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - mock_embeddings = Mock() - - mock_get_embeddings.return_value = mock_embeddings - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_init_vector.return_value = mock_vector_processor - - # Act - vector = Vector(dataset=dataset) - - # Assert - assert vector._dataset == dataset - - assert vector._attributes == ["doc_id", "dataset_id", "document_id", "doc_hash"] - - mock_get_embeddings.assert_called_once() - - mock_init_vector.assert_called_once() - - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") - def test_vector_initialization_custom_attributes(self, mock_get_embeddings, mock_init_vector): - """ - Test Vector initialization with custom attributes. - - This test verifies that Vector is correctly initialized with custom - attributes when provided. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - custom_attributes = ["custom_attr1", "custom_attr2"] - - mock_embeddings = Mock() - - mock_get_embeddings.return_value = mock_embeddings - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_init_vector.return_value = mock_vector_processor - - # Act - vector = Vector(dataset=dataset, attributes=custom_attributes) - - # Assert - assert vector._dataset == dataset - - assert vector._attributes == custom_attributes - - # ======================================================================== - # Tests for Vector.create - # ======================================================================== - - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") - def test_vector_create_with_texts(self, mock_get_embeddings, mock_init_vector): - """ - Test Vector.create with texts list. - - This test verifies that documents are correctly embedded and created - in the vector store with batch processing. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - documents = [ - VectorServiceTestDataFactory.create_rag_document_mock(page_content=f"Content {i}") for i in range(5) - ] - - mock_embeddings = Mock() - - mock_embeddings.embed_documents = Mock(return_value=[[0.1] * 1536] * 5) - - mock_get_embeddings.return_value = mock_embeddings - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_init_vector.return_value = mock_vector_processor - - vector = Vector(dataset=dataset) - - # Act - vector.create(texts=documents) - - # Assert - mock_embeddings.embed_documents.assert_called() - - mock_vector_processor.create.assert_called() - - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") - def test_vector_create_empty_texts(self, mock_get_embeddings, mock_init_vector): - """ - Test Vector.create with empty texts list. - - This test verifies that when texts is None or empty, no operations - are performed. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - mock_embeddings = Mock() - - mock_get_embeddings.return_value = mock_embeddings - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_init_vector.return_value = mock_vector_processor - - vector = Vector(dataset=dataset) - - # Act - vector.create(texts=None) - - # Assert - mock_embeddings.embed_documents.assert_not_called() - - mock_vector_processor.create.assert_not_called() - - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") - def test_vector_create_large_batch(self, mock_get_embeddings, mock_init_vector): - """ - Test Vector.create with large batch of documents. - - This test verifies that large batches are correctly processed in - chunks of 1000 documents. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - documents = [ - VectorServiceTestDataFactory.create_rag_document_mock(page_content=f"Content {i}") for i in range(2500) - ] - - mock_embeddings = Mock() - - mock_embeddings.embed_documents = Mock(return_value=[[0.1] * 1536] * 1000) - - mock_get_embeddings.return_value = mock_embeddings - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_init_vector.return_value = mock_vector_processor - - vector = Vector(dataset=dataset) - - # Act - vector.create(texts=documents) - - # Assert - # Should be called 3 times (1000, 1000, 500) - assert mock_embeddings.embed_documents.call_count == 3 - - assert mock_vector_processor.create.call_count == 3 - - # ======================================================================== - # Tests for Vector.add_texts - # ======================================================================== - - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") - def test_vector_add_texts_without_duplicate_check(self, mock_get_embeddings, mock_init_vector): - """ - Test Vector.add_texts without duplicate check. - - This test verifies that documents are added without checking for - duplicates when duplicate_check is False. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - documents = [VectorServiceTestDataFactory.create_rag_document_mock()] - - mock_embeddings = Mock() - - mock_embeddings.embed_documents = Mock(return_value=[[0.1] * 1536]) - - mock_get_embeddings.return_value = mock_embeddings - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_init_vector.return_value = mock_vector_processor - - vector = Vector(dataset=dataset) - - # Act - vector.add_texts(documents, duplicate_check=False) - - # Assert - mock_embeddings.embed_documents.assert_called_once() - - mock_vector_processor.create.assert_called_once() - - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") - def test_vector_add_texts_with_duplicate_check(self, mock_get_embeddings, mock_init_vector): - """ - Test Vector.add_texts with duplicate check. - - This test verifies that duplicate documents are filtered out when - duplicate_check is True. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - documents = [VectorServiceTestDataFactory.create_rag_document_mock(doc_id="doc-123")] - - mock_embeddings = Mock() - - mock_embeddings.embed_documents = Mock(return_value=[[0.1] * 1536]) - - mock_get_embeddings.return_value = mock_embeddings - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_vector_processor.text_exists = Mock(return_value=True) # Document exists - - mock_init_vector.return_value = mock_vector_processor - - vector = Vector(dataset=dataset) - - # Act - vector.add_texts(documents, duplicate_check=True) - - # Assert - mock_vector_processor.text_exists.assert_called_once_with("doc-123") - - mock_embeddings.embed_documents.assert_not_called() - - mock_vector_processor.create.assert_not_called() - - # ======================================================================== - # Tests for Vector.text_exists - # ======================================================================== - - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") - def test_vector_text_exists_true(self, mock_get_embeddings, mock_init_vector): - """ - Test Vector.text_exists when text exists. - - This test verifies that text_exists correctly returns True when - a document exists in the vector store. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - mock_embeddings = Mock() - - mock_get_embeddings.return_value = mock_embeddings - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_vector_processor.text_exists = Mock(return_value=True) - - mock_init_vector.return_value = mock_vector_processor - - vector = Vector(dataset=dataset) - - # Act - result = vector.text_exists("doc-123") - - # Assert - assert result is True - - mock_vector_processor.text_exists.assert_called_once_with("doc-123") - - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") - def test_vector_text_exists_false(self, mock_get_embeddings, mock_init_vector): - """ - Test Vector.text_exists when text does not exist. - - This test verifies that text_exists correctly returns False when - a document does not exist in the vector store. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - mock_embeddings = Mock() - - mock_get_embeddings.return_value = mock_embeddings - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_vector_processor.text_exists = Mock(return_value=False) - - mock_init_vector.return_value = mock_vector_processor - - vector = Vector(dataset=dataset) - - # Act - result = vector.text_exists("doc-123") - - # Assert - assert result is False - - mock_vector_processor.text_exists.assert_called_once_with("doc-123") - - # ======================================================================== - # Tests for Vector.delete_by_ids - # ======================================================================== - - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") - def test_vector_delete_by_ids(self, mock_get_embeddings, mock_init_vector): - """ - Test Vector.delete_by_ids. - - This test verifies that documents are correctly deleted by their IDs. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - mock_embeddings = Mock() - - mock_get_embeddings.return_value = mock_embeddings - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_init_vector.return_value = mock_vector_processor - - vector = Vector(dataset=dataset) - - ids = ["doc-1", "doc-2", "doc-3"] - - # Act - vector.delete_by_ids(ids) - - # Assert - mock_vector_processor.delete_by_ids.assert_called_once_with(ids) - - # ======================================================================== - # Tests for Vector.delete_by_metadata_field - # ======================================================================== - - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") - def test_vector_delete_by_metadata_field(self, mock_get_embeddings, mock_init_vector): - """ - Test Vector.delete_by_metadata_field. - - This test verifies that documents are correctly deleted by metadata - field value. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - mock_embeddings = Mock() - - mock_get_embeddings.return_value = mock_embeddings - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_init_vector.return_value = mock_vector_processor - - vector = Vector(dataset=dataset) - - # Act - vector.delete_by_metadata_field("dataset_id", "dataset-123") - - # Assert - mock_vector_processor.delete_by_metadata_field.assert_called_once_with("dataset_id", "dataset-123") - - # ======================================================================== - # Tests for Vector.search_by_vector - # ======================================================================== - - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") - def test_vector_search_by_vector(self, mock_get_embeddings, mock_init_vector): - """ - Test Vector.search_by_vector. - - This test verifies that vector search correctly embeds the query - and searches the vector store. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - query = "test query" - - query_vector = [0.1] * 1536 - - mock_embeddings = Mock() - - mock_embeddings.embed_query = Mock(return_value=query_vector) - - mock_get_embeddings.return_value = mock_embeddings - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_vector_processor.search_by_vector = Mock(return_value=[]) - - mock_init_vector.return_value = mock_vector_processor - - vector = Vector(dataset=dataset) - - # Act - result = vector.search_by_vector(query) - - # Assert - mock_embeddings.embed_query.assert_called_once_with(query) - - mock_vector_processor.search_by_vector.assert_called_once_with(query_vector) - - assert result == [] - - # ======================================================================== - # Tests for Vector.search_by_full_text - # ======================================================================== - - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") - def test_vector_search_by_full_text(self, mock_get_embeddings, mock_init_vector): - """ - Test Vector.search_by_full_text. - - This test verifies that full-text search correctly searches the - vector store without embedding the query. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - query = "test query" - - mock_embeddings = Mock() - - mock_get_embeddings.return_value = mock_embeddings - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_vector_processor.search_by_full_text = Mock(return_value=[]) - - mock_init_vector.return_value = mock_vector_processor - - vector = Vector(dataset=dataset) - - # Act - result = vector.search_by_full_text(query) - - # Assert - mock_vector_processor.search_by_full_text.assert_called_once_with(query) - - assert result == [] - - # ======================================================================== - # Tests for Vector.delete - # ======================================================================== - - @patch("core.rag.datasource.vdb.vector_factory.redis_client") - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") - def test_vector_delete(self, mock_get_embeddings, mock_init_vector, mock_redis_client): - """ - Test Vector.delete. - - This test verifies that the collection is deleted and Redis cache - is cleared. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - mock_embeddings = Mock() - - mock_get_embeddings.return_value = mock_embeddings - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_vector_processor.collection_name = "test_collection" - - mock_init_vector.return_value = mock_vector_processor - - vector = Vector(dataset=dataset) - - # Act - vector.delete() - - # Assert - mock_vector_processor.delete.assert_called_once() - - mock_redis_client.delete.assert_called_once_with("vector_indexing_test_collection") - - # ======================================================================== - # Tests for Vector.get_vector_factory - # ======================================================================== - - def test_vector_get_vector_factory_chroma(self): - """ - Test Vector.get_vector_factory for Chroma. - - This test verifies that the correct factory class is returned for - Chroma vector type. - """ - # Act - factory_class = Vector.get_vector_factory(VectorType.CHROMA) - - # Assert - assert factory_class is not None - - # Verify it's the correct factory by checking the module name - assert "chroma" in factory_class.__module__.lower() - - def test_vector_get_vector_factory_milvus(self): - """ - Test Vector.get_vector_factory for Milvus. - - This test verifies that the correct factory class is returned for - Milvus vector type. - """ - # Act - factory_class = Vector.get_vector_factory(VectorType.MILVUS) - - # Assert - assert factory_class is not None - - assert "milvus" in factory_class.__module__.lower() - - def test_vector_get_vector_factory_invalid_type(self): - """ - Test Vector.get_vector_factory with invalid vector type. - - This test verifies that a ValueError is raised when an invalid - vector type is provided. - """ - # Act & Assert - with pytest.raises(ValueError, match="Vector store .* is not supported"): - Vector.get_vector_factory("invalid_type") - - # ======================================================================== - # Tests for Vector._filter_duplicate_texts - # ======================================================================== - - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") - def test_vector_filter_duplicate_texts(self, mock_get_embeddings, mock_init_vector): - """ - Test Vector._filter_duplicate_texts. - - This test verifies that duplicate documents are correctly filtered - based on doc_id in metadata. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - mock_embeddings = Mock() - - mock_get_embeddings.return_value = mock_embeddings - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_vector_processor.text_exists = Mock(side_effect=[True, False]) # First exists, second doesn't - - mock_init_vector.return_value = mock_vector_processor - - vector = Vector(dataset=dataset) - - doc1 = VectorServiceTestDataFactory.create_rag_document_mock(doc_id="doc-1") - - doc2 = VectorServiceTestDataFactory.create_rag_document_mock(doc_id="doc-2") - - documents = [doc1, doc2] - - # Act - filtered = vector._filter_duplicate_texts(documents) - - # Assert - assert len(filtered) == 1 - - assert filtered[0].metadata["doc_id"] == "doc-2" - - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") - def test_vector_filter_duplicate_texts_no_metadata(self, mock_get_embeddings, mock_init_vector): - """ - Test Vector._filter_duplicate_texts with documents without metadata. - - This test verifies that documents without metadata are not filtered. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock() - - mock_embeddings = Mock() - - mock_get_embeddings.return_value = mock_embeddings - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_init_vector.return_value = mock_vector_processor - - vector = Vector(dataset=dataset) - - doc1 = Document(page_content="Content 1", metadata=None) - - doc2 = Document(page_content="Content 2", metadata={}) - - documents = [doc1, doc2] - - # Act - filtered = vector._filter_duplicate_texts(documents) - - # Assert - assert len(filtered) == 2 - - # ======================================================================== - # Tests for Vector._get_embeddings - # ======================================================================== - - @patch("core.rag.datasource.vdb.vector_factory.CacheEmbedding") - @patch("core.rag.datasource.vdb.vector_factory.ModelManager.for_tenant") - @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") - def test_vector_get_embeddings(self, mock_init_vector, mock_model_manager, mock_cache_embedding): - """ - Test Vector._get_embeddings. - - This test verifies that embeddings are correctly retrieved from - ModelManager and wrapped in CacheEmbedding. - """ - # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock( - embedding_model_provider="openai", embedding_model="text-embedding-ada-002" - ) - - mock_embedding_model = VectorServiceTestDataFactory.create_embedding_model_instance_mock() - - mock_model_manager.return_value.get_model_instance.return_value = mock_embedding_model - - mock_cache_embedding_instance = Mock() - - mock_cache_embedding.return_value = mock_cache_embedding_instance - - mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock() - - mock_init_vector.return_value = mock_vector_processor - - # Act - vector = Vector(dataset=dataset) - - # Assert - mock_model_manager.return_value.get_model_instance.assert_called_once() - - mock_cache_embedding.assert_called_once_with(mock_embedding_model) - - assert vector._embeddings == mock_cache_embedding_instance diff --git a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py index 5dad58b8f1..b74079bd69 100644 --- a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py @@ -89,9 +89,6 @@ def mock_db_session(): session = MagicMock() session._shared_data = {"dataset": None, "documents": []} - # Keep a pointer so repeated Document.first() calls iterate across provided docs - session._doc_first_idx = 0 - def _get_entity(stmt) -> type | None: """Extract the mapped entity class from a SQLAlchemy select statement.""" try: @@ -1591,18 +1588,7 @@ class TestDocumentIndexingTaskSummaryFlow: need_summary=True, ) - dataset_query = MagicMock() - dataset_query.where.return_value = dataset_query - dataset_query.first.return_value = dataset - phase1_docs = [SimpleNamespace(id="doc-1"), SimpleNamespace(id="doc-2"), SimpleNamespace(id="doc-3")] - phase1_document_query = MagicMock() - phase1_document_query.where.return_value = phase1_document_query - phase1_document_query.all.return_value = phase1_docs - - summary_document_query = MagicMock() - summary_document_query.where.return_value = summary_document_query - summary_document_query.all.return_value = [doc_eligible, doc_skip_form, doc_skip_status] session1 = MagicMock() session2 = MagicMock() @@ -1657,18 +1643,6 @@ class TestDocumentIndexingTaskSummaryFlow: need_summary=True, ) - dataset_query = MagicMock() - dataset_query.where.return_value = dataset_query - dataset_query.first.return_value = dataset - - phase1_query = MagicMock() - phase1_query.where.return_value = phase1_query - phase1_query.all.return_value = [SimpleNamespace(id="doc-1")] - - summary_query = MagicMock() - summary_query.where.return_value = summary_query - summary_query.all.return_value = [doc_eligible] - session1 = MagicMock() session2 = MagicMock() session2.begin.return_value = nullcontext()