refactor: Extract dify_graph variable conversion helpers

This commit is contained in:
-LAN-
2026-03-16 04:23:09 +08:00
parent fbb74a4af9
commit 93a5ad3d08
7 changed files with 256 additions and 237 deletions

View File

@@ -16,11 +16,10 @@ from dify_graph.constants import (
)
from dify_graph.file import File, FileAttribute, file_manager
from dify_graph.system_variable import SystemVariable
from dify_graph.variables import Segment, SegmentGroup, VariableBase
from dify_graph.variables import Segment, SegmentGroup, VariableBase, build_segment, segment_to_variable
from dify_graph.variables.consts import SELECTORS_LENGTH
from dify_graph.variables.segments import FileSegment, ObjectSegment
from dify_graph.variables.variables import RAGPipelineVariableInput, Variable
from factories import variable_factory
VariableValue = Union[str, int, float, dict[str, object], list[object], File]
@@ -114,10 +113,10 @@ class VariablePool(BaseModel):
if isinstance(value, VariableBase):
variable = value
elif isinstance(value, Segment):
variable = variable_factory.segment_to_variable(segment=value, selector=selector)
variable = segment_to_variable(segment=value, selector=selector)
else:
segment = variable_factory.build_segment(value)
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
segment = build_segment(value)
variable = segment_to_variable(segment=segment, selector=selector)
node_id, name = self._selector_to_keys(selector)
# Based on the definition of `Variable`,
@@ -180,7 +179,7 @@ class VariablePool(BaseModel):
return None
attr = FileAttribute(attr)
attr_value = file_manager.get_attr(file=segment.value, attr=attr)
return variable_factory.build_segment(attr_value)
return build_segment(attr_value)
# Navigate through nested attributes
result: Any = segment
@@ -191,7 +190,7 @@ class VariablePool(BaseModel):
return None
# Return result as Segment
return result if isinstance(result, Segment) else variable_factory.build_segment(result)
return result if isinstance(result, Segment) else build_segment(result)
def _extract_value(self, obj: Any):
"""Extract the actual value from an ObjectSegment."""
@@ -212,7 +211,7 @@ class VariablePool(BaseModel):
"""
if not isinstance(obj, dict) or attr not in obj:
return None
return variable_factory.build_segment(obj.get(attr))
return build_segment(obj.get(attr))
def remove(self, selector: Sequence[str], /):
"""
@@ -239,7 +238,7 @@ class VariablePool(BaseModel):
if "." in part and (variable := self.get(part.split("."))):
segments.append(variable)
else:
segments.append(variable_factory.build_segment(part))
segments.append(build_segment(part))
return SegmentGroup(value=segments)
def get_file(self, selector: Sequence[str], /) -> FileSegment | None: