From 047ea8c143efaaec61a5bc2b6fc419d85e1a9bbe Mon Sep 17 00:00:00 2001 From: Novice Date: Thu, 18 Dec 2025 10:09:31 +0800 Subject: [PATCH] chore: improve type checking --- api/core/app/entities/queue_entities.py | 3 +- .../index_tool_callback_handler.py | 3 +- .../workflow/nodes/test_llm_node_streaming.py | 148 ++++++++++++++++++ 3 files changed, 151 insertions(+), 3 deletions(-) create mode 100644 api/tests/unit_tests/core/workflow/nodes/test_llm_node_streaming.py diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index e07efbc38c..31b95ad165 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -7,9 +7,8 @@ from pydantic import BaseModel, ConfigDict, Field from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from core.workflow.entities import AgentNodeStrategyInit +from core.workflow.entities import AgentNodeStrategyInit, ToolCall, ToolResult from core.workflow.enums import WorkflowNodeExecutionMetadataKey -from core.workflow.graph_events import ToolCall, ToolResult from core.workflow.nodes import NodeType diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index d0279349ca..5249fea8cd 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -5,7 +5,6 @@ from sqlalchemy import select from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.entities.queue_entities import QueueRetrieverResourcesEvent from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.models.document import Document @@ -90,6 +89,8 @@ class DatasetIndexToolCallbackHandler: # TODO(-LAN-): Improve type check def return_retriever_resource_info(self, resource: Sequence[RetrievalSourceMetadata]): """Handle return_retriever_resource_info.""" + from core.app.entities.queue_entities import QueueRetrieverResourcesEvent + self._queue_manager.publish( QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER ) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_llm_node_streaming.py b/api/tests/unit_tests/core/workflow/nodes/test_llm_node_streaming.py new file mode 100644 index 0000000000..9d793f804f --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_llm_node_streaming.py @@ -0,0 +1,148 @@ +import types +from collections.abc import Generator +from typing import Any + +import pytest + +from core.model_runtime.entities.llm_entities import LLMUsage +from core.workflow.entities import ToolCallResult +from core.workflow.entities.tool_entities import ToolResultStatus +from core.workflow.node_events import ModelInvokeCompletedEvent, NodeEventBase +from core.workflow.nodes.llm.node import LLMNode + + +class _StubModelInstance: + """Minimal stub to satisfy _stream_llm_events signature.""" + + provider_model_bundle = None + + +def _drain(generator: Generator[NodeEventBase, None, Any]): + events: list = [] + try: + while True: + events.append(next(generator)) + except StopIteration as exc: + return events, exc.value + + +@pytest.fixture(autouse=True) +def patch_deduct_llm_quota(monkeypatch): + # Avoid touching real quota logic during unit tests + monkeypatch.setattr("core.workflow.nodes.llm.node.llm_utils.deduct_llm_quota", lambda **_: None) + + +def _make_llm_node(reasoning_format: str) -> LLMNode: + node = LLMNode.__new__(LLMNode) + object.__setattr__(node, "_node_data", types.SimpleNamespace(reasoning_format=reasoning_format, tools=[])) + object.__setattr__(node, "tenant_id", "tenant") + return node + + +def test_stream_llm_events_extracts_reasoning_for_tagged(): + node = _make_llm_node(reasoning_format="tagged") + tagged_text = "ThoughtAnswer" + usage = LLMUsage.empty_usage() + + def generator(): + yield ModelInvokeCompletedEvent( + text=tagged_text, + usage=usage, + finish_reason="stop", + reasoning_content="", + structured_output=None, + ) + + events, returned = _drain( + node._stream_llm_events(generator(), model_instance=types.SimpleNamespace(provider_model_bundle=None)) + ) + + assert events == [] + clean_text, reasoning_content, gen_reasoning, gen_clean, ret_usage, finish_reason, structured, gen_data = returned + assert clean_text == tagged_text # original preserved for output + assert reasoning_content == "" # tagged mode keeps reasoning separate + assert gen_clean == "Answer" # stripped content for generation + assert gen_reasoning == "Thought" # reasoning extracted from tag + assert ret_usage == usage + assert finish_reason == "stop" + assert structured is None + assert gen_data is None + + # generation building should include reasoning and sequence + generation_content = gen_clean or clean_text + sequence = [ + {"type": "reasoning", "index": 0}, + {"type": "content", "start": 0, "end": len(generation_content)}, + ] + assert sequence == [ + {"type": "reasoning", "index": 0}, + {"type": "content", "start": 0, "end": len("Answer")}, + ] + + +def test_stream_llm_events_no_reasoning_results_in_empty_sequence(): + node = _make_llm_node(reasoning_format="tagged") + plain_text = "Hello world" + usage = LLMUsage.empty_usage() + + def generator(): + yield ModelInvokeCompletedEvent( + text=plain_text, + usage=usage, + finish_reason=None, + reasoning_content="", + structured_output=None, + ) + + events, returned = _drain( + node._stream_llm_events(generator(), model_instance=types.SimpleNamespace(provider_model_bundle=None)) + ) + + assert events == [] + _, _, gen_reasoning, gen_clean, *_ = returned + generation_content = gen_clean or plain_text + assert gen_reasoning == "" + assert generation_content == plain_text + # Empty reasoning should imply empty sequence in generation construction + sequence = [] + assert sequence == [] + + +def test_serialize_tool_call_strips_files_to_ids(): + file_cls = pytest.importorskip("core.file").File + file_type = pytest.importorskip("core.file.enums").FileType + transfer_method = pytest.importorskip("core.file.enums").FileTransferMethod + + file_with_id = file_cls( + id="f1", + tenant_id="t", + type=file_type.IMAGE, + transfer_method=transfer_method.REMOTE_URL, + remote_url="http://example.com/f1", + storage_key="k1", + ) + file_with_related = file_cls( + id=None, + tenant_id="t", + type=file_type.IMAGE, + transfer_method=transfer_method.REMOTE_URL, + related_id="rel2", + remote_url="http://example.com/f2", + storage_key="k2", + ) + tool_call = ToolCallResult( + id="tc", + name="do", + arguments='{"a":1}', + output="ok", + files=[file_with_id, file_with_related], + status=ToolResultStatus.SUCCESS, + ) + + serialized = LLMNode._serialize_tool_call(tool_call) + + assert serialized["files"] == ["f1", "rel2"] + assert serialized["id"] == "tc" + assert serialized["name"] == "do" + assert serialized["arguments"] == '{"a":1}' + assert serialized["output"] == "ok"