feat: extract model runtime

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2026-03-15 15:34:47 +08:00
parent 3d5a29462e
commit fbb74a4af9
178 changed files with 4343 additions and 2134 deletions

View File

@@ -3,9 +3,6 @@ import re
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.model_manager import ModelInstance
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import (
@@ -14,7 +11,7 @@ from dify_graph.enums import (
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from dify_graph.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
from dify_graph.model_runtime.entities import LLMMode, LLMUsage, ModelPropertyKey, PromptMessageRole
from dify_graph.model_runtime.memory import PromptMessageMemory
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.node_events import ModelInvokeCompletedEvent, NodeRunResult
@@ -27,10 +24,11 @@ from dify_graph.nodes.llm import (
LLMNodeCompletionModelPromptTemplate,
llm_utils,
)
from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
from dify_graph.nodes.llm.file_saver import LLMFileSaver
from dify_graph.nodes.llm.protocols import TemplateRenderer
from dify_graph.nodes.llm.runtime_protocols import PreparedLLMProtocol, PromptMessageSerializerProtocol
from dify_graph.nodes.protocols import HttpClientProtocol
from libs.json_in_md_parser import parse_and_check_json_markdown
from dify_graph.utils.json_in_md_parser import parse_and_check_json_markdown
from .entities import QuestionClassifierNodeData
from .exc import InvalidModelTypeError
@@ -49,15 +47,20 @@ if TYPE_CHECKING:
from dify_graph.runtime import GraphRuntimeState
class _PassthroughPromptMessageSerializer:
def serialize(self, *, model_mode: Any, prompt_messages: Sequence[Any]) -> Any:
_ = model_mode
return list(prompt_messages)
class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER
execution_type = NodeExecutionType.BRANCH
_file_outputs: list["File"]
_llm_file_saver: LLMFileSaver
_credentials_provider: "CredentialsProvider"
_model_factory: "ModelFactory"
_model_instance: ModelInstance
_prompt_message_serializer: PromptMessageSerializerProtocol
_model_instance: PreparedLLMProtocol
_memory: PromptMessageMemory | None
_template_renderer: TemplateRenderer
@@ -68,13 +71,14 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
credentials_provider: "CredentialsProvider",
model_factory: "ModelFactory",
model_instance: ModelInstance,
credentials_provider: object | None = None,
model_factory: object | None = None,
model_instance: PreparedLLMProtocol,
http_client: HttpClientProtocol,
template_renderer: TemplateRenderer,
memory: PromptMessageMemory | None = None,
llm_file_saver: LLMFileSaver | None = None,
llm_file_saver: LLMFileSaver,
prompt_message_serializer: PromptMessageSerializerProtocol | None = None,
):
super().__init__(
id=id,
@@ -85,20 +89,13 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
# LLM file outputs, used for MultiModal outputs.
self._file_outputs = []
self._credentials_provider = credentials_provider
self._model_factory = model_factory
_ = credentials_provider, model_factory, http_client
self._model_instance = model_instance
self._memory = memory
self._template_renderer = template_renderer
if llm_file_saver is None:
dify_ctx = self.require_dify_context()
llm_file_saver = FileSaverImpl(
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
http_client=http_client,
)
self._llm_file_saver = llm_file_saver
self._prompt_message_serializer = prompt_message_serializer or _PassthroughPromptMessageSerializer()
@classmethod
def version(cls):
@@ -169,7 +166,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
user_id=self.require_dify_context().user_id,
structured_output_enabled=False,
structured_output=None,
file_saver=self._llm_file_saver,
@@ -205,7 +201,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
category_id = category_id_result
process_data = {
"model_mode": node_data.model.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
"prompts": self._prompt_message_serializer.serialize(
model_mode=node_data.model.mode, prompt_messages=prompt_messages
),
"usage": jsonable_encoder(usage),
@@ -247,7 +243,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
)
@property
def model_instance(self) -> ModelInstance:
def model_instance(self) -> PreparedLLMProtocol:
return self._model_instance
@classmethod
@@ -285,7 +281,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
self,
node_data: QuestionClassifierNodeData,
query: str,
model_instance: ModelInstance,
model_instance: PreparedLLMProtocol,
context: str | None,
) -> int:
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
@@ -334,7 +330,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
memory: PromptMessageMemory | None,
max_token_limit: int = 2000,
):
model_mode = ModelMode(node_data.model.mode)
model_mode = LLMMode(node_data.model.mode)
classes = node_data.classes
categories = []
for class_ in classes:
@@ -350,7 +346,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None,
)
prompt_messages: list[LLMNodeChatModelMessage] = []
if model_mode == ModelMode.CHAT:
if model_mode == LLMMode.CHAT:
system_prompt_messages = LLMNodeChatModelMessage(
role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str)
)
@@ -381,7 +377,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
)
prompt_messages.append(user_prompt_message_3)
return prompt_messages
elif model_mode == ModelMode.COMPLETION:
elif model_mode == LLMMode.COMPLETION:
return LLMNodeCompletionModelPromptTemplate(
text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(
histories=memory_str,