feat: extract model runtime

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2026-03-15 15:34:47 +08:00
parent 3d5a29462e
commit fbb74a4af9
178 changed files with 4343 additions and 2134 deletions

View File

@@ -7,10 +7,10 @@ from pydantic import (
field_validator,
)
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig
from dify_graph.prompt_entities import MemoryConfig
from dify_graph.variables.types import SegmentType
_OLD_BOOL_TYPE_NAME = "bool"

View File

@@ -5,11 +5,6 @@ import uuid
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from core.model_manager import ModelInstance
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import (
BuiltinNodeTypes,
@@ -17,8 +12,8 @@ from dify_graph.enums import (
WorkflowNodeExecutionStatus,
)
from dify_graph.file import File
from dify_graph.model_runtime.entities import ImagePromptMessageContent
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
@@ -27,14 +22,15 @@ from dify_graph.model_runtime.entities.message_entities import (
ToolPromptMessage,
UserPromptMessage,
)
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType
from dify_graph.model_runtime.memory import PromptMessageMemory
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base import variable_template_parser
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.llm import llm_utils
from dify_graph.nodes.llm import LLMNode, llm_utils
from dify_graph.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate
from dify_graph.nodes.llm.runtime_protocols import PreparedLLMProtocol, PromptMessageSerializerProtocol
from dify_graph.runtime import VariablePool
from dify_graph.variables.types import ArrayValidation, SegmentType
from factories.variable_factory import build_segment_with_type
@@ -66,7 +62,6 @@ logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from dify_graph.entities import GraphInitParams
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.runtime import GraphRuntimeState
@@ -99,9 +94,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_type = BuiltinNodeTypes.PARAMETER_EXTRACTOR
_model_instance: ModelInstance
_credentials_provider: "CredentialsProvider"
_model_factory: "ModelFactory"
_model_instance: PreparedLLMProtocol
_prompt_message_serializer: PromptMessageSerializerProtocol
_memory: PromptMessageMemory | None
def __init__(
@@ -111,10 +105,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
credentials_provider: "CredentialsProvider",
model_factory: "ModelFactory",
model_instance: ModelInstance,
credentials_provider: object | None = None,
model_factory: object | None = None,
model_instance: PreparedLLMProtocol,
memory: PromptMessageMemory | None = None,
prompt_message_serializer: PromptMessageSerializerProtocol,
) -> None:
super().__init__(
id=id,
@@ -122,9 +117,9 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._credentials_provider = credentials_provider
self._model_factory = model_factory
_ = credentials_provider, model_factory
self._model_instance = model_instance
self._prompt_message_serializer = prompt_message_serializer
self._memory = memory
@classmethod
@@ -164,13 +159,12 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
)
model_instance = self._model_instance
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
raise InvalidModelTypeError("Model is not a Large Language Model")
try:
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
except ValueError as exc:
raise ModelSchemaNotFoundError("Model schema not found") from exc
if model_schema.model_type != ModelType.LLM:
raise InvalidModelTypeError("Model is not a Large Language Model")
memory = self._memory
if (
@@ -210,8 +204,9 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
process_data = {
"model_mode": node_data.model.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=node_data.model.mode, prompt_messages=prompt_messages
"prompts": self._prompt_message_serializer.serialize(
model_mode=node_data.model.mode,
prompt_messages=prompt_messages,
),
"usage": None,
"function": {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]),
@@ -287,18 +282,20 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
def _invoke(
self,
model_instance: ModelInstance,
model_instance: PreparedLLMProtocol,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
stop: Sequence[str],
stop: Sequence[str] | None,
) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]:
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=dict(model_instance.parameters),
tools=tools,
stop=list(stop),
stream=False,
user=self.require_dify_context().user_id,
invoke_result = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=dict(model_instance.parameters),
tools=tools or None,
stop=stop,
stream=False,
),
)
# handle invoke result
@@ -317,7 +314,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
model_instance: PreparedLLMProtocol,
memory: PromptMessageMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
@@ -329,7 +326,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
content=query, structure=json.dumps(node_data.get_parameter_json_schema())
)
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(
node_data=node_data,
query=query,
@@ -340,15 +336,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
prompt_template = self._get_function_calling_prompt_template(
node_data, query, variable_pool, memory, rest_token
)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query="",
files=files,
context="",
memory_config=node_data.memory,
memory=None,
prompt_messages = self._compile_prompt_messages(
model_instance=model_instance,
prompt_template=prompt_template,
files=files,
vision_enabled=node_data.vision.enabled,
image_detail_config=vision_detail,
)
@@ -405,7 +397,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
model_instance: PreparedLLMProtocol,
memory: PromptMessageMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
@@ -413,9 +405,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
"""
Generate prompt engineering prompt.
"""
model_mode = ModelMode(data.model.mode)
if model_mode == ModelMode.COMPLETION:
if data.model.mode == LLMMode.COMPLETION:
return self._generate_prompt_engineering_completion_prompt(
node_data=data,
query=query,
@@ -425,7 +415,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
files=files,
vision_detail=vision_detail,
)
elif model_mode == ModelMode.CHAT:
if data.model.mode == LLMMode.CHAT:
return self._generate_prompt_engineering_chat_prompt(
node_data=data,
query=query,
@@ -435,15 +425,14 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
files=files,
vision_detail=vision_detail,
)
else:
raise InvalidModelModeError(f"Invalid model mode: {model_mode}")
raise InvalidModelModeError(f"Invalid model mode: {data.model.mode}")
def _generate_prompt_engineering_completion_prompt(
self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
model_instance: PreparedLLMProtocol,
memory: PromptMessageMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
@@ -451,7 +440,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
"""
Generate completion prompt.
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(
node_data=node_data,
query=query,
@@ -462,27 +450,20 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
prompt_template = self._get_prompt_engineering_prompt_template(
node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token
)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={"structure": json.dumps(node_data.get_parameter_json_schema())},
query="",
files=files,
context="",
memory_config=node_data.memory,
# AdvancedPromptTransform is still typed against TokenBufferMemory.
memory=cast(Any, memory),
return self._compile_prompt_messages(
model_instance=model_instance,
prompt_template=prompt_template,
files=files,
vision_enabled=node_data.vision.enabled,
image_detail_config=vision_detail,
)
return prompt_messages
def _generate_prompt_engineering_chat_prompt(
self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
model_instance: PreparedLLMProtocol,
memory: PromptMessageMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
@@ -490,7 +471,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
"""
Generate chat prompt.
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(
node_data=node_data,
query=query,
@@ -508,15 +488,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
max_token_limit=rest_token,
)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query="",
files=files,
context="",
memory_config=node_data.memory,
memory=None,
prompt_messages = self._compile_prompt_messages(
model_instance=model_instance,
prompt_template=prompt_template,
files=files,
vision_enabled=node_data.vision.enabled,
image_detail_config=vision_detail,
)
@@ -717,8 +693,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
variable_pool: VariablePool,
memory: PromptMessageMemory | None,
max_token_limit: int = 2000,
) -> list[ChatModelMessage]:
model_mode = ModelMode(node_data.model.mode)
) -> list[LLMNodeChatModelMessage]:
input_text = query
memory_str = ""
instruction = variable_pool.convert_template(node_data.instruction or "").text
@@ -727,15 +702,14 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
memory_str = llm_utils.fetch_memory_text(
memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
)
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
if node_data.model.mode == LLMMode.CHAT:
system_prompt_messages = LLMNodeChatModelMessage(
role=PromptMessageRole.SYSTEM,
text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction),
)
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
user_prompt_message = LLMNodeChatModelMessage(role=PromptMessageRole.USER, text=input_text)
return [system_prompt_messages, user_prompt_message]
else:
raise InvalidModelModeError(f"Model mode {model_mode} not support.")
raise InvalidModelModeError(f"Model mode {node_data.model.mode} not support.")
def _get_prompt_engineering_prompt_template(
self,
@@ -744,8 +718,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
variable_pool: VariablePool,
memory: PromptMessageMemory | None,
max_token_limit: int = 2000,
):
model_mode = ModelMode(node_data.model.mode)
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
input_text = query
memory_str = ""
instruction = variable_pool.convert_template(node_data.instruction or "").text
@@ -754,64 +727,53 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
memory_str = llm_utils.fetch_memory_text(
memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
)
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
if node_data.model.mode == LLMMode.CHAT:
system_prompt_messages = LLMNodeChatModelMessage(
role=PromptMessageRole.SYSTEM,
text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str, instructions=instruction),
)
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
user_prompt_message = LLMNodeChatModelMessage(role=PromptMessageRole.USER, text=input_text)
return [system_prompt_messages, user_prompt_message]
elif model_mode == ModelMode.COMPLETION:
return CompletionModelPromptTemplate(
if node_data.model.mode == LLMMode.COMPLETION:
return LLMNodeCompletionModelPromptTemplate(
text=COMPLETION_GENERATE_JSON_PROMPT.format(
histories=memory_str, text=input_text, instruction=instruction
)
.replace("{γγγ", "")
.replace("}γγγ", "")
.replace("{ structure }", json.dumps(node_data.get_parameter_json_schema())),
)
else:
raise InvalidModelModeError(f"Model mode {model_mode} not support.")
raise InvalidModelModeError(f"Model mode {node_data.model.mode} not support.")
def _calculate_rest_token(
self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
model_instance: PreparedLLMProtocol,
context: str | None,
) -> int:
try:
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
except ValueError as exc:
raise ModelSchemaNotFoundError("Model schema not found") from exc
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000)
else:
prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query="",
files=[],
context=context,
memory_config=node_data.memory,
memory=None,
prompt_messages = self._compile_prompt_messages(
model_instance=model_instance,
prompt_template=prompt_template,
files=[],
vision_enabled=False,
context=context,
)
rest_tokens = 2000
model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
curr_message_tokens = (
model_type_instance.get_num_tokens(
model_instance.model_name, model_instance.credentials, prompt_messages
)
+ 1000
) # add 1000 to ensure tool call messages
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) + 1000
max_tokens = 0
for parameter_rule in model_schema.parameter_rules:
@@ -828,8 +790,34 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
return rest_tokens
def _compile_prompt_messages(
self,
*,
model_instance: PreparedLLMProtocol,
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
files: Sequence[File],
vision_enabled: bool,
context: str | None = "",
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
prompt_messages, _ = LLMNode.fetch_prompt_messages(
sys_query="",
sys_files=files,
context=context,
memory=None,
model_instance=model_instance,
prompt_template=prompt_template,
stop=model_instance.stop,
memory_config=None,
vision_enabled=vision_enabled,
vision_detail=image_detail_config or ImagePromptMessageContent.DETAIL.HIGH,
variable_pool=self.graph_runtime_state.variable_pool,
jinja2_variables=[],
)
return list(prompt_messages)
@property
def model_instance(self) -> ModelInstance:
def model_instance(self) -> PreparedLLMProtocol:
return self._model_instance
@classmethod