mirror of
https://github.com/langgenius/dify.git
synced 2025-12-25 01:00:42 -05:00
refactor: decouple Node and NodeData (#22581)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
This commit is contained in:
@@ -17,7 +17,8 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
||||
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
|
||||
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
|
||||
@@ -15,7 +15,8 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
||||
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
|
||||
from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
|
||||
|
||||
@@ -169,7 +169,3 @@ class AppQueueManager:
|
||||
raise TypeError(
|
||||
"Critical Error: Passing SQLAlchemy Model instances that cause thread safety issues is not allowed."
|
||||
)
|
||||
|
||||
|
||||
class GenerateTaskStoppedError(Exception):
|
||||
pass
|
||||
|
||||
@@ -118,7 +118,7 @@ class AppRunner:
|
||||
else:
|
||||
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
|
||||
|
||||
model_mode = ModelMode.value_of(model_config.mode)
|
||||
model_mode = ModelMode(model_config.mode)
|
||||
prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]
|
||||
if model_mode == ModelMode.COMPLETION:
|
||||
advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template
|
||||
|
||||
@@ -11,10 +11,11 @@ from configs import dify_config
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
|
||||
from core.app.apps.chat.app_runner import ChatAppRunner
|
||||
from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
|
||||
|
||||
@@ -10,10 +10,11 @@ from pydantic import ValidationError
|
||||
from configs import dify_config
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
|
||||
from core.app.apps.completion.app_runner import CompletionAppRunner
|
||||
from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
|
||||
|
||||
2
api/core/app/apps/exc.py
Normal file
2
api/core/app/apps/exc.py
Normal file
@@ -0,0 +1,2 @@
|
||||
class GenerateTaskStoppedError(Exception):
|
||||
pass
|
||||
@@ -6,7 +6,8 @@ from typing import Optional, Union, cast
|
||||
|
||||
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
|
||||
@@ -13,7 +13,8 @@ import contexts
|
||||
from configs import dify_config
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
|
||||
from core.app.apps.workflow.app_runner import WorkflowAppRunner
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
|
||||
@@ -29,19 +29,6 @@ class ModelMode(enum.StrEnum):
|
||||
COMPLETION = "completion"
|
||||
CHAT = "chat"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ModelMode":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
prompt_file_contents: dict[str, Any] = {}
|
||||
|
||||
@@ -65,7 +52,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
inputs = {key: str(value) for key, value in inputs.items()}
|
||||
|
||||
model_mode = ModelMode.value_of(model_config.mode)
|
||||
model_mode = ModelMode(model_config.mode)
|
||||
if model_mode == ModelMode.CHAT:
|
||||
prompt_messages, stops = self._get_chat_model_prompt_messages(
|
||||
app_mode=app_mode,
|
||||
|
||||
@@ -1137,7 +1137,7 @@ class DatasetRetrieval:
|
||||
def _get_prompt_template(
|
||||
self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list, query: str
|
||||
):
|
||||
model_mode = ModelMode.value_of(mode)
|
||||
model_mode = ModelMode(mode)
|
||||
input_text = query
|
||||
|
||||
prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]
|
||||
|
||||
@@ -2,7 +2,7 @@ from core.workflow.nodes.base import BaseNode
|
||||
|
||||
|
||||
class WorkflowNodeRunFailedError(Exception):
|
||||
def __init__(self, node_instance: BaseNode, error: str):
|
||||
self.node_instance = node_instance
|
||||
self.error = error
|
||||
super().__init__(f"Node {node_instance.node_data.title} run failed: {error}")
|
||||
def __init__(self, node: BaseNode, err_msg: str):
|
||||
self._node = node
|
||||
self._error = err_msg
|
||||
super().__init__(f"Node {node.title} run failed: {err_msg}")
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from .entities import Graph, GraphInitParams, GraphRuntimeState, RuntimeRouteState
|
||||
from .graph_engine import GraphEngine
|
||||
|
||||
__all__ = ["Graph", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]
|
||||
__all__ = ["Graph", "GraphEngine", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]
|
||||
|
||||
@@ -12,7 +12,7 @@ from typing import Any, Optional, cast
|
||||
from flask import Flask, current_app
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
@@ -48,11 +48,9 @@ from core.workflow.nodes.agent.entities import AgentNodeData
|
||||
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
||||
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.utils import variable_utils
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models.enums import UserFrom
|
||||
@@ -260,12 +258,16 @@ class GraphEngine:
|
||||
# convert to specific node
|
||||
node_type = NodeType(node_config.get("data", {}).get("type"))
|
||||
node_version = node_config.get("data", {}).get("version", "1")
|
||||
|
||||
# Import here to avoid circular import
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None
|
||||
|
||||
# init workflow run state
|
||||
node_instance = node_cls( # type: ignore
|
||||
node = node_cls(
|
||||
id=route_node_state.id,
|
||||
config=node_config,
|
||||
graph_init_params=self.init_params,
|
||||
@@ -274,11 +276,11 @@ class GraphEngine:
|
||||
previous_node_id=previous_node_id,
|
||||
thread_pool_id=self.thread_pool_id,
|
||||
)
|
||||
node_instance = cast(BaseNode[BaseNodeData], node_instance)
|
||||
node.init_node_data(node_config.get("data", {}))
|
||||
try:
|
||||
# run node
|
||||
generator = self._run_node(
|
||||
node_instance=node_instance,
|
||||
node=node,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=in_parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
@@ -306,16 +308,16 @@ class GraphEngine:
|
||||
route_node_state.failed_reason = str(e)
|
||||
yield NodeRunFailedEvent(
|
||||
error=str(e),
|
||||
id=node_instance.id,
|
||||
id=node.id,
|
||||
node_id=next_node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_instance.node_data,
|
||||
node_data=node.get_base_node_data(),
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=in_parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
node_version=node.version(),
|
||||
)
|
||||
raise e
|
||||
|
||||
@@ -337,7 +339,7 @@ class GraphEngine:
|
||||
edge = edge_mappings[0]
|
||||
if (
|
||||
previous_route_node_state.status == RouteNodeState.Status.EXCEPTION
|
||||
and node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
|
||||
and node.error_strategy == ErrorStrategy.FAIL_BRANCH
|
||||
and edge.run_condition is None
|
||||
):
|
||||
break
|
||||
@@ -413,8 +415,8 @@ class GraphEngine:
|
||||
|
||||
next_node_id = final_node_id
|
||||
elif (
|
||||
node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
|
||||
and node_instance.should_continue_on_error
|
||||
node.continue_on_error
|
||||
and node.error_strategy == ErrorStrategy.FAIL_BRANCH
|
||||
and previous_route_node_state.status == RouteNodeState.Status.EXCEPTION
|
||||
):
|
||||
break
|
||||
@@ -597,7 +599,7 @@ class GraphEngine:
|
||||
|
||||
def _run_node(
|
||||
self,
|
||||
node_instance: BaseNode[BaseNodeData],
|
||||
node: BaseNode,
|
||||
route_node_state: RouteNodeState,
|
||||
parallel_id: Optional[str] = None,
|
||||
parallel_start_node_id: Optional[str] = None,
|
||||
@@ -611,29 +613,29 @@ class GraphEngine:
|
||||
# trigger node run start event
|
||||
agent_strategy = (
|
||||
AgentNodeStrategyInit(
|
||||
name=cast(AgentNodeData, node_instance.node_data).agent_strategy_name,
|
||||
icon=cast(AgentNode, node_instance).agent_strategy_icon,
|
||||
name=cast(AgentNodeData, node.get_base_node_data()).agent_strategy_name,
|
||||
icon=cast(AgentNode, node).agent_strategy_icon,
|
||||
)
|
||||
if node_instance.node_type == NodeType.AGENT
|
||||
if node.type_ == NodeType.AGENT
|
||||
else None
|
||||
)
|
||||
yield NodeRunStartedEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
id=node.id,
|
||||
node_id=node.node_id,
|
||||
node_type=node.type_,
|
||||
node_data=node.get_base_node_data(),
|
||||
route_node_state=route_node_state,
|
||||
predecessor_node_id=node_instance.previous_node_id,
|
||||
predecessor_node_id=node.previous_node_id,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
agent_strategy=agent_strategy,
|
||||
node_version=node_instance.version(),
|
||||
node_version=node.version(),
|
||||
)
|
||||
|
||||
max_retries = node_instance.node_data.retry_config.max_retries
|
||||
retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
|
||||
max_retries = node.retry_config.max_retries
|
||||
retry_interval = node.retry_config.retry_interval_seconds
|
||||
retries = 0
|
||||
should_continue_retry = True
|
||||
while should_continue_retry and retries <= max_retries:
|
||||
@@ -642,7 +644,7 @@ class GraphEngine:
|
||||
retry_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
# yield control to other threads
|
||||
time.sleep(0.001)
|
||||
event_stream = node_instance.run()
|
||||
event_stream = node.run()
|
||||
for event in event_stream:
|
||||
if isinstance(event, GraphEngineEvent):
|
||||
# add parallel info to iteration event
|
||||
@@ -658,21 +660,21 @@ class GraphEngine:
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
if (
|
||||
retries == max_retries
|
||||
and node_instance.node_type == NodeType.HTTP_REQUEST
|
||||
and node.type_ == NodeType.HTTP_REQUEST
|
||||
and run_result.outputs
|
||||
and not node_instance.should_continue_on_error
|
||||
and not node.continue_on_error
|
||||
):
|
||||
run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
if node_instance.should_retry and retries < max_retries:
|
||||
if node.retry and retries < max_retries:
|
||||
retries += 1
|
||||
route_node_state.node_run_result = run_result
|
||||
yield NodeRunRetryEvent(
|
||||
id=str(uuid.uuid4()),
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
node_id=node.node_id,
|
||||
node_type=node.type_,
|
||||
node_data=node.get_base_node_data(),
|
||||
route_node_state=route_node_state,
|
||||
predecessor_node_id=node_instance.previous_node_id,
|
||||
predecessor_node_id=node.previous_node_id,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
@@ -680,17 +682,17 @@ class GraphEngine:
|
||||
error=run_result.error or "Unknown error",
|
||||
retry_index=retries,
|
||||
start_at=retry_start_at,
|
||||
node_version=node_instance.version(),
|
||||
node_version=node.version(),
|
||||
)
|
||||
time.sleep(retry_interval)
|
||||
break
|
||||
route_node_state.set_finished(run_result=run_result)
|
||||
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
if node_instance.should_continue_on_error:
|
||||
if node.continue_on_error:
|
||||
# if run failed, handle error
|
||||
run_result = self._handle_continue_on_error(
|
||||
node_instance,
|
||||
node,
|
||||
event.run_result,
|
||||
self.graph_runtime_state.variable_pool,
|
||||
handle_exceptions=handle_exceptions,
|
||||
@@ -701,44 +703,44 @@ class GraphEngine:
|
||||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
self._append_variables_recursively(
|
||||
node_id=node_instance.node_id,
|
||||
node_id=node.node_id,
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value,
|
||||
)
|
||||
yield NodeRunExceptionEvent(
|
||||
error=run_result.error or "System Error",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
id=node.id,
|
||||
node_id=node.node_id,
|
||||
node_type=node.type_,
|
||||
node_data=node.get_base_node_data(),
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
node_version=node.version(),
|
||||
)
|
||||
should_continue_retry = False
|
||||
else:
|
||||
yield NodeRunFailedEvent(
|
||||
error=route_node_state.failed_reason or "Unknown error.",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
id=node.id,
|
||||
node_id=node.node_id,
|
||||
node_type=node.type_,
|
||||
node_data=node.get_base_node_data(),
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
node_version=node.version(),
|
||||
)
|
||||
should_continue_retry = False
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
if (
|
||||
node_instance.should_continue_on_error
|
||||
and self.graph.edge_mapping.get(node_instance.node_id)
|
||||
and node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH
|
||||
node.continue_on_error
|
||||
and self.graph.edge_mapping.get(node.node_id)
|
||||
and node.error_strategy is ErrorStrategy.FAIL_BRANCH
|
||||
):
|
||||
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
|
||||
if run_result.metadata and run_result.metadata.get(
|
||||
@@ -758,7 +760,7 @@ class GraphEngine:
|
||||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
self._append_variables_recursively(
|
||||
node_id=node_instance.node_id,
|
||||
node_id=node.node_id,
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value,
|
||||
)
|
||||
@@ -783,26 +785,26 @@ class GraphEngine:
|
||||
run_result.metadata = metadata_dict
|
||||
|
||||
yield NodeRunSucceededEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
id=node.id,
|
||||
node_id=node.node_id,
|
||||
node_type=node.type_,
|
||||
node_data=node.get_base_node_data(),
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
node_version=node.version(),
|
||||
)
|
||||
should_continue_retry = False
|
||||
|
||||
break
|
||||
elif isinstance(event, RunStreamChunkEvent):
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
id=node.id,
|
||||
node_id=node.node_id,
|
||||
node_type=node.type_,
|
||||
node_data=node.get_base_node_data(),
|
||||
chunk_content=event.chunk_content,
|
||||
from_variable_selector=event.from_variable_selector,
|
||||
route_node_state=route_node_state,
|
||||
@@ -810,14 +812,14 @@ class GraphEngine:
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
node_version=node.version(),
|
||||
)
|
||||
elif isinstance(event, RunRetrieverResourceEvent):
|
||||
yield NodeRunRetrieverResourceEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
id=node.id,
|
||||
node_id=node.node_id,
|
||||
node_type=node.type_,
|
||||
node_data=node.get_base_node_data(),
|
||||
retriever_resources=event.retriever_resources,
|
||||
context=event.context,
|
||||
route_node_state=route_node_state,
|
||||
@@ -825,7 +827,7 @@ class GraphEngine:
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
node_version=node.version(),
|
||||
)
|
||||
except GenerateTaskStoppedError:
|
||||
# trigger node run failed event
|
||||
@@ -833,20 +835,20 @@ class GraphEngine:
|
||||
route_node_state.failed_reason = "Workflow stopped."
|
||||
yield NodeRunFailedEvent(
|
||||
error="Workflow stopped.",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
id=node.id,
|
||||
node_id=node.node_id,
|
||||
node_type=node.type_,
|
||||
node_data=node.get_base_node_data(),
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
node_version=node.version(),
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception(f"Node {node_instance.node_data.title} run failed")
|
||||
logger.exception(f"Node {node.title} run failed")
|
||||
raise e
|
||||
|
||||
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
|
||||
@@ -886,22 +888,14 @@ class GraphEngine:
|
||||
|
||||
def _handle_continue_on_error(
|
||||
self,
|
||||
node_instance: BaseNode[BaseNodeData],
|
||||
node: BaseNode,
|
||||
error_result: NodeRunResult,
|
||||
variable_pool: VariablePool,
|
||||
handle_exceptions: list[str] = [],
|
||||
) -> NodeRunResult:
|
||||
"""
|
||||
handle continue on error when self._should_continue_on_error is True
|
||||
|
||||
|
||||
:param error_result (NodeRunResult): error run result
|
||||
:param variable_pool (VariablePool): variable pool
|
||||
:return: excption run result
|
||||
"""
|
||||
# add error message and error type to variable pool
|
||||
variable_pool.add([node_instance.node_id, "error_message"], error_result.error)
|
||||
variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type)
|
||||
variable_pool.add([node.node_id, "error_message"], error_result.error)
|
||||
variable_pool.add([node.node_id, "error_type"], error_result.error_type)
|
||||
# add error message to handle_exceptions
|
||||
handle_exceptions.append(error_result.error or "")
|
||||
node_error_args: dict[str, Any] = {
|
||||
@@ -909,21 +903,21 @@ class GraphEngine:
|
||||
"error": error_result.error,
|
||||
"inputs": error_result.inputs,
|
||||
"metadata": {
|
||||
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy,
|
||||
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node.error_strategy,
|
||||
},
|
||||
}
|
||||
|
||||
if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
|
||||
if node.error_strategy is ErrorStrategy.DEFAULT_VALUE:
|
||||
return NodeRunResult(
|
||||
**node_error_args,
|
||||
outputs={
|
||||
**node_instance.node_data.default_value_dict,
|
||||
**node.default_value_dict,
|
||||
"error_message": error_result.error,
|
||||
"error_type": error_result.error_type,
|
||||
},
|
||||
)
|
||||
elif node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH:
|
||||
if self.graph.edge_mapping.get(node_instance.node_id):
|
||||
elif node.error_strategy is ErrorStrategy.FAIL_BRANCH:
|
||||
if self.graph.edge_mapping.get(node.node_id):
|
||||
node_error_args["edge_source_handle"] = FailBranchSourceHandle.FAILED
|
||||
return NodeRunResult(
|
||||
**node_error_args,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
@@ -11,8 +10,10 @@ from sqlalchemy.orm import Session
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.agent.strategy.plugin import PluginAgentStrategy
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
@@ -25,45 +26,75 @@ from core.tools.entities.tool_entities import (
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.variables.segments import StringSegment
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.variables.segments import ArrayFileSegment, StringSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import AgentLogEvent
|
||||
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event.event import RunCompletedEvent
|
||||
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from factories.agent_factory import get_plugin_agent_strategy
|
||||
from models import ToolFile
|
||||
from models.model import Conversation
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
from .exc import (
|
||||
AgentInputTypeError,
|
||||
AgentInvocationError,
|
||||
AgentMessageTransformError,
|
||||
AgentVariableNotFoundError,
|
||||
AgentVariableTypeError,
|
||||
ToolFileNotFoundError,
|
||||
)
|
||||
|
||||
|
||||
class AgentNode(ToolNode):
|
||||
class AgentNode(BaseNode):
|
||||
"""
|
||||
Agent Node
|
||||
"""
|
||||
|
||||
_node_data_cls = AgentNodeData # type: ignore
|
||||
_node_type = NodeType.AGENT
|
||||
_node_data: AgentNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = AgentNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""
|
||||
Run the agent node
|
||||
"""
|
||||
node_data = cast(AgentNodeData, self.node_data)
|
||||
|
||||
try:
|
||||
strategy = get_plugin_agent_strategy(
|
||||
tenant_id=self.tenant_id,
|
||||
agent_strategy_provider_name=node_data.agent_strategy_provider_name,
|
||||
agent_strategy_name=node_data.agent_strategy_name,
|
||||
agent_strategy_provider_name=self._node_data.agent_strategy_provider_name,
|
||||
agent_strategy_name=self._node_data.agent_strategy_name,
|
||||
)
|
||||
except Exception as e:
|
||||
yield RunCompletedEvent(
|
||||
@@ -81,13 +112,13 @@ class AgentNode(ToolNode):
|
||||
parameters = self._generate_agent_parameters(
|
||||
agent_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=node_data,
|
||||
node_data=self._node_data,
|
||||
strategy=strategy,
|
||||
)
|
||||
parameters_for_log = self._generate_agent_parameters(
|
||||
agent_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=node_data,
|
||||
node_data=self._node_data,
|
||||
for_log=True,
|
||||
strategy=strategy,
|
||||
)
|
||||
@@ -105,59 +136,39 @@ class AgentNode(ToolNode):
|
||||
credentials=credentials,
|
||||
)
|
||||
except Exception as e:
|
||||
error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e)
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
error=f"Failed to invoke agent: {str(e)}",
|
||||
error=str(error),
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# convert tool messages
|
||||
agent_thoughts: list = []
|
||||
|
||||
thought_log_message = ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LOG,
|
||||
message=ToolInvokeMessage.LogMessage(
|
||||
id=str(uuid.uuid4()),
|
||||
label=f"Agent Strategy: {cast(AgentNodeData, self.node_data).agent_strategy_name}",
|
||||
parent_id=None,
|
||||
error=None,
|
||||
status=ToolInvokeMessage.LogMessage.LogStatus.START,
|
||||
data={
|
||||
"strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
|
||||
"parameters": parameters_for_log,
|
||||
"thought_process": "Agent strategy execution started",
|
||||
},
|
||||
metadata={
|
||||
"icon": self.agent_strategy_icon,
|
||||
"agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
def enhanced_message_stream():
|
||||
yield thought_log_message
|
||||
|
||||
yield from message_stream
|
||||
|
||||
yield from self._transform_message(
|
||||
message_stream,
|
||||
{
|
||||
messages=message_stream,
|
||||
tool_info={
|
||||
"icon": self.agent_strategy_icon,
|
||||
"agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
|
||||
"agent_strategy": cast(AgentNodeData, self._node_data).agent_strategy_name,
|
||||
},
|
||||
parameters_for_log,
|
||||
agent_thoughts,
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
node_type=self.type_,
|
||||
node_id=self.node_id,
|
||||
node_execution_id=self.id,
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
transform_error = AgentMessageTransformError(
|
||||
f"Failed to transform agent message: {str(e)}", original_error=e
|
||||
)
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
error=f"Failed to transform agent message: {str(e)}",
|
||||
error=str(transform_error),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -194,7 +205,7 @@ class AgentNode(ToolNode):
|
||||
if agent_input.type == "variable":
|
||||
variable = variable_pool.get(agent_input.value) # type: ignore
|
||||
if variable is None:
|
||||
raise ValueError(f"Variable {agent_input.value} does not exist")
|
||||
raise AgentVariableNotFoundError(str(agent_input.value))
|
||||
parameter_value = variable.value
|
||||
elif agent_input.type in {"mixed", "constant"}:
|
||||
# variable_pool.convert_template expects a string template,
|
||||
@@ -216,7 +227,7 @@ class AgentNode(ToolNode):
|
||||
except json.JSONDecodeError:
|
||||
parameter_value = parameter_value
|
||||
else:
|
||||
raise ValueError(f"Unknown agent input type '{agent_input.type}'")
|
||||
raise AgentInputTypeError(agent_input.type)
|
||||
value = parameter_value
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
@@ -259,7 +270,7 @@ class AgentNode(ToolNode):
|
||||
)
|
||||
|
||||
extra = tool.get("extra", {})
|
||||
runtime_variable_pool = variable_pool if self.node_data.version != "1" else None
|
||||
runtime_variable_pool = variable_pool if self._node_data.version != "1" else None
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool
|
||||
)
|
||||
@@ -343,19 +354,14 @@ class AgentNode(ToolNode):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: BaseNodeData,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
node_data = cast(AgentNodeData, node_data)
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = AgentNodeData.model_validate(node_data)
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in node_data.agent_parameters:
|
||||
input = node_data.agent_parameters[parameter_name]
|
||||
for parameter_name in typed_node_data.agent_parameters:
|
||||
input = typed_node_data.agent_parameters[parameter_name]
|
||||
if input.type in ["mixed", "constant"]:
|
||||
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
@@ -380,7 +386,7 @@ class AgentNode(ToolNode):
|
||||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}"
|
||||
== cast(AgentNodeData, self.node_data).agent_strategy_provider_name
|
||||
== cast(AgentNodeData, self._node_data).agent_strategy_provider_name
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
@@ -448,3 +454,236 @@ class AgentNode(ToolNode):
|
||||
return tools
|
||||
else:
|
||||
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP.value]
|
||||
|
||||
def _transform_message(
|
||||
self,
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
tool_info: Mapping[str, Any],
|
||||
parameters_for_log: dict[str, Any],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
node_type: NodeType,
|
||||
node_id: str,
|
||||
node_execution_id: str,
|
||||
) -> Generator:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json: list[dict] = []
|
||||
|
||||
agent_logs: list[AgentLogEvent] = []
|
||||
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
|
||||
llm_usage: LLMUsage | None = None
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
if message.type in {
|
||||
ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
ToolInvokeMessage.MessageType.BINARY_LINK,
|
||||
ToolInvokeMessage.MessageType.IMAGE,
|
||||
}:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
if message.meta:
|
||||
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
||||
else:
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
tool_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileNotFoundError(tool_file_id)
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileNotFoundError(tool_file_id)
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield RunStreamChunkEvent(chunk_content=message.message.text, from_variable_selector=[node_id, "text"])
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||
if node_type == NodeType.AGENT:
|
||||
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
|
||||
llm_usage = LLMUsage.from_metadata(msg_metadata)
|
||||
agent_execution_metadata = {
|
||||
WorkflowNodeExecutionMetadataKey(key): value
|
||||
for key, value in msg_metadata.items()
|
||||
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||
}
|
||||
if message.message.json_object is not None:
|
||||
json.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])
|
||||
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise AgentVariableTypeError(
|
||||
"When 'stream' is True, 'variable_value' must be a string.",
|
||||
variable_name=variable_name,
|
||||
expected_type="str",
|
||||
actual_type=type(variable_value).__name__,
|
||||
)
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=variable_value, from_variable_selector=[node_id, variable_name]
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
assert isinstance(message.meta, File)
|
||||
files.append(message.meta["file"])
|
||||
elif message.type == ToolInvokeMessage.MessageType.LOG:
|
||||
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
|
||||
if message.message.metadata:
|
||||
icon = tool_info.get("icon", "")
|
||||
dict_metadata = dict(message.message.metadata)
|
||||
if dict_metadata.get("provider"):
|
||||
manager = PluginInstaller()
|
||||
plugins = manager.list_plugins(tenant_id)
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
pass
|
||||
icon_dark = None
|
||||
try:
|
||||
builtin_tool = next(
|
||||
provider
|
||||
for provider in BuiltinToolManageService.list_builtin_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
)
|
||||
if provider.name == dict_metadata["provider"]
|
||||
)
|
||||
icon = builtin_tool.icon
|
||||
icon_dark = builtin_tool.icon_dark
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
dict_metadata["icon"] = icon
|
||||
dict_metadata["icon_dark"] = icon_dark
|
||||
message.message.metadata = dict_metadata
|
||||
agent_log = AgentLogEvent(
|
||||
id=message.message.id,
|
||||
node_execution_id=node_execution_id,
|
||||
parent_id=message.message.parent_id,
|
||||
error=message.message.error,
|
||||
status=message.message.status.value,
|
||||
data=message.message.data,
|
||||
label=message.message.label,
|
||||
metadata=message.message.metadata,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
# check if the agent log is already in the list
|
||||
for log in agent_logs:
|
||||
if log.id == agent_log.id:
|
||||
# update the log
|
||||
log.data = agent_log.data
|
||||
log.status = agent_log.status
|
||||
log.error = agent_log.error
|
||||
log.label = agent_log.label
|
||||
log.metadata = agent_log.metadata
|
||||
break
|
||||
else:
|
||||
agent_logs.append(agent_log)
|
||||
|
||||
yield agent_log
|
||||
|
||||
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
|
||||
json_output: list[dict[str, Any]] = []
|
||||
|
||||
# Step 1: append each agent log as its own dict.
|
||||
if agent_logs:
|
||||
for log in agent_logs:
|
||||
json_output.append(
|
||||
{
|
||||
"id": log.id,
|
||||
"parent_id": log.parent_id,
|
||||
"error": log.error,
|
||||
"status": log.status,
|
||||
"data": log.data,
|
||||
"label": log.label,
|
||||
"metadata": log.metadata,
|
||||
"node_id": log.node_id,
|
||||
}
|
||||
)
|
||||
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
|
||||
if json:
|
||||
json_output.extend(json)
|
||||
else:
|
||||
json_output.append({"data": []})
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
|
||||
metadata={
|
||||
**agent_execution_metadata,
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
|
||||
},
|
||||
inputs=parameters_for_log,
|
||||
llm_usage=llm_usage,
|
||||
)
|
||||
)
|
||||
|
||||
124
api/core/workflow/nodes/agent/exc.py
Normal file
124
api/core/workflow/nodes/agent/exc.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class AgentNodeError(Exception):
|
||||
"""Base exception for all agent node errors."""
|
||||
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class AgentStrategyError(AgentNodeError):
|
||||
"""Exception raised when there's an error with the agent strategy."""
|
||||
|
||||
def __init__(self, message: str, strategy_name: Optional[str] = None, provider_name: Optional[str] = None):
|
||||
self.strategy_name = strategy_name
|
||||
self.provider_name = provider_name
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentStrategyNotFoundError(AgentStrategyError):
|
||||
"""Exception raised when the specified agent strategy is not found."""
|
||||
|
||||
def __init__(self, strategy_name: str, provider_name: Optional[str] = None):
|
||||
super().__init__(
|
||||
f"Agent strategy '{strategy_name}' not found"
|
||||
+ (f" for provider '{provider_name}'" if provider_name else ""),
|
||||
strategy_name,
|
||||
provider_name,
|
||||
)
|
||||
|
||||
|
||||
class AgentInvocationError(AgentNodeError):
|
||||
"""Exception raised when there's an error invoking the agent."""
|
||||
|
||||
def __init__(self, message: str, original_error: Optional[Exception] = None):
|
||||
self.original_error = original_error
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentParameterError(AgentNodeError):
|
||||
"""Exception raised when there's an error with agent parameters."""
|
||||
|
||||
def __init__(self, message: str, parameter_name: Optional[str] = None):
|
||||
self.parameter_name = parameter_name
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentVariableError(AgentNodeError):
|
||||
"""Exception raised when there's an error with variables in the agent node."""
|
||||
|
||||
def __init__(self, message: str, variable_name: Optional[str] = None):
|
||||
self.variable_name = variable_name
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentVariableNotFoundError(AgentVariableError):
|
||||
"""Exception raised when a variable is not found in the variable pool."""
|
||||
|
||||
def __init__(self, variable_name: str):
|
||||
super().__init__(f"Variable '{variable_name}' does not exist", variable_name)
|
||||
|
||||
|
||||
class AgentInputTypeError(AgentNodeError):
|
||||
"""Exception raised when an unknown agent input type is encountered."""
|
||||
|
||||
def __init__(self, input_type: str):
|
||||
super().__init__(f"Unknown agent input type '{input_type}'")
|
||||
|
||||
|
||||
class ToolFileError(AgentNodeError):
|
||||
"""Exception raised when there's an error with a tool file."""
|
||||
|
||||
def __init__(self, message: str, file_id: Optional[str] = None):
|
||||
self.file_id = file_id
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ToolFileNotFoundError(ToolFileError):
|
||||
"""Exception raised when a tool file is not found."""
|
||||
|
||||
def __init__(self, file_id: str):
|
||||
super().__init__(f"Tool file '{file_id}' does not exist", file_id)
|
||||
|
||||
|
||||
class AgentMessageTransformError(AgentNodeError):
|
||||
"""Exception raised when there's an error transforming agent messages."""
|
||||
|
||||
def __init__(self, message: str, original_error: Optional[Exception] = None):
|
||||
self.original_error = original_error
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentModelError(AgentNodeError):
|
||||
"""Exception raised when there's an error with the model used by the agent."""
|
||||
|
||||
def __init__(self, message: str, model_name: Optional[str] = None, provider: Optional[str] = None):
|
||||
self.model_name = model_name
|
||||
self.provider = provider
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentMemoryError(AgentNodeError):
|
||||
"""Exception raised when there's an error with the agent's memory."""
|
||||
|
||||
def __init__(self, message: str, conversation_id: Optional[str] = None):
|
||||
self.conversation_id = conversation_id
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentVariableTypeError(AgentNodeError):
|
||||
"""Exception raised when a variable has an unexpected type."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
variable_name: Optional[str] = None,
|
||||
expected_type: Optional[str] = None,
|
||||
actual_type: Optional[str] = None,
|
||||
):
|
||||
self.variable_name = variable_name
|
||||
self.expected_type = expected_type
|
||||
self.actual_type = actual_type
|
||||
super().__init__(message)
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.variables import ArrayFileSegment, FileSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
@@ -12,14 +12,37 @@ from core.workflow.nodes.answer.entities import (
|
||||
VarGenerateRouteChunk,
|
||||
)
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
|
||||
|
||||
class AnswerNode(BaseNode[AnswerNodeData]):
|
||||
_node_data_cls = AnswerNodeData
|
||||
class AnswerNode(BaseNode):
|
||||
_node_type = NodeType.ANSWER
|
||||
|
||||
_node_data: AnswerNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = AnswerNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
@@ -30,7 +53,7 @@ class AnswerNode(BaseNode[AnswerNodeData]):
|
||||
:return:
|
||||
"""
|
||||
# generate routes
|
||||
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self.node_data)
|
||||
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self._node_data)
|
||||
|
||||
answer = ""
|
||||
files = []
|
||||
@@ -60,16 +83,12 @@ class AnswerNode(BaseNode[AnswerNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: AnswerNodeData,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.answer)
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = AnswerNodeData.model_validate(node_data)
|
||||
|
||||
variable_template_parser = VariableTemplateParser(template=typed_node_data.answer)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
variable_mapping = {}
|
||||
|
||||
@@ -122,13 +122,13 @@ class RetryConfig(BaseModel):
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
title: str
|
||||
desc: Optional[str] = None
|
||||
version: str = "1"
|
||||
error_strategy: Optional[ErrorStrategy] = None
|
||||
default_value: Optional[list[DefaultValue]] = None
|
||||
version: str = "1"
|
||||
retry_config: RetryConfig = RetryConfig()
|
||||
|
||||
@property
|
||||
def default_value_dict(self):
|
||||
def default_value_dict(self) -> dict[str, Any]:
|
||||
if self.default_value:
|
||||
return {item.key: item.value for item in self.default_value}
|
||||
return {}
|
||||
|
||||
@@ -1,28 +1,22 @@
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
|
||||
from .entities import BaseNodeData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData)
|
||||
|
||||
|
||||
class BaseNode(Generic[GenericNodeData]):
|
||||
_node_data_cls: type[GenericNodeData]
|
||||
class BaseNode:
|
||||
_node_type: ClassVar[NodeType]
|
||||
|
||||
def __init__(
|
||||
@@ -56,8 +50,8 @@ class BaseNode(Generic[GenericNodeData]):
|
||||
|
||||
self.node_id = node_id
|
||||
|
||||
node_data = self._node_data_cls.model_validate(config.get("data", {}))
|
||||
self.node_data = node_data
|
||||
@abstractmethod
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
|
||||
@@ -130,9 +124,9 @@ class BaseNode(Generic[GenericNodeData]):
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
|
||||
|
||||
node_data = cls._node_data_cls(**config.get("data", {}))
|
||||
# Pass raw dict data instead of creating NodeData instance
|
||||
data = cls._extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data)
|
||||
graph_config=graph_config, node_id=node_id, node_data=config.get("data", {})
|
||||
)
|
||||
return data
|
||||
|
||||
@@ -142,32 +136,16 @@ class BaseNode(Generic[GenericNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: GenericNodeData,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters.
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def node_type(self) -> NodeType:
|
||||
"""
|
||||
Get node type
|
||||
:return:
|
||||
"""
|
||||
def type_(self) -> NodeType:
|
||||
return self._node_type
|
||||
|
||||
@classmethod
|
||||
@@ -181,19 +159,68 @@ class BaseNode(Generic[GenericNodeData]):
|
||||
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
|
||||
|
||||
@property
|
||||
def should_continue_on_error(self) -> bool:
|
||||
"""judge if should continue on error
|
||||
|
||||
Returns:
|
||||
bool: if should continue on error
|
||||
"""
|
||||
return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE
|
||||
def continue_on_error(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def should_retry(self) -> bool:
|
||||
"""judge if should retry
|
||||
def retry(self) -> bool:
|
||||
return False
|
||||
|
||||
Returns:
|
||||
bool: if should retry
|
||||
"""
|
||||
return self.node_data.retry_config.retry_enabled and self.node_type in RETRY_ON_ERROR_NODE_TYPE
|
||||
# Abstract methods that subclasses must implement to provide access
|
||||
# to BaseNodeData properties in a type-safe way
|
||||
|
||||
@abstractmethod
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
"""Get the error strategy for this node."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
"""Get the retry configuration for this node."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_title(self) -> str:
|
||||
"""Get the node title."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_description(self) -> Optional[str]:
|
||||
"""Get the node description."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
"""Get the default values dictionary for this node."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
"""Get the BaseNodeData object for this node."""
|
||||
...
|
||||
|
||||
# Public interface properties that delegate to abstract methods
|
||||
@property
|
||||
def error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
"""Get the error strategy for this node."""
|
||||
return self._get_error_strategy()
|
||||
|
||||
@property
|
||||
def retry_config(self) -> RetryConfig:
|
||||
"""Get the retry configuration for this node."""
|
||||
return self._get_retry_config()
|
||||
|
||||
@property
|
||||
def title(self) -> str:
|
||||
"""Get the node title."""
|
||||
return self._get_title()
|
||||
|
||||
@property
|
||||
def description(self) -> Optional[str]:
|
||||
"""Get the node description."""
|
||||
return self._get_description()
|
||||
|
||||
@property
|
||||
def default_value_dict(self) -> dict[str, Any]:
|
||||
"""Get the default values dictionary for this node."""
|
||||
return self._get_default_value_dict()
|
||||
|
||||
@@ -11,8 +11,9 @@ from core.variables.segments import ArrayFileSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.code.entities import CodeNodeData
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
|
||||
from .exc import (
|
||||
CodeNodeError,
|
||||
@@ -21,10 +22,32 @@ from .exc import (
|
||||
)
|
||||
|
||||
|
||||
class CodeNode(BaseNode[CodeNodeData]):
|
||||
_node_data_cls = CodeNodeData
|
||||
class CodeNode(BaseNode):
|
||||
_node_type = NodeType.CODE
|
||||
|
||||
_node_data: CodeNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = CodeNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
@@ -47,12 +70,12 @@ class CodeNode(BaseNode[CodeNodeData]):
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
# Get code language
|
||||
code_language = self.node_data.code_language
|
||||
code = self.node_data.code
|
||||
code_language = self._node_data.code_language
|
||||
code = self._node_data.code
|
||||
|
||||
# Get variables
|
||||
variables = {}
|
||||
for variable_selector in self.node_data.variables:
|
||||
for variable_selector in self._node_data.variables:
|
||||
variable_name = variable_selector.variable
|
||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
||||
if isinstance(variable, ArrayFileSegment):
|
||||
@@ -68,7 +91,7 @@ class CodeNode(BaseNode[CodeNodeData]):
|
||||
)
|
||||
|
||||
# Transform result
|
||||
result = self._transform_result(result=result, output_schema=self.node_data.outputs)
|
||||
result = self._transform_result(result=result, output_schema=self._node_data.outputs)
|
||||
except (CodeExecutionError, CodeNodeError) as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
|
||||
@@ -334,16 +357,20 @@ class CodeNode(BaseNode[CodeNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: CodeNodeData,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = CodeNodeData.model_validate(node_data)
|
||||
|
||||
return {
|
||||
node_id + "." + variable_selector.variable: variable_selector.value_selector
|
||||
for variable_selector in node_data.variables
|
||||
for variable_selector in typed_node_data.variables
|
||||
}
|
||||
|
||||
@property
|
||||
def continue_on_error(self) -> bool:
|
||||
return self._node_data.error_strategy is not None
|
||||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return self._node_data.retry_config.retry_enabled
|
||||
|
||||
@@ -5,7 +5,7 @@ import logging
|
||||
import os
|
||||
import tempfile
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import chardet
|
||||
import docx
|
||||
@@ -28,7 +28,8 @@ from core.variables.segments import ArrayStringSegment, FileSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
|
||||
from .entities import DocumentExtractorNodeData
|
||||
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
|
||||
@@ -36,21 +37,43 @@ from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError,
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
|
||||
class DocumentExtractorNode(BaseNode):
|
||||
"""
|
||||
Extracts text content from various file types.
|
||||
Supports plain text, PDF, and DOC/DOCX files.
|
||||
"""
|
||||
|
||||
_node_data_cls = DocumentExtractorNodeData
|
||||
_node_type = NodeType.DOCUMENT_EXTRACTOR
|
||||
|
||||
_node_data: DocumentExtractorNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = DocumentExtractorNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self):
|
||||
variable_selector = self.node_data.variable_selector
|
||||
variable_selector = self._node_data.variable_selector
|
||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector)
|
||||
|
||||
if variable is None:
|
||||
@@ -97,16 +120,12 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: DocumentExtractorNodeData,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {node_id + ".files": node_data.variable_selector}
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = DocumentExtractorNodeData.model_validate(node_data)
|
||||
|
||||
return {node_id + ".files": typed_node_data.variable_selector}
|
||||
|
||||
|
||||
def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
|
||||
|
||||
@@ -1,14 +1,40 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
|
||||
|
||||
class EndNode(BaseNode[EndNodeData]):
|
||||
_node_data_cls = EndNodeData
|
||||
class EndNode(BaseNode):
|
||||
_node_type = NodeType.END
|
||||
|
||||
_node_data: EndNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = EndNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
@@ -18,7 +44,7 @@ class EndNode(BaseNode[EndNodeData]):
|
||||
Run node
|
||||
:return:
|
||||
"""
|
||||
output_variables = self.node_data.outputs
|
||||
output_variables = self._node_data.outputs
|
||||
|
||||
outputs = {}
|
||||
for variable_selector in output_variables:
|
||||
|
||||
@@ -35,7 +35,3 @@ class ErrorStrategy(StrEnum):
|
||||
class FailBranchSourceHandle(StrEnum):
|
||||
FAILED = "fail-branch"
|
||||
SUCCESS = "success-branch"
|
||||
|
||||
|
||||
CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST]
|
||||
RETRY_ON_ERROR_NODE_TYPE = CONTINUE_ON_ERROR_NODE_TYPE
|
||||
|
||||
@@ -11,7 +11,8 @@ from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.http_request.executor import Executor
|
||||
from core.workflow.utils import variable_template_parser
|
||||
from factories import file_factory
|
||||
@@ -32,10 +33,32 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
_node_data_cls = HttpRequestNodeData
|
||||
class HttpRequestNode(BaseNode):
|
||||
_node_type = NodeType.HTTP_REQUEST
|
||||
|
||||
_node_data: HttpRequestNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = HttpRequestNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict:
|
||||
return {
|
||||
@@ -69,8 +92,8 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
process_data = {}
|
||||
try:
|
||||
http_executor = Executor(
|
||||
node_data=self.node_data,
|
||||
timeout=self._get_request_timeout(self.node_data),
|
||||
node_data=self._node_data,
|
||||
timeout=self._get_request_timeout(self._node_data),
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
max_retries=0,
|
||||
)
|
||||
@@ -78,7 +101,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
|
||||
response = http_executor.invoke()
|
||||
files = self.extract_files(url=http_executor.url, response=response)
|
||||
if not response.response.is_success and (self.should_continue_on_error or self.should_retry):
|
||||
if not response.response.is_success and (self.continue_on_error or self.retry):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
outputs={
|
||||
@@ -131,15 +154,18 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: HttpRequestNodeData,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = HttpRequestNodeData.model_validate(node_data)
|
||||
|
||||
selectors: list[VariableSelector] = []
|
||||
selectors += variable_template_parser.extract_selectors_from_template(node_data.url)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(node_data.headers)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(node_data.params)
|
||||
if node_data.body:
|
||||
body_type = node_data.body.type
|
||||
data = node_data.body.data
|
||||
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.url)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.headers)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.params)
|
||||
if typed_node_data.body:
|
||||
body_type = typed_node_data.body.type
|
||||
data = typed_node_data.body.data
|
||||
match body_type:
|
||||
case "binary":
|
||||
if len(data) != 1:
|
||||
@@ -217,3 +243,11 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
files.append(file)
|
||||
|
||||
return ArrayFileSegment(value=files)
|
||||
|
||||
@property
|
||||
def continue_on_error(self) -> bool:
|
||||
return self._node_data.error_strategy is not None
|
||||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return self._node_data.retry_config.retry_enabled
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
@@ -7,16 +7,39 @@ from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.if_else.entities import IfElseNodeData
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
|
||||
|
||||
class IfElseNode(BaseNode[IfElseNodeData]):
|
||||
_node_data_cls = IfElseNodeData
|
||||
class IfElseNode(BaseNode):
|
||||
_node_type = NodeType.IF_ELSE
|
||||
|
||||
_node_data: IfElseNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = IfElseNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
@@ -36,8 +59,8 @@ class IfElseNode(BaseNode[IfElseNodeData]):
|
||||
condition_processor = ConditionProcessor()
|
||||
try:
|
||||
# Check if the new cases structure is used
|
||||
if self.node_data.cases:
|
||||
for case in self.node_data.cases:
|
||||
if self._node_data.cases:
|
||||
for case in self._node_data.cases:
|
||||
input_conditions, group_result, final_result = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=case.conditions,
|
||||
@@ -63,8 +86,8 @@ class IfElseNode(BaseNode[IfElseNodeData]):
|
||||
input_conditions, group_result, final_result = _should_not_use_old_function(
|
||||
condition_processor=condition_processor,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=self.node_data.conditions or [],
|
||||
operator=self.node_data.logical_operator or "and",
|
||||
conditions=self._node_data.conditions or [],
|
||||
operator=self._node_data.logical_operator or "and",
|
||||
)
|
||||
|
||||
selected_case_id = "true" if final_result else "false"
|
||||
@@ -98,10 +121,13 @@ class IfElseNode(BaseNode[IfElseNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: IfElseNodeData,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = IfElseNodeData.model_validate(node_data)
|
||||
|
||||
var_mapping: dict[str, list[str]] = {}
|
||||
for case in node_data.cases or []:
|
||||
for case in typed_node_data.cases or []:
|
||||
for condition in case.conditions:
|
||||
key = "{}.#{}#".format(node_id, ".".join(condition.variable_selector))
|
||||
var_mapping[key] = condition.variable_selector
|
||||
|
||||
@@ -36,7 +36,8 @@ from core.workflow.graph_engine.entities.event import (
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
from factories.variable_factory import build_segment
|
||||
@@ -56,14 +57,36 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IterationNode(BaseNode[IterationNodeData]):
|
||||
class IterationNode(BaseNode):
|
||||
"""
|
||||
Iteration Node.
|
||||
"""
|
||||
|
||||
_node_data_cls = IterationNodeData
|
||||
_node_type = NodeType.ITERATION
|
||||
|
||||
_node_data: IterationNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = IterationNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
return {
|
||||
@@ -83,10 +106,10 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
|
||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector)
|
||||
|
||||
if not variable:
|
||||
raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found")
|
||||
raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
|
||||
|
||||
if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable):
|
||||
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
|
||||
@@ -116,10 +139,10 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
|
||||
graph_config = self.graph_config
|
||||
|
||||
if not self.node_data.start_node_id:
|
||||
if not self._node_data.start_node_id:
|
||||
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found")
|
||||
|
||||
root_node_id = self.node_data.start_node_id
|
||||
root_node_id = self._node_data.start_node_id
|
||||
|
||||
# init graph
|
||||
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id)
|
||||
@@ -161,8 +184,8 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
yield IterationRunStartedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
metadata={"iterator_length": len(iterator_list_value)},
|
||||
@@ -172,8 +195,8 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
index=0,
|
||||
pre_iteration_output=None,
|
||||
duration=None,
|
||||
@@ -181,11 +204,11 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
iter_run_map: dict[str, float] = {}
|
||||
outputs: list[Any] = [None] * len(iterator_list_value)
|
||||
try:
|
||||
if self.node_data.is_parallel:
|
||||
if self._node_data.is_parallel:
|
||||
futures: list[Future] = []
|
||||
q: Queue = Queue()
|
||||
thread_pool = GraphEngineThreadPool(
|
||||
max_workers=self.node_data.parallel_nums, max_submit_count=dify_config.MAX_SUBMIT_COUNT
|
||||
max_workers=self._node_data.parallel_nums, max_submit_count=dify_config.MAX_SUBMIT_COUNT
|
||||
)
|
||||
for index, item in enumerate(iterator_list_value):
|
||||
future: Future = thread_pool.submit(
|
||||
@@ -242,7 +265,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
iteration_graph=iteration_graph,
|
||||
iter_run_map=iter_run_map,
|
||||
)
|
||||
if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||
if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||
outputs = [output for output in outputs if output is not None]
|
||||
|
||||
# Flatten the list of lists
|
||||
@@ -253,8 +276,8 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
yield IterationRunSucceededEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": outputs},
|
||||
@@ -278,8 +301,8 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": outputs},
|
||||
@@ -305,21 +328,17 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: IterationNodeData,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = IterationNodeData.model_validate(node_data)
|
||||
|
||||
variable_mapping: dict[str, Sequence[str]] = {
|
||||
f"{node_id}.input_selector": node_data.iterator_selector,
|
||||
f"{node_id}.input_selector": typed_node_data.iterator_selector,
|
||||
}
|
||||
|
||||
# init graph
|
||||
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id)
|
||||
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id)
|
||||
|
||||
if not iteration_graph:
|
||||
raise IterationGraphNotFoundError("iteration graph not found")
|
||||
@@ -375,7 +394,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
"""
|
||||
if not isinstance(event, BaseNodeEvent):
|
||||
return event
|
||||
if self.node_data.is_parallel and isinstance(event, NodeRunStartedEvent):
|
||||
if self._node_data.is_parallel and isinstance(event, NodeRunStartedEvent):
|
||||
event.parallel_mode_run_id = parallel_mode_run_id
|
||||
|
||||
iter_metadata = {
|
||||
@@ -438,12 +457,12 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
elif isinstance(event, BaseGraphEvent):
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
# iteration run failed
|
||||
if self.node_data.is_parallel:
|
||||
if self._node_data.is_parallel:
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
@@ -456,8 +475,8 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": outputs},
|
||||
@@ -478,7 +497,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
|
||||
)
|
||||
if isinstance(event, NodeRunFailedEvent):
|
||||
if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
|
||||
if self._node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
|
||||
yield NodeInIterationFailedEvent(
|
||||
**metadata_event.model_dump(),
|
||||
)
|
||||
@@ -491,15 +510,15 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
index=next_index,
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
pre_iteration_output=None,
|
||||
duration=duration,
|
||||
)
|
||||
return
|
||||
elif self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||
elif self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||
yield NodeInIterationFailedEvent(
|
||||
**metadata_event.model_dump(),
|
||||
)
|
||||
@@ -512,15 +531,15 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
index=next_index,
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
pre_iteration_output=None,
|
||||
duration=duration,
|
||||
)
|
||||
return
|
||||
elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED:
|
||||
elif self._node_data.error_handle_mode == ErrorHandleMode.TERMINATED:
|
||||
yield NodeInIterationFailedEvent(
|
||||
**metadata_event.model_dump(),
|
||||
)
|
||||
@@ -531,12 +550,12 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
variable_pool.remove([node_id])
|
||||
|
||||
# iteration run failed
|
||||
if self.node_data.is_parallel:
|
||||
if self._node_data.is_parallel:
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
@@ -549,8 +568,8 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": outputs},
|
||||
@@ -569,7 +588,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
return
|
||||
yield metadata_event
|
||||
|
||||
current_output_segment = variable_pool.get(self.node_data.output_selector)
|
||||
current_output_segment = variable_pool.get(self._node_data.output_selector)
|
||||
if current_output_segment is None:
|
||||
raise IterationNodeError("iteration output selector not found")
|
||||
current_iteration_output = current_output_segment.value
|
||||
@@ -588,8 +607,8 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
index=next_index,
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
pre_iteration_output=current_iteration_output or None,
|
||||
@@ -601,8 +620,8 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": None},
|
||||
|
||||
@@ -1,18 +1,44 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.iteration.entities import IterationStartNodeData
|
||||
|
||||
|
||||
class IterationStartNode(BaseNode[IterationStartNodeData]):
|
||||
class IterationStartNode(BaseNode):
|
||||
"""
|
||||
Iteration Start Node.
|
||||
"""
|
||||
|
||||
_node_data_cls = IterationStartNodeData
|
||||
_node_type = NodeType.ITERATION_START
|
||||
|
||||
_node_data: IterationStartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = IterationStartNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.llm.entities import VisionConfig
|
||||
from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
|
||||
|
||||
|
||||
class RerankingModelConfig(BaseModel):
|
||||
@@ -56,17 +56,6 @@ class MultipleRetrievalConfig(BaseModel):
|
||||
weights: Optional[WeightedScoreConfig] = None
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
"""
|
||||
Model Config.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
name: str
|
||||
mode: str
|
||||
completion_params: dict[str, Any] = {}
|
||||
|
||||
|
||||
class SingleRetrievalConfig(BaseModel):
|
||||
"""
|
||||
Single Retrieval Config.
|
||||
@@ -129,7 +118,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
|
||||
multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None
|
||||
single_retrieval_config: Optional[SingleRetrievalConfig] = None
|
||||
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
|
||||
metadata_model_config: Optional[ModelConfig] = None
|
||||
metadata_model_config: ModelConfig
|
||||
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import re
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
from sqlalchemy import Float, and_, func, or_, text
|
||||
from sqlalchemy import cast as sqlalchemy_cast
|
||||
@@ -15,20 +15,31 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessageRole,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
ModelFeature,
|
||||
ModelType,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.variables import StringSegment
|
||||
from core.variables import (
|
||||
StringSegment,
|
||||
)
|
||||
from core.variables.segments import ArrayObjectSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event.event import ModelInvokeCompletedEvent
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.event import (
|
||||
ModelInvokeCompletedEvent,
|
||||
)
|
||||
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
|
||||
METADATA_FILTER_ASSISTANT_PROMPT_1,
|
||||
METADATA_FILTER_ASSISTANT_PROMPT_2,
|
||||
@@ -38,7 +49,8 @@ from core.workflow.nodes.knowledge_retrieval.template_prompts import (
|
||||
METADATA_FILTER_USER_PROMPT_2,
|
||||
METADATA_FILTER_USER_PROMPT_3,
|
||||
)
|
||||
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate
|
||||
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, ModelConfig
|
||||
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
@@ -46,7 +58,7 @@ from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from .entities import KnowledgeRetrievalNodeData, ModelConfig
|
||||
from .entities import KnowledgeRetrievalNodeData
|
||||
from .exc import (
|
||||
InvalidModelTypeError,
|
||||
KnowledgeRetrievalNodeError,
|
||||
@@ -56,6 +68,10 @@ from .exc import (
|
||||
ModelQuotaExceededError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_retrieval_model = {
|
||||
@@ -67,18 +83,76 @@ default_retrieval_model = {
|
||||
}
|
||||
|
||||
|
||||
class KnowledgeRetrievalNode(LLMNode):
|
||||
_node_data_cls = KnowledgeRetrievalNodeData # type: ignore
|
||||
class KnowledgeRetrievalNode(BaseNode):
|
||||
_node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
|
||||
_node_data: KnowledgeRetrievalNodeData
|
||||
|
||||
# Instance attributes specific to LLMNode.
|
||||
# Output variable for file
|
||||
_file_outputs: list["File"]
|
||||
|
||||
_llm_file_saver: LLMFileSaver
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph: "Graph",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
previous_node_id: Optional[str] = None,
|
||||
thread_pool_id: Optional[str] = None,
|
||||
*,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
previous_node_id=previous_node_id,
|
||||
thread_pool_id=thread_pool_id,
|
||||
)
|
||||
# LLM file outputs, used for MultiModal outputs.
|
||||
self._file_outputs: list[File] = []
|
||||
|
||||
if llm_file_saver is None:
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=graph_init_params.user_id,
|
||||
tenant_id=graph_init_params.tenant_id,
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = KnowledgeRetrievalNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls):
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult: # type: ignore
|
||||
node_data = cast(KnowledgeRetrievalNodeData, self.node_data)
|
||||
# extract variables
|
||||
variable = self.graph_runtime_state.variable_pool.get(node_data.query_variable_selector)
|
||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
|
||||
if not isinstance(variable, StringSegment):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
@@ -119,7 +193,7 @@ class KnowledgeRetrievalNode(LLMNode):
|
||||
|
||||
# retrieve knowledge
|
||||
try:
|
||||
results = self._fetch_dataset_retriever(node_data=node_data, query=query)
|
||||
results = self._fetch_dataset_retriever(node_data=self._node_data, query=query)
|
||||
outputs = {"result": ArrayObjectSegment(value=results)}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
@@ -435,20 +509,15 @@ class KnowledgeRetrievalNode(LLMNode):
|
||||
# get all metadata field
|
||||
metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
|
||||
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
|
||||
# get metadata model config
|
||||
metadata_model_config = node_data.metadata_model_config
|
||||
if metadata_model_config is None:
|
||||
raise ValueError("metadata_model_config is required")
|
||||
# get metadata model instance
|
||||
# fetch model config
|
||||
model_instance, model_config = self.get_model_config(metadata_model_config)
|
||||
# get metadata model instance and fetch model config
|
||||
model_instance, model_config = self.get_model_config(node_data.metadata_model_config)
|
||||
# fetch prompt messages
|
||||
prompt_template = self._get_prompt_template(
|
||||
node_data=node_data,
|
||||
metadata_fields=all_metadata_fields,
|
||||
query=query or "",
|
||||
)
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
prompt_messages, stop = LLMNode.fetch_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
sys_query=query,
|
||||
memory=None,
|
||||
@@ -458,16 +527,23 @@ class KnowledgeRetrievalNode(LLMNode):
|
||||
vision_detail=node_data.vision.configs.detail,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
jinja2_variables=[],
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
|
||||
result_text = ""
|
||||
try:
|
||||
# handle invoke result
|
||||
generator = self._invoke_llm(
|
||||
node_data_model=node_data.metadata_model_config, # type: ignore
|
||||
generator = LLMNode.invoke_llm(
|
||||
node_data_model=node_data.metadata_model_config,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=self.user_id,
|
||||
structured_output_enabled=self._node_data.structured_output_enabled,
|
||||
structured_output=None,
|
||||
file_saver=self._llm_file_saver,
|
||||
file_outputs=self._file_outputs,
|
||||
node_id=self.node_id,
|
||||
)
|
||||
|
||||
for event in generator:
|
||||
@@ -557,17 +633,13 @@ class KnowledgeRetrievalNode(LLMNode):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: KnowledgeRetrievalNodeData, # type: ignore
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data)
|
||||
|
||||
variable_mapping = {}
|
||||
variable_mapping[node_id + ".query"] = node_data.query_variable_selector
|
||||
variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector
|
||||
return variable_mapping
|
||||
|
||||
def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
@@ -629,7 +701,7 @@ class KnowledgeRetrievalNode(LLMNode):
|
||||
)
|
||||
|
||||
def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str):
|
||||
model_mode = ModelMode.value_of(node_data.metadata_model_config.mode) # type: ignore
|
||||
model_mode = ModelMode(node_data.metadata_model_config.mode)
|
||||
input_text = query
|
||||
|
||||
prompt_messages: list[LLMNodeChatModelMessage] = []
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, Literal, Union
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from core.file import File
|
||||
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
|
||||
@@ -7,16 +7,39 @@ from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
|
||||
from .entities import ListOperatorNodeData
|
||||
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
|
||||
|
||||
|
||||
class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
||||
_node_data_cls = ListOperatorNodeData
|
||||
class ListOperatorNode(BaseNode):
|
||||
_node_type = NodeType.LIST_OPERATOR
|
||||
|
||||
_node_data: ListOperatorNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = ListOperatorNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
@@ -26,9 +49,9 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
||||
process_data: dict[str, list] = {}
|
||||
outputs: dict[str, Any] = {}
|
||||
|
||||
variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable)
|
||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable)
|
||||
if variable is None:
|
||||
error_message = f"Variable not found for selector: {self.node_data.variable}"
|
||||
error_message = f"Variable not found for selector: {self._node_data.variable}"
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
|
||||
)
|
||||
@@ -48,7 +71,7 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
||||
)
|
||||
if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
|
||||
error_message = (
|
||||
f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
|
||||
f"Variable {self._node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
|
||||
"or ArrayStringSegment"
|
||||
)
|
||||
return NodeRunResult(
|
||||
@@ -64,19 +87,19 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
||||
|
||||
try:
|
||||
# Filter
|
||||
if self.node_data.filter_by.enabled:
|
||||
if self._node_data.filter_by.enabled:
|
||||
variable = self._apply_filter(variable)
|
||||
|
||||
# Extract
|
||||
if self.node_data.extract_by.enabled:
|
||||
if self._node_data.extract_by.enabled:
|
||||
variable = self._extract_slice(variable)
|
||||
|
||||
# Order
|
||||
if self.node_data.order_by.enabled:
|
||||
if self._node_data.order_by.enabled:
|
||||
variable = self._apply_order(variable)
|
||||
|
||||
# Slice
|
||||
if self.node_data.limit.enabled:
|
||||
if self._node_data.limit.enabled:
|
||||
variable = self._apply_slice(variable)
|
||||
|
||||
outputs = {
|
||||
@@ -104,7 +127,7 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
filter_func: Callable[[Any], bool]
|
||||
result: list[Any] = []
|
||||
for condition in self.node_data.filter_by.conditions:
|
||||
for condition in self._node_data.filter_by.conditions:
|
||||
if isinstance(variable, ArrayStringSegment):
|
||||
if not isinstance(condition.value, str):
|
||||
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
|
||||
@@ -137,14 +160,14 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
if isinstance(variable, ArrayStringSegment):
|
||||
result = _order_string(order=self.node_data.order_by.value, array=variable.value)
|
||||
result = _order_string(order=self._node_data.order_by.value, array=variable.value)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayNumberSegment):
|
||||
result = _order_number(order=self.node_data.order_by.value, array=variable.value)
|
||||
result = _order_number(order=self._node_data.order_by.value, array=variable.value)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayFileSegment):
|
||||
result = _order_file(
|
||||
order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value
|
||||
order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value
|
||||
)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
return variable
|
||||
@@ -152,13 +175,13 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
||||
def _apply_slice(
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
result = variable.value[: self.node_data.limit.size]
|
||||
result = variable.value[: self._node_data.limit.size]
|
||||
return variable.model_copy(update={"value": result})
|
||||
|
||||
def _extract_slice(
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
value = int(self.graph_runtime_state.variable_pool.convert_template(self.node_data.extract_by.serial).text)
|
||||
value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text)
|
||||
if value < 1:
|
||||
raise ValueError(f"Invalid serial index: must be >= 1, got {value}")
|
||||
value -= 1
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
@@ -65,7 +65,7 @@ class LLMNodeData(BaseNodeData):
|
||||
memory: Optional[MemoryConfig] = None
|
||||
context: ContextConfig
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
structured_output: dict | None = None
|
||||
structured_output: Mapping[str, Any] | None = None
|
||||
# We used 'structured_output_enabled' in the past, but it's not a good name.
|
||||
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
|
||||
|
||||
|
||||
@@ -59,7 +59,8 @@ from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.event import (
|
||||
ModelInvokeCompletedEvent,
|
||||
NodeEvent,
|
||||
@@ -90,17 +91,16 @@ from .file_saver import FileSaverImpl, LLMFileSaver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMNode(BaseNode[LLMNodeData]):
|
||||
_node_data_cls = LLMNodeData
|
||||
class LLMNode(BaseNode):
|
||||
_node_type = NodeType.LLM
|
||||
|
||||
_node_data: LLMNodeData
|
||||
|
||||
# Instance attributes specific to LLMNode.
|
||||
# Output variable for file
|
||||
_file_outputs: list["File"]
|
||||
@@ -138,6 +138,27 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = LLMNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
@@ -152,13 +173,13 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
|
||||
try:
|
||||
# init messages template
|
||||
self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template)
|
||||
self._node_data.prompt_template = self._transform_chat_messages(self._node_data.prompt_template)
|
||||
|
||||
# fetch variables and fetch values from variable pool
|
||||
inputs = self._fetch_inputs(node_data=self.node_data)
|
||||
inputs = self._fetch_inputs(node_data=self._node_data)
|
||||
|
||||
# fetch jinja2 inputs
|
||||
jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data)
|
||||
jinja_inputs = self._fetch_jinja_inputs(node_data=self._node_data)
|
||||
|
||||
# merge inputs
|
||||
inputs.update(jinja_inputs)
|
||||
@@ -169,9 +190,9 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
files = (
|
||||
llm_utils.fetch_files(
|
||||
variable_pool=variable_pool,
|
||||
selector=self.node_data.vision.configs.variable_selector,
|
||||
selector=self._node_data.vision.configs.variable_selector,
|
||||
)
|
||||
if self.node_data.vision.enabled
|
||||
if self._node_data.vision.enabled
|
||||
else []
|
||||
)
|
||||
|
||||
@@ -179,7 +200,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
node_inputs["#files#"] = [file.to_dict() for file in files]
|
||||
|
||||
# fetch context value
|
||||
generator = self._fetch_context(node_data=self.node_data)
|
||||
generator = self._fetch_context(node_data=self._node_data)
|
||||
context = None
|
||||
for event in generator:
|
||||
if isinstance(event, RunRetrieverResourceEvent):
|
||||
@@ -189,44 +210,54 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
node_inputs["#context#"] = context
|
||||
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(self.node_data.model)
|
||||
model_instance, model_config = LLMNode._fetch_model_config(
|
||||
node_data_model=self._node_data.model,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
|
||||
# fetch memory
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=self.app_id,
|
||||
node_data_memory=self.node_data.memory,
|
||||
node_data_memory=self._node_data.memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
query = None
|
||||
if self.node_data.memory:
|
||||
query = self.node_data.memory.query_prompt_template
|
||||
if self._node_data.memory:
|
||||
query = self._node_data.memory.query_prompt_template
|
||||
if not query and (
|
||||
query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
|
||||
):
|
||||
query = query_variable.text
|
||||
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
prompt_messages, stop = LLMNode.fetch_prompt_messages(
|
||||
sys_query=query,
|
||||
sys_files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
prompt_template=self.node_data.prompt_template,
|
||||
memory_config=self.node_data.memory,
|
||||
vision_enabled=self.node_data.vision.enabled,
|
||||
vision_detail=self.node_data.vision.configs.detail,
|
||||
prompt_template=self._node_data.prompt_template,
|
||||
memory_config=self._node_data.memory,
|
||||
vision_enabled=self._node_data.vision.enabled,
|
||||
vision_detail=self._node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
jinja2_variables=self._node_data.prompt_config.jinja2_variables,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
generator = self._invoke_llm(
|
||||
node_data_model=self.node_data.model,
|
||||
generator = LLMNode.invoke_llm(
|
||||
node_data_model=self._node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=self.user_id,
|
||||
structured_output_enabled=self._node_data.structured_output_enabled,
|
||||
structured_output=self._node_data.structured_output,
|
||||
file_saver=self._llm_file_saver,
|
||||
file_outputs=self._file_outputs,
|
||||
node_id=self.node_id,
|
||||
)
|
||||
|
||||
structured_output: LLMStructuredOutput | None = None
|
||||
@@ -296,12 +327,19 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
)
|
||||
)
|
||||
|
||||
def _invoke_llm(
|
||||
self,
|
||||
@staticmethod
|
||||
def invoke_llm(
|
||||
*,
|
||||
node_data_model: ModelConfig,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
user_id: str,
|
||||
structured_output_enabled: bool,
|
||||
structured_output: Optional[Mapping[str, Any]] = None,
|
||||
file_saver: LLMFileSaver,
|
||||
file_outputs: list["File"],
|
||||
node_id: str,
|
||||
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||
node_data_model.name, model_instance.credentials
|
||||
@@ -309,8 +347,10 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model schema not found for {node_data_model.name}")
|
||||
|
||||
if self.node_data.structured_output_enabled:
|
||||
output_schema = self._fetch_structured_output_schema()
|
||||
if structured_output_enabled:
|
||||
output_schema = LLMNode.fetch_structured_output_schema(
|
||||
structured_output=structured_output or {},
|
||||
)
|
||||
invoke_result = invoke_llm_with_structured_output(
|
||||
provider=model_instance.provider,
|
||||
model_schema=model_schema,
|
||||
@@ -320,7 +360,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
model_parameters=node_data_model.completion_params,
|
||||
stop=list(stop or []),
|
||||
stream=True,
|
||||
user=self.user_id,
|
||||
user=user_id,
|
||||
)
|
||||
else:
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
@@ -328,17 +368,31 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
model_parameters=node_data_model.completion_params,
|
||||
stop=list(stop or []),
|
||||
stream=True,
|
||||
user=self.user_id,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
return self._handle_invoke_result(invoke_result=invoke_result)
|
||||
return LLMNode.handle_invoke_result(
|
||||
invoke_result=invoke_result,
|
||||
file_saver=file_saver,
|
||||
file_outputs=file_outputs,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def _handle_invoke_result(
|
||||
self, invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None]
|
||||
@staticmethod
|
||||
def handle_invoke_result(
|
||||
*,
|
||||
invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None],
|
||||
file_saver: LLMFileSaver,
|
||||
file_outputs: list["File"],
|
||||
node_id: str,
|
||||
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
|
||||
# For blocking mode
|
||||
if isinstance(invoke_result, LLMResult):
|
||||
event = self._handle_blocking_result(invoke_result=invoke_result)
|
||||
event = LLMNode.handle_blocking_result(
|
||||
invoke_result=invoke_result,
|
||||
saver=file_saver,
|
||||
file_outputs=file_outputs,
|
||||
)
|
||||
yield event
|
||||
return
|
||||
|
||||
@@ -356,11 +410,13 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
yield result
|
||||
if isinstance(result, LLMResultChunk):
|
||||
contents = result.delta.message.content
|
||||
for text_part in self._save_multimodal_output_and_convert_result_to_markdown(contents):
|
||||
for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
|
||||
contents=contents,
|
||||
file_saver=file_saver,
|
||||
file_outputs=file_outputs,
|
||||
):
|
||||
full_text_buffer.write(text_part)
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=text_part, from_variable_selector=[self.node_id, "text"]
|
||||
)
|
||||
yield RunStreamChunkEvent(chunk_content=text_part, from_variable_selector=[node_id, "text"])
|
||||
|
||||
# Update the whole metadata
|
||||
if not model and result.model:
|
||||
@@ -378,7 +434,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
|
||||
yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason)
|
||||
|
||||
def _image_file_to_markdown(self, file: "File", /):
|
||||
@staticmethod
|
||||
def _image_file_to_markdown(file: "File", /):
|
||||
text_chunk = f"})"
|
||||
return text_chunk
|
||||
|
||||
@@ -539,11 +596,14 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _fetch_model_config(
|
||||
self, node_data_model: ModelConfig
|
||||
*,
|
||||
node_data_model: ModelConfig,
|
||||
tenant_id: str,
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
model, model_config_with_cred = llm_utils.fetch_model_config(
|
||||
tenant_id=self.tenant_id, node_data_model=node_data_model
|
||||
tenant_id=tenant_id, node_data_model=node_data_model
|
||||
)
|
||||
completion_params = model_config_with_cred.parameters
|
||||
|
||||
@@ -556,8 +616,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
node_data_model.completion_params = completion_params
|
||||
return model, model_config_with_cred
|
||||
|
||||
def _fetch_prompt_messages(
|
||||
self,
|
||||
@staticmethod
|
||||
def fetch_prompt_messages(
|
||||
*,
|
||||
sys_query: str | None = None,
|
||||
sys_files: Sequence["File"],
|
||||
@@ -570,13 +630,14 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
vision_detail: ImagePromptMessageContent.DETAIL,
|
||||
variable_pool: VariablePool,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
tenant_id: str,
|
||||
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
|
||||
if isinstance(prompt_template, list):
|
||||
# For chat model
|
||||
prompt_messages.extend(
|
||||
self._handle_list_messages(
|
||||
LLMNode.handle_list_messages(
|
||||
messages=prompt_template,
|
||||
context=context,
|
||||
jinja2_variables=jinja2_variables,
|
||||
@@ -602,7 +663,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
edition_type="basic",
|
||||
)
|
||||
prompt_messages.extend(
|
||||
self._handle_list_messages(
|
||||
LLMNode.handle_list_messages(
|
||||
messages=[message],
|
||||
context="",
|
||||
jinja2_variables=[],
|
||||
@@ -731,7 +792,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
)
|
||||
|
||||
model = ModelManager().get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.provider,
|
||||
model=model_config.model,
|
||||
@@ -750,10 +811,12 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: LLMNodeData,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
prompt_template = node_data.prompt_template
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = LLMNodeData.model_validate(node_data)
|
||||
|
||||
prompt_template = typed_node_data.prompt_template
|
||||
variable_selectors = []
|
||||
if isinstance(prompt_template, list) and all(
|
||||
isinstance(prompt, LLMNodeChatModelMessage) for prompt in prompt_template
|
||||
@@ -773,7 +836,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
memory = node_data.memory
|
||||
memory = typed_node_data.memory
|
||||
if memory and memory.query_prompt_template:
|
||||
query_variable_selectors = VariableTemplateParser(
|
||||
template=memory.query_prompt_template
|
||||
@@ -781,16 +844,16 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
for variable_selector in query_variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
if node_data.context.enabled:
|
||||
variable_mapping["#context#"] = node_data.context.variable_selector
|
||||
if typed_node_data.context.enabled:
|
||||
variable_mapping["#context#"] = typed_node_data.context.variable_selector
|
||||
|
||||
if node_data.vision.enabled:
|
||||
variable_mapping["#files#"] = node_data.vision.configs.variable_selector
|
||||
if typed_node_data.vision.enabled:
|
||||
variable_mapping["#files#"] = typed_node_data.vision.configs.variable_selector
|
||||
|
||||
if node_data.memory:
|
||||
if typed_node_data.memory:
|
||||
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value]
|
||||
|
||||
if node_data.prompt_config:
|
||||
if typed_node_data.prompt_config:
|
||||
enable_jinja = False
|
||||
|
||||
if isinstance(prompt_template, list):
|
||||
@@ -803,7 +866,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
enable_jinja = True
|
||||
|
||||
if enable_jinja:
|
||||
for variable_selector in node_data.prompt_config.jinja2_variables or []:
|
||||
for variable_selector in typed_node_data.prompt_config.jinja2_variables or []:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
|
||||
@@ -835,8 +898,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
},
|
||||
}
|
||||
|
||||
def _handle_list_messages(
|
||||
self,
|
||||
@staticmethod
|
||||
def handle_list_messages(
|
||||
*,
|
||||
messages: Sequence[LLMNodeChatModelMessage],
|
||||
context: Optional[str],
|
||||
@@ -897,9 +960,19 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _handle_blocking_result(self, *, invoke_result: LLMResult) -> ModelInvokeCompletedEvent:
|
||||
@staticmethod
|
||||
def handle_blocking_result(
|
||||
*,
|
||||
invoke_result: LLMResult,
|
||||
saver: LLMFileSaver,
|
||||
file_outputs: list["File"],
|
||||
) -> ModelInvokeCompletedEvent:
|
||||
buffer = io.StringIO()
|
||||
for text_part in self._save_multimodal_output_and_convert_result_to_markdown(invoke_result.message.content):
|
||||
for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
|
||||
contents=invoke_result.message.content,
|
||||
file_saver=saver,
|
||||
file_outputs=file_outputs,
|
||||
):
|
||||
buffer.write(text_part)
|
||||
|
||||
return ModelInvokeCompletedEvent(
|
||||
@@ -908,7 +981,12 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
def _save_multimodal_image_output(self, content: ImagePromptMessageContent) -> "File":
|
||||
@staticmethod
|
||||
def save_multimodal_image_output(
|
||||
*,
|
||||
content: ImagePromptMessageContent,
|
||||
file_saver: LLMFileSaver,
|
||||
) -> "File":
|
||||
"""_save_multimodal_output saves multi-modal contents generated by LLM plugins.
|
||||
|
||||
There are two kinds of multimodal outputs:
|
||||
@@ -918,26 +996,21 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
|
||||
Currently, only image files are supported.
|
||||
"""
|
||||
# Inject the saver somehow...
|
||||
_saver = self._llm_file_saver
|
||||
|
||||
# If this
|
||||
if content.url != "":
|
||||
saved_file = _saver.save_remote_url(content.url, FileType.IMAGE)
|
||||
saved_file = file_saver.save_remote_url(content.url, FileType.IMAGE)
|
||||
else:
|
||||
saved_file = _saver.save_binary_string(
|
||||
saved_file = file_saver.save_binary_string(
|
||||
data=base64.b64decode(content.base64_data),
|
||||
mime_type=content.mime_type,
|
||||
file_type=FileType.IMAGE,
|
||||
)
|
||||
self._file_outputs.append(saved_file)
|
||||
return saved_file
|
||||
|
||||
def _fetch_model_schema(self, provider: str) -> AIModelEntity | None:
|
||||
"""
|
||||
Fetch model schema
|
||||
"""
|
||||
model_name = self.node_data.model.name
|
||||
model_name = self._node_data.model.name
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name
|
||||
@@ -948,16 +1021,20 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
return model_schema
|
||||
|
||||
def _fetch_structured_output_schema(self) -> dict[str, Any]:
|
||||
@staticmethod
|
||||
def fetch_structured_output_schema(
|
||||
*,
|
||||
structured_output: Mapping[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Fetch the structured output schema from the node data.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: The structured output schema
|
||||
"""
|
||||
if not self.node_data.structured_output:
|
||||
if not structured_output:
|
||||
raise LLMNodeError("Please provide a valid structured output schema")
|
||||
structured_output_schema = json.dumps(self.node_data.structured_output.get("schema", {}), ensure_ascii=False)
|
||||
structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False)
|
||||
if not structured_output_schema:
|
||||
raise LLMNodeError("Please provide a valid structured output schema")
|
||||
|
||||
@@ -969,9 +1046,12 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
except json.JSONDecodeError:
|
||||
raise LLMNodeError("structured_output_schema is not valid JSON format")
|
||||
|
||||
@staticmethod
|
||||
def _save_multimodal_output_and_convert_result_to_markdown(
|
||||
self,
|
||||
*,
|
||||
contents: str | list[PromptMessageContentUnionTypes] | None,
|
||||
file_saver: LLMFileSaver,
|
||||
file_outputs: list["File"],
|
||||
) -> Generator[str, None, None]:
|
||||
"""Convert intermediate prompt messages into strings and yield them to the caller.
|
||||
|
||||
@@ -994,9 +1074,12 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
yield item.data
|
||||
elif isinstance(item, ImagePromptMessageContent):
|
||||
file = self._save_multimodal_image_output(item)
|
||||
self._file_outputs.append(file)
|
||||
yield self._image_file_to_markdown(file)
|
||||
file = LLMNode.save_multimodal_image_output(
|
||||
content=item,
|
||||
file_saver=file_saver,
|
||||
)
|
||||
file_outputs.append(file)
|
||||
yield LLMNode._image_file_to_markdown(file)
|
||||
else:
|
||||
logger.warning("unknown item type encountered, type=%s", type(item))
|
||||
yield str(item)
|
||||
@@ -1004,6 +1087,14 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
logger.warning("unknown contents type encountered, type=%s", type(contents))
|
||||
yield str(contents)
|
||||
|
||||
@property
|
||||
def continue_on_error(self) -> bool:
|
||||
return self._node_data.error_strategy is not None
|
||||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return self._node_data.retry_config.retry_enabled
|
||||
|
||||
|
||||
def _combine_message_content_with_role(
|
||||
*, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole
|
||||
|
||||
@@ -1,18 +1,44 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.loop.entities import LoopEndNodeData
|
||||
|
||||
|
||||
class LoopEndNode(BaseNode[LoopEndNodeData]):
|
||||
class LoopEndNode(BaseNode):
|
||||
"""
|
||||
Loop End Node.
|
||||
"""
|
||||
|
||||
_node_data_cls = LoopEndNodeData
|
||||
_node_type = NodeType.LOOP_END
|
||||
|
||||
_node_data: LoopEndNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = LoopEndNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.variables import (
|
||||
@@ -30,7 +30,8 @@ from core.workflow.graph_engine.entities.event import (
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from core.workflow.nodes.loop.entities import LoopNodeData
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
@@ -43,14 +44,36 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoopNode(BaseNode[LoopNodeData]):
|
||||
class LoopNode(BaseNode):
|
||||
"""
|
||||
Loop Node.
|
||||
"""
|
||||
|
||||
_node_data_cls = LoopNodeData
|
||||
_node_type = NodeType.LOOP
|
||||
|
||||
_node_data: LoopNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = LoopNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
@@ -58,17 +81,17 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
"""Run the node."""
|
||||
# Get inputs
|
||||
loop_count = self.node_data.loop_count
|
||||
break_conditions = self.node_data.break_conditions
|
||||
logical_operator = self.node_data.logical_operator
|
||||
loop_count = self._node_data.loop_count
|
||||
break_conditions = self._node_data.break_conditions
|
||||
logical_operator = self._node_data.logical_operator
|
||||
|
||||
inputs = {"loop_count": loop_count}
|
||||
|
||||
if not self.node_data.start_node_id:
|
||||
if not self._node_data.start_node_id:
|
||||
raise ValueError(f"field start_node_id in loop {self.node_id} not found")
|
||||
|
||||
# Initialize graph
|
||||
loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self.node_data.start_node_id)
|
||||
loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self._node_data.start_node_id)
|
||||
if not loop_graph:
|
||||
raise ValueError("loop graph not found")
|
||||
|
||||
@@ -78,8 +101,8 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
|
||||
# Initialize loop variables
|
||||
loop_variable_selectors = {}
|
||||
if self.node_data.loop_variables:
|
||||
for loop_variable in self.node_data.loop_variables:
|
||||
if self._node_data.loop_variables:
|
||||
for loop_variable in self._node_data.loop_variables:
|
||||
value_processor = {
|
||||
"constant": lambda var=loop_variable: self._get_segment_for_constant(var.var_type, var.value),
|
||||
"variable": lambda var=loop_variable: variable_pool.get(var.value),
|
||||
@@ -127,8 +150,8 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
yield LoopRunStartedEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
loop_node_type=self.type_,
|
||||
loop_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
metadata={"loop_length": loop_count},
|
||||
@@ -184,11 +207,11 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
yield LoopRunSucceededEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
loop_node_type=self.type_,
|
||||
loop_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs=self.node_data.outputs,
|
||||
outputs=self._node_data.outputs,
|
||||
steps=loop_count,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
@@ -206,7 +229,7 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
outputs=self.node_data.outputs,
|
||||
outputs=self._node_data.outputs,
|
||||
inputs=inputs,
|
||||
)
|
||||
)
|
||||
@@ -217,8 +240,8 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
yield LoopRunFailedEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
loop_node_type=self.type_,
|
||||
loop_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
steps=loop_count,
|
||||
@@ -320,8 +343,8 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
yield LoopRunFailedEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
loop_node_type=self.type_,
|
||||
loop_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
steps=current_index,
|
||||
@@ -351,8 +374,8 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
yield LoopRunFailedEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
loop_node_type=self.type_,
|
||||
loop_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
steps=current_index,
|
||||
@@ -388,7 +411,7 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
_outputs[loop_variable_key] = None
|
||||
|
||||
_outputs["loop_round"] = current_index + 1
|
||||
self.node_data.outputs = _outputs
|
||||
self._node_data.outputs = _outputs
|
||||
|
||||
if check_break_result:
|
||||
return {"check_break_result": True}
|
||||
@@ -400,10 +423,10 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
yield LoopRunNextEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
loop_node_type=self.type_,
|
||||
loop_node_data=self._node_data,
|
||||
index=next_index,
|
||||
pre_loop_output=self.node_data.outputs,
|
||||
pre_loop_output=self._node_data.outputs,
|
||||
)
|
||||
|
||||
return {"check_break_result": False}
|
||||
@@ -438,19 +461,15 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: LoopNodeData,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = LoopNodeData.model_validate(node_data)
|
||||
|
||||
variable_mapping = {}
|
||||
|
||||
# init graph
|
||||
loop_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id)
|
||||
loop_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id)
|
||||
|
||||
if not loop_graph:
|
||||
raise ValueError("loop graph not found")
|
||||
@@ -486,7 +505,7 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
|
||||
variable_mapping.update(sub_node_variable_mapping)
|
||||
|
||||
for loop_variable in node_data.loop_variables or []:
|
||||
for loop_variable in typed_node_data.loop_variables or []:
|
||||
if loop_variable.value_type == "variable":
|
||||
assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
|
||||
# add loop variable to variable mapping
|
||||
|
||||
@@ -1,18 +1,44 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.loop.entities import LoopStartNodeData
|
||||
|
||||
|
||||
class LoopStartNode(BaseNode[LoopStartNodeData]):
|
||||
class LoopStartNode(BaseNode):
|
||||
"""
|
||||
Loop Start Node.
|
||||
"""
|
||||
|
||||
_node_data_cls = LoopStartNodeData
|
||||
_node_type = NodeType.LOOP_START
|
||||
|
||||
_node_data: LoopStartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = LoopStartNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
@@ -29,8 +29,9 @@ from core.variables.types import SegmentType
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.llm import ModelConfig, llm_utils
|
||||
from core.workflow.utils import variable_template_parser
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
@@ -91,10 +92,31 @@ class ParameterExtractorNode(BaseNode):
|
||||
Parameter Extractor Node.
|
||||
"""
|
||||
|
||||
# FIXME: figure out why here is different from super class
|
||||
_node_data_cls = ParameterExtractorNodeData # type: ignore
|
||||
_node_type = NodeType.PARAMETER_EXTRACTOR
|
||||
|
||||
_node_data: ParameterExtractorNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = ParameterExtractorNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
_model_instance: Optional[ModelInstance] = None
|
||||
_model_config: Optional[ModelConfigWithCredentialsEntity] = None
|
||||
|
||||
@@ -119,7 +141,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
node_data = cast(ParameterExtractorNodeData, self.node_data)
|
||||
node_data = cast(ParameterExtractorNodeData, self._node_data)
|
||||
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
|
||||
query = variable.text if variable else ""
|
||||
|
||||
@@ -398,7 +420,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
"""
|
||||
Generate prompt engineering prompt.
|
||||
"""
|
||||
model_mode = ModelMode.value_of(data.model.mode)
|
||||
model_mode = ModelMode(data.model.mode)
|
||||
|
||||
if model_mode == ModelMode.COMPLETION:
|
||||
return self._generate_prompt_engineering_completion_prompt(
|
||||
@@ -694,7 +716,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
memory: Optional[TokenBufferMemory],
|
||||
max_token_limit: int = 2000,
|
||||
) -> list[ChatModelMessage]:
|
||||
model_mode = ModelMode.value_of(node_data.model.mode)
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
input_text = query
|
||||
memory_str = ""
|
||||
instruction = variable_pool.convert_template(node_data.instruction or "").text
|
||||
@@ -721,7 +743,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
memory: Optional[TokenBufferMemory],
|
||||
max_token_limit: int = 2000,
|
||||
):
|
||||
model_mode = ModelMode.value_of(node_data.model.mode)
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
input_text = query
|
||||
memory_str = ""
|
||||
instruction = variable_pool.convert_template(node_data.instruction or "").text
|
||||
@@ -827,19 +849,15 @@ class ParameterExtractorNode(BaseNode):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: ParameterExtractorNodeData, # type: ignore
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query}
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = ParameterExtractorNodeData.model_validate(node_data)
|
||||
|
||||
if node_data.instruction:
|
||||
selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction)
|
||||
variable_mapping: dict[str, Sequence[str]] = {"query": typed_node_data.query}
|
||||
|
||||
if typed_node_data.instruction:
|
||||
selectors = variable_template_parser.extract_selectors_from_template(typed_node_data.instruction)
|
||||
for selector in selectors:
|
||||
variable_mapping[selector.variable] = selector.value_selector
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
@@ -11,8 +11,11 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import BaseNode
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.event import ModelInvokeCompletedEvent
|
||||
from core.workflow.nodes.llm import (
|
||||
LLMNode,
|
||||
@@ -20,6 +23,7 @@ from core.workflow.nodes.llm import (
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
llm_utils,
|
||||
)
|
||||
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
|
||||
@@ -35,17 +39,77 @@ from .template_prompts import (
|
||||
QUESTION_CLASSIFIER_USER_PROMPT_3,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
|
||||
class QuestionClassifierNode(LLMNode):
|
||||
_node_data_cls = QuestionClassifierNodeData # type: ignore
|
||||
|
||||
class QuestionClassifierNode(BaseNode):
|
||||
_node_type = NodeType.QUESTION_CLASSIFIER
|
||||
|
||||
_node_data: QuestionClassifierNodeData
|
||||
|
||||
_file_outputs: list["File"]
|
||||
_llm_file_saver: LLMFileSaver
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph: "Graph",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
previous_node_id: Optional[str] = None,
|
||||
thread_pool_id: Optional[str] = None,
|
||||
*,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
previous_node_id=previous_node_id,
|
||||
thread_pool_id=thread_pool_id,
|
||||
)
|
||||
# LLM file outputs, used for MultiModal outputs.
|
||||
self._file_outputs: list[File] = []
|
||||
|
||||
if llm_file_saver is None:
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=graph_init_params.user_id,
|
||||
tenant_id=graph_init_params.tenant_id,
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = QuestionClassifierNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls):
|
||||
return "1"
|
||||
|
||||
def _run(self):
|
||||
node_data = cast(QuestionClassifierNodeData, self.node_data)
|
||||
node_data = cast(QuestionClassifierNodeData, self._node_data)
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
# extract variables
|
||||
@@ -53,7 +117,10 @@ class QuestionClassifierNode(LLMNode):
|
||||
query = variable.value if variable else None
|
||||
variables = {"query": query}
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
model_instance, model_config = LLMNode._fetch_model_config(
|
||||
node_data_model=node_data.model,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
# fetch memory
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
@@ -91,7 +158,7 @@ class QuestionClassifierNode(LLMNode):
|
||||
# If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt,
|
||||
# two consecutive user prompts will be generated, causing model's error.
|
||||
# To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end.
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
prompt_messages, stop = LLMNode.fetch_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
sys_query="",
|
||||
memory=memory,
|
||||
@@ -101,6 +168,7 @@ class QuestionClassifierNode(LLMNode):
|
||||
vision_detail=node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=[],
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
|
||||
result_text = ""
|
||||
@@ -109,11 +177,17 @@ class QuestionClassifierNode(LLMNode):
|
||||
|
||||
try:
|
||||
# handle invoke result
|
||||
generator = self._invoke_llm(
|
||||
generator = LLMNode.invoke_llm(
|
||||
node_data_model=node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=self.user_id,
|
||||
structured_output_enabled=False,
|
||||
structured_output=None,
|
||||
file_saver=self._llm_file_saver,
|
||||
file_outputs=self._file_outputs,
|
||||
node_id=self.node_id,
|
||||
)
|
||||
|
||||
for event in generator:
|
||||
@@ -183,23 +257,18 @@ class QuestionClassifierNode(LLMNode):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Any,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
node_data = cast(QuestionClassifierNodeData, node_data)
|
||||
variable_mapping = {"query": node_data.query_variable_selector}
|
||||
variable_selectors = []
|
||||
if node_data.instruction:
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.instruction)
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = QuestionClassifierNodeData.model_validate(node_data)
|
||||
|
||||
variable_mapping = {"query": typed_node_data.query_variable_selector}
|
||||
variable_selectors: list[VariableSelector] = []
|
||||
if typed_node_data.instruction:
|
||||
variable_template_parser = VariableTemplateParser(template=typed_node_data.instruction)
|
||||
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
variable_mapping[variable_selector.variable] = list(variable_selector.value_selector)
|
||||
|
||||
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
|
||||
|
||||
@@ -265,7 +334,7 @@ class QuestionClassifierNode(LLMNode):
|
||||
memory: Optional[TokenBufferMemory],
|
||||
max_token_limit: int = 2000,
|
||||
):
|
||||
model_mode = ModelMode.value_of(node_data.model.mode)
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
classes = node_data.classes
|
||||
categories = []
|
||||
for class_ in classes:
|
||||
|
||||
@@ -1,15 +1,41 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
|
||||
|
||||
class StartNode(BaseNode[StartNodeData]):
|
||||
_node_data_cls = StartNodeData
|
||||
class StartNode(BaseNode):
|
||||
_node_type = NodeType.START
|
||||
|
||||
_node_data: StartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = StartNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
@@ -6,16 +6,39 @@ from core.helper.code_executor.code_executor import CodeExecutionError, CodeExec
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
|
||||
|
||||
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000"))
|
||||
|
||||
|
||||
class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
|
||||
_node_data_cls = TemplateTransformNodeData
|
||||
class TemplateTransformNode(BaseNode):
|
||||
_node_type = NodeType.TEMPLATE_TRANSFORM
|
||||
|
||||
_node_data: TemplateTransformNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = TemplateTransformNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
@@ -35,14 +58,14 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
|
||||
def _run(self) -> NodeRunResult:
|
||||
# Get variables
|
||||
variables = {}
|
||||
for variable_selector in self.node_data.variables:
|
||||
for variable_selector in self._node_data.variables:
|
||||
variable_name = variable_selector.variable
|
||||
value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
||||
variables[variable_name] = value.to_object() if value else None
|
||||
# Run code
|
||||
try:
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables
|
||||
language=CodeLanguage.JINJA2, code=self._node_data.template, inputs=variables
|
||||
)
|
||||
except CodeExecutionError as e:
|
||||
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
|
||||
@@ -60,16 +83,12 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData
|
||||
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: Mapping[str, Any]
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = TemplateTransformNodeData.model_validate(node_data)
|
||||
|
||||
return {
|
||||
node_id + "." + variable_selector.variable: variable_selector.value_selector
|
||||
for variable_selector in node_data.variables
|
||||
for variable_selector in typed_node_data.variables
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
@@ -19,10 +18,10 @@ from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import AgentLogEvent
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
@@ -37,14 +36,18 @@ from .exc import (
|
||||
)
|
||||
|
||||
|
||||
class ToolNode(BaseNode[ToolNodeData]):
|
||||
class ToolNode(BaseNode):
|
||||
"""
|
||||
Tool Node
|
||||
"""
|
||||
|
||||
_node_data_cls = ToolNodeData
|
||||
_node_type = NodeType.TOOL
|
||||
|
||||
_node_data: ToolNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = ToolNodeData.model_validate(data)
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
@@ -54,7 +57,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
Run the tool node
|
||||
"""
|
||||
|
||||
node_data = cast(ToolNodeData, self.node_data)
|
||||
node_data = cast(ToolNodeData, self._node_data)
|
||||
|
||||
# fetch tool icon
|
||||
tool_info = {
|
||||
@@ -67,9 +70,9 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
try:
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
variable_pool = self.graph_runtime_state.variable_pool if self.node_data.version != "1" else None
|
||||
variable_pool = self.graph_runtime_state.variable_pool if self._node_data.version != "1" else None
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(
|
||||
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from, variable_pool
|
||||
self.tenant_id, self.app_id, self.node_id, self._node_data, self.invoke_from, variable_pool
|
||||
)
|
||||
except ToolNodeError as e:
|
||||
yield RunCompletedEvent(
|
||||
@@ -88,12 +91,12 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
parameters = self._generate_parameters(
|
||||
tool_parameters=tool_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self.node_data,
|
||||
node_data=self._node_data,
|
||||
)
|
||||
parameters_for_log = self._generate_parameters(
|
||||
tool_parameters=tool_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self.node_data,
|
||||
node_data=self._node_data,
|
||||
for_log=True,
|
||||
)
|
||||
# get conversation id
|
||||
@@ -124,7 +127,14 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
|
||||
try:
|
||||
# convert tool messages
|
||||
yield from self._transform_message(message_stream, tool_info, parameters_for_log)
|
||||
yield from self._transform_message(
|
||||
messages=message_stream,
|
||||
tool_info=tool_info,
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
node_id=self.node_id,
|
||||
)
|
||||
except (PluginDaemonClientSideError, ToolInvokeError) as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
@@ -191,7 +201,9 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
tool_info: Mapping[str, Any],
|
||||
parameters_for_log: dict[str, Any],
|
||||
agent_thoughts: Optional[list] = None,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
node_id: str,
|
||||
) -> Generator:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
@@ -199,8 +211,8 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
# transform message and handle file storage
|
||||
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
@@ -208,9 +220,6 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
files: list[File] = []
|
||||
json: list[dict] = []
|
||||
|
||||
agent_logs: list[AgentLogEvent] = []
|
||||
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
|
||||
llm_usage: LLMUsage | None = None
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
@@ -243,7 +252,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
@@ -266,45 +275,36 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
|
||||
)
|
||||
yield RunStreamChunkEvent(chunk_content=message.message.text, from_variable_selector=[node_id, "text"])
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||
if self.node_type == NodeType.AGENT:
|
||||
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
|
||||
llm_usage = LLMUsage.from_metadata(msg_metadata)
|
||||
agent_execution_metadata = {
|
||||
WorkflowNodeExecutionMetadataKey(key): value
|
||||
for key, value in msg_metadata.items()
|
||||
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||
}
|
||||
# JSON message handling for tool node
|
||||
if message.message.json_object is not None:
|
||||
json.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"])
|
||||
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])
|
||||
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
raise ToolNodeError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name]
|
||||
chunk_content=variable_value, from_variable_selector=[node_id, variable_name]
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
@@ -319,7 +319,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
dict_metadata = dict(message.message.metadata)
|
||||
if dict_metadata.get("provider"):
|
||||
manager = PluginInstaller()
|
||||
plugins = manager.list_plugins(self.tenant_id)
|
||||
plugins = manager.list_plugins(tenant_id)
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin
|
||||
@@ -334,8 +334,8 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
builtin_tool = next(
|
||||
provider
|
||||
for provider in BuiltinToolManageService.list_builtin_tools(
|
||||
self.user_id,
|
||||
self.tenant_id,
|
||||
user_id,
|
||||
tenant_id,
|
||||
)
|
||||
if provider.name == dict_metadata["provider"]
|
||||
)
|
||||
@@ -347,57 +347,10 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
dict_metadata["icon"] = icon
|
||||
dict_metadata["icon_dark"] = icon_dark
|
||||
message.message.metadata = dict_metadata
|
||||
agent_log = AgentLogEvent(
|
||||
id=message.message.id,
|
||||
node_execution_id=self.id,
|
||||
parent_id=message.message.parent_id,
|
||||
error=message.message.error,
|
||||
status=message.message.status.value,
|
||||
data=message.message.data,
|
||||
label=message.message.label,
|
||||
metadata=message.message.metadata,
|
||||
node_id=self.node_id,
|
||||
)
|
||||
|
||||
# check if the agent log is already in the list
|
||||
for log in agent_logs:
|
||||
if log.id == agent_log.id:
|
||||
# update the log
|
||||
log.data = agent_log.data
|
||||
log.status = agent_log.status
|
||||
log.error = agent_log.error
|
||||
log.label = agent_log.label
|
||||
log.metadata = agent_log.metadata
|
||||
break
|
||||
else:
|
||||
agent_logs.append(agent_log)
|
||||
|
||||
yield agent_log
|
||||
elif message.type == ToolInvokeMessage.MessageType.RETRIEVER_RESOURCES:
|
||||
assert isinstance(message.message, ToolInvokeMessage.RetrieverResourceMessage)
|
||||
yield RunRetrieverResourceEvent(
|
||||
retriever_resources=message.message.retriever_resources,
|
||||
context=message.message.context,
|
||||
)
|
||||
|
||||
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
|
||||
json_output: list[dict[str, Any]] = []
|
||||
|
||||
# Step 1: append each agent log as its own dict.
|
||||
if agent_logs:
|
||||
for log in agent_logs:
|
||||
json_output.append(
|
||||
{
|
||||
"id": log.id,
|
||||
"parent_id": log.parent_id,
|
||||
"error": log.error,
|
||||
"status": log.status,
|
||||
"data": log.data,
|
||||
"label": log.label,
|
||||
"metadata": log.metadata,
|
||||
"node_id": log.node_id,
|
||||
}
|
||||
)
|
||||
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
|
||||
if json:
|
||||
json_output.extend(json)
|
||||
@@ -409,12 +362,9 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
|
||||
metadata={
|
||||
**agent_execution_metadata,
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
|
||||
},
|
||||
inputs=parameters_for_log,
|
||||
llm_usage=llm_usage,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -424,7 +374,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: ToolNodeData,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@@ -433,9 +383,12 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = ToolNodeData.model_validate(node_data)
|
||||
|
||||
result = {}
|
||||
for parameter_name in node_data.tool_parameters:
|
||||
input = node_data.tool_parameters[parameter_name]
|
||||
for parameter_name in typed_node_data.tool_parameters:
|
||||
input = typed_node_data.tool_parameters[parameter_name]
|
||||
if input.type == "mixed":
|
||||
assert isinstance(input.value, str)
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
@@ -449,3 +402,29 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
return result
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@property
|
||||
def continue_on_error(self) -> bool:
|
||||
return self._node_data.error_strategy is not None
|
||||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return self._node_data.retry_config.retry_enabled
|
||||
|
||||
@@ -1,17 +1,41 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
|
||||
|
||||
|
||||
class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
|
||||
_node_data_cls = VariableAssignerNodeData
|
||||
class VariableAggregatorNode(BaseNode):
|
||||
_node_type = NodeType.VARIABLE_AGGREGATOR
|
||||
|
||||
_node_data: VariableAssignerNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = VariableAssignerNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
@@ -21,8 +45,8 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
|
||||
outputs: dict[str, Segment | Mapping[str, Segment]] = {}
|
||||
inputs = {}
|
||||
|
||||
if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled:
|
||||
for selector in self.node_data.variables:
|
||||
if not self._node_data.advanced_settings or not self._node_data.advanced_settings.group_enabled:
|
||||
for selector in self._node_data.variables:
|
||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||
if variable is not None:
|
||||
outputs = {"output": variable}
|
||||
@@ -30,7 +54,7 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
|
||||
inputs = {".".join(selector[1:]): variable.to_object()}
|
||||
break
|
||||
else:
|
||||
for group in self.node_data.advanced_settings.groups:
|
||||
for group in self._node_data.advanced_settings.groups:
|
||||
for selector in group.variables:
|
||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||
|
||||
|
||||
@@ -7,7 +7,8 @@ from core.workflow.conversation_variable_updater import ConversationVariableUpda
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from factories import variable_factory
|
||||
@@ -22,11 +23,33 @@ if TYPE_CHECKING:
|
||||
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
|
||||
|
||||
|
||||
class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
||||
_node_data_cls = VariableAssignerData
|
||||
class VariableAssignerNode(BaseNode):
|
||||
_node_type = NodeType.VARIABLE_ASSIGNER
|
||||
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
|
||||
|
||||
_node_data: VariableAssignerData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = VariableAssignerData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
@@ -59,36 +82,39 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: VariableAssignerData,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
mapping = {}
|
||||
assigned_variable_node_id = node_data.assigned_variable_selector[0]
|
||||
if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID:
|
||||
selector_key = ".".join(node_data.assigned_variable_selector)
|
||||
key = f"{node_id}.#{selector_key}#"
|
||||
mapping[key] = node_data.assigned_variable_selector
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = VariableAssignerData.model_validate(node_data)
|
||||
|
||||
selector_key = ".".join(node_data.input_variable_selector)
|
||||
mapping = {}
|
||||
assigned_variable_node_id = typed_node_data.assigned_variable_selector[0]
|
||||
if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID:
|
||||
selector_key = ".".join(typed_node_data.assigned_variable_selector)
|
||||
key = f"{node_id}.#{selector_key}#"
|
||||
mapping[key] = typed_node_data.assigned_variable_selector
|
||||
|
||||
selector_key = ".".join(typed_node_data.input_variable_selector)
|
||||
key = f"{node_id}.#{selector_key}#"
|
||||
mapping[key] = node_data.input_variable_selector
|
||||
mapping[key] = typed_node_data.input_variable_selector
|
||||
return mapping
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
assigned_variable_selector = self.node_data.assigned_variable_selector
|
||||
assigned_variable_selector = self._node_data.assigned_variable_selector
|
||||
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
|
||||
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
|
||||
if not isinstance(original_variable, Variable):
|
||||
raise VariableOperatorNodeError("assigned variable not found")
|
||||
|
||||
match self.node_data.write_mode:
|
||||
match self._node_data.write_mode:
|
||||
case WriteMode.OVER_WRITE:
|
||||
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
|
||||
income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector)
|
||||
if not income_value:
|
||||
raise VariableOperatorNodeError("input value not found")
|
||||
updated_variable = original_variable.model_copy(update={"value": income_value.value})
|
||||
|
||||
case WriteMode.APPEND:
|
||||
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
|
||||
income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector)
|
||||
if not income_value:
|
||||
raise VariableOperatorNodeError("input value not found")
|
||||
updated_value = original_variable.value + [income_value.value]
|
||||
@@ -101,7 +127,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
||||
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
|
||||
|
||||
case _:
|
||||
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")
|
||||
raise VariableOperatorNodeError(f"unsupported write mode: {self._node_data.write_mode}")
|
||||
|
||||
# Over write the variable.
|
||||
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
from collections.abc import Callable, Mapping, MutableMapping, Sequence
|
||||
from typing import Any, TypeAlias, cast
|
||||
from collections.abc import Mapping, MutableMapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables import SegmentType, Variable
|
||||
@@ -10,7 +10,8 @@ from core.workflow.conversation_variable_updater import ConversationVariableUpda
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
|
||||
@@ -28,8 +29,6 @@ from .exc import (
|
||||
VariableNotFoundError,
|
||||
)
|
||||
|
||||
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
|
||||
|
||||
|
||||
def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem):
|
||||
selector_node_id = item.variable_selector[0]
|
||||
@@ -54,10 +53,32 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
|
||||
mapping[key] = selector
|
||||
|
||||
|
||||
class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
|
||||
_node_data_cls = VariableAssignerNodeData
|
||||
class VariableAssignerNode(BaseNode):
|
||||
_node_type = NodeType.VARIABLE_ASSIGNER
|
||||
|
||||
_node_data: VariableAssignerNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = VariableAssignerNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
def _conv_var_updater_factory(self) -> ConversationVariableUpdater:
|
||||
return conversation_variable_updater_factory()
|
||||
|
||||
@@ -71,22 +92,25 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: VariableAssignerNodeData,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = VariableAssignerNodeData.model_validate(node_data)
|
||||
|
||||
var_mapping: dict[str, Sequence[str]] = {}
|
||||
for item in node_data.items:
|
||||
for item in typed_node_data.items:
|
||||
_target_mapping_from_item(var_mapping, node_id, item)
|
||||
_source_mapping_from_item(var_mapping, node_id, item)
|
||||
return var_mapping
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
inputs = self.node_data.model_dump()
|
||||
inputs = self._node_data.model_dump()
|
||||
process_data: dict[str, Any] = {}
|
||||
# NOTE: This node has no outputs
|
||||
updated_variable_selectors: list[Sequence[str]] = []
|
||||
|
||||
try:
|
||||
for item in self.node_data.items:
|
||||
for item in self._node_data.items:
|
||||
variable = self.graph_runtime_state.variable_pool.get(item.variable_selector)
|
||||
|
||||
# ==================== Validation Part
|
||||
|
||||
@@ -5,7 +5,7 @@ from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from core.workflow.callbacks import WorkflowCallback
|
||||
@@ -146,7 +146,7 @@ class WorkflowEntry:
|
||||
graph = Graph.init(graph_config=workflow.graph_dict)
|
||||
|
||||
# init workflow run state
|
||||
node_instance = node_cls(
|
||||
node = node_cls(
|
||||
id=str(uuid.uuid4()),
|
||||
config=node_config,
|
||||
graph_init_params=GraphInitParams(
|
||||
@@ -190,17 +190,11 @@ class WorkflowEntry:
|
||||
|
||||
try:
|
||||
# run node
|
||||
generator = node_instance.run()
|
||||
generator = node.run()
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s",
|
||||
workflow.id,
|
||||
node_instance.id,
|
||||
node_instance.node_type,
|
||||
node_instance.version(),
|
||||
)
|
||||
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
|
||||
return node_instance, generator
|
||||
logger.exception(f"error while running node, {workflow.id=}, {node.id=}, {node.type_=}, {node.version()=}")
|
||||
raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
|
||||
return node, generator
|
||||
|
||||
@classmethod
|
||||
def run_free_node(
|
||||
@@ -262,7 +256,7 @@ class WorkflowEntry:
|
||||
|
||||
node_cls = cast(type[BaseNode], node_cls)
|
||||
# init workflow run state
|
||||
node_instance: BaseNode = node_cls(
|
||||
node: BaseNode = node_cls(
|
||||
id=str(uuid.uuid4()),
|
||||
config=node_config,
|
||||
graph_init_params=GraphInitParams(
|
||||
@@ -297,17 +291,12 @@ class WorkflowEntry:
|
||||
)
|
||||
|
||||
# run node
|
||||
generator = node_instance.run()
|
||||
generator = node.run()
|
||||
|
||||
return node_instance, generator
|
||||
return node, generator
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"error while running node_instance, node_id=%s, type=%s, version=%s",
|
||||
node_instance.id,
|
||||
node_instance.node_type,
|
||||
node_instance.version(),
|
||||
)
|
||||
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
|
||||
logger.exception(f"error while running node, {node.id=}, {node.type_=}, {node.version()=}")
|
||||
raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
|
||||
|
||||
@staticmethod
|
||||
def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None:
|
||||
|
||||
@@ -465,10 +465,10 @@ class WorkflowService:
|
||||
node_id: str,
|
||||
) -> WorkflowNodeExecution:
|
||||
try:
|
||||
node_instance, generator = invoke_node_fn()
|
||||
node, node_events = invoke_node_fn()
|
||||
|
||||
node_run_result: NodeRunResult | None = None
|
||||
for event in generator:
|
||||
for event in node_events:
|
||||
if isinstance(event, RunCompletedEvent):
|
||||
node_run_result = event.run_result
|
||||
|
||||
@@ -479,18 +479,18 @@ class WorkflowService:
|
||||
if not node_run_result:
|
||||
raise ValueError("Node run failed with no run result")
|
||||
# single step debug mode error handling return
|
||||
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error:
|
||||
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node.continue_on_error:
|
||||
node_error_args: dict[str, Any] = {
|
||||
"status": WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
"error": node_run_result.error,
|
||||
"inputs": node_run_result.inputs,
|
||||
"metadata": {"error_strategy": node_instance.node_data.error_strategy},
|
||||
"metadata": {"error_strategy": node.error_strategy},
|
||||
}
|
||||
if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
|
||||
if node.error_strategy is ErrorStrategy.DEFAULT_VALUE:
|
||||
node_run_result = NodeRunResult(
|
||||
**node_error_args,
|
||||
outputs={
|
||||
**node_instance.node_data.default_value_dict,
|
||||
**node.default_value_dict,
|
||||
"error_message": node_run_result.error,
|
||||
"error_type": node_run_result.error_type,
|
||||
},
|
||||
@@ -509,10 +509,10 @@ class WorkflowService:
|
||||
)
|
||||
error = node_run_result.error if not run_succeeded else None
|
||||
except WorkflowNodeRunFailedError as e:
|
||||
node_instance = e.node_instance
|
||||
node = e._node
|
||||
run_succeeded = False
|
||||
node_run_result = None
|
||||
error = e.error
|
||||
error = e._error
|
||||
|
||||
# Create a NodeExecution domain model
|
||||
node_execution = WorkflowNodeExecution(
|
||||
@@ -520,8 +520,8 @@ class WorkflowService:
|
||||
workflow_id="", # This is a single-step execution, so no workflow ID
|
||||
index=1,
|
||||
node_id=node_id,
|
||||
node_type=node_instance.node_type,
|
||||
title=node_instance.node_data.title,
|
||||
node_type=node.type_,
|
||||
title=node.title,
|
||||
elapsed_time=time.perf_counter() - start_at,
|
||||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
finished_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
|
||||
@@ -15,7 +15,7 @@ def get_mocked_fetch_model_config(
|
||||
mode: str,
|
||||
credentials: dict,
|
||||
):
|
||||
model_provider_factory = ModelProviderFactory(tenant_id="test_tenant")
|
||||
model_provider_factory = ModelProviderFactory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b")
|
||||
model_type_instance = model_provider_factory.get_model_type_instance(provider, ModelType.LLM)
|
||||
provider_model_bundle = ProviderModelBundle(
|
||||
configuration=ProviderConfiguration(
|
||||
|
||||
@@ -66,6 +66,10 @@ def init_code_node(code_config: dict):
|
||||
config=code_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
if "data" in code_config:
|
||||
node.init_node_data(code_config["data"])
|
||||
|
||||
return node
|
||||
|
||||
|
||||
@@ -234,10 +238,10 @@ def test_execute_code_output_validator_depth():
|
||||
"object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}},
|
||||
}
|
||||
|
||||
node.node_data = cast(CodeNodeData, node.node_data)
|
||||
node._node_data = cast(CodeNodeData, node._node_data)
|
||||
|
||||
# validate
|
||||
node._transform_result(result, node.node_data.outputs)
|
||||
node._transform_result(result, node._node_data.outputs)
|
||||
|
||||
# construct result
|
||||
result = {
|
||||
@@ -250,7 +254,7 @@ def test_execute_code_output_validator_depth():
|
||||
|
||||
# validate
|
||||
with pytest.raises(ValueError):
|
||||
node._transform_result(result, node.node_data.outputs)
|
||||
node._transform_result(result, node._node_data.outputs)
|
||||
|
||||
# construct result
|
||||
result = {
|
||||
@@ -263,7 +267,7 @@ def test_execute_code_output_validator_depth():
|
||||
|
||||
# validate
|
||||
with pytest.raises(ValueError):
|
||||
node._transform_result(result, node.node_data.outputs)
|
||||
node._transform_result(result, node._node_data.outputs)
|
||||
|
||||
# construct result
|
||||
result = {
|
||||
@@ -276,7 +280,7 @@ def test_execute_code_output_validator_depth():
|
||||
|
||||
# validate
|
||||
with pytest.raises(ValueError):
|
||||
node._transform_result(result, node.node_data.outputs)
|
||||
node._transform_result(result, node._node_data.outputs)
|
||||
|
||||
|
||||
def test_execute_code_output_object_list():
|
||||
@@ -330,10 +334,10 @@ def test_execute_code_output_object_list():
|
||||
]
|
||||
}
|
||||
|
||||
node.node_data = cast(CodeNodeData, node.node_data)
|
||||
node._node_data = cast(CodeNodeData, node._node_data)
|
||||
|
||||
# validate
|
||||
node._transform_result(result, node.node_data.outputs)
|
||||
node._transform_result(result, node._node_data.outputs)
|
||||
|
||||
# construct result
|
||||
result = {
|
||||
@@ -353,7 +357,7 @@ def test_execute_code_output_object_list():
|
||||
|
||||
# validate
|
||||
with pytest.raises(ValueError):
|
||||
node._transform_result(result, node.node_data.outputs)
|
||||
node._transform_result(result, node._node_data.outputs)
|
||||
|
||||
|
||||
def test_execute_code_scientific_notation():
|
||||
|
||||
@@ -52,7 +52,7 @@ def init_http_node(config: dict):
|
||||
variable_pool.add(["a", "b123", "args1"], 1)
|
||||
variable_pool.add(["a", "b123", "args2"], 2)
|
||||
|
||||
return HttpRequestNode(
|
||||
node = HttpRequestNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
@@ -60,6 +60,12 @@ def init_http_node(config: dict):
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
if "data" in config:
|
||||
node.init_node_data(config["data"])
|
||||
|
||||
return node
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_get(setup_http_mock):
|
||||
|
||||
@@ -2,15 +2,10 @@ import json
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.llm_generator.output_parser.structured_output import _parse_structured_output
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
@@ -24,8 +19,6 @@ from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
||||
from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
|
||||
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
|
||||
|
||||
|
||||
def init_llm_node(config: dict) -> LLMNode:
|
||||
@@ -84,10 +77,14 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
if "data" in config:
|
||||
node.init_node_data(config["data"])
|
||||
|
||||
return node
|
||||
|
||||
|
||||
def test_execute_llm(flask_req_ctx):
|
||||
def test_execute_llm():
|
||||
node = init_llm_node(
|
||||
config={
|
||||
"id": "llm",
|
||||
@@ -95,7 +92,7 @@ def test_execute_llm(flask_req_ctx):
|
||||
"title": "123",
|
||||
"type": "llm",
|
||||
"model": {
|
||||
"provider": "langgenius/openai/openai",
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {},
|
||||
@@ -114,53 +111,62 @@ def test_execute_llm(flask_req_ctx):
|
||||
},
|
||||
)
|
||||
|
||||
# Create a proper LLM result with real entities
|
||||
mock_usage = LLMUsage(
|
||||
prompt_tokens=30,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal(1000),
|
||||
prompt_price=Decimal("0.00003"),
|
||||
completion_tokens=20,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal(1000),
|
||||
completion_price=Decimal("0.00004"),
|
||||
total_tokens=50,
|
||||
total_price=Decimal("0.00007"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
db.session.close = MagicMock()
|
||||
|
||||
mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.")
|
||||
# Mock the _fetch_model_config to avoid database calls
|
||||
def mock_fetch_model_config(**_kwargs):
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_llm_result = LLMResult(
|
||||
model="gpt-3.5-turbo",
|
||||
prompt_messages=[],
|
||||
message=mock_message,
|
||||
usage=mock_usage,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
|
||||
# Create a simple mock model instance that doesn't call real providers
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_llm.return_value = mock_llm_result
|
||||
# Create mock model instance
|
||||
mock_model_instance = MagicMock()
|
||||
mock_usage = LLMUsage(
|
||||
prompt_tokens=30,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal(1000),
|
||||
prompt_price=Decimal("0.00003"),
|
||||
completion_tokens=20,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal(1000),
|
||||
completion_price=Decimal("0.00004"),
|
||||
total_tokens=50,
|
||||
total_price=Decimal("0.00007"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
mock_message = AssistantPromptMessage(content="Test response from mock")
|
||||
mock_llm_result = LLMResult(
|
||||
model="gpt-3.5-turbo",
|
||||
prompt_messages=[],
|
||||
message=mock_message,
|
||||
usage=mock_usage,
|
||||
)
|
||||
mock_model_instance.invoke_llm.return_value = mock_llm_result
|
||||
|
||||
# Create a simple mock model config with required attributes
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.mode = "chat"
|
||||
mock_model_config.provider = "langgenius/openai/openai"
|
||||
mock_model_config.model = "gpt-3.5-turbo"
|
||||
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
|
||||
# Create mock model config
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.mode = "chat"
|
||||
mock_model_config.provider = "openai"
|
||||
mock_model_config.model = "gpt-3.5-turbo"
|
||||
mock_model_config.parameters = {}
|
||||
|
||||
# Mock the _fetch_model_config method
|
||||
def mock_fetch_model_config_func(_node_data_model):
|
||||
return mock_model_instance, mock_model_config
|
||||
|
||||
# Also mock ModelManager.get_model_instance to avoid database calls
|
||||
def mock_get_model_instance(_self, **kwargs):
|
||||
return mock_model_instance
|
||||
# Mock fetch_prompt_messages to avoid database calls
|
||||
def mock_fetch_prompt_messages_1(**_kwargs):
|
||||
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
|
||||
|
||||
return [
|
||||
SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."),
|
||||
UserPromptMessage(content="what's the weather today?"),
|
||||
], []
|
||||
|
||||
with (
|
||||
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
|
||||
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
|
||||
patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config),
|
||||
patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_1),
|
||||
):
|
||||
# execute node
|
||||
result = node._run()
|
||||
@@ -168,6 +174,9 @@ def test_execute_llm(flask_req_ctx):
|
||||
|
||||
for item in result:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
if item.run_result.status != WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
print(f"Error: {item.run_result.error}")
|
||||
print(f"Error type: {item.run_result.error_type}")
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.process_data is not None
|
||||
assert item.run_result.outputs is not None
|
||||
@@ -175,8 +184,7 @@ def test_execute_llm(flask_req_ctx):
|
||||
assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
|
||||
def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock):
|
||||
def test_execute_llm_with_jinja2():
|
||||
"""
|
||||
Test execute LLM node with jinja2
|
||||
"""
|
||||
@@ -217,53 +225,60 @@ def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock):
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# Create a proper LLM result with real entities
|
||||
mock_usage = LLMUsage(
|
||||
prompt_tokens=30,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal(1000),
|
||||
prompt_price=Decimal("0.00003"),
|
||||
completion_tokens=20,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal(1000),
|
||||
completion_price=Decimal("0.00004"),
|
||||
total_tokens=50,
|
||||
total_price=Decimal("0.00007"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
|
||||
mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
|
||||
|
||||
mock_llm_result = LLMResult(
|
||||
model="gpt-3.5-turbo",
|
||||
prompt_messages=[],
|
||||
message=mock_message,
|
||||
usage=mock_usage,
|
||||
)
|
||||
|
||||
# Create a simple mock model instance that doesn't call real providers
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_llm.return_value = mock_llm_result
|
||||
|
||||
# Create a simple mock model config with required attributes
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.mode = "chat"
|
||||
mock_model_config.provider = "openai"
|
||||
mock_model_config.model = "gpt-3.5-turbo"
|
||||
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
|
||||
|
||||
# Mock the _fetch_model_config method
|
||||
def mock_fetch_model_config_func(_node_data_model):
|
||||
def mock_fetch_model_config(**_kwargs):
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
|
||||
# Create mock model instance
|
||||
mock_model_instance = MagicMock()
|
||||
mock_usage = LLMUsage(
|
||||
prompt_tokens=30,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal(1000),
|
||||
prompt_price=Decimal("0.00003"),
|
||||
completion_tokens=20,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal(1000),
|
||||
completion_price=Decimal("0.00004"),
|
||||
total_tokens=50,
|
||||
total_price=Decimal("0.00007"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
|
||||
mock_llm_result = LLMResult(
|
||||
model="gpt-3.5-turbo",
|
||||
prompt_messages=[],
|
||||
message=mock_message,
|
||||
usage=mock_usage,
|
||||
)
|
||||
mock_model_instance.invoke_llm.return_value = mock_llm_result
|
||||
|
||||
# Create mock model config
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.mode = "chat"
|
||||
mock_model_config.provider = "openai"
|
||||
mock_model_config.model = "gpt-3.5-turbo"
|
||||
mock_model_config.parameters = {}
|
||||
|
||||
return mock_model_instance, mock_model_config
|
||||
|
||||
# Also mock ModelManager.get_model_instance to avoid database calls
|
||||
def mock_get_model_instance(_self, **kwargs):
|
||||
return mock_model_instance
|
||||
# Mock fetch_prompt_messages to avoid database calls
|
||||
def mock_fetch_prompt_messages_2(**_kwargs):
|
||||
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
|
||||
|
||||
return [
|
||||
SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."),
|
||||
UserPromptMessage(content="what's the weather today?"),
|
||||
], []
|
||||
|
||||
with (
|
||||
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
|
||||
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
|
||||
patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config),
|
||||
patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_2),
|
||||
):
|
||||
# execute node
|
||||
result = node._run()
|
||||
|
||||
@@ -74,13 +74,15 @@ def init_parameter_extractor_node(config: dict):
|
||||
variable_pool.add(["a", "b123", "args1"], 1)
|
||||
variable_pool.add(["a", "b123", "args2"], 2)
|
||||
|
||||
return ParameterExtractorNode(
|
||||
node = ParameterExtractorNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config=config,
|
||||
)
|
||||
node.init_node_data(config.get("data", {}))
|
||||
return node
|
||||
|
||||
|
||||
def test_function_calling_parameter_extractor(setup_model_mock):
|
||||
|
||||
@@ -76,6 +76,7 @@ def test_execute_code(setup_code_executor_mock):
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config=config,
|
||||
)
|
||||
node.init_node_data(config.get("data", {}))
|
||||
|
||||
# execute node
|
||||
result = node._run()
|
||||
|
||||
@@ -50,13 +50,15 @@ def init_tool_node(config: dict):
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
return ToolNode(
|
||||
node = ToolNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config=config,
|
||||
)
|
||||
node.init_node_data(config.get("data", {}))
|
||||
return node
|
||||
|
||||
|
||||
def test_tool_variable_invoke():
|
||||
|
||||
@@ -58,21 +58,26 @@ def test_execute_answer():
|
||||
pool.add(["start", "weather"], "sunny")
|
||||
pool.add(["llm", "text"], "You are a helpful AI.")
|
||||
|
||||
node_config = {
|
||||
"id": "answer",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "answer",
|
||||
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
|
||||
},
|
||||
}
|
||||
|
||||
node = AnswerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"id": "answer",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "answer",
|
||||
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
|
||||
},
|
||||
},
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
|
||||
@@ -57,12 +57,15 @@ def test_http_request_node_binary_file(monkeypatch):
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
}
|
||||
|
||||
node = HttpRequestNode(
|
||||
id="1",
|
||||
config={
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
},
|
||||
config=node_config,
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
@@ -90,6 +93,9 @@ def test_http_request_node_binary_file(monkeypatch):
|
||||
start_at=0,
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
monkeypatch.setattr(
|
||||
"core.workflow.nodes.http_request.executor.file_manager.download",
|
||||
lambda *args, **kwargs: b"test",
|
||||
@@ -145,12 +151,15 @@ def test_http_request_node_form_with_file(monkeypatch):
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
}
|
||||
|
||||
node = HttpRequestNode(
|
||||
id="1",
|
||||
config={
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
},
|
||||
config=node_config,
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
@@ -178,6 +187,10 @@ def test_http_request_node_form_with_file(monkeypatch):
|
||||
start_at=0,
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.workflow.nodes.http_request.executor.file_manager.download",
|
||||
lambda *args, **kwargs: b"test",
|
||||
@@ -257,12 +270,14 @@ def test_http_request_node_form_with_multiple_files(monkeypatch):
|
||||
),
|
||||
)
|
||||
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
}
|
||||
|
||||
node = HttpRequestNode(
|
||||
id="1",
|
||||
config={
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
},
|
||||
config=node_config,
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
@@ -291,6 +306,9 @@ def test_http_request_node_form_with_multiple_files(monkeypatch):
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.workflow.nodes.http_request.executor.file_manager.download",
|
||||
lambda file: b"test_image_data" if file.mime_type == "image/jpeg" else b"test_pdf_data",
|
||||
|
||||
@@ -162,25 +162,30 @@ def test_run():
|
||||
)
|
||||
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
|
||||
|
||||
node_config = {
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "tt",
|
||||
"title": "迭代",
|
||||
"type": "iteration",
|
||||
},
|
||||
"id": "iteration-1",
|
||||
}
|
||||
|
||||
iteration_node = IterationNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "tt",
|
||||
"title": "迭代",
|
||||
"type": "iteration",
|
||||
},
|
||||
"id": "iteration-1",
|
||||
},
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
iteration_node.init_node_data(node_config["data"])
|
||||
|
||||
def tt_generator(self):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
@@ -379,25 +384,30 @@ def test_run_parallel():
|
||||
)
|
||||
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
|
||||
|
||||
node_config = {
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "iteration-start",
|
||||
"title": "迭代",
|
||||
"type": "iteration",
|
||||
},
|
||||
"id": "iteration-1",
|
||||
}
|
||||
|
||||
iteration_node = IterationNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "iteration-start",
|
||||
"title": "迭代",
|
||||
"type": "iteration",
|
||||
},
|
||||
"id": "iteration-1",
|
||||
},
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
iteration_node.init_node_data(node_config["data"])
|
||||
|
||||
def tt_generator(self):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
@@ -595,45 +605,55 @@ def test_iteration_run_in_parallel_mode():
|
||||
)
|
||||
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
|
||||
|
||||
parallel_node_config = {
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "iteration-start",
|
||||
"title": "迭代",
|
||||
"type": "iteration",
|
||||
"is_parallel": True,
|
||||
},
|
||||
"id": "iteration-1",
|
||||
}
|
||||
|
||||
parallel_iteration_node = IterationNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "iteration-start",
|
||||
"title": "迭代",
|
||||
"type": "iteration",
|
||||
"is_parallel": True,
|
||||
},
|
||||
"id": "iteration-1",
|
||||
},
|
||||
config=parallel_node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
parallel_iteration_node.init_node_data(parallel_node_config["data"])
|
||||
sequential_node_config = {
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "iteration-start",
|
||||
"title": "迭代",
|
||||
"type": "iteration",
|
||||
"is_parallel": True,
|
||||
},
|
||||
"id": "iteration-1",
|
||||
}
|
||||
|
||||
sequential_iteration_node = IterationNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "iteration-start",
|
||||
"title": "迭代",
|
||||
"type": "iteration",
|
||||
"is_parallel": True,
|
||||
},
|
||||
"id": "iteration-1",
|
||||
},
|
||||
config=sequential_node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
sequential_iteration_node.init_node_data(sequential_node_config["data"])
|
||||
|
||||
def tt_generator(self):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
@@ -645,8 +665,8 @@ def test_iteration_run_in_parallel_mode():
|
||||
# execute node
|
||||
parallel_result = parallel_iteration_node._run()
|
||||
sequential_result = sequential_iteration_node._run()
|
||||
assert parallel_iteration_node.node_data.parallel_nums == 10
|
||||
assert parallel_iteration_node.node_data.error_handle_mode == ErrorHandleMode.TERMINATED
|
||||
assert parallel_iteration_node._node_data.parallel_nums == 10
|
||||
assert parallel_iteration_node._node_data.error_handle_mode == ErrorHandleMode.TERMINATED
|
||||
count = 0
|
||||
parallel_arr = []
|
||||
sequential_arr = []
|
||||
@@ -818,26 +838,31 @@ def test_iteration_run_error_handle():
|
||||
environment_variables=[],
|
||||
)
|
||||
pool.add(["pe", "list_output"], ["1", "1"])
|
||||
error_node_config = {
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "iteration-start",
|
||||
"title": "iteration",
|
||||
"type": "iteration",
|
||||
"is_parallel": True,
|
||||
"error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR,
|
||||
},
|
||||
"id": "iteration-1",
|
||||
}
|
||||
|
||||
iteration_node = IterationNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "iteration-start",
|
||||
"title": "iteration",
|
||||
"type": "iteration",
|
||||
"is_parallel": True,
|
||||
"error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR,
|
||||
},
|
||||
"id": "iteration-1",
|
||||
},
|
||||
config=error_node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
iteration_node.init_node_data(error_node_config["data"])
|
||||
# execute continue on error node
|
||||
result = iteration_node._run()
|
||||
result_arr = []
|
||||
@@ -851,7 +876,7 @@ def test_iteration_run_error_handle():
|
||||
|
||||
assert count == 14
|
||||
# execute remove abnormal output
|
||||
iteration_node.node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT
|
||||
iteration_node._node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT
|
||||
result = iteration_node._run()
|
||||
count = 0
|
||||
for item in result:
|
||||
|
||||
@@ -119,17 +119,20 @@ def llm_node(
|
||||
llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph: Graph, graph_runtime_state: GraphRuntimeState
|
||||
) -> LLMNode:
|
||||
mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": llm_node_data.model_dump(),
|
||||
}
|
||||
node = LLMNode(
|
||||
id="1",
|
||||
config={
|
||||
"id": "1",
|
||||
"data": llm_node_data.model_dump(),
|
||||
},
|
||||
config=node_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
llm_file_saver=mock_file_saver,
|
||||
)
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
return node
|
||||
|
||||
|
||||
@@ -488,7 +491,7 @@ def test_handle_list_messages_basic(llm_node):
|
||||
variable_pool = llm_node.graph_runtime_state.variable_pool
|
||||
vision_detail_config = ImagePromptMessageContent.DETAIL.HIGH
|
||||
|
||||
result = llm_node._handle_list_messages(
|
||||
result = llm_node.handle_list_messages(
|
||||
messages=messages,
|
||||
context=context,
|
||||
jinja2_variables=jinja2_variables,
|
||||
@@ -506,17 +509,20 @@ def llm_node_for_multimodal(
|
||||
llm_node_data, graph_init_params, graph, graph_runtime_state
|
||||
) -> tuple[LLMNode, LLMFileSaver]:
|
||||
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": llm_node_data.model_dump(),
|
||||
}
|
||||
node = LLMNode(
|
||||
id="1",
|
||||
config={
|
||||
"id": "1",
|
||||
"data": llm_node_data.model_dump(),
|
||||
},
|
||||
config=node_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
llm_file_saver=mock_file_saver,
|
||||
)
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
return node, mock_file_saver
|
||||
|
||||
|
||||
@@ -540,7 +546,12 @@ class TestLLMNodeSaveMultiModalImageOutput:
|
||||
size=9,
|
||||
)
|
||||
mock_file_saver.save_binary_string.return_value = mock_file
|
||||
file = llm_node._save_multimodal_image_output(content=content)
|
||||
file = llm_node.save_multimodal_image_output(
|
||||
content=content,
|
||||
file_saver=mock_file_saver,
|
||||
)
|
||||
# Manually append to _file_outputs since the static method doesn't do it
|
||||
llm_node._file_outputs.append(file)
|
||||
assert llm_node._file_outputs == [mock_file]
|
||||
assert file == mock_file
|
||||
mock_file_saver.save_binary_string.assert_called_once_with(
|
||||
@@ -566,7 +577,12 @@ class TestLLMNodeSaveMultiModalImageOutput:
|
||||
size=9,
|
||||
)
|
||||
mock_file_saver.save_remote_url.return_value = mock_file
|
||||
file = llm_node._save_multimodal_image_output(content=content)
|
||||
file = llm_node.save_multimodal_image_output(
|
||||
content=content,
|
||||
file_saver=mock_file_saver,
|
||||
)
|
||||
# Manually append to _file_outputs since the static method doesn't do it
|
||||
llm_node._file_outputs.append(file)
|
||||
assert llm_node._file_outputs == [mock_file]
|
||||
assert file == mock_file
|
||||
mock_file_saver.save_remote_url.assert_called_once_with(content.url, FileType.IMAGE)
|
||||
@@ -582,7 +598,9 @@ def test_llm_node_image_file_to_markdown(llm_node: LLMNode):
|
||||
class TestSaveMultimodalOutputAndConvertResultToMarkdown:
|
||||
def test_str_content(self, llm_node_for_multimodal):
|
||||
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown("hello world")
|
||||
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
|
||||
contents="hello world", file_saver=mock_file_saver, file_outputs=[]
|
||||
)
|
||||
assert list(gen) == ["hello world"]
|
||||
mock_file_saver.save_binary_string.assert_not_called()
|
||||
mock_file_saver.save_remote_url.assert_not_called()
|
||||
@@ -590,7 +608,7 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
|
||||
def test_text_prompt_message_content(self, llm_node_for_multimodal):
|
||||
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
|
||||
[TextPromptMessageContent(data="hello world")]
|
||||
contents=[TextPromptMessageContent(data="hello world")], file_saver=mock_file_saver, file_outputs=[]
|
||||
)
|
||||
assert list(gen) == ["hello world"]
|
||||
mock_file_saver.save_binary_string.assert_not_called()
|
||||
@@ -616,13 +634,15 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
|
||||
)
|
||||
mock_file_saver.save_binary_string.return_value = mock_saved_file
|
||||
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
|
||||
[
|
||||
contents=[
|
||||
ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data=image_b64_data,
|
||||
mime_type="image/png",
|
||||
)
|
||||
]
|
||||
],
|
||||
file_saver=mock_file_saver,
|
||||
file_outputs=llm_node._file_outputs,
|
||||
)
|
||||
yielded_strs = list(gen)
|
||||
assert len(yielded_strs) == 1
|
||||
@@ -645,21 +665,27 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
|
||||
|
||||
def test_unknown_content_type(self, llm_node_for_multimodal):
|
||||
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(frozenset(["hello world"]))
|
||||
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
|
||||
contents=frozenset(["hello world"]), file_saver=mock_file_saver, file_outputs=[]
|
||||
)
|
||||
assert list(gen) == ["frozenset({'hello world'})"]
|
||||
mock_file_saver.save_binary_string.assert_not_called()
|
||||
mock_file_saver.save_remote_url.assert_not_called()
|
||||
|
||||
def test_unknown_item_type(self, llm_node_for_multimodal):
|
||||
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown([frozenset(["hello world"])])
|
||||
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
|
||||
contents=[frozenset(["hello world"])], file_saver=mock_file_saver, file_outputs=[]
|
||||
)
|
||||
assert list(gen) == ["frozenset({'hello world'})"]
|
||||
mock_file_saver.save_binary_string.assert_not_called()
|
||||
mock_file_saver.save_remote_url.assert_not_called()
|
||||
|
||||
def test_none_content(self, llm_node_for_multimodal):
|
||||
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(None)
|
||||
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
|
||||
contents=None, file_saver=mock_file_saver, file_outputs=[]
|
||||
)
|
||||
assert list(gen) == []
|
||||
mock_file_saver.save_binary_string.assert_not_called()
|
||||
mock_file_saver.save_remote_url.assert_not_called()
|
||||
|
||||
@@ -61,21 +61,26 @@ def test_execute_answer():
|
||||
variable_pool.add(["start", "weather"], "sunny")
|
||||
variable_pool.add(["llm", "text"], "You are a helpful AI.")
|
||||
|
||||
node_config = {
|
||||
"id": "answer",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "answer",
|
||||
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
|
||||
},
|
||||
}
|
||||
|
||||
node = AnswerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"id": "answer",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "answer",
|
||||
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
|
||||
},
|
||||
},
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
|
||||
@@ -27,13 +27,17 @@ def document_extractor_node():
|
||||
title="Test Document Extractor",
|
||||
variable_selector=["node_id", "variable_name"],
|
||||
)
|
||||
return DocumentExtractorNode(
|
||||
node_config = {"id": "test_node_id", "data": node_data.model_dump()}
|
||||
node = DocumentExtractorNode(
|
||||
id="test_node_id",
|
||||
config={"id": "test_node_id", "data": node_data.model_dump()},
|
||||
config=node_config,
|
||||
graph_init_params=Mock(),
|
||||
graph=Mock(),
|
||||
graph_runtime_state=Mock(),
|
||||
)
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
return node
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -57,57 +57,62 @@ def test_execute_if_else_result_true():
|
||||
pool.add(["start", "null"], None)
|
||||
pool.add(["start", "not_null"], "1212")
|
||||
|
||||
node_config = {
|
||||
"id": "if-else",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "if-else",
|
||||
"logical_operator": "and",
|
||||
"conditions": [
|
||||
{
|
||||
"comparison_operator": "contains",
|
||||
"variable_selector": ["start", "array_contains"],
|
||||
"value": "ab",
|
||||
},
|
||||
{
|
||||
"comparison_operator": "not contains",
|
||||
"variable_selector": ["start", "array_not_contains"],
|
||||
"value": "ab",
|
||||
},
|
||||
{"comparison_operator": "contains", "variable_selector": ["start", "contains"], "value": "ab"},
|
||||
{
|
||||
"comparison_operator": "not contains",
|
||||
"variable_selector": ["start", "not_contains"],
|
||||
"value": "ab",
|
||||
},
|
||||
{"comparison_operator": "start with", "variable_selector": ["start", "start_with"], "value": "ab"},
|
||||
{"comparison_operator": "end with", "variable_selector": ["start", "end_with"], "value": "ab"},
|
||||
{"comparison_operator": "is", "variable_selector": ["start", "is"], "value": "ab"},
|
||||
{"comparison_operator": "is not", "variable_selector": ["start", "is_not"], "value": "ab"},
|
||||
{"comparison_operator": "empty", "variable_selector": ["start", "empty"], "value": "ab"},
|
||||
{"comparison_operator": "not empty", "variable_selector": ["start", "not_empty"], "value": "ab"},
|
||||
{"comparison_operator": "=", "variable_selector": ["start", "equals"], "value": "22"},
|
||||
{"comparison_operator": "≠", "variable_selector": ["start", "not_equals"], "value": "22"},
|
||||
{"comparison_operator": ">", "variable_selector": ["start", "greater_than"], "value": "22"},
|
||||
{"comparison_operator": "<", "variable_selector": ["start", "less_than"], "value": "22"},
|
||||
{
|
||||
"comparison_operator": "≥",
|
||||
"variable_selector": ["start", "greater_than_or_equal"],
|
||||
"value": "22",
|
||||
},
|
||||
{"comparison_operator": "≤", "variable_selector": ["start", "less_than_or_equal"], "value": "22"},
|
||||
{"comparison_operator": "null", "variable_selector": ["start", "null"]},
|
||||
{"comparison_operator": "not null", "variable_selector": ["start", "not_null"]},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
node = IfElseNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"id": "if-else",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "if-else",
|
||||
"logical_operator": "and",
|
||||
"conditions": [
|
||||
{
|
||||
"comparison_operator": "contains",
|
||||
"variable_selector": ["start", "array_contains"],
|
||||
"value": "ab",
|
||||
},
|
||||
{
|
||||
"comparison_operator": "not contains",
|
||||
"variable_selector": ["start", "array_not_contains"],
|
||||
"value": "ab",
|
||||
},
|
||||
{"comparison_operator": "contains", "variable_selector": ["start", "contains"], "value": "ab"},
|
||||
{
|
||||
"comparison_operator": "not contains",
|
||||
"variable_selector": ["start", "not_contains"],
|
||||
"value": "ab",
|
||||
},
|
||||
{"comparison_operator": "start with", "variable_selector": ["start", "start_with"], "value": "ab"},
|
||||
{"comparison_operator": "end with", "variable_selector": ["start", "end_with"], "value": "ab"},
|
||||
{"comparison_operator": "is", "variable_selector": ["start", "is"], "value": "ab"},
|
||||
{"comparison_operator": "is not", "variable_selector": ["start", "is_not"], "value": "ab"},
|
||||
{"comparison_operator": "empty", "variable_selector": ["start", "empty"], "value": "ab"},
|
||||
{"comparison_operator": "not empty", "variable_selector": ["start", "not_empty"], "value": "ab"},
|
||||
{"comparison_operator": "=", "variable_selector": ["start", "equals"], "value": "22"},
|
||||
{"comparison_operator": "≠", "variable_selector": ["start", "not_equals"], "value": "22"},
|
||||
{"comparison_operator": ">", "variable_selector": ["start", "greater_than"], "value": "22"},
|
||||
{"comparison_operator": "<", "variable_selector": ["start", "less_than"], "value": "22"},
|
||||
{
|
||||
"comparison_operator": "≥",
|
||||
"variable_selector": ["start", "greater_than_or_equal"],
|
||||
"value": "22",
|
||||
},
|
||||
{"comparison_operator": "≤", "variable_selector": ["start", "less_than_or_equal"], "value": "22"},
|
||||
{"comparison_operator": "null", "variable_selector": ["start", "null"]},
|
||||
{"comparison_operator": "not null", "variable_selector": ["start", "not_null"]},
|
||||
],
|
||||
},
|
||||
},
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
@@ -162,33 +167,38 @@ def test_execute_if_else_result_false():
|
||||
pool.add(["start", "array_contains"], ["1ab", "def"])
|
||||
pool.add(["start", "array_not_contains"], ["ab", "def"])
|
||||
|
||||
node_config = {
|
||||
"id": "if-else",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "if-else",
|
||||
"logical_operator": "or",
|
||||
"conditions": [
|
||||
{
|
||||
"comparison_operator": "contains",
|
||||
"variable_selector": ["start", "array_contains"],
|
||||
"value": "ab",
|
||||
},
|
||||
{
|
||||
"comparison_operator": "not contains",
|
||||
"variable_selector": ["start", "array_not_contains"],
|
||||
"value": "ab",
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
node = IfElseNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"id": "if-else",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "if-else",
|
||||
"logical_operator": "or",
|
||||
"conditions": [
|
||||
{
|
||||
"comparison_operator": "contains",
|
||||
"variable_selector": ["start", "array_contains"],
|
||||
"value": "ab",
|
||||
},
|
||||
{
|
||||
"comparison_operator": "not contains",
|
||||
"variable_selector": ["start", "array_not_contains"],
|
||||
"value": "ab",
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
@@ -228,17 +238,22 @@ def test_array_file_contains_file_name():
|
||||
],
|
||||
)
|
||||
|
||||
node_config = {
|
||||
"id": "if-else",
|
||||
"data": node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = IfElseNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=Mock(),
|
||||
graph=Mock(),
|
||||
graph_runtime_state=Mock(),
|
||||
config={
|
||||
"id": "if-else",
|
||||
"data": node_data.model_dump(),
|
||||
},
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment(
|
||||
value=[
|
||||
File(
|
||||
|
||||
@@ -33,16 +33,19 @@ def list_operator_node():
|
||||
"title": "Test Title",
|
||||
}
|
||||
node_data = ListOperatorNodeData(**config)
|
||||
node_config = {
|
||||
"id": "test_node_id",
|
||||
"data": node_data.model_dump(),
|
||||
}
|
||||
node = ListOperatorNode(
|
||||
id="test_node_id",
|
||||
config={
|
||||
"id": "test_node_id",
|
||||
"data": node_data.model_dump(),
|
||||
},
|
||||
config=node_config,
|
||||
graph_init_params=MagicMock(),
|
||||
graph=MagicMock(),
|
||||
graph_runtime_state=MagicMock(),
|
||||
)
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.variable_pool = MagicMock()
|
||||
return node
|
||||
|
||||
@@ -38,12 +38,13 @@ def _create_tool_node():
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
}
|
||||
node = ToolNode(
|
||||
id="1",
|
||||
config={
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
},
|
||||
config=node_config,
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
@@ -71,6 +72,8 @@ def _create_tool_node():
|
||||
start_at=0,
|
||||
),
|
||||
)
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
return node
|
||||
|
||||
|
||||
|
||||
@@ -82,23 +82,28 @@ def test_overwrite_string_variable():
|
||||
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
|
||||
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
|
||||
|
||||
node_config = {
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"title": "test",
|
||||
"assigned_variable_selector": ["conversation", conversation_variable.name],
|
||||
"write_mode": WriteMode.OVER_WRITE.value,
|
||||
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
|
||||
},
|
||||
}
|
||||
|
||||
node = VariableAssignerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"title": "test",
|
||||
"assigned_variable_selector": ["conversation", conversation_variable.name],
|
||||
"write_mode": WriteMode.OVER_WRITE.value,
|
||||
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
|
||||
},
|
||||
},
|
||||
config=node_config,
|
||||
conv_var_updater_factory=mock_conv_var_updater_factory,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
list(node.run())
|
||||
expected_var = StringVariable(
|
||||
id=conversation_variable.id,
|
||||
@@ -178,23 +183,28 @@ def test_append_variable_to_array():
|
||||
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
|
||||
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
|
||||
|
||||
node_config = {
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"title": "test",
|
||||
"assigned_variable_selector": ["conversation", conversation_variable.name],
|
||||
"write_mode": WriteMode.APPEND.value,
|
||||
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
|
||||
},
|
||||
}
|
||||
|
||||
node = VariableAssignerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"title": "test",
|
||||
"assigned_variable_selector": ["conversation", conversation_variable.name],
|
||||
"write_mode": WriteMode.APPEND.value,
|
||||
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
|
||||
},
|
||||
},
|
||||
config=node_config,
|
||||
conv_var_updater_factory=mock_conv_var_updater_factory,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
list(node.run())
|
||||
expected_value = list(conversation_variable.value)
|
||||
expected_value.append(input_variable.value)
|
||||
@@ -265,23 +275,28 @@ def test_clear_array():
|
||||
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
|
||||
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
|
||||
|
||||
node_config = {
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"title": "test",
|
||||
"assigned_variable_selector": ["conversation", conversation_variable.name],
|
||||
"write_mode": WriteMode.CLEAR.value,
|
||||
"input_variable_selector": [],
|
||||
},
|
||||
}
|
||||
|
||||
node = VariableAssignerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"title": "test",
|
||||
"assigned_variable_selector": ["conversation", conversation_variable.name],
|
||||
"write_mode": WriteMode.CLEAR.value,
|
||||
"input_variable_selector": [],
|
||||
},
|
||||
},
|
||||
config=node_config,
|
||||
conv_var_updater_factory=mock_conv_var_updater_factory,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
list(node.run())
|
||||
expected_var = ArrayStringVariable(
|
||||
id=conversation_variable.id,
|
||||
|
||||
@@ -115,28 +115,33 @@ def test_remove_first_from_array():
|
||||
conversation_variables=[conversation_variable],
|
||||
)
|
||||
|
||||
node_config = {
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"title": "test",
|
||||
"version": "2",
|
||||
"items": [
|
||||
{
|
||||
"variable_selector": ["conversation", conversation_variable.name],
|
||||
"input_type": InputType.VARIABLE,
|
||||
"operation": Operation.REMOVE_FIRST,
|
||||
"value": None,
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
node = VariableAssignerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"title": "test",
|
||||
"version": "2",
|
||||
"items": [
|
||||
{
|
||||
"variable_selector": ["conversation", conversation_variable.name],
|
||||
"input_type": InputType.VARIABLE,
|
||||
"operation": Operation.REMOVE_FIRST,
|
||||
"value": None,
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Skip the mock assertion since we're in a test environment
|
||||
# Print the variable before running
|
||||
print(f"Before: {variable_pool.get(['conversation', conversation_variable.name]).to_object()}")
|
||||
@@ -202,28 +207,33 @@ def test_remove_last_from_array():
|
||||
conversation_variables=[conversation_variable],
|
||||
)
|
||||
|
||||
node_config = {
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"title": "test",
|
||||
"version": "2",
|
||||
"items": [
|
||||
{
|
||||
"variable_selector": ["conversation", conversation_variable.name],
|
||||
"input_type": InputType.VARIABLE,
|
||||
"operation": Operation.REMOVE_LAST,
|
||||
"value": None,
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
node = VariableAssignerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"title": "test",
|
||||
"version": "2",
|
||||
"items": [
|
||||
{
|
||||
"variable_selector": ["conversation", conversation_variable.name],
|
||||
"input_type": InputType.VARIABLE,
|
||||
"operation": Operation.REMOVE_LAST,
|
||||
"value": None,
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Skip the mock assertion since we're in a test environment
|
||||
list(node.run())
|
||||
|
||||
@@ -281,28 +291,33 @@ def test_remove_first_from_empty_array():
|
||||
conversation_variables=[conversation_variable],
|
||||
)
|
||||
|
||||
node_config = {
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"title": "test",
|
||||
"version": "2",
|
||||
"items": [
|
||||
{
|
||||
"variable_selector": ["conversation", conversation_variable.name],
|
||||
"input_type": InputType.VARIABLE,
|
||||
"operation": Operation.REMOVE_FIRST,
|
||||
"value": None,
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
node = VariableAssignerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"title": "test",
|
||||
"version": "2",
|
||||
"items": [
|
||||
{
|
||||
"variable_selector": ["conversation", conversation_variable.name],
|
||||
"input_type": InputType.VARIABLE,
|
||||
"operation": Operation.REMOVE_FIRST,
|
||||
"value": None,
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Skip the mock assertion since we're in a test environment
|
||||
list(node.run())
|
||||
|
||||
@@ -360,28 +375,33 @@ def test_remove_last_from_empty_array():
|
||||
conversation_variables=[conversation_variable],
|
||||
)
|
||||
|
||||
node_config = {
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"title": "test",
|
||||
"version": "2",
|
||||
"items": [
|
||||
{
|
||||
"variable_selector": ["conversation", conversation_variable.name],
|
||||
"input_type": InputType.VARIABLE,
|
||||
"operation": Operation.REMOVE_LAST,
|
||||
"value": None,
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
node = VariableAssignerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"title": "test",
|
||||
"version": "2",
|
||||
"items": [
|
||||
{
|
||||
"variable_selector": ["conversation", conversation_variable.name],
|
||||
"input_type": InputType.VARIABLE,
|
||||
"operation": Operation.REMOVE_LAST,
|
||||
"value": None,
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Skip the mock assertion since we're in a test environment
|
||||
list(node.run())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user