mirror of
https://github.com/langgenius/dify.git
synced 2026-05-31 01:00:40 -04:00
feat: extract model runtime
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -7,10 +7,10 @@ from pydantic import (
|
||||
field_validator,
|
||||
)
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig
|
||||
from dify_graph.prompt_entities import MemoryConfig
|
||||
from dify_graph.variables.types import SegmentType
|
||||
|
||||
_OLD_BOOL_TYPE_NAME = "bool"
|
||||
|
||||
@@ -5,11 +5,6 @@ import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import (
|
||||
BuiltinNodeTypes,
|
||||
@@ -17,8 +12,8 @@ from dify_graph.enums import (
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from dify_graph.file import File
|
||||
from dify_graph.model_runtime.entities import ImagePromptMessageContent
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
@@ -27,14 +22,15 @@ from dify_graph.model_runtime.entities.message_entities import (
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType
|
||||
from dify_graph.model_runtime.memory import PromptMessageMemory
|
||||
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base import variable_template_parser
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.llm import llm_utils
|
||||
from dify_graph.nodes.llm import LLMNode, llm_utils
|
||||
from dify_graph.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate
|
||||
from dify_graph.nodes.llm.runtime_protocols import PreparedLLMProtocol, PromptMessageSerializerProtocol
|
||||
from dify_graph.runtime import VariablePool
|
||||
from dify_graph.variables.types import ArrayValidation, SegmentType
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
@@ -66,7 +62,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
@@ -99,9 +94,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
|
||||
node_type = BuiltinNodeTypes.PARAMETER_EXTRACTOR
|
||||
|
||||
_model_instance: ModelInstance
|
||||
_credentials_provider: "CredentialsProvider"
|
||||
_model_factory: "ModelFactory"
|
||||
_model_instance: PreparedLLMProtocol
|
||||
_prompt_message_serializer: PromptMessageSerializerProtocol
|
||||
_memory: PromptMessageMemory | None
|
||||
|
||||
def __init__(
|
||||
@@ -111,10 +105,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
credentials_provider: "CredentialsProvider",
|
||||
model_factory: "ModelFactory",
|
||||
model_instance: ModelInstance,
|
||||
credentials_provider: object | None = None,
|
||||
model_factory: object | None = None,
|
||||
model_instance: PreparedLLMProtocol,
|
||||
memory: PromptMessageMemory | None = None,
|
||||
prompt_message_serializer: PromptMessageSerializerProtocol,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
@@ -122,9 +117,9 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._credentials_provider = credentials_provider
|
||||
self._model_factory = model_factory
|
||||
_ = credentials_provider, model_factory
|
||||
self._model_instance = model_instance
|
||||
self._prompt_message_serializer = prompt_message_serializer
|
||||
self._memory = memory
|
||||
|
||||
@classmethod
|
||||
@@ -164,13 +159,12 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
)
|
||||
|
||||
model_instance = self._model_instance
|
||||
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
|
||||
raise InvalidModelTypeError("Model is not a Large Language Model")
|
||||
|
||||
try:
|
||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
except ValueError as exc:
|
||||
raise ModelSchemaNotFoundError("Model schema not found") from exc
|
||||
if model_schema.model_type != ModelType.LLM:
|
||||
raise InvalidModelTypeError("Model is not a Large Language Model")
|
||||
memory = self._memory
|
||||
|
||||
if (
|
||||
@@ -210,8 +204,9 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
|
||||
process_data = {
|
||||
"model_mode": node_data.model.mode,
|
||||
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
model_mode=node_data.model.mode, prompt_messages=prompt_messages
|
||||
"prompts": self._prompt_message_serializer.serialize(
|
||||
model_mode=node_data.model.mode,
|
||||
prompt_messages=prompt_messages,
|
||||
),
|
||||
"usage": None,
|
||||
"function": {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]),
|
||||
@@ -287,18 +282,20 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model_instance: ModelInstance,
|
||||
model_instance: PreparedLLMProtocol,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool],
|
||||
stop: Sequence[str],
|
||||
stop: Sequence[str] | None,
|
||||
) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]:
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=dict(model_instance.parameters),
|
||||
tools=tools,
|
||||
stop=list(stop),
|
||||
stream=False,
|
||||
user=self.require_dify_context().user_id,
|
||||
invoke_result = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=dict(model_instance.parameters),
|
||||
tools=tools or None,
|
||||
stop=stop,
|
||||
stream=False,
|
||||
),
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
@@ -317,7 +314,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_instance: ModelInstance,
|
||||
model_instance: PreparedLLMProtocol,
|
||||
memory: PromptMessageMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
@@ -329,7 +326,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
content=query, structure=json.dumps(node_data.get_parameter_json_schema())
|
||||
)
|
||||
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
rest_token = self._calculate_rest_token(
|
||||
node_data=node_data,
|
||||
query=query,
|
||||
@@ -340,15 +336,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
prompt_template = self._get_function_calling_prompt_template(
|
||||
node_data, query, variable_pool, memory, rest_token
|
||||
)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query="",
|
||||
files=files,
|
||||
context="",
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
prompt_messages = self._compile_prompt_messages(
|
||||
model_instance=model_instance,
|
||||
prompt_template=prompt_template,
|
||||
files=files,
|
||||
vision_enabled=node_data.vision.enabled,
|
||||
image_detail_config=vision_detail,
|
||||
)
|
||||
|
||||
@@ -405,7 +397,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_instance: ModelInstance,
|
||||
model_instance: PreparedLLMProtocol,
|
||||
memory: PromptMessageMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
@@ -413,9 +405,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
"""
|
||||
Generate prompt engineering prompt.
|
||||
"""
|
||||
model_mode = ModelMode(data.model.mode)
|
||||
|
||||
if model_mode == ModelMode.COMPLETION:
|
||||
if data.model.mode == LLMMode.COMPLETION:
|
||||
return self._generate_prompt_engineering_completion_prompt(
|
||||
node_data=data,
|
||||
query=query,
|
||||
@@ -425,7 +415,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
files=files,
|
||||
vision_detail=vision_detail,
|
||||
)
|
||||
elif model_mode == ModelMode.CHAT:
|
||||
if data.model.mode == LLMMode.CHAT:
|
||||
return self._generate_prompt_engineering_chat_prompt(
|
||||
node_data=data,
|
||||
query=query,
|
||||
@@ -435,15 +425,14 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
files=files,
|
||||
vision_detail=vision_detail,
|
||||
)
|
||||
else:
|
||||
raise InvalidModelModeError(f"Invalid model mode: {model_mode}")
|
||||
raise InvalidModelModeError(f"Invalid model mode: {data.model.mode}")
|
||||
|
||||
def _generate_prompt_engineering_completion_prompt(
|
||||
self,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_instance: ModelInstance,
|
||||
model_instance: PreparedLLMProtocol,
|
||||
memory: PromptMessageMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
@@ -451,7 +440,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
"""
|
||||
Generate completion prompt.
|
||||
"""
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
rest_token = self._calculate_rest_token(
|
||||
node_data=node_data,
|
||||
query=query,
|
||||
@@ -462,27 +450,20 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
prompt_template = self._get_prompt_engineering_prompt_template(
|
||||
node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token
|
||||
)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={"structure": json.dumps(node_data.get_parameter_json_schema())},
|
||||
query="",
|
||||
files=files,
|
||||
context="",
|
||||
memory_config=node_data.memory,
|
||||
# AdvancedPromptTransform is still typed against TokenBufferMemory.
|
||||
memory=cast(Any, memory),
|
||||
return self._compile_prompt_messages(
|
||||
model_instance=model_instance,
|
||||
prompt_template=prompt_template,
|
||||
files=files,
|
||||
vision_enabled=node_data.vision.enabled,
|
||||
image_detail_config=vision_detail,
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _generate_prompt_engineering_chat_prompt(
|
||||
self,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_instance: ModelInstance,
|
||||
model_instance: PreparedLLMProtocol,
|
||||
memory: PromptMessageMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
@@ -490,7 +471,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
"""
|
||||
Generate chat prompt.
|
||||
"""
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
rest_token = self._calculate_rest_token(
|
||||
node_data=node_data,
|
||||
query=query,
|
||||
@@ -508,15 +488,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
max_token_limit=rest_token,
|
||||
)
|
||||
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query="",
|
||||
files=files,
|
||||
context="",
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
prompt_messages = self._compile_prompt_messages(
|
||||
model_instance=model_instance,
|
||||
prompt_template=prompt_template,
|
||||
files=files,
|
||||
vision_enabled=node_data.vision.enabled,
|
||||
image_detail_config=vision_detail,
|
||||
)
|
||||
|
||||
@@ -717,8 +693,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
variable_pool: VariablePool,
|
||||
memory: PromptMessageMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
) -> list[ChatModelMessage]:
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
) -> list[LLMNodeChatModelMessage]:
|
||||
input_text = query
|
||||
memory_str = ""
|
||||
instruction = variable_pool.convert_template(node_data.instruction or "").text
|
||||
@@ -727,15 +702,14 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
memory_str = llm_utils.fetch_memory_text(
|
||||
memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
|
||||
)
|
||||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = ChatModelMessage(
|
||||
if node_data.model.mode == LLMMode.CHAT:
|
||||
system_prompt_messages = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction),
|
||||
)
|
||||
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
|
||||
user_prompt_message = LLMNodeChatModelMessage(role=PromptMessageRole.USER, text=input_text)
|
||||
return [system_prompt_messages, user_prompt_message]
|
||||
else:
|
||||
raise InvalidModelModeError(f"Model mode {model_mode} not support.")
|
||||
raise InvalidModelModeError(f"Model mode {node_data.model.mode} not support.")
|
||||
|
||||
def _get_prompt_engineering_prompt_template(
|
||||
self,
|
||||
@@ -744,8 +718,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
variable_pool: VariablePool,
|
||||
memory: PromptMessageMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
):
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
||||
input_text = query
|
||||
memory_str = ""
|
||||
instruction = variable_pool.convert_template(node_data.instruction or "").text
|
||||
@@ -754,64 +727,53 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
memory_str = llm_utils.fetch_memory_text(
|
||||
memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
|
||||
)
|
||||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = ChatModelMessage(
|
||||
if node_data.model.mode == LLMMode.CHAT:
|
||||
system_prompt_messages = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str, instructions=instruction),
|
||||
)
|
||||
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
|
||||
user_prompt_message = LLMNodeChatModelMessage(role=PromptMessageRole.USER, text=input_text)
|
||||
return [system_prompt_messages, user_prompt_message]
|
||||
elif model_mode == ModelMode.COMPLETION:
|
||||
return CompletionModelPromptTemplate(
|
||||
if node_data.model.mode == LLMMode.COMPLETION:
|
||||
return LLMNodeCompletionModelPromptTemplate(
|
||||
text=COMPLETION_GENERATE_JSON_PROMPT.format(
|
||||
histories=memory_str, text=input_text, instruction=instruction
|
||||
)
|
||||
.replace("{γγγ", "")
|
||||
.replace("}γγγ", "")
|
||||
.replace("{ structure }", json.dumps(node_data.get_parameter_json_schema())),
|
||||
)
|
||||
else:
|
||||
raise InvalidModelModeError(f"Model mode {model_mode} not support.")
|
||||
raise InvalidModelModeError(f"Model mode {node_data.model.mode} not support.")
|
||||
|
||||
def _calculate_rest_token(
|
||||
self,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_instance: ModelInstance,
|
||||
model_instance: PreparedLLMProtocol,
|
||||
context: str | None,
|
||||
) -> int:
|
||||
try:
|
||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
except ValueError as exc:
|
||||
raise ModelSchemaNotFoundError("Model schema not found") from exc
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
|
||||
if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
|
||||
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000)
|
||||
else:
|
||||
prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000)
|
||||
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query="",
|
||||
files=[],
|
||||
context=context,
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
prompt_messages = self._compile_prompt_messages(
|
||||
model_instance=model_instance,
|
||||
prompt_template=prompt_template,
|
||||
files=[],
|
||||
vision_enabled=False,
|
||||
context=context,
|
||||
)
|
||||
rest_tokens = 2000
|
||||
|
||||
model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
curr_message_tokens = (
|
||||
model_type_instance.get_num_tokens(
|
||||
model_instance.model_name, model_instance.credentials, prompt_messages
|
||||
)
|
||||
+ 1000
|
||||
) # add 1000 to ensure tool call messages
|
||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) + 1000
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_schema.parameter_rules:
|
||||
@@ -828,8 +790,34 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
|
||||
return rest_tokens
|
||||
|
||||
def _compile_prompt_messages(
|
||||
self,
|
||||
*,
|
||||
model_instance: PreparedLLMProtocol,
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
|
||||
files: Sequence[File],
|
||||
vision_enabled: bool,
|
||||
context: str | None = "",
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
prompt_messages, _ = LLMNode.fetch_prompt_messages(
|
||||
sys_query="",
|
||||
sys_files=files,
|
||||
context=context,
|
||||
memory=None,
|
||||
model_instance=model_instance,
|
||||
prompt_template=prompt_template,
|
||||
stop=model_instance.stop,
|
||||
memory_config=None,
|
||||
vision_enabled=vision_enabled,
|
||||
vision_detail=image_detail_config or ImagePromptMessageContent.DETAIL.HIGH,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
jinja2_variables=[],
|
||||
)
|
||||
return list(prompt_messages)
|
||||
|
||||
@property
|
||||
def model_instance(self) -> ModelInstance:
|
||||
def model_instance(self) -> PreparedLLMProtocol:
|
||||
return self._model_instance
|
||||
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user