feat: extract model runtime

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2026-03-15 15:34:47 +08:00
parent 3d5a29462e
commit fbb74a4af9
178 changed files with 4343 additions and 2134 deletions

View File

@@ -1,41 +0,0 @@
from __future__ import annotations
from collections.abc import Mapping
from typing import Any, Protocol
from dify_graph.nodes.code.code_node import WorkflowCodeExecutor
from dify_graph.nodes.code.entities import CodeLanguage
class TemplateRenderError(ValueError):
"""Raised when rendering a Jinja2 template fails."""
class Jinja2TemplateRenderer(Protocol):
"""Render Jinja2 templates for template transform nodes."""
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
"""Render a Jinja2 template with provided variables."""
raise NotImplementedError
class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer):
"""Adapter that renders Jinja2 templates via CodeExecutor."""
_code_executor: WorkflowCodeExecutor
def __init__(self, code_executor: WorkflowCodeExecutor) -> None:
self._code_executor = code_executor
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
try:
result = self._code_executor.execute(language=CodeLanguage.JINJA2, code=template, inputs=variables)
except Exception as exc:
if self._code_executor.is_execution_error(exc):
raise TemplateRenderError(str(exc)) from exc
raise
rendered = result.get("result")
if not isinstance(rendered, str):
raise TemplateRenderError("Template render result must be a string.")
return rendered

View File

@@ -4,9 +4,10 @@ from typing import TYPE_CHECKING, Any
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.entities import VariableSelector
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.template_transform.entities import TemplateTransformNodeData
from dify_graph.nodes.template_transform.template_renderer import (
from dify_graph.template_rendering import (
Jinja2TemplateRenderer,
TemplateRenderError,
)
@@ -20,7 +21,7 @@ DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH = 400_000
class TemplateTransformNode(Node[TemplateTransformNodeData]):
node_type = BuiltinNodeTypes.TEMPLATE_TRANSFORM
_template_renderer: Jinja2TemplateRenderer
_jinja2_template_renderer: Jinja2TemplateRenderer
_max_output_length: int
def __init__(
@@ -30,7 +31,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
template_renderer: Jinja2TemplateRenderer,
jinja2_template_renderer: Jinja2TemplateRenderer,
max_output_length: int | None = None,
) -> None:
super().__init__(
@@ -39,7 +40,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._template_renderer = template_renderer
self._jinja2_template_renderer = jinja2_template_renderer
if max_output_length is not None and max_output_length <= 0:
raise ValueError("max_output_length must be a positive integer")
@@ -70,7 +71,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
variables[variable_name] = value.to_object() if value else None
# Run code
try:
rendered = self._template_renderer.render_template(self.node_data.template, variables)
rendered = self._jinja2_template_renderer.render_template(self.node_data.template, variables)
except TemplateRenderError as e:
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
@@ -87,9 +88,32 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: TemplateTransformNodeData | Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
return {
node_id + "." + variable_selector.variable: variable_selector.value_selector
for variable_selector in node_data.variables
}
_ = graph_config
raw_variables = (
node_data.variables if isinstance(node_data, TemplateTransformNodeData) else node_data.get("variables", [])
)
variable_mapping: dict[str, Sequence[str]] = {}
for variable_selector in raw_variables:
if isinstance(variable_selector, VariableSelector):
variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector
continue
if not isinstance(variable_selector, Mapping):
continue
variable = variable_selector.get("variable")
value_selector = variable_selector.get("value_selector")
if (
isinstance(variable, str)
and isinstance(value_selector, Sequence)
and all(isinstance(selector_part, str) for selector_part in value_selector)
):
variable_mapping[node_id + "." + variable] = list(value_selector)
return variable_mapping