refactor: decouple Node and NodeData (#22581)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
This commit is contained in:
-LAN-
2025-07-18 10:08:51 +08:00
committed by GitHub
parent 54c56f2d05
commit 460a825ef1
65 changed files with 2305 additions and 1146 deletions

View File

@@ -15,7 +15,7 @@ def get_mocked_fetch_model_config(
mode: str,
credentials: dict,
):
model_provider_factory = ModelProviderFactory(tenant_id="test_tenant")
model_provider_factory = ModelProviderFactory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b")
model_type_instance = model_provider_factory.get_model_type_instance(provider, ModelType.LLM)
provider_model_bundle = ProviderModelBundle(
configuration=ProviderConfiguration(

View File

@@ -66,6 +66,10 @@ def init_code_node(code_config: dict):
config=code_config,
)
# Initialize node data
if "data" in code_config:
node.init_node_data(code_config["data"])
return node
@@ -234,10 +238,10 @@ def test_execute_code_output_validator_depth():
"object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}},
}
node.node_data = cast(CodeNodeData, node.node_data)
node._node_data = cast(CodeNodeData, node._node_data)
# validate
node._transform_result(result, node.node_data.outputs)
node._transform_result(result, node._node_data.outputs)
# construct result
result = {
@@ -250,7 +254,7 @@ def test_execute_code_output_validator_depth():
# validate
with pytest.raises(ValueError):
node._transform_result(result, node.node_data.outputs)
node._transform_result(result, node._node_data.outputs)
# construct result
result = {
@@ -263,7 +267,7 @@ def test_execute_code_output_validator_depth():
# validate
with pytest.raises(ValueError):
node._transform_result(result, node.node_data.outputs)
node._transform_result(result, node._node_data.outputs)
# construct result
result = {
@@ -276,7 +280,7 @@ def test_execute_code_output_validator_depth():
# validate
with pytest.raises(ValueError):
node._transform_result(result, node.node_data.outputs)
node._transform_result(result, node._node_data.outputs)
def test_execute_code_output_object_list():
@@ -330,10 +334,10 @@ def test_execute_code_output_object_list():
]
}
node.node_data = cast(CodeNodeData, node.node_data)
node._node_data = cast(CodeNodeData, node._node_data)
# validate
node._transform_result(result, node.node_data.outputs)
node._transform_result(result, node._node_data.outputs)
# construct result
result = {
@@ -353,7 +357,7 @@ def test_execute_code_output_object_list():
# validate
with pytest.raises(ValueError):
node._transform_result(result, node.node_data.outputs)
node._transform_result(result, node._node_data.outputs)
def test_execute_code_scientific_notation():

View File

@@ -52,7 +52,7 @@ def init_http_node(config: dict):
variable_pool.add(["a", "b123", "args1"], 1)
variable_pool.add(["a", "b123", "args2"], 2)
return HttpRequestNode(
node = HttpRequestNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
@@ -60,6 +60,12 @@ def init_http_node(config: dict):
config=config,
)
# Initialize node data
if "data" in config:
node.init_node_data(config["data"])
return node
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_get(setup_http_mock):

View File

@@ -2,15 +2,10 @@ import json
import time
import uuid
from collections.abc import Generator
from decimal import Decimal
from unittest.mock import MagicMock, patch
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.llm_generator.output_parser.structured_output import _parse_structured_output
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.graph_engine.entities.graph import Graph
@@ -24,8 +19,6 @@ from models.enums import UserFrom
from models.workflow import WorkflowType
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
def init_llm_node(config: dict) -> LLMNode:
@@ -84,10 +77,14 @@ def init_llm_node(config: dict) -> LLMNode:
config=config,
)
# Initialize node data
if "data" in config:
node.init_node_data(config["data"])
return node
def test_execute_llm(flask_req_ctx):
def test_execute_llm():
node = init_llm_node(
config={
"id": "llm",
@@ -95,7 +92,7 @@ def test_execute_llm(flask_req_ctx):
"title": "123",
"type": "llm",
"model": {
"provider": "langgenius/openai/openai",
"provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {},
@@ -114,53 +111,62 @@ def test_execute_llm(flask_req_ctx):
},
)
# Create a proper LLM result with real entities
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal(1000),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal(1000),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)
db.session.close = MagicMock()
mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.")
# Mock the _fetch_model_config to avoid database calls
def mock_fetch_model_config(**_kwargs):
from decimal import Decimal
from unittest.mock import MagicMock
mock_llm_result = LLMResult(
model="gpt-3.5-turbo",
prompt_messages=[],
message=mock_message,
usage=mock_usage,
)
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage
# Create a simple mock model instance that doesn't call real providers
mock_model_instance = MagicMock()
mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create mock model instance
mock_model_instance = MagicMock()
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal(1000),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal(1000),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)
mock_message = AssistantPromptMessage(content="Test response from mock")
mock_llm_result = LLMResult(
model="gpt-3.5-turbo",
prompt_messages=[],
message=mock_message,
usage=mock_usage,
)
mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create a simple mock model config with required attributes
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "langgenius/openai/openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
# Create mock model config
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.parameters = {}
# Mock the _fetch_model_config method
def mock_fetch_model_config_func(_node_data_model):
return mock_model_instance, mock_model_config
# Also mock ModelManager.get_model_instance to avoid database calls
def mock_get_model_instance(_self, **kwargs):
return mock_model_instance
# Mock fetch_prompt_messages to avoid database calls
def mock_fetch_prompt_messages_1(**_kwargs):
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
return [
SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."),
UserPromptMessage(content="what's the weather today?"),
], []
with (
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config),
patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_1),
):
# execute node
result = node._run()
@@ -168,6 +174,9 @@ def test_execute_llm(flask_req_ctx):
for item in result:
if isinstance(item, RunCompletedEvent):
if item.run_result.status != WorkflowNodeExecutionStatus.SUCCEEDED:
print(f"Error: {item.run_result.error}")
print(f"Error type: {item.run_result.error_type}")
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.process_data is not None
assert item.run_result.outputs is not None
@@ -175,8 +184,7 @@ def test_execute_llm(flask_req_ctx):
assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock):
def test_execute_llm_with_jinja2():
"""
Test execute LLM node with jinja2
"""
@@ -217,53 +225,60 @@ def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock):
# Mock db.session.close()
db.session.close = MagicMock()
# Create a proper LLM result with real entities
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal(1000),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal(1000),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)
mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
mock_llm_result = LLMResult(
model="gpt-3.5-turbo",
prompt_messages=[],
message=mock_message,
usage=mock_usage,
)
# Create a simple mock model instance that doesn't call real providers
mock_model_instance = MagicMock()
mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create a simple mock model config with required attributes
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
# Mock the _fetch_model_config method
def mock_fetch_model_config_func(_node_data_model):
def mock_fetch_model_config(**_kwargs):
from decimal import Decimal
from unittest.mock import MagicMock
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage
# Create mock model instance
mock_model_instance = MagicMock()
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal(1000),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal(1000),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)
mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
mock_llm_result = LLMResult(
model="gpt-3.5-turbo",
prompt_messages=[],
message=mock_message,
usage=mock_usage,
)
mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create mock model config
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.parameters = {}
return mock_model_instance, mock_model_config
# Also mock ModelManager.get_model_instance to avoid database calls
def mock_get_model_instance(_self, **kwargs):
return mock_model_instance
# Mock fetch_prompt_messages to avoid database calls
def mock_fetch_prompt_messages_2(**_kwargs):
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
return [
SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."),
UserPromptMessage(content="what's the weather today?"),
], []
with (
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config),
patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_2),
):
# execute node
result = node._run()

View File

@@ -74,13 +74,15 @@ def init_parameter_extractor_node(config: dict):
variable_pool.add(["a", "b123", "args1"], 1)
variable_pool.add(["a", "b123", "args2"], 2)
return ParameterExtractorNode(
node = ParameterExtractorNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=config,
)
node.init_node_data(config.get("data", {}))
return node
def test_function_calling_parameter_extractor(setup_model_mock):

View File

@@ -76,6 +76,7 @@ def test_execute_code(setup_code_executor_mock):
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=config,
)
node.init_node_data(config.get("data", {}))
# execute node
result = node._run()

View File

@@ -50,13 +50,15 @@ def init_tool_node(config: dict):
conversation_variables=[],
)
return ToolNode(
node = ToolNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=config,
)
node.init_node_data(config.get("data", {}))
return node
def test_tool_variable_invoke():

View File

@@ -58,21 +58,26 @@ def test_execute_answer():
pool.add(["start", "weather"], "sunny")
pool.add(["llm", "text"], "You are a helpful AI.")
node_config = {
"id": "answer",
"data": {
"title": "123",
"type": "answer",
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
},
}
node = AnswerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"id": "answer",
"data": {
"title": "123",
"type": "answer",
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
},
},
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Mock db.session.close()
db.session.close = MagicMock()

View File

@@ -57,12 +57,15 @@ def test_http_request_node_binary_file(monkeypatch):
),
),
)
node_config = {
"id": "1",
"data": data.model_dump(),
}
node = HttpRequestNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
config=node_config,
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
@@ -90,6 +93,9 @@ def test_http_request_node_binary_file(monkeypatch):
start_at=0,
),
)
# Initialize node data
node.init_node_data(node_config["data"])
monkeypatch.setattr(
"core.workflow.nodes.http_request.executor.file_manager.download",
lambda *args, **kwargs: b"test",
@@ -145,12 +151,15 @@ def test_http_request_node_form_with_file(monkeypatch):
),
),
)
node_config = {
"id": "1",
"data": data.model_dump(),
}
node = HttpRequestNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
config=node_config,
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
@@ -178,6 +187,10 @@ def test_http_request_node_form_with_file(monkeypatch):
start_at=0,
),
)
# Initialize node data
node.init_node_data(node_config["data"])
monkeypatch.setattr(
"core.workflow.nodes.http_request.executor.file_manager.download",
lambda *args, **kwargs: b"test",
@@ -257,12 +270,14 @@ def test_http_request_node_form_with_multiple_files(monkeypatch):
),
)
node_config = {
"id": "1",
"data": data.model_dump(),
}
node = HttpRequestNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
config=node_config,
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
@@ -291,6 +306,9 @@ def test_http_request_node_form_with_multiple_files(monkeypatch):
),
)
# Initialize node data
node.init_node_data(node_config["data"])
monkeypatch.setattr(
"core.workflow.nodes.http_request.executor.file_manager.download",
lambda file: b"test_image_data" if file.mime_type == "image/jpeg" else b"test_pdf_data",

View File

@@ -162,25 +162,30 @@ def test_run():
)
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
node_config = {
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "tt",
"title": "迭代",
"type": "iteration",
},
"id": "iteration-1",
}
iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "tt",
"title": "迭代",
"type": "iteration",
},
"id": "iteration-1",
},
config=node_config,
)
# Initialize node data
iteration_node.init_node_data(node_config["data"])
def tt_generator(self):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@@ -379,25 +384,30 @@ def test_run_parallel():
)
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
node_config = {
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "迭代",
"type": "iteration",
},
"id": "iteration-1",
}
iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "迭代",
"type": "iteration",
},
"id": "iteration-1",
},
config=node_config,
)
# Initialize node data
iteration_node.init_node_data(node_config["data"])
def tt_generator(self):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@@ -595,45 +605,55 @@ def test_iteration_run_in_parallel_mode():
)
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
parallel_node_config = {
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "迭代",
"type": "iteration",
"is_parallel": True,
},
"id": "iteration-1",
}
parallel_iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "迭代",
"type": "iteration",
"is_parallel": True,
},
"id": "iteration-1",
},
config=parallel_node_config,
)
# Initialize node data
parallel_iteration_node.init_node_data(parallel_node_config["data"])
sequential_node_config = {
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "迭代",
"type": "iteration",
"is_parallel": True,
},
"id": "iteration-1",
}
sequential_iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "迭代",
"type": "iteration",
"is_parallel": True,
},
"id": "iteration-1",
},
config=sequential_node_config,
)
# Initialize node data
sequential_iteration_node.init_node_data(sequential_node_config["data"])
def tt_generator(self):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@@ -645,8 +665,8 @@ def test_iteration_run_in_parallel_mode():
# execute node
parallel_result = parallel_iteration_node._run()
sequential_result = sequential_iteration_node._run()
assert parallel_iteration_node.node_data.parallel_nums == 10
assert parallel_iteration_node.node_data.error_handle_mode == ErrorHandleMode.TERMINATED
assert parallel_iteration_node._node_data.parallel_nums == 10
assert parallel_iteration_node._node_data.error_handle_mode == ErrorHandleMode.TERMINATED
count = 0
parallel_arr = []
sequential_arr = []
@@ -818,26 +838,31 @@ def test_iteration_run_error_handle():
environment_variables=[],
)
pool.add(["pe", "list_output"], ["1", "1"])
error_node_config = {
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "iteration",
"type": "iteration",
"is_parallel": True,
"error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR,
},
"id": "iteration-1",
}
iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "iteration",
"type": "iteration",
"is_parallel": True,
"error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR,
},
"id": "iteration-1",
},
config=error_node_config,
)
# Initialize node data
iteration_node.init_node_data(error_node_config["data"])
# execute continue on error node
result = iteration_node._run()
result_arr = []
@@ -851,7 +876,7 @@ def test_iteration_run_error_handle():
assert count == 14
# execute remove abnormal output
iteration_node.node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT
iteration_node._node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT
result = iteration_node._run()
count = 0
for item in result:

View File

@@ -119,17 +119,20 @@ def llm_node(
llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph: Graph, graph_runtime_state: GraphRuntimeState
) -> LLMNode:
mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
node_config = {
"id": "1",
"data": llm_node_data.model_dump(),
}
node = LLMNode(
id="1",
config={
"id": "1",
"data": llm_node_data.model_dump(),
},
config=node_config,
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
llm_file_saver=mock_file_saver,
)
# Initialize node data
node.init_node_data(node_config["data"])
return node
@@ -488,7 +491,7 @@ def test_handle_list_messages_basic(llm_node):
variable_pool = llm_node.graph_runtime_state.variable_pool
vision_detail_config = ImagePromptMessageContent.DETAIL.HIGH
result = llm_node._handle_list_messages(
result = llm_node.handle_list_messages(
messages=messages,
context=context,
jinja2_variables=jinja2_variables,
@@ -506,17 +509,20 @@ def llm_node_for_multimodal(
llm_node_data, graph_init_params, graph, graph_runtime_state
) -> tuple[LLMNode, LLMFileSaver]:
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
node_config = {
"id": "1",
"data": llm_node_data.model_dump(),
}
node = LLMNode(
id="1",
config={
"id": "1",
"data": llm_node_data.model_dump(),
},
config=node_config,
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
llm_file_saver=mock_file_saver,
)
# Initialize node data
node.init_node_data(node_config["data"])
return node, mock_file_saver
@@ -540,7 +546,12 @@ class TestLLMNodeSaveMultiModalImageOutput:
size=9,
)
mock_file_saver.save_binary_string.return_value = mock_file
file = llm_node._save_multimodal_image_output(content=content)
file = llm_node.save_multimodal_image_output(
content=content,
file_saver=mock_file_saver,
)
# Manually append to _file_outputs since the static method doesn't do it
llm_node._file_outputs.append(file)
assert llm_node._file_outputs == [mock_file]
assert file == mock_file
mock_file_saver.save_binary_string.assert_called_once_with(
@@ -566,7 +577,12 @@ class TestLLMNodeSaveMultiModalImageOutput:
size=9,
)
mock_file_saver.save_remote_url.return_value = mock_file
file = llm_node._save_multimodal_image_output(content=content)
file = llm_node.save_multimodal_image_output(
content=content,
file_saver=mock_file_saver,
)
# Manually append to _file_outputs since the static method doesn't do it
llm_node._file_outputs.append(file)
assert llm_node._file_outputs == [mock_file]
assert file == mock_file
mock_file_saver.save_remote_url.assert_called_once_with(content.url, FileType.IMAGE)
@@ -582,7 +598,9 @@ def test_llm_node_image_file_to_markdown(llm_node: LLMNode):
class TestSaveMultimodalOutputAndConvertResultToMarkdown:
def test_str_content(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown("hello world")
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
contents="hello world", file_saver=mock_file_saver, file_outputs=[]
)
assert list(gen) == ["hello world"]
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()
@@ -590,7 +608,7 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
def test_text_prompt_message_content(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
[TextPromptMessageContent(data="hello world")]
contents=[TextPromptMessageContent(data="hello world")], file_saver=mock_file_saver, file_outputs=[]
)
assert list(gen) == ["hello world"]
mock_file_saver.save_binary_string.assert_not_called()
@@ -616,13 +634,15 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
)
mock_file_saver.save_binary_string.return_value = mock_saved_file
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
[
contents=[
ImagePromptMessageContent(
format="png",
base64_data=image_b64_data,
mime_type="image/png",
)
]
],
file_saver=mock_file_saver,
file_outputs=llm_node._file_outputs,
)
yielded_strs = list(gen)
assert len(yielded_strs) == 1
@@ -645,21 +665,27 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
def test_unknown_content_type(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(frozenset(["hello world"]))
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
contents=frozenset(["hello world"]), file_saver=mock_file_saver, file_outputs=[]
)
assert list(gen) == ["frozenset({'hello world'})"]
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()
def test_unknown_item_type(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown([frozenset(["hello world"])])
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
contents=[frozenset(["hello world"])], file_saver=mock_file_saver, file_outputs=[]
)
assert list(gen) == ["frozenset({'hello world'})"]
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()
def test_none_content(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(None)
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
contents=None, file_saver=mock_file_saver, file_outputs=[]
)
assert list(gen) == []
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()

View File

@@ -61,21 +61,26 @@ def test_execute_answer():
variable_pool.add(["start", "weather"], "sunny")
variable_pool.add(["llm", "text"], "You are a helpful AI.")
node_config = {
"id": "answer",
"data": {
"title": "123",
"type": "answer",
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
},
}
node = AnswerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "answer",
"data": {
"title": "123",
"type": "answer",
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
},
},
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Mock db.session.close()
db.session.close = MagicMock()

View File

@@ -27,13 +27,17 @@ def document_extractor_node():
title="Test Document Extractor",
variable_selector=["node_id", "variable_name"],
)
return DocumentExtractorNode(
node_config = {"id": "test_node_id", "data": node_data.model_dump()}
node = DocumentExtractorNode(
id="test_node_id",
config={"id": "test_node_id", "data": node_data.model_dump()},
config=node_config,
graph_init_params=Mock(),
graph=Mock(),
graph_runtime_state=Mock(),
)
# Initialize node data
node.init_node_data(node_config["data"])
return node
@pytest.fixture

View File

@@ -57,57 +57,62 @@ def test_execute_if_else_result_true():
pool.add(["start", "null"], None)
pool.add(["start", "not_null"], "1212")
node_config = {
"id": "if-else",
"data": {
"title": "123",
"type": "if-else",
"logical_operator": "and",
"conditions": [
{
"comparison_operator": "contains",
"variable_selector": ["start", "array_contains"],
"value": "ab",
},
{
"comparison_operator": "not contains",
"variable_selector": ["start", "array_not_contains"],
"value": "ab",
},
{"comparison_operator": "contains", "variable_selector": ["start", "contains"], "value": "ab"},
{
"comparison_operator": "not contains",
"variable_selector": ["start", "not_contains"],
"value": "ab",
},
{"comparison_operator": "start with", "variable_selector": ["start", "start_with"], "value": "ab"},
{"comparison_operator": "end with", "variable_selector": ["start", "end_with"], "value": "ab"},
{"comparison_operator": "is", "variable_selector": ["start", "is"], "value": "ab"},
{"comparison_operator": "is not", "variable_selector": ["start", "is_not"], "value": "ab"},
{"comparison_operator": "empty", "variable_selector": ["start", "empty"], "value": "ab"},
{"comparison_operator": "not empty", "variable_selector": ["start", "not_empty"], "value": "ab"},
{"comparison_operator": "=", "variable_selector": ["start", "equals"], "value": "22"},
{"comparison_operator": "", "variable_selector": ["start", "not_equals"], "value": "22"},
{"comparison_operator": ">", "variable_selector": ["start", "greater_than"], "value": "22"},
{"comparison_operator": "<", "variable_selector": ["start", "less_than"], "value": "22"},
{
"comparison_operator": "",
"variable_selector": ["start", "greater_than_or_equal"],
"value": "22",
},
{"comparison_operator": "", "variable_selector": ["start", "less_than_or_equal"], "value": "22"},
{"comparison_operator": "null", "variable_selector": ["start", "null"]},
{"comparison_operator": "not null", "variable_selector": ["start", "not_null"]},
],
},
}
node = IfElseNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"id": "if-else",
"data": {
"title": "123",
"type": "if-else",
"logical_operator": "and",
"conditions": [
{
"comparison_operator": "contains",
"variable_selector": ["start", "array_contains"],
"value": "ab",
},
{
"comparison_operator": "not contains",
"variable_selector": ["start", "array_not_contains"],
"value": "ab",
},
{"comparison_operator": "contains", "variable_selector": ["start", "contains"], "value": "ab"},
{
"comparison_operator": "not contains",
"variable_selector": ["start", "not_contains"],
"value": "ab",
},
{"comparison_operator": "start with", "variable_selector": ["start", "start_with"], "value": "ab"},
{"comparison_operator": "end with", "variable_selector": ["start", "end_with"], "value": "ab"},
{"comparison_operator": "is", "variable_selector": ["start", "is"], "value": "ab"},
{"comparison_operator": "is not", "variable_selector": ["start", "is_not"], "value": "ab"},
{"comparison_operator": "empty", "variable_selector": ["start", "empty"], "value": "ab"},
{"comparison_operator": "not empty", "variable_selector": ["start", "not_empty"], "value": "ab"},
{"comparison_operator": "=", "variable_selector": ["start", "equals"], "value": "22"},
{"comparison_operator": "", "variable_selector": ["start", "not_equals"], "value": "22"},
{"comparison_operator": ">", "variable_selector": ["start", "greater_than"], "value": "22"},
{"comparison_operator": "<", "variable_selector": ["start", "less_than"], "value": "22"},
{
"comparison_operator": "",
"variable_selector": ["start", "greater_than_or_equal"],
"value": "22",
},
{"comparison_operator": "", "variable_selector": ["start", "less_than_or_equal"], "value": "22"},
{"comparison_operator": "null", "variable_selector": ["start", "null"]},
{"comparison_operator": "not null", "variable_selector": ["start", "not_null"]},
],
},
},
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Mock db.session.close()
db.session.close = MagicMock()
@@ -162,33 +167,38 @@ def test_execute_if_else_result_false():
pool.add(["start", "array_contains"], ["1ab", "def"])
pool.add(["start", "array_not_contains"], ["ab", "def"])
node_config = {
"id": "if-else",
"data": {
"title": "123",
"type": "if-else",
"logical_operator": "or",
"conditions": [
{
"comparison_operator": "contains",
"variable_selector": ["start", "array_contains"],
"value": "ab",
},
{
"comparison_operator": "not contains",
"variable_selector": ["start", "array_not_contains"],
"value": "ab",
},
],
},
}
node = IfElseNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"id": "if-else",
"data": {
"title": "123",
"type": "if-else",
"logical_operator": "or",
"conditions": [
{
"comparison_operator": "contains",
"variable_selector": ["start", "array_contains"],
"value": "ab",
},
{
"comparison_operator": "not contains",
"variable_selector": ["start", "array_not_contains"],
"value": "ab",
},
],
},
},
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Mock db.session.close()
db.session.close = MagicMock()
@@ -228,17 +238,22 @@ def test_array_file_contains_file_name():
],
)
node_config = {
"id": "if-else",
"data": node_data.model_dump(),
}
node = IfElseNode(
id=str(uuid.uuid4()),
graph_init_params=Mock(),
graph=Mock(),
graph_runtime_state=Mock(),
config={
"id": "if-else",
"data": node_data.model_dump(),
},
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment(
value=[
File(

View File

@@ -33,16 +33,19 @@ def list_operator_node():
"title": "Test Title",
}
node_data = ListOperatorNodeData(**config)
node_config = {
"id": "test_node_id",
"data": node_data.model_dump(),
}
node = ListOperatorNode(
id="test_node_id",
config={
"id": "test_node_id",
"data": node_data.model_dump(),
},
config=node_config,
graph_init_params=MagicMock(),
graph=MagicMock(),
graph_runtime_state=MagicMock(),
)
# Initialize node data
node.init_node_data(node_config["data"])
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.variable_pool = MagicMock()
return node

View File

@@ -38,12 +38,13 @@ def _create_tool_node():
system_variables=SystemVariable.empty(),
user_inputs={},
)
node_config = {
"id": "1",
"data": data.model_dump(),
}
node = ToolNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
config=node_config,
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
@@ -71,6 +72,8 @@ def _create_tool_node():
start_at=0,
),
)
# Initialize node data
node.init_node_data(node_config["data"])
return node

View File

@@ -82,23 +82,28 @@ def test_overwrite_string_variable():
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
node_config = {
"id": "node_id",
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.OVER_WRITE.value,
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
},
}
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id",
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.OVER_WRITE.value,
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
},
},
config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory,
)
# Initialize node data
node.init_node_data(node_config["data"])
list(node.run())
expected_var = StringVariable(
id=conversation_variable.id,
@@ -178,23 +183,28 @@ def test_append_variable_to_array():
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
node_config = {
"id": "node_id",
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.APPEND.value,
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
},
}
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id",
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.APPEND.value,
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
},
},
config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory,
)
# Initialize node data
node.init_node_data(node_config["data"])
list(node.run())
expected_value = list(conversation_variable.value)
expected_value.append(input_variable.value)
@@ -265,23 +275,28 @@ def test_clear_array():
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
node_config = {
"id": "node_id",
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.CLEAR.value,
"input_variable_selector": [],
},
}
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id",
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.CLEAR.value,
"input_variable_selector": [],
},
},
config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory,
)
# Initialize node data
node.init_node_data(node_config["data"])
list(node.run())
expected_var = ArrayStringVariable(
id=conversation_variable.id,

View File

@@ -115,28 +115,33 @@ def test_remove_first_from_array():
conversation_variables=[conversation_variable],
)
node_config = {
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_FIRST,
"value": None,
}
],
},
}
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_FIRST,
"value": None,
}
],
},
},
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Skip the mock assertion since we're in a test environment
# Print the variable before running
print(f"Before: {variable_pool.get(['conversation', conversation_variable.name]).to_object()}")
@@ -202,28 +207,33 @@ def test_remove_last_from_array():
conversation_variables=[conversation_variable],
)
node_config = {
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_LAST,
"value": None,
}
],
},
}
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_LAST,
"value": None,
}
],
},
},
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Skip the mock assertion since we're in a test environment
list(node.run())
@@ -281,28 +291,33 @@ def test_remove_first_from_empty_array():
conversation_variables=[conversation_variable],
)
node_config = {
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_FIRST,
"value": None,
}
],
},
}
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_FIRST,
"value": None,
}
],
},
},
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Skip the mock assertion since we're in a test environment
list(node.run())
@@ -360,28 +375,33 @@ def test_remove_last_from_empty_array():
conversation_variables=[conversation_variable],
)
node_config = {
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_LAST,
"value": None,
}
],
},
}
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_LAST,
"value": None,
}
],
},
},
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Skip the mock assertion since we're in a test environment
list(node.run())