mirror of
https://github.com/langgenius/dify.git
synced 2026-05-31 19:00:22 -04:00
feat: extract model runtime
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -1,41 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Protocol
|
||||
|
||||
from dify_graph.nodes.code.code_node import WorkflowCodeExecutor
|
||||
from dify_graph.nodes.code.entities import CodeLanguage
|
||||
|
||||
|
||||
class TemplateRenderError(ValueError):
|
||||
"""Raised when rendering a Jinja2 template fails."""
|
||||
|
||||
|
||||
class Jinja2TemplateRenderer(Protocol):
|
||||
"""Render Jinja2 templates for template transform nodes."""
|
||||
|
||||
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
|
||||
"""Render a Jinja2 template with provided variables."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer):
|
||||
"""Adapter that renders Jinja2 templates via CodeExecutor."""
|
||||
|
||||
_code_executor: WorkflowCodeExecutor
|
||||
|
||||
def __init__(self, code_executor: WorkflowCodeExecutor) -> None:
|
||||
self._code_executor = code_executor
|
||||
|
||||
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
|
||||
try:
|
||||
result = self._code_executor.execute(language=CodeLanguage.JINJA2, code=template, inputs=variables)
|
||||
except Exception as exc:
|
||||
if self._code_executor.is_execution_error(exc):
|
||||
raise TemplateRenderError(str(exc)) from exc
|
||||
raise
|
||||
|
||||
rendered = result.get("result")
|
||||
if not isinstance(rendered, str):
|
||||
raise TemplateRenderError("Template render result must be a string.")
|
||||
return rendered
|
||||
@@ -4,9 +4,10 @@ from typing import TYPE_CHECKING, Any
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.entities import VariableSelector
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.template_transform.entities import TemplateTransformNodeData
|
||||
from dify_graph.nodes.template_transform.template_renderer import (
|
||||
from dify_graph.template_rendering import (
|
||||
Jinja2TemplateRenderer,
|
||||
TemplateRenderError,
|
||||
)
|
||||
@@ -20,7 +21,7 @@ DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH = 400_000
|
||||
|
||||
class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
node_type = BuiltinNodeTypes.TEMPLATE_TRANSFORM
|
||||
_template_renderer: Jinja2TemplateRenderer
|
||||
_jinja2_template_renderer: Jinja2TemplateRenderer
|
||||
_max_output_length: int
|
||||
|
||||
def __init__(
|
||||
@@ -30,7 +31,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
template_renderer: Jinja2TemplateRenderer,
|
||||
jinja2_template_renderer: Jinja2TemplateRenderer,
|
||||
max_output_length: int | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
@@ -39,7 +40,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._template_renderer = template_renderer
|
||||
self._jinja2_template_renderer = jinja2_template_renderer
|
||||
|
||||
if max_output_length is not None and max_output_length <= 0:
|
||||
raise ValueError("max_output_length must be a positive integer")
|
||||
@@ -70,7 +71,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
variables[variable_name] = value.to_object() if value else None
|
||||
# Run code
|
||||
try:
|
||||
rendered = self._template_renderer.render_template(self.node_data.template, variables)
|
||||
rendered = self._jinja2_template_renderer.render_template(self.node_data.template, variables)
|
||||
except TemplateRenderError as e:
|
||||
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
|
||||
|
||||
@@ -87,9 +88,32 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: TemplateTransformNodeData | Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
return {
|
||||
node_id + "." + variable_selector.variable: variable_selector.value_selector
|
||||
for variable_selector in node_data.variables
|
||||
}
|
||||
_ = graph_config
|
||||
raw_variables = (
|
||||
node_data.variables if isinstance(node_data, TemplateTransformNodeData) else node_data.get("variables", [])
|
||||
)
|
||||
variable_mapping: dict[str, Sequence[str]] = {}
|
||||
for variable_selector in raw_variables:
|
||||
if isinstance(variable_selector, VariableSelector):
|
||||
variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector
|
||||
continue
|
||||
|
||||
if not isinstance(variable_selector, Mapping):
|
||||
continue
|
||||
|
||||
variable = variable_selector.get("variable")
|
||||
value_selector = variable_selector.get("value_selector")
|
||||
if (
|
||||
isinstance(variable, str)
|
||||
and isinstance(value_selector, Sequence)
|
||||
and all(isinstance(selector_part, str) for selector_part in value_selector)
|
||||
):
|
||||
variable_mapping[node_id + "." + variable] = list(value_selector)
|
||||
|
||||
return variable_mapping
|
||||
|
||||
Reference in New Issue
Block a user