chore: improve type checking

This commit is contained in:
Novice
2025-12-18 10:09:31 +08:00
parent f54b9b12b0
commit 047ea8c143
3 changed files with 151 additions and 3 deletions

View File

@@ -7,9 +7,8 @@ from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.entities import AgentNodeStrategyInit, ToolCall, ToolResult
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
from core.workflow.graph_events import ToolCall, ToolResult
from core.workflow.nodes import NodeType

View File

@@ -5,7 +5,6 @@ from sqlalchemy import select
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.models.document import Document
@@ -90,6 +89,8 @@ class DatasetIndexToolCallbackHandler:
# TODO(-LAN-): Improve type check
def return_retriever_resource_info(self, resource: Sequence[RetrievalSourceMetadata]):
"""Handle return_retriever_resource_info."""
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
self._queue_manager.publish(
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
)

View File

@@ -0,0 +1,148 @@
import types
from collections.abc import Generator
from typing import Any
import pytest
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities import ToolCallResult
from core.workflow.entities.tool_entities import ToolResultStatus
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeEventBase
from core.workflow.nodes.llm.node import LLMNode
class _StubModelInstance:
"""Minimal stub to satisfy _stream_llm_events signature."""
provider_model_bundle = None
def _drain(generator: Generator[NodeEventBase, None, Any]):
events: list = []
try:
while True:
events.append(next(generator))
except StopIteration as exc:
return events, exc.value
@pytest.fixture(autouse=True)
def patch_deduct_llm_quota(monkeypatch):
# Avoid touching real quota logic during unit tests
monkeypatch.setattr("core.workflow.nodes.llm.node.llm_utils.deduct_llm_quota", lambda **_: None)
def _make_llm_node(reasoning_format: str) -> LLMNode:
node = LLMNode.__new__(LLMNode)
object.__setattr__(node, "_node_data", types.SimpleNamespace(reasoning_format=reasoning_format, tools=[]))
object.__setattr__(node, "tenant_id", "tenant")
return node
def test_stream_llm_events_extracts_reasoning_for_tagged():
node = _make_llm_node(reasoning_format="tagged")
tagged_text = "<think>Thought</think>Answer"
usage = LLMUsage.empty_usage()
def generator():
yield ModelInvokeCompletedEvent(
text=tagged_text,
usage=usage,
finish_reason="stop",
reasoning_content="",
structured_output=None,
)
events, returned = _drain(
node._stream_llm_events(generator(), model_instance=types.SimpleNamespace(provider_model_bundle=None))
)
assert events == []
clean_text, reasoning_content, gen_reasoning, gen_clean, ret_usage, finish_reason, structured, gen_data = returned
assert clean_text == tagged_text # original preserved for output
assert reasoning_content == "" # tagged mode keeps reasoning separate
assert gen_clean == "Answer" # stripped content for generation
assert gen_reasoning == "Thought" # reasoning extracted from <think> tag
assert ret_usage == usage
assert finish_reason == "stop"
assert structured is None
assert gen_data is None
# generation building should include reasoning and sequence
generation_content = gen_clean or clean_text
sequence = [
{"type": "reasoning", "index": 0},
{"type": "content", "start": 0, "end": len(generation_content)},
]
assert sequence == [
{"type": "reasoning", "index": 0},
{"type": "content", "start": 0, "end": len("Answer")},
]
def test_stream_llm_events_no_reasoning_results_in_empty_sequence():
node = _make_llm_node(reasoning_format="tagged")
plain_text = "Hello world"
usage = LLMUsage.empty_usage()
def generator():
yield ModelInvokeCompletedEvent(
text=plain_text,
usage=usage,
finish_reason=None,
reasoning_content="",
structured_output=None,
)
events, returned = _drain(
node._stream_llm_events(generator(), model_instance=types.SimpleNamespace(provider_model_bundle=None))
)
assert events == []
_, _, gen_reasoning, gen_clean, *_ = returned
generation_content = gen_clean or plain_text
assert gen_reasoning == ""
assert generation_content == plain_text
# Empty reasoning should imply empty sequence in generation construction
sequence = []
assert sequence == []
def test_serialize_tool_call_strips_files_to_ids():
file_cls = pytest.importorskip("core.file").File
file_type = pytest.importorskip("core.file.enums").FileType
transfer_method = pytest.importorskip("core.file.enums").FileTransferMethod
file_with_id = file_cls(
id="f1",
tenant_id="t",
type=file_type.IMAGE,
transfer_method=transfer_method.REMOTE_URL,
remote_url="http://example.com/f1",
storage_key="k1",
)
file_with_related = file_cls(
id=None,
tenant_id="t",
type=file_type.IMAGE,
transfer_method=transfer_method.REMOTE_URL,
related_id="rel2",
remote_url="http://example.com/f2",
storage_key="k2",
)
tool_call = ToolCallResult(
id="tc",
name="do",
arguments='{"a":1}',
output="ok",
files=[file_with_id, file_with_related],
status=ToolResultStatus.SUCCESS,
)
serialized = LLMNode._serialize_tool_call(tool_call)
assert serialized["files"] == ["f1", "rel2"]
assert serialized["id"] == "tc"
assert serialized["name"] == "do"
assert serialized["arguments"] == '{"a":1}'
assert serialized["output"] == "ok"