chore(api): upgrade graphon to v0.3.0 (#35469)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: WH-2099 <wh2099@pm.me>
This commit is contained in:
-LAN-
2026-05-09 15:30:03 +08:00
committed by GitHub
parent f3eb3ab4dd
commit 19476109da
80 changed files with 2526 additions and 673 deletions

View File

@@ -73,7 +73,7 @@ def test_node_integration_minimal_stream(mocker: MockerFixture):
node = DatasourceNode(
node_id="n",
config=DatasourceNodeData(
data=DatasourceNodeData(
type="datasource",
version="1",
title="Datasource",

View File

@@ -4,7 +4,7 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration
from core.model_manager import ModelInstance
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly
from graphon.model_runtime.entities.model_entities import ModelType
from models.provider import ProviderType
@@ -15,8 +15,9 @@ def get_mocked_fetch_model_config(
mode: str,
credentials: dict,
):
model_provider_factory = create_plugin_model_provider_factory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b")
model_type_instance = model_provider_factory.get_model_type_instance(provider, ModelType.LLM)
model_assembly = create_plugin_model_assembly(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b")
model_provider_factory = model_assembly.model_provider_factory
model_type_instance = model_assembly.create_model_type_instance(provider=provider, model_type=ModelType.LLM)
provider_model_bundle = ProviderModelBundle(
configuration=ProviderConfiguration(
tenant_id="1",

View File

@@ -45,7 +45,7 @@ def init_code_node(code_config: dict):
)
# construct variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
@@ -66,7 +66,7 @@ def init_code_node(code_config: dict):
node = CodeNode(
node_id=str(uuid.uuid4()),
config=CodeNodeData.model_validate(code_config["data"]),
data=CodeNodeData.model_validate(code_config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
code_executor=node_factory._code_executor,

View File

@@ -55,7 +55,7 @@ def init_http_node(config: dict):
)
# construct variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
@@ -76,7 +76,7 @@ def init_http_node(config: dict):
node = HttpRequestNode(
node_id=str(uuid.uuid4()),
config=HttpRequestNodeData.model_validate(config["data"]),
data=HttpRequestNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
http_request_config=HTTP_REQUEST_CONFIG,
@@ -204,7 +204,7 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock):
from graphon.runtime import VariablePool
# Create variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="test", files=[]),
user_inputs={},
environment_variables=[],
@@ -702,7 +702,7 @@ def test_nested_object_variable_selector(setup_http_mock):
)
# Create independent variable pool for this test only
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
@@ -724,7 +724,7 @@ def test_nested_object_variable_selector(setup_http_mock):
node = HttpRequestNode(
node_id=str(uuid.uuid4()),
config=HttpRequestNodeData.model_validate(graph_config["nodes"][1]["data"]),
data=HttpRequestNodeData.model_validate(graph_config["nodes"][1]["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
http_request_config=HTTP_REQUEST_CONFIG,

View File

@@ -53,7 +53,7 @@ def init_llm_node(config: dict) -> LLMNode:
)
# construct variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
user_id="aaa",
app_id=app_id,
@@ -77,7 +77,7 @@ def init_llm_node(config: dict) -> LLMNode:
node = LLMNode(
node_id=str(uuid.uuid4()),
config=LLMNodeData.model_validate(config["data"]),
data=LLMNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=MagicMock(spec=CredentialsProvider),

View File

@@ -56,7 +56,7 @@ def init_parameter_extractor_node(config: dict, memory=None):
)
# construct variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
user_id="aaa", files=[], query="what's the weather in SF", conversation_id="abababa"
),
@@ -71,7 +71,7 @@ def init_parameter_extractor_node(config: dict, memory=None):
node = ParameterExtractorNode(
node_id=str(uuid.uuid4()),
config=ParameterExtractorNodeData.model_validate(config["data"]),
data=ParameterExtractorNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=MagicMock(spec=CredentialsProvider),

View File

@@ -66,7 +66,7 @@ def test_execute_template_transform():
)
# construct variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
@@ -88,7 +88,7 @@ def test_execute_template_transform():
node = TemplateTransformNode(
node_id=str(uuid.uuid4()),
config=TemplateTransformNodeData.model_validate(config["data"]),
data=TemplateTransformNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
jinja2_template_renderer=_SimpleJinja2Renderer(),

View File

@@ -43,7 +43,7 @@ def init_tool_node(config: dict):
)
# construct variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
@@ -64,7 +64,7 @@ def init_tool_node(config: dict):
node = ToolNode(
node_id=str(uuid.uuid4()),
config=ToolNodeData.model_validate(config["data"]),
data=ToolNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
tool_file_manager_factory=tool_file_manager_factory,

View File

@@ -210,7 +210,9 @@ class TestPauseStatePersistenceLayerTestContainers:
execution_id = workflow_run_id or getattr(self, "test_workflow_run_id", None) or str(uuid.uuid4())
# Create variable pool
variable_pool = VariablePool(system_variables=build_system_variables(workflow_execution_id=execution_id))
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id=execution_id)
)
if variables:
for (node_id, var_key), value in variables.items():
variable_pool.add([node_id, var_key], value)

View File

@@ -66,7 +66,7 @@ def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepos
def _build_runtime_state(workflow_execution_id: str, app_id: str, workflow_id: str, user_id: str) -> GraphRuntimeState:
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
workflow_execution_id=workflow_execution_id,
app_id=app_id,
@@ -102,7 +102,7 @@ def _build_graph(
start_data = StartNodeData(title="start", variables=[])
start_node = StartNode(
node_id="start",
config=start_data,
data=start_data,
graph_init_params=params,
graph_runtime_state=runtime_state,
)
@@ -117,7 +117,7 @@ def _build_graph(
)
human_node = HumanInputNode(
node_id="human",
config=human_data,
data=human_data,
graph_init_params=params,
graph_runtime_state=runtime_state,
form_repository=form_repository,
@@ -131,7 +131,7 @@ def _build_graph(
)
end_node = EndNode(
node_id="end",
config=end_data,
data=end_data,
graph_init_params=params,
graph_runtime_state=runtime_state,
)

View File

@@ -90,16 +90,34 @@ class TestCreditPoolService:
pool = CreditPoolService.get_pool(tenant_id=tenant_id)
assert pool.quota_used == credits_required
def test_check_and_deduct_credits_caps_at_remaining(self, db_session_with_containers: Session):
def test_check_and_deduct_credits_raises_without_deducting_when_insufficient(
self, db_session_with_containers: Session
):
tenant_id = self._create_tenant_id()
pool = CreditPoolService.create_default_pool(tenant_id)
remaining = 5
pool.quota_used = pool.quota_limit - remaining
quota_used = pool.quota_used
db_session_with_containers.commit()
result = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=200)
with pytest.raises(QuotaExceededError, match="Insufficient credits remaining"):
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=200)
db_session_with_containers.expire_all()
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
assert updated_pool.quota_used == quota_used
def test_deduct_credits_capped_depletes_available_balance(self, db_session_with_containers: Session):
tenant_id = self._create_tenant_id()
pool = CreditPoolService.create_default_pool(tenant_id)
remaining = 5
pool.quota_used = pool.quota_limit - remaining
quota_limit = pool.quota_limit
db_session_with_containers.commit()
result = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=200)
assert result == remaining
db_session_with_containers.expire_all()
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
assert updated_pool.quota_used == pool.quota_limit
assert updated_pool.quota_used == quota_limit

View File

@@ -132,7 +132,9 @@ class TestAdvancedChatGenerateTaskPipeline:
pipeline._task_state.answer = "partial answer"
pipeline._workflow_run_id = "run-id"
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
variable_pool=build_test_variable_pool(
variables=build_system_variables(workflow_execution_id="run-id"),
),
start_at=0.0,
total_tokens=7,
node_run_steps=3,
@@ -372,7 +374,9 @@ class TestAdvancedChatGenerateTaskPipeline:
pipeline = _make_pipeline()
pipeline._workflow_run_id = "run-id"
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id="run-id")
),
start_at=0.0,
)
pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish"
@@ -583,7 +587,9 @@ class TestAdvancedChatGenerateTaskPipeline:
self.items = items
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id="run-id")
),
start_at=0.0,
)
@@ -617,7 +623,9 @@ class TestAdvancedChatGenerateTaskPipeline:
def test_handle_message_end_event_applies_output_moderation(self, monkeypatch: pytest.MonkeyPatch):
pipeline = _make_pipeline()
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id="run-id")
),
start_at=0.0,
)
pipeline._base_task_pipeline.handle_output_moderation_when_task_finished = lambda answer: "safe"

View File

@@ -60,7 +60,7 @@ class _StubToolNode(Node[_StubToolNodeData]):
def __init__(
self,
node_id: str,
config: _StubToolNodeData,
data: _StubToolNodeData,
*,
graph_init_params,
graph_runtime_state,
@@ -68,7 +68,7 @@ class _StubToolNode(Node[_StubToolNodeData]):
) -> None:
super().__init__(
node_id=node_id,
config=config,
data=data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
@@ -169,7 +169,7 @@ def _build_graph(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> G
def _build_runtime_state(run_id: str) -> GraphRuntimeState:
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="user", app_id="app", workflow_id="workflow"),
user_inputs={},
conversation_variables=[],

View File

@@ -54,7 +54,7 @@ class TestWorkflowBasedAppRunner:
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=default_system_variables()),
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
start_at=0.0,
)
@@ -93,7 +93,7 @@ class TestWorkflowBasedAppRunner:
def test_get_graph_and_variable_pool_for_single_node_run(self, monkeypatch: pytest.MonkeyPatch):
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=default_system_variables()),
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
start_at=0.0,
)
@@ -164,7 +164,7 @@ class TestWorkflowBasedAppRunner:
app_id="app",
)
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=default_system_variables()),
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
start_at=0.0,
)
@@ -243,7 +243,7 @@ class TestWorkflowBasedAppRunner:
runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app")
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=default_system_variables()),
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
start_at=0.0,
)
graph_runtime_state.register_paused_node("node-1")
@@ -286,7 +286,7 @@ class TestWorkflowBasedAppRunner:
runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app")
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=default_system_variables()),
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
start_at=0.0,
)
workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state))
@@ -425,7 +425,7 @@ class TestWorkflowBasedAppRunner:
runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app")
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=default_system_variables()),
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
start_at=0.0,
)
workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state))

View File

@@ -16,7 +16,7 @@ from models.workflow import Workflow
def _make_graph_state():
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
environment_variables=[],

View File

@@ -95,7 +95,9 @@ class TestWorkflowGenerateTaskPipeline:
def test_to_blocking_response_falls_back_to_human_input_required_when_pause_event_missing(self):
pipeline = _make_pipeline()
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
variable_pool=build_test_variable_pool(
variables=build_system_variables(workflow_execution_id="run-id"),
),
start_at=0.0,
total_tokens=5,
node_run_steps=2,
@@ -283,7 +285,9 @@ class TestWorkflowGenerateTaskPipeline:
pipeline = _make_pipeline()
pipeline._workflow_execution_id = "run-id"
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id="run-id")
),
start_at=0.0,
)
pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish"
@@ -725,7 +729,9 @@ class TestWorkflowGenerateTaskPipeline:
pipeline = _make_pipeline()
pipeline._workflow_execution_id = "run-id"
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id="run-id")
),
start_at=0.0,
)
@@ -753,7 +759,9 @@ class TestWorkflowGenerateTaskPipeline:
pipeline = _make_pipeline()
pipeline._workflow_execution_id = "run-id"
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id="run-id")
),
start_at=0.0,
)
pipeline._handle_ping_event = lambda event, **kwargs: iter(["ping"])
@@ -769,7 +777,9 @@ class TestWorkflowGenerateTaskPipeline:
def test_process_stream_response_main_match_paths_and_cleanup(self):
pipeline = _make_pipeline()
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id="run-id")
),
start_at=0.0,
)
pipeline._base_task_pipeline.queue_manager.listen = lambda: iter(

View File

@@ -21,7 +21,9 @@ class TestTriggerPostLayer:
)
runtime_state = SimpleNamespace(
outputs={"answer": "ok"},
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id="run-1")
),
total_tokens=12,
)
@@ -60,7 +62,9 @@ class TestTriggerPostLayer:
def test_on_event_handles_missing_trigger_log(self):
runtime_state = SimpleNamespace(
outputs={},
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id="run-1")
),
total_tokens=0,
)
@@ -91,7 +95,9 @@ class TestTriggerPostLayer:
def test_on_event_ignores_non_status_events(self):
runtime_state = SimpleNamespace(
outputs={},
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id="run-1")
),
total_tokens=0,
)

View File

@@ -0,0 +1,617 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from sqlalchemy import create_engine, select
from configs import dify_config
from core.app.llm.quota import (
deduct_llm_quota,
deduct_llm_quota_for_model,
ensure_llm_quota_available,
ensure_llm_quota_available_for_model,
)
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.errors.error import QuotaExceededError
from graphon.model_runtime.entities.llm_entities import LLMUsage
from graphon.model_runtime.entities.model_entities import ModelType
from models import TenantCreditPool
from models.enums import ProviderQuotaType as ModelProviderQuotaType
from models.provider import Provider, ProviderType
def test_ensure_llm_quota_available_for_model_raises_when_system_model_is_exhausted() -> None:
provider_configuration = SimpleNamespace(
using_provider_type=ProviderType.SYSTEM,
get_provider_model=MagicMock(return_value=SimpleNamespace(status=ModelStatus.QUOTA_EXCEEDED)),
)
provider_manager = MagicMock()
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
pytest.raises(QuotaExceededError, match="Model provider openai quota exceeded."),
):
ensure_llm_quota_available_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
)
provider_configuration.get_provider_model.assert_called_once_with(
model_type=ModelType.LLM,
model="gpt-4o",
)
def test_ensure_llm_quota_available_for_model_raises_when_provider_is_missing() -> None:
provider_manager = MagicMock()
provider_manager.get_configurations.return_value.get.return_value = None
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
pytest.raises(ValueError, match="Provider openai does not exist."),
):
ensure_llm_quota_available_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
)
def test_ensure_llm_quota_available_for_model_ignores_custom_provider_configuration() -> None:
provider_configuration = SimpleNamespace(
using_provider_type=ProviderType.CUSTOM,
get_provider_model=MagicMock(),
)
provider_manager = MagicMock()
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
with patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager):
ensure_llm_quota_available_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
)
provider_configuration.get_provider_model.assert_not_called()
def test_deduct_llm_quota_for_model_uses_identity_based_trial_billing() -> None:
usage = LLMUsage.empty_usage()
usage.total_tokens = 42
provider_configuration = SimpleNamespace(
using_provider_type=ProviderType.SYSTEM,
system_configuration=SimpleNamespace(
current_quota_type=ProviderQuotaType.TRIAL,
quota_configurations=[
SimpleNamespace(
quota_type=ProviderQuotaType.TRIAL,
quota_unit=QuotaUnit.TOKENS,
quota_limit=100,
)
],
),
)
provider_manager = MagicMock()
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
):
deduct_llm_quota_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
usage=usage,
)
mock_deduct_credits.assert_called_once_with(
tenant_id="tenant-id",
credits_required=42,
)
def test_deduct_llm_quota_for_model_caps_trial_pool_when_usage_exceeds_remaining() -> None:
usage = LLMUsage.empty_usage()
usage.total_tokens = 3
provider_configuration = SimpleNamespace(
using_provider_type=ProviderType.SYSTEM,
system_configuration=SimpleNamespace(
current_quota_type=ProviderQuotaType.TRIAL,
quota_configurations=[
SimpleNamespace(
quota_type=ProviderQuotaType.TRIAL,
quota_unit=QuotaUnit.TOKENS,
quota_limit=100,
)
],
),
)
provider_manager = MagicMock()
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
engine = create_engine("sqlite:///:memory:")
TenantCreditPool.__table__.create(engine)
with engine.begin() as connection:
connection.execute(
TenantCreditPool.__table__.insert(),
{
"id": "trial-pool",
"tenant_id": "tenant-id",
"pool_type": ModelProviderQuotaType.TRIAL,
"quota_limit": 10,
"quota_used": 9,
},
)
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
):
deduct_llm_quota_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
usage=usage,
)
with engine.connect() as connection:
quota_used = connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == "trial-pool"))
assert quota_used == 10
def test_deduct_llm_quota_for_model_returns_for_unbounded_quota() -> None:
usage = LLMUsage.empty_usage()
usage.total_tokens = 42
provider_configuration = SimpleNamespace(
using_provider_type=ProviderType.SYSTEM,
system_configuration=SimpleNamespace(
current_quota_type=ProviderQuotaType.TRIAL,
quota_configurations=[
SimpleNamespace(
quota_type=ProviderQuotaType.TRIAL,
quota_unit=QuotaUnit.TOKENS,
quota_limit=-1,
)
],
),
)
provider_manager = MagicMock()
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
):
deduct_llm_quota_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
usage=usage,
)
mock_deduct_credits.assert_not_called()
def test_deduct_llm_quota_for_model_uses_credit_configuration() -> None:
usage = LLMUsage.empty_usage()
provider_configuration = SimpleNamespace(
using_provider_type=ProviderType.SYSTEM,
system_configuration=SimpleNamespace(
current_quota_type=ProviderQuotaType.TRIAL,
quota_configurations=[
SimpleNamespace(
quota_type=ProviderQuotaType.TRIAL,
quota_unit=QuotaUnit.CREDITS,
quota_limit=100,
)
],
),
)
provider_manager = MagicMock()
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch.object(type(dify_config), "get_model_credits", return_value=9) as mock_get_model_credits,
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
):
deduct_llm_quota_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
usage=usage,
)
mock_get_model_credits.assert_called_once_with("gpt-4o")
mock_deduct_credits.assert_called_once_with(
tenant_id="tenant-id",
credits_required=9,
)
def test_deduct_llm_quota_for_model_uses_single_charge_for_times_quota() -> None:
usage = LLMUsage.empty_usage()
provider_configuration = SimpleNamespace(
using_provider_type=ProviderType.SYSTEM,
system_configuration=SimpleNamespace(
current_quota_type=ProviderQuotaType.TRIAL,
quota_configurations=[
SimpleNamespace(
quota_type=ProviderQuotaType.TRIAL,
quota_unit=QuotaUnit.TIMES,
quota_limit=100,
)
],
),
)
provider_manager = MagicMock()
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
):
deduct_llm_quota_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
usage=usage,
)
mock_deduct_credits.assert_called_once_with(
tenant_id="tenant-id",
credits_required=1,
)
def test_deduct_llm_quota_for_model_uses_paid_billing_pool() -> None:
usage = LLMUsage.empty_usage()
usage.total_tokens = 5
provider_configuration = SimpleNamespace(
using_provider_type=ProviderType.SYSTEM,
system_configuration=SimpleNamespace(
current_quota_type=ProviderQuotaType.PAID,
quota_configurations=[
SimpleNamespace(
quota_type=ProviderQuotaType.PAID,
quota_unit=QuotaUnit.TOKENS,
quota_limit=100,
)
],
),
)
provider_manager = MagicMock()
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
):
deduct_llm_quota_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
usage=usage,
)
mock_deduct_credits.assert_called_once_with(
tenant_id="tenant-id",
credits_required=5,
pool_type="paid",
)
def test_deduct_llm_quota_for_model_updates_free_quota_usage() -> None:
usage = LLMUsage.empty_usage()
usage.total_tokens = 3
provider_configuration = SimpleNamespace(
using_provider_type=ProviderType.SYSTEM,
system_configuration=SimpleNamespace(
current_quota_type=ProviderQuotaType.FREE,
quota_configurations=[
SimpleNamespace(
quota_type=ProviderQuotaType.FREE,
quota_unit=QuotaUnit.TOKENS,
quota_limit=100,
)
],
),
)
provider_manager = MagicMock()
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
engine = create_engine("sqlite:///:memory:")
Provider.__table__.create(engine)
with engine.begin() as connection:
connection.execute(
Provider.__table__.insert(),
[
{
"id": "matching-provider",
"tenant_id": "tenant-id",
"provider_name": "openai",
"provider_type": ProviderType.SYSTEM,
"quota_type": ProviderQuotaType.FREE,
"quota_limit": 100,
"quota_used": 10,
"is_valid": True,
},
{
"id": "other-tenant",
"tenant_id": "other-tenant-id",
"provider_name": "openai",
"provider_type": ProviderType.SYSTEM,
"quota_type": ProviderQuotaType.FREE,
"quota_limit": 100,
"quota_used": 20,
"is_valid": True,
},
{
"id": "other-provider",
"tenant_id": "tenant-id",
"provider_name": "anthropic",
"provider_type": ProviderType.SYSTEM,
"quota_type": ProviderQuotaType.FREE,
"quota_limit": 100,
"quota_used": 30,
"is_valid": True,
},
{
"id": "custom-provider",
"tenant_id": "tenant-id",
"provider_name": "openai",
"provider_type": ProviderType.CUSTOM,
"quota_type": ProviderQuotaType.FREE,
"quota_limit": 100,
"quota_used": 40,
"is_valid": True,
},
],
)
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch("core.app.llm.quota.db", SimpleNamespace(engine=engine)),
):
deduct_llm_quota_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
usage=usage,
)
with engine.connect() as connection:
quota_used_by_id = dict(connection.execute(select(Provider.id, Provider.quota_used)).all())
assert quota_used_by_id == {
"matching-provider": 13,
"other-tenant": 20,
"other-provider": 30,
"custom-provider": 40,
}
with engine.begin() as connection:
connection.execute(
Provider.__table__.update().where(Provider.id == "matching-provider").values(quota_limit=13, quota_used=13)
)
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch("core.app.llm.quota.db", SimpleNamespace(engine=engine)),
pytest.raises(QuotaExceededError, match="Model provider openai quota exceeded."),
):
deduct_llm_quota_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
usage=usage,
)
with engine.connect() as connection:
exhausted_quota_used = connection.scalar(select(Provider.quota_used).where(Provider.id == "matching-provider"))
assert exhausted_quota_used == 13
def test_deduct_llm_quota_for_model_caps_free_quota_and_raises_when_usage_exceeds_remaining() -> None:
usage = LLMUsage.empty_usage()
usage.total_tokens = 3
provider_configuration = SimpleNamespace(
using_provider_type=ProviderType.SYSTEM,
system_configuration=SimpleNamespace(
current_quota_type=ProviderQuotaType.FREE,
quota_configurations=[
SimpleNamespace(
quota_type=ProviderQuotaType.FREE,
quota_unit=QuotaUnit.TOKENS,
quota_limit=100,
)
],
),
)
provider_manager = MagicMock()
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
engine = create_engine("sqlite:///:memory:")
Provider.__table__.create(engine)
with engine.begin() as connection:
connection.execute(
Provider.__table__.insert(),
{
"id": "matching-provider",
"tenant_id": "tenant-id",
"provider_name": "openai",
"provider_type": ProviderType.SYSTEM,
"quota_type": ProviderQuotaType.FREE,
"quota_limit": 15,
"quota_used": 13,
"is_valid": True,
},
)
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch("core.app.llm.quota.db", SimpleNamespace(engine=engine)),
pytest.raises(QuotaExceededError, match="Model provider openai quota exceeded."),
):
deduct_llm_quota_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
usage=usage,
)
with engine.connect() as connection:
quota_used = connection.scalar(select(Provider.quota_used).where(Provider.id == "matching-provider"))
assert quota_used == 15
def test_deduct_llm_quota_for_model_ignores_unknown_quota_type() -> None:
usage = LLMUsage.empty_usage()
usage.total_tokens = 2
provider_configuration = SimpleNamespace(
using_provider_type=ProviderType.SYSTEM,
system_configuration=SimpleNamespace(
current_quota_type="unexpected",
quota_configurations=[
SimpleNamespace(
quota_type="unexpected",
quota_unit=QuotaUnit.TOKENS,
quota_limit=100,
)
],
),
)
provider_manager = MagicMock()
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
patch("core.app.llm.quota.sessionmaker") as mock_sessionmaker,
):
deduct_llm_quota_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
usage=usage,
)
mock_deduct_credits.assert_not_called()
mock_sessionmaker.assert_not_called()
def test_deduct_llm_quota_for_model_ignores_custom_provider_configuration() -> None:
usage = LLMUsage.empty_usage()
usage.total_tokens = 2
provider_configuration = SimpleNamespace(
using_provider_type=ProviderType.CUSTOM,
system_configuration=SimpleNamespace(
current_quota_type=ProviderQuotaType.TRIAL,
quota_configurations=[],
),
)
provider_manager = MagicMock()
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
patch("core.app.llm.quota.sessionmaker") as mock_sessionmaker,
):
deduct_llm_quota_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
usage=usage,
)
mock_deduct_credits.assert_not_called()
mock_sessionmaker.assert_not_called()
def test_ensure_llm_quota_available_wrapper_warns_and_delegates() -> None:
model_instance = SimpleNamespace(
provider="openai",
model_name="gpt-4o",
provider_model_bundle=SimpleNamespace(configuration=SimpleNamespace(tenant_id="tenant-id")),
model_type_instance=SimpleNamespace(model_type=ModelType.LLM),
)
with (
pytest.deprecated_call(match="ensure_llm_quota_available\\(model_instance=.*deprecated"),
patch("core.app.llm.quota.ensure_llm_quota_available_for_model") as mock_ensure,
):
ensure_llm_quota_available(model_instance=model_instance)
mock_ensure.assert_called_once_with(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
)
def test_ensure_llm_quota_available_wrapper_rejects_non_llm_model_instances() -> None:
model_instance = SimpleNamespace(
provider="openai",
model_name="gpt-4o",
provider_model_bundle=SimpleNamespace(configuration=SimpleNamespace(tenant_id="tenant-id")),
model_type_instance=SimpleNamespace(model_type=ModelType.TEXT_EMBEDDING),
)
with (
pytest.deprecated_call(match="ensure_llm_quota_available\\(model_instance=.*deprecated"),
pytest.raises(ValueError, match="only support LLM model instances"),
):
ensure_llm_quota_available(model_instance=model_instance)
def test_deduct_llm_quota_wrapper_warns_and_delegates() -> None:
usage = LLMUsage.empty_usage()
usage.total_tokens = 7
model_instance = SimpleNamespace(
provider="openai",
model_name="gpt-4o",
model_type_instance=SimpleNamespace(model_type=ModelType.LLM),
provider_model_bundle=SimpleNamespace(configuration=SimpleNamespace()),
)
with (
pytest.deprecated_call(match="deduct_llm_quota\\(tenant_id=.*deprecated"),
patch("core.app.llm.quota.deduct_llm_quota_for_model") as mock_deduct,
):
deduct_llm_quota(
tenant_id="tenant-id",
model_instance=model_instance,
usage=usage,
)
mock_deduct.assert_called_once_with(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
usage=usage,
)
def test_deduct_llm_quota_wrapper_rejects_non_llm_model_instances() -> None:
usage = LLMUsage.empty_usage()
model_instance = SimpleNamespace(
provider="openai",
model_name="gpt-4o",
model_type_instance=SimpleNamespace(model_type=ModelType.TEXT_EMBEDDING),
provider_model_bundle=SimpleNamespace(configuration=SimpleNamespace()),
)
with (
pytest.deprecated_call(match="deduct_llm_quota\\(tenant_id=.*deprecated"),
pytest.raises(ValueError, match="only support LLM model instances"),
):
deduct_llm_quota(
tenant_id="tenant-id",
model_instance=model_instance,
usage=usage,
)

View File

@@ -8,9 +8,9 @@ from graphon.enums import BuiltinNodeTypes
class DummyNode:
def __init__(self, *, node_id, config, graph_init_params, graph_runtime_state, **kwargs):
def __init__(self, *, node_id, data, graph_init_params, graph_runtime_state, **kwargs):
self.id = node_id
self.config = config
self.data = data
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
self.kwargs = kwargs

View File

@@ -60,7 +60,10 @@ def _make_layer(
workflow_execution_id="run-id",
conversation_id="conv-id",
)
runtime_state = GraphRuntimeState(variable_pool=VariablePool(system_variables=system_variables), start_at=0.0)
runtime_state = GraphRuntimeState(
variable_pool=VariablePool.from_bootstrap(system_variables=system_variables),
start_at=0.0,
)
read_only_state = ReadOnlyGraphRuntimeStateWrapper(runtime_state)
application_generate_entity = WorkflowAppGenerateEntity.model_construct(

View File

@@ -354,7 +354,8 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None:
with _patched_session(mock_session):
with patch(
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
"core.entities.provider_configuration.create_plugin_model_assembly",
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
):
with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"):
with patch(
@@ -379,7 +380,10 @@ def test_validate_provider_credentials_without_credential_id() -> None:
mock_factory = Mock()
mock_factory.provider_credentials_validate.return_value = {"region": "us"}
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
with patch(
"core.entities.provider_configuration.create_plugin_model_assembly",
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
):
validated = configuration.validate_provider_credentials(credentials={"region": "us"})
assert validated == {"region": "us"}
@@ -426,23 +430,37 @@ def test_switch_preferred_provider_type_creates_record_when_missing() -> None:
def test_get_model_type_instance_and_schema_delegate_to_factory() -> None:
configuration = _build_provider_configuration()
mock_factory = Mock()
mock_model_type_instance = Mock()
mock_schema = _build_ai_model("gpt-4o")
mock_factory.get_model_type_instance.return_value = mock_model_type_instance
mock_factory = Mock()
mock_factory.get_provider_schema.return_value = configuration.provider
mock_factory.get_model_schema.return_value = mock_schema
mock_assembly = Mock()
mock_assembly.model_runtime = Mock()
mock_assembly.model_provider_factory = mock_factory
with patch(
"core.entities.provider_configuration.create_plugin_model_provider_factory",
return_value=mock_factory,
) as mock_factory_builder:
with (
patch(
"core.entities.provider_configuration.create_plugin_model_assembly",
return_value=mock_assembly,
) as mock_assembly_builder,
patch(
"core.entities.provider_configuration.create_model_type_instance",
return_value=mock_model_type_instance,
) as mock_model_builder,
):
model_type_instance = configuration.get_model_type_instance(ModelType.LLM)
model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"})
assert model_type_instance is mock_model_type_instance
assert model_schema is mock_schema
assert mock_factory_builder.call_count == 2
mock_factory.get_model_type_instance.assert_called_once_with(provider="openai", model_type=ModelType.LLM)
assert mock_assembly_builder.call_count == 2
mock_factory.get_provider_schema.assert_called_once_with(provider="openai")
mock_model_builder.assert_called_once_with(
runtime=mock_assembly.model_runtime,
provider_schema=configuration.provider,
model_type=ModelType.LLM,
)
mock_factory.get_model_schema.assert_called_once_with(
provider="openai",
model_type=ModelType.LLM,
@@ -456,17 +474,21 @@ def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> Non
bound_runtime = Mock()
configuration.bind_model_runtime(bound_runtime)
mock_factory = Mock()
mock_model_type_instance = Mock()
mock_schema = _build_ai_model("gpt-4o")
mock_factory.get_model_type_instance.return_value = mock_model_type_instance
mock_factory = Mock()
mock_factory.get_provider_schema.return_value = configuration.provider
mock_factory.get_model_schema.return_value = mock_schema
with (
patch(
"core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory
) as mock_factory_cls,
patch("core.entities.provider_configuration.create_plugin_model_provider_factory") as mock_factory_builder,
patch("core.entities.provider_configuration.create_plugin_model_assembly") as mock_assembly_builder,
patch(
"core.entities.provider_configuration.create_model_type_instance",
return_value=mock_model_type_instance,
) as mock_model_builder,
):
model_type_instance = configuration.get_model_type_instance(ModelType.LLM)
model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"})
@@ -474,8 +496,14 @@ def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> Non
assert model_type_instance is mock_model_type_instance
assert model_schema is mock_schema
assert mock_factory_cls.call_count == 2
mock_factory_cls.assert_called_with(model_runtime=bound_runtime)
mock_factory_builder.assert_not_called()
mock_factory_cls.assert_called_with(runtime=bound_runtime)
mock_assembly_builder.assert_not_called()
mock_factory.get_provider_schema.assert_called_once_with(provider="openai")
mock_model_builder.assert_called_once_with(
runtime=bound_runtime,
provider_schema=configuration.provider,
model_type=ModelType.LLM,
)
def test_get_provider_model_returns_none_when_model_not_found() -> None:
@@ -504,7 +532,10 @@ def test_get_provider_models_system_deduplicates_sorts_and_filters_active() -> N
mock_factory = Mock()
mock_factory.get_provider_schema.return_value = provider_schema
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
with patch(
"core.entities.provider_configuration.create_plugin_model_assembly",
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
):
all_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=False)
active_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=True)
@@ -722,7 +753,8 @@ def test_validate_provider_credentials_handles_invalid_original_json() -> None:
with _patched_session(mock_session):
with patch(
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
"core.entities.provider_configuration.create_plugin_model_assembly",
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
):
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"):
validated = configuration.validate_provider_credentials(
@@ -1069,7 +1101,8 @@ def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless
with _patched_session(mock_session):
with patch(
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
"core.entities.provider_configuration.create_plugin_model_assembly",
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
):
with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"):
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
@@ -1083,7 +1116,10 @@ def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless
mock_factory2 = Mock()
mock_factory2.model_credentials_validate.return_value = {"region": "us"}
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory2):
with patch(
"core.entities.provider_configuration.create_plugin_model_assembly",
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory2),
):
validated = configuration.validate_custom_model_credentials(
model_type=ModelType.LLM,
model="gpt-4o",
@@ -1575,7 +1611,8 @@ def test_validate_provider_credentials_uses_empty_original_when_record_missing()
with _patched_session(mock_session):
with patch(
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
"core.entities.provider_configuration.create_plugin_model_assembly",
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
):
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
validated = configuration.validate_provider_credentials(
@@ -1701,7 +1738,8 @@ def test_validate_custom_model_credentials_handles_invalid_original_json() -> No
with _patched_session(mock_session):
with patch(
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
"core.entities.provider_configuration.create_plugin_model_assembly",
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
):
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
validated = configuration.validate_custom_model_credentials(

View File

@@ -68,8 +68,8 @@ def test_check_moderation_returns_true_when_model_accepts_text(mocker: MockerFix
mocker.patch("core.helper.moderation.secrets.choice", return_value="chunk")
moderation_model = SimpleNamespace(invoke=lambda **invoke_kwargs: invoke_kwargs["text"] == "chunk")
factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: moderation_model)
mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory)
assembly = SimpleNamespace(create_model_type_instance=lambda **_factory_kwargs: moderation_model)
mocker.patch("core.helper.moderation.create_plugin_model_assembly", return_value=assembly)
assert (
check_moderation(
@@ -91,7 +91,7 @@ def test_check_moderation_returns_true_when_text_is_empty(mocker: MockerFixture)
provider_map={openai_provider: hosting_openai},
),
)
factory_mock = mocker.patch("core.helper.moderation.create_plugin_model_provider_factory")
factory_mock = mocker.patch("core.helper.moderation.create_plugin_model_assembly")
choice_mock = mocker.patch("core.helper.moderation.secrets.choice")
assert (
@@ -119,8 +119,8 @@ def test_check_moderation_returns_false_when_model_rejects_text(mocker: MockerFi
mocker.patch("core.helper.moderation.secrets.choice", return_value="chunk")
moderation_model = SimpleNamespace(invoke=lambda **_invoke_kwargs: False)
factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: moderation_model)
mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory)
assembly = SimpleNamespace(create_model_type_instance=lambda **_factory_kwargs: moderation_model)
mocker.patch("core.helper.moderation.create_plugin_model_assembly", return_value=assembly)
assert (
check_moderation(
@@ -147,8 +147,8 @@ def test_check_moderation_raises_bad_request_when_provider_call_fails(mocker: Mo
failing_model = SimpleNamespace(
invoke=lambda **_invoke_kwargs: (_ for _ in ()).throw(RuntimeError("boom")),
)
factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: failing_model)
mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory)
assembly = SimpleNamespace(create_model_type_instance=lambda **_factory_kwargs: failing_model)
mocker.patch("core.helper.moderation.create_plugin_model_assembly", return_value=assembly)
with pytest.raises(InvokeBadRequestError, match="Rate limit exceeded, please try again later."):
check_moderation(

View File

@@ -2,6 +2,7 @@ from unittest.mock import Mock
import pytest
from core.plugin.impl.model_runtime_factory import create_model_type_instance
from graphon.model_runtime.entities.common_entities import I18nObject
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from graphon.model_runtime.entities.provider_entities import (
@@ -73,7 +74,7 @@ def test_model_provider_factory_resolves_runtime_provider_name() -> None:
supported_model_types=[ModelType.LLM],
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
)
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime([provider]))
factory = ModelProviderFactory(runtime=_FakeModelRuntime([provider]))
provider_schema = factory.get_model_provider("openai")
@@ -98,7 +99,7 @@ def test_model_provider_factory_resolves_canonical_short_name_independent_of_pro
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
),
]
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
provider_schema = factory.get_model_provider("openai")
@@ -107,8 +108,8 @@ def test_model_provider_factory_resolves_canonical_short_name_independent_of_pro
def test_model_provider_factory_requires_runtime() -> None:
with pytest.raises(ValueError, match="model_runtime is required"):
ModelProviderFactory(model_runtime=None) # type: ignore[arg-type]
with pytest.raises(ValueError, match="runtime is required"):
ModelProviderFactory(runtime=None) # type: ignore[arg-type]
def test_model_provider_factory_get_providers_returns_runtime_providers() -> None:
@@ -119,7 +120,7 @@ def test_model_provider_factory_get_providers_returns_runtime_providers() -> Non
supported_model_types=[ModelType.LLM],
)
]
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
result = factory.get_providers()
@@ -133,7 +134,7 @@ def test_model_provider_factory_get_provider_schema_delegates_to_provider_lookup
provider_name="openai",
supported_model_types=[ModelType.LLM],
)
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime([provider]))
factory = ModelProviderFactory(runtime=_FakeModelRuntime([provider]))
result = factory.get_provider_schema("openai")
@@ -142,7 +143,7 @@ def test_model_provider_factory_get_provider_schema_delegates_to_provider_lookup
def test_model_provider_factory_raises_for_unknown_provider() -> None:
factory = ModelProviderFactory(
model_runtime=_FakeModelRuntime(
runtime=_FakeModelRuntime(
[
_build_provider(
provider="langgenius/openai/openai",
@@ -172,7 +173,7 @@ def test_model_provider_factory_get_models_filters_provider_and_model_type() ->
models=[_build_model("rerank-v3", ModelType.RERANK)],
),
]
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
results = factory.get_models(provider="openai", model_type=ModelType.LLM)
@@ -196,7 +197,7 @@ def test_model_provider_factory_get_models_skips_providers_without_requested_mod
models=[_build_model("eleven_multilingual_v2", ModelType.TTS)],
),
]
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
results = factory.get_models(model_type=ModelType.TTS)
@@ -214,7 +215,7 @@ def test_model_provider_factory_get_models_without_model_type_keeps_all_provider
models=[_build_model("gpt-4o-mini", ModelType.LLM), _build_model("tts-1", ModelType.TTS)],
)
]
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
results = factory.get_models(provider="openai")
@@ -242,7 +243,7 @@ def test_model_provider_factory_validates_provider_credentials() -> None:
)
]
)
factory = ModelProviderFactory(model_runtime=runtime)
factory = ModelProviderFactory(runtime=runtime)
filtered = factory.provider_credentials_validate(
provider="openai",
@@ -258,7 +259,7 @@ def test_model_provider_factory_validates_provider_credentials() -> None:
def test_model_provider_factory_provider_credentials_validate_requires_schema() -> None:
factory = ModelProviderFactory(
model_runtime=_FakeModelRuntime(
runtime=_FakeModelRuntime(
[
_build_provider(
provider="langgenius/openai/openai",
@@ -294,7 +295,7 @@ def test_model_provider_factory_validates_model_credentials() -> None:
)
]
)
factory = ModelProviderFactory(model_runtime=runtime)
factory = ModelProviderFactory(runtime=runtime)
filtered = factory.model_credentials_validate(
provider="openai",
@@ -314,7 +315,7 @@ def test_model_provider_factory_validates_model_credentials() -> None:
def test_model_provider_factory_model_credentials_validate_requires_schema() -> None:
factory = ModelProviderFactory(
model_runtime=_FakeModelRuntime(
runtime=_FakeModelRuntime(
[
_build_provider(
provider="langgenius/openai/openai",
@@ -346,7 +347,7 @@ def test_model_provider_factory_get_model_schema_and_icon_use_canonical_provider
)
runtime.get_model_schema.return_value = "schema"
runtime.get_provider_icon.return_value = (b"icon", "image/png")
factory = ModelProviderFactory(model_runtime=runtime)
factory = ModelProviderFactory(runtime=runtime)
assert (
factory.get_model_schema(
@@ -382,39 +383,43 @@ def test_model_provider_factory_get_model_schema_and_icon_use_canonical_provider
(ModelType.TTS, TTSModel),
],
)
def test_model_provider_factory_builds_model_type_instances(
def test_create_model_type_instance_builds_model_wrappers(
model_type: ModelType,
expected_type: type[object],
) -> None:
factory = ModelProviderFactory(
model_runtime=_FakeModelRuntime(
[
_build_provider(
provider="langgenius/openai/openai",
provider_name="openai",
supported_model_types=[model_type],
)
]
)
runtime = _FakeModelRuntime(
[
_build_provider(
provider="langgenius/openai/openai",
provider_name="openai",
supported_model_types=[model_type],
)
]
)
instance = factory.get_model_type_instance("openai", model_type)
instance = create_model_type_instance(
runtime=runtime,
provider_schema=runtime.fetch_model_providers()[0],
model_type=model_type,
)
assert isinstance(instance, expected_type)
def test_model_provider_factory_rejects_unsupported_model_type() -> None:
factory = ModelProviderFactory(
model_runtime=_FakeModelRuntime(
[
_build_provider(
provider="langgenius/openai/openai",
provider_name="openai",
supported_model_types=[ModelType.LLM],
)
]
)
def test_create_model_type_instance_rejects_unsupported_model_type() -> None:
runtime = _FakeModelRuntime(
[
_build_provider(
provider="langgenius/openai/openai",
provider_name="openai",
supported_model_types=[ModelType.LLM],
)
]
)
with pytest.raises(ValueError, match="Unsupported model type: unsupported"):
factory.get_model_type_instance("openai", "unsupported") # type: ignore[arg-type]
create_model_type_instance(
runtime=runtime,
provider_schema=runtime.fetch_model_providers()[0],
model_type="unsupported", # type: ignore[arg-type]
)

View File

@@ -31,6 +31,6 @@ def test_plugin_model_assembly_reuses_single_runtime_across_views():
assert assembly.model_manager is model_manager
mock_runtime_factory.assert_called_once_with(tenant_id="tenant-1", user_id="user-1")
mock_provider_factory_cls.assert_called_once_with(model_runtime=runtime)
mock_provider_factory_cls.assert_called_once_with(runtime=runtime)
mock_provider_manager_cls.assert_called_once_with(model_runtime=runtime)
mock_model_manager_cls.assert_called_once_with(provider_manager=provider_manager)

View File

@@ -3,7 +3,7 @@
import datetime
import uuid
from types import SimpleNamespace
from unittest.mock import Mock, sentinel
from unittest.mock import Mock, patch, sentinel
import pytest
@@ -13,6 +13,8 @@ from core.plugin.impl.model import PluginModelClient
from core.plugin.impl.model_runtime import TENANT_SCOPE_SCHEMA_CACHE_USER_ID, PluginModelRuntime
from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime
from graphon.model_runtime.entities.common_entities import I18nObject
from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta, LLMUsage
from graphon.model_runtime.entities.message_entities import AssistantPromptMessage
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
@@ -146,7 +148,31 @@ class TestPluginModelRuntime:
def test_invoke_llm_resolves_plugin_fields(self) -> None:
client = Mock(spec=PluginModelClient)
client.invoke_llm.return_value = sentinel.result
usage = LLMUsage.empty_usage()
client.invoke_llm.return_value = iter(
[
LLMResultChunk(
model="gpt-4o-mini",
prompt_messages=[],
system_fingerprint="fp-plugin",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content="plugin "),
),
),
LLMResultChunk(
model="gpt-4o-mini",
prompt_messages=[],
system_fingerprint="fp-plugin",
delta=LLMResultChunkDelta(
index=1,
message=AssistantPromptMessage(content="response"),
usage=usage,
finish_reason="stop",
),
),
]
)
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
result = runtime.invoke_llm(
@@ -160,7 +186,11 @@ class TestPluginModelRuntime:
stream=False,
)
assert result is sentinel.result
assert result.model == "gpt-4o-mini"
assert result.prompt_messages == []
assert result.message.content == "plugin response"
assert result.usage == usage
assert result.system_fingerprint == "fp-plugin"
client.invoke_llm.assert_called_once_with(
tenant_id="tenant",
user_id="user",
@@ -175,6 +205,38 @@ class TestPluginModelRuntime:
stream=False,
)
def test_invoke_llm_returns_plugin_stream_directly(self) -> None:
client = Mock(spec=PluginModelClient)
stream_result = iter([])
client.invoke_llm.return_value = stream_result
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
result = runtime.invoke_llm(
provider="langgenius/openai/openai",
model="gpt-4o-mini",
credentials={"api_key": "secret"},
model_parameters={"temperature": 0.3},
prompt_messages=[],
tools=None,
stop=("END",),
stream=True,
)
assert result is stream_result
client.invoke_llm.assert_called_once_with(
tenant_id="tenant",
user_id="user",
plugin_id="langgenius/openai",
provider="openai",
model="gpt-4o-mini",
credentials={"api_key": "secret"},
model_parameters={"temperature": 0.3},
prompt_messages=[],
tools=None,
stop=["END"],
stream=True,
)
def test_invoke_llm_rejects_per_call_user_override(self) -> None:
client = Mock(spec=PluginModelClient)
client.invoke_llm.return_value = sentinel.result
@@ -267,6 +329,129 @@ def test_get_model_schema_uses_cached_schema_without_hitting_client(monkeypatch:
client.get_model_schema.assert_not_called()
def test_structured_output_adapter_invokes_bound_runtime_streaming() -> None:
runtime = Mock()
runtime.invoke_llm.return_value = sentinel.stream_result
adapter = model_runtime_module._PluginStructuredOutputModelInstance(
runtime=runtime,
provider="langgenius/openai/openai",
model="gpt-4o-mini",
credentials={"api_key": "secret"},
)
tool = Mock()
result = adapter.invoke_llm(
prompt_messages=[],
model_parameters=None,
tools=[tool],
stop=["END"],
stream=True,
callbacks=sentinel.callbacks,
)
assert result is sentinel.stream_result
runtime.invoke_llm.assert_called_once_with(
provider="langgenius/openai/openai",
model="gpt-4o-mini",
credentials={"api_key": "secret"},
model_parameters={},
prompt_messages=[],
tools=[tool],
stop=["END"],
stream=True,
)
def test_structured_output_adapter_invokes_bound_runtime_non_streaming() -> None:
runtime = Mock()
runtime.invoke_llm.return_value = sentinel.result
adapter = model_runtime_module._PluginStructuredOutputModelInstance(
runtime=runtime,
provider="langgenius/openai/openai",
model="gpt-4o-mini",
credentials={"api_key": "secret"},
)
result = adapter.invoke_llm(
prompt_messages=[],
model_parameters={"temperature": 0},
tools=None,
stop=None,
stream=False,
)
assert result is sentinel.result
runtime.invoke_llm.assert_called_once_with(
provider="langgenius/openai/openai",
model="gpt-4o-mini",
credentials={"api_key": "secret"},
model_parameters={"temperature": 0},
prompt_messages=[],
tools=None,
stop=None,
stream=False,
)
def test_invoke_llm_with_structured_output_delegates_with_bound_adapter() -> None:
client = Mock(spec=PluginModelClient)
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
schema = _build_model_schema()
runtime.get_model_schema = Mock(return_value=schema) # type: ignore[method-assign]
with patch.object(
model_runtime_module,
"invoke_llm_with_structured_output_helper",
return_value=sentinel.structured_result,
) as mock_helper:
result = runtime.invoke_llm_with_structured_output(
provider="langgenius/openai/openai",
model="gpt-4o-mini",
credentials={"api_key": "secret"},
json_schema={"type": "object"},
model_parameters={"temperature": 0},
prompt_messages=[],
stop=("END",),
stream=False,
)
assert result is sentinel.structured_result
runtime.get_model_schema.assert_called_once_with(
provider="langgenius/openai/openai",
model_type=ModelType.LLM,
model="gpt-4o-mini",
credentials={"api_key": "secret"},
)
helper_kwargs = mock_helper.call_args.kwargs
assert helper_kwargs["provider"] == "langgenius/openai/openai"
assert helper_kwargs["model_schema"] == schema
assert helper_kwargs["json_schema"] == {"type": "object"}
assert helper_kwargs["model_parameters"] == {"temperature": 0}
assert helper_kwargs["prompt_messages"] == []
assert helper_kwargs["tools"] is None
assert helper_kwargs["stop"] == ["END"]
assert helper_kwargs["stream"] is False
assert isinstance(helper_kwargs["model_instance"], model_runtime_module._PluginStructuredOutputModelInstance)
def test_invoke_llm_with_structured_output_raises_when_model_schema_is_missing() -> None:
client = Mock(spec=PluginModelClient)
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
runtime.get_model_schema = Mock(return_value=None) # type: ignore[method-assign]
with pytest.raises(ValueError, match="Model schema not found for gpt-4o-mini"):
runtime.invoke_llm_with_structured_output(
provider="langgenius/openai/openai",
model="gpt-4o-mini",
credentials={"api_key": "secret"},
json_schema={"type": "object"},
model_parameters={},
prompt_messages=[],
stop=None,
stream=False,
)
def test_get_model_schema_deletes_invalid_cache_and_refetches(monkeypatch: pytest.MonkeyPatch) -> None:
client = Mock(spec=PluginModelClient)
schema = _build_model_schema()

View File

@@ -289,7 +289,7 @@ def test_get_default_model_uses_injected_runtime_for_existing_default_record(moc
result = manager.get_default_model("tenant-id", ModelType.LLM)
mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime)
mock_factory_cls.assert_called_once_with(runtime=manager._model_runtime)
assert result is not None
assert result.model == "gpt-4"
assert result.provider.provider == "openai"
@@ -316,7 +316,7 @@ def test_get_configurations_uses_injected_runtime_and_adds_provider_aliases(mock
result = manager.get_configurations("tenant-id")
expected_alias = str(ModelProviderID("openai"))
mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime)
mock_factory_cls.assert_called_once_with(runtime=manager._model_runtime)
assert result.tenant_id == "tenant-id"
assert expected_alias in provider_records
assert expected_alias in provider_model_records
@@ -402,7 +402,7 @@ def test_get_configurations_reuses_cached_result_for_same_tenant(mocker: MockerF
assert first is second
mock_get_all_providers.assert_called_once_with("tenant-id")
mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime)
mock_factory_cls.assert_called_once_with(runtime=manager._model_runtime)
mock_provider_configuration.assert_called_once()
provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime)

View File

@@ -1,12 +1,11 @@
import logging
import threading
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.errors.error import QuotaExceededError
from core.model_manager import ModelInstance
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from graphon.graph_engine.entities.commands import CommandType
from graphon.graph_events import NodeRunSucceededEvent
@@ -14,17 +13,7 @@ from graphon.model_runtime.entities.llm_entities import LLMUsage
from graphon.node_events import NodeRunResult
def _build_dify_context() -> DifyRunContext:
return DifyRunContext(
tenant_id="tenant-id",
app_id="app-id",
user_id="user-id",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
)
def _build_succeeded_event() -> NodeRunSucceededEvent:
def _build_succeeded_event(*, provider: str = "openai", model_name: str = "gpt-4o") -> NodeRunSucceededEvent:
return NodeRunSucceededEvent(
id="execution-id",
node_id="llm-node-id",
@@ -32,113 +21,162 @@ def _build_succeeded_event() -> NodeRunSucceededEvent:
start_at=datetime.now(),
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={"question": "hello"},
inputs={
"question": "hello",
"model_provider": provider,
"model_name": model_name,
},
llm_usage=LLMUsage.empty_usage(),
),
)
def _build_wrapped_model_instance() -> tuple[SimpleNamespace, ModelInstance]:
raw_model_instance = ModelInstance.__new__(ModelInstance)
return SimpleNamespace(_model_instance=raw_model_instance), raw_model_instance
def _build_public_model_identity(*, provider: str = "openai", model_name: str = "gpt-4o") -> SimpleNamespace:
return SimpleNamespace(provider=provider, name=model_name)
def _build_node_data(*, model: SimpleNamespace | None = None) -> SimpleNamespace:
return SimpleNamespace(
error_strategy=None,
retry_config=SimpleNamespace(retry_enabled=False),
model=model,
)
def _build_node(*, node_type: BuiltinNodeTypes = BuiltinNodeTypes.LLM) -> MagicMock:
node = MagicMock()
node.id = "node-id"
node.execution_id = "execution-id"
node.node_type = node_type
node.node_data = _build_node_data(model=_build_public_model_identity())
node.model_instance = SimpleNamespace(provider="stale-provider", model_name="stale-model")
return node
class _RunnableQuotaNode:
id = "node-id"
execution_id = "execution-id"
node_type = BuiltinNodeTypes.LLM
title = "LLM node"
def __init__(self, *, stop_event: threading.Event, node_data: SimpleNamespace | None = None) -> None:
self.node_data = node_data or _build_node_data(model=_build_public_model_identity())
self.graph_runtime_state = SimpleNamespace(stop_event=stop_event)
self.original_run_called = False
def _run(self) -> NodeRunResult:
self.original_run_called = True
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED)
def test_deduct_quota_called_for_successful_llm_node() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "llm-node-id"
node.execution_id = "execution-id"
node.node_type = BuiltinNodeTypes.LLM
node.tenant_id = "tenant-id"
node.require_run_context_value.return_value = _build_dify_context()
node.model_instance, raw_model_instance = _build_wrapped_model_instance()
layer = LLMQuotaLayer(tenant_id="tenant-id")
node = _build_node(node_type=BuiltinNodeTypes.LLM)
result_event = _build_succeeded_event()
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", autospec=True) as mock_deduct:
layer.on_node_run_end(node=node, error=None, result_event=result_event)
mock_deduct.assert_called_once_with(
tenant_id="tenant-id",
model_instance=raw_model_instance,
provider="openai",
model="gpt-4o",
usage=result_event.node_run_result.llm_usage,
)
def test_deduct_quota_called_for_question_classifier_node() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "question-classifier-node-id"
node.execution_id = "execution-id"
node.node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER
node.tenant_id = "tenant-id"
node.require_run_context_value.return_value = _build_dify_context()
node.model_instance, raw_model_instance = _build_wrapped_model_instance()
layer = LLMQuotaLayer(tenant_id="tenant-id")
node = _build_node(node_type=BuiltinNodeTypes.QUESTION_CLASSIFIER)
result_event = _build_succeeded_event(provider="anthropic", model_name="claude-3-7-sonnet")
result_event = _build_succeeded_event()
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", autospec=True) as mock_deduct:
layer.on_node_run_end(node=node, error=None, result_event=result_event)
mock_deduct.assert_called_once_with(
tenant_id="tenant-id",
model_instance=raw_model_instance,
provider="anthropic",
model="claude-3-7-sonnet",
usage=result_event.node_run_result.llm_usage,
)
def test_non_llm_node_is_ignored() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "start-node-id"
node.execution_id = "execution-id"
node.node_type = BuiltinNodeTypes.START
node.tenant_id = "tenant-id"
node.require_run_context_value.return_value = _build_dify_context()
node._model_instance = object()
layer = LLMQuotaLayer(tenant_id="tenant-id")
node = _build_node(node_type=BuiltinNodeTypes.START)
result_event = _build_succeeded_event()
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", autospec=True) as mock_deduct:
layer.on_node_run_end(node=node, error=None, result_event=result_event)
mock_deduct.assert_not_called()
def test_quota_error_is_handled_in_layer() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "llm-node-id"
node.execution_id = "execution-id"
node.node_type = BuiltinNodeTypes.LLM
node.tenant_id = "tenant-id"
node.require_run_context_value.return_value = _build_dify_context()
node.model_instance = object()
def test_precheck_ignores_non_quota_node() -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
node = _build_node(node_type=BuiltinNodeTypes.START)
result_event = _build_succeeded_event()
with patch(
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
autospec=True,
side_effect=ValueError("quota exceeded"),
):
layer.on_node_run_end(node=node, error=None, result_event=result_event)
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
layer.on_node_run_start(node)
mock_check.assert_not_called()
def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None:
layer = LLMQuotaLayer()
def test_quota_error_is_handled_in_layer(caplog) -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = MagicMock()
node.id = "llm-node-id"
node.execution_id = "execution-id"
node.node_type = BuiltinNodeTypes.LLM
node.tenant_id = "tenant-id"
node.require_run_context_value.return_value = _build_dify_context()
node.model_instance, _ = _build_wrapped_model_instance()
node = _build_node(node_type=BuiltinNodeTypes.LLM)
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
result_event = _build_succeeded_event()
with (
caplog.at_level(logging.ERROR, logger="core.app.workflow.layers.llm_quota"),
patch(
"core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model",
autospec=True,
side_effect=ValueError("quota exceeded"),
) as mock_deduct,
):
layer.on_node_run_end(node=node, error=None, result_event=result_event)
mock_deduct.assert_called_once_with(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
usage=result_event.node_run_result.llm_usage,
)
assert "LLM quota deduction failed, node_id=node-id" in caplog.text
assert not stop_event.is_set()
layer.command_channel.send_command.assert_not_called()
def test_send_abort_command_is_noop_without_channel_or_after_abort() -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
layer._send_abort_command(reason="no channel")
layer.command_channel = MagicMock()
layer._abort_sent = True
layer._send_abort_command(reason="already aborted")
layer.command_channel.send_command.assert_not_called()
def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = _build_node(node_type=BuiltinNodeTypes.LLM)
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
result_event = _build_succeeded_event()
with patch(
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
"core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model",
autospec=True,
side_effect=QuotaExceededError("No credits remaining"),
):
@@ -152,19 +190,16 @@ def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None:
def test_quota_precheck_failure_aborts_workflow_immediately() -> None:
layer = LLMQuotaLayer()
layer = LLMQuotaLayer(tenant_id="tenant-id")
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = MagicMock()
node.id = "llm-node-id"
node.node_type = BuiltinNodeTypes.LLM
node.model_instance, _ = _build_wrapped_model_instance()
node = _build_node(node_type=BuiltinNodeTypes.LLM)
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
with patch(
"core.app.workflow.layers.llm_quota.ensure_llm_quota_available",
"core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model",
autospec=True,
side_effect=QuotaExceededError("Model provider openai quota exceeded."),
):
@@ -177,21 +212,140 @@ def test_quota_precheck_failure_aborts_workflow_immediately() -> None:
assert abort_command.reason == "Model provider openai quota exceeded."
def test_quota_precheck_passes_without_abort() -> None:
layer = LLMQuotaLayer()
def test_quota_precheck_failure_blocks_current_node_run() -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = MagicMock()
node.id = "llm-node-id"
node.node_type = BuiltinNodeTypes.LLM
node.model_instance, raw_model_instance = _build_wrapped_model_instance()
node = _RunnableQuotaNode(stop_event=stop_event)
with patch(
"core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model",
autospec=True,
side_effect=QuotaExceededError("Model provider openai quota exceeded."),
):
layer.on_node_run_start(node)
result = node._run()
assert not node.original_run_called
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert result.error == "Model provider openai quota exceeded."
assert result.error_type == QuotaExceededError.__name__
def test_missing_model_identity_blocks_current_node_run() -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = _RunnableQuotaNode(stop_event=stop_event, node_data=_build_node_data())
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
layer.on_node_run_start(node)
result = node._run()
assert not node.original_run_called
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert result.error == "LLM quota check requires public node model identity before execution."
assert result.error_type == "LLMQuotaIdentityError"
mock_check.assert_not_called()
def test_quota_precheck_passes_without_abort() -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = _build_node(node_type=BuiltinNodeTypes.LLM)
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available", autospec=True) as mock_check:
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
layer.on_node_run_start(node)
assert not stop_event.is_set()
mock_check.assert_called_once_with(model_instance=raw_model_instance)
mock_check.assert_called_once_with(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
)
layer.command_channel.send_command.assert_not_called()
def test_precheck_reads_model_identity_from_data_when_node_data_is_absent() -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
node = SimpleNamespace(
id="node-id",
node_type=BuiltinNodeTypes.LLM,
data=_build_node_data(model=_build_public_model_identity(provider="anthropic", model_name="claude")),
)
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
layer.on_node_run_start(node)
mock_check.assert_called_once_with(
tenant_id="tenant-id",
provider="anthropic",
model="claude",
)
def test_precheck_rejects_invalid_public_model_identity() -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = _build_node(node_type=BuiltinNodeTypes.LLM)
node.node_data = _build_node_data(model=_build_public_model_identity(provider="", model_name="gpt-4o"))
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
layer.on_node_run_start(node)
assert stop_event.is_set()
mock_check.assert_not_called()
layer.command_channel.send_command.assert_called_once()
def test_precheck_requires_public_node_model_config() -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = _build_node(node_type=BuiltinNodeTypes.LLM)
node.node_data = _build_node_data()
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
layer.on_node_run_start(node)
assert stop_event.is_set()
mock_check.assert_not_called()
layer.command_channel.send_command.assert_called_once()
abort_command = layer.command_channel.send_command.call_args.args[0]
assert abort_command.command_type == CommandType.ABORT
assert abort_command.reason == "LLM quota check requires public node model identity before execution."
def test_deduction_requires_public_event_model_identity() -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = _build_node(node_type=BuiltinNodeTypes.LLM)
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
result_event = _build_succeeded_event()
result_event.node_run_result.inputs = {"question": "hello"}
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", autospec=True) as mock_deduct:
layer.on_node_run_end(node=node, error=None, result_event=result_event)
assert stop_event.is_set()
mock_deduct.assert_not_called()
layer.command_channel.send_command.assert_called_once()
abort_command = layer.command_channel.send_command.call_args.args[0]
assert abort_command.command_type == CommandType.ABORT
assert abort_command.reason == "LLM quota deduction requires model identity in the node result event."

View File

@@ -96,7 +96,7 @@ class MockNodeFactory(DifyNodeFactory):
if node_type == BuiltinNodeTypes.CODE:
mock_instance = mock_class(
node_id=node_id,
config=resolved_node_data,
data=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@@ -106,7 +106,7 @@ class MockNodeFactory(DifyNodeFactory):
elif node_type == BuiltinNodeTypes.HTTP_REQUEST:
mock_instance = mock_class(
node_id=node_id,
config=resolved_node_data,
data=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@@ -122,7 +122,7 @@ class MockNodeFactory(DifyNodeFactory):
}:
mock_instance = mock_class(
node_id=node_id,
config=resolved_node_data,
data=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@@ -132,7 +132,7 @@ class MockNodeFactory(DifyNodeFactory):
else:
mock_instance = mock_class(
node_id=node_id,
config=resolved_node_data,
data=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,

View File

@@ -56,7 +56,7 @@ class MockNodeMixin:
def __init__(
self,
node_id: str,
config: Any,
data: Any,
*,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
@@ -98,7 +98,7 @@ class MockNodeMixin:
super().__init__(
node_id=node_id,
config=config,
data=data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
**kwargs,

View File

@@ -111,7 +111,7 @@ class StaticRepo(HumanInputFormRepository):
def _build_runtime_state() -> GraphRuntimeState:
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
user_id="user",
app_id="app",
@@ -140,7 +140,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
start_node = StartNode(
node_id=start_config["id"],
config=StartNodeData(title="Start", variables=[]),
data=StartNodeData(title="Start", variables=[]),
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
@@ -155,7 +155,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
human_a_config = {"id": "human_a", "data": human_data.model_dump()}
human_a = HumanInputNode(
node_id=human_a_config["id"],
config=human_data,
data=human_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=repo,
@@ -165,7 +165,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
human_b_config = {"id": "human_b", "data": human_data.model_dump()}
human_b = HumanInputNode(
node_id=human_b_config["id"],
config=human_data,
data=human_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=repo,
@@ -183,7 +183,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
end_config = {"id": "end", "data": end_data.model_dump()}
end_node = EndNode(
node_id=end_config["id"],
config=end_data,
data=end_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)

View File

@@ -1,41 +1,36 @@
import time
import uuid
from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
from core.workflow.node_factory import DifyNodeFactory
from core.workflow.system_variables import build_system_variables
from extensions.ext_database import db
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.graph import Graph
from graphon.nodes.answer.answer_node import AnswerNode
from graphon.nodes.answer.entities import AnswerNodeData
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.workflow_test_utils import build_test_graph_init_params
def test_execute_answer():
def _build_variable_pool() -> VariablePool:
return VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
)
def _build_answer_node(*, answer: str, variable_pool: VariablePool) -> AnswerNode:
graph_config = {
"edges": [
{
"id": "start-source-answer-target",
"source": "start",
"target": "answer",
},
],
"edges": [],
"nodes": [
{"data": {"type": "start", "title": "Start"}, "id": "start"},
{
"data": {
"title": "123",
"title": "Answer",
"type": "answer",
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
"answer": answer,
},
"id": "answer",
},
}
],
}
init_params = build_test_graph_init_params(
workflow_id="1",
graph_config=graph_config,
@@ -46,42 +41,31 @@ def test_execute_answer():
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
# construct variable pool
variable_pool = VariablePool(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
conversation_variables=[],
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
start_at=time.perf_counter(),
)
variable_pool.add(["start", "weather"], "sunny")
variable_pool.add(["llm", "text"], "You are a helpful AI.")
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# create node factory
node_factory = DifyNodeFactory(
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
node = AnswerNode(
return AnswerNode(
node_id=str(uuid.uuid4()),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
config=AnswerNodeData(
title="123",
data=AnswerNodeData(
title="Answer",
type="answer",
answer="Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
answer=answer,
),
)
# Mock db.session.close()
db.session.close = MagicMock()
# execute node
def test_execute_answer_renders_variable_selectors() -> None:
variable_pool = _build_variable_pool()
variable_pool.add(["start", "weather"], "sunny")
variable_pool.add(["llm", "text"], "You are a helpful AI.")
node = _build_answer_node(
answer="Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
variable_pool=variable_pool,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
@@ -89,36 +73,11 @@ def test_execute_answer():
def test_execute_answer_renders_structured_output_object_as_json() -> None:
init_params = build_test_graph_init_params(
workflow_id="1",
graph_config={"nodes": [], "edges": []},
tenant_id="1",
app_id="1",
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
variable_pool = VariablePool(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
conversation_variables=[],
)
variable_pool = _build_variable_pool()
variable_pool.add(["1777539038857", "structured_output"], {"type": "greeting"})
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
node = AnswerNode(
node_id=str(uuid.uuid4()),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
config=AnswerNodeData(
title="123",
type="answer",
answer="{{#1777539038857.structured_output#}}",
),
node = _build_answer_node(
answer="{{#1777539038857.structured_output#}}",
variable_pool=variable_pool,
)
result = node._run()
@@ -128,35 +87,9 @@ def test_execute_answer_renders_structured_output_object_as_json() -> None:
def test_execute_answer_falls_back_to_plain_selector_text_when_structured_output_missing() -> None:
init_params = build_test_graph_init_params(
workflow_id="1",
graph_config={"nodes": [], "edges": []},
tenant_id="1",
app_id="1",
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
variable_pool = VariablePool(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
conversation_variables=[],
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
node = AnswerNode(
node_id=str(uuid.uuid4()),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
config=AnswerNodeData(
title="123",
type="answer",
answer="{{#1777539038857.structured_output#}}",
),
node = _build_answer_node(
answer="{{#1777539038857.structured_output#}}",
variable_pool=_build_variable_pool(),
)
result = node._run()

View File

@@ -81,7 +81,7 @@ def test_datasource_node_delegates_to_manager_stream(mocker: MockerFixture):
node = DatasourceNode(
node_id="n",
config=DatasourceNodeData(
data=DatasourceNodeData(
type="datasource",
version="1",
title="Datasource",

View File

@@ -29,7 +29,7 @@ HTTP_REQUEST_CONFIG = HttpRequestNodeConfig(
def test_executor_with_json_body_and_number_variable():
# Prepare the variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@@ -85,7 +85,7 @@ def test_executor_with_json_body_and_number_variable():
def test_executor_with_json_body_and_object_variable():
# Prepare the variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@@ -143,7 +143,7 @@ def test_executor_with_json_body_and_object_variable():
def test_executor_with_json_body_and_nested_object_variable():
# Prepare the variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@@ -201,7 +201,7 @@ def test_executor_with_json_body_and_nested_object_variable():
def test_extract_selectors_from_template_with_newline():
variable_pool = VariablePool(system_variables=default_system_variables())
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables())
variable_pool.add(("node_id", "custom_query"), "line1\nline2")
node_data = HttpRequestNodeData(
title="Test JSON Body with Nested Object Variable",
@@ -230,7 +230,7 @@ def test_extract_selectors_from_template_with_newline():
def test_executor_with_form_data():
# Prepare the variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@@ -320,7 +320,7 @@ def test_init_headers():
node_data=node_data,
timeout=timeout,
http_request_config=HTTP_REQUEST_CONFIG,
variable_pool=VariablePool(system_variables=default_system_variables()),
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
http_client=ssrf_proxy,
file_manager=file_manager,
)
@@ -357,7 +357,7 @@ def test_init_params():
node_data=node_data,
timeout=timeout,
http_request_config=HTTP_REQUEST_CONFIG,
variable_pool=VariablePool(system_variables=default_system_variables()),
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
http_client=ssrf_proxy,
file_manager=file_manager,
)
@@ -390,7 +390,7 @@ def test_init_params():
def test_empty_api_key_raises_error_bearer():
"""Test that empty API key raises AuthorizationConfigError for bearer auth."""
variable_pool = VariablePool(system_variables=default_system_variables())
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables())
node_data = HttpRequestNodeData(
title="test",
method="get",
@@ -417,7 +417,7 @@ def test_empty_api_key_raises_error_bearer():
def test_empty_api_key_raises_error_basic():
"""Test that empty API key raises AuthorizationConfigError for basic auth."""
variable_pool = VariablePool(system_variables=default_system_variables())
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables())
node_data = HttpRequestNodeData(
title="test",
method="get",
@@ -444,7 +444,7 @@ def test_empty_api_key_raises_error_basic():
def test_empty_api_key_raises_error_custom():
"""Test that empty API key raises AuthorizationConfigError for custom auth."""
variable_pool = VariablePool(system_variables=default_system_variables())
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables())
node_data = HttpRequestNodeData(
title="test",
method="get",
@@ -471,7 +471,7 @@ def test_empty_api_key_raises_error_custom():
def test_whitespace_only_api_key_raises_error():
"""Test that whitespace-only API key raises AuthorizationConfigError."""
variable_pool = VariablePool(system_variables=default_system_variables())
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables())
node_data = HttpRequestNodeData(
title="test",
method="get",
@@ -498,7 +498,7 @@ def test_whitespace_only_api_key_raises_error():
def test_valid_api_key_works():
"""Test that valid API key works correctly for bearer auth."""
variable_pool = VariablePool(system_variables=default_system_variables())
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables())
node_data = HttpRequestNodeData(
title="test",
method="get",
@@ -536,7 +536,7 @@ def test_executor_with_json_body_and_unquoted_uuid_variable():
# UUID that triggers the json_repair truncation bug
test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2"
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@@ -583,7 +583,7 @@ def test_executor_with_json_body_and_unquoted_uuid_with_newlines():
"""
test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2"
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@@ -624,7 +624,7 @@ def test_executor_with_json_body_and_unquoted_uuid_with_newlines():
def test_executor_with_json_body_preserves_numbers_and_strings():
"""Test that numbers are preserved and string values are properly quoted."""
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)

View File

@@ -110,12 +110,15 @@ def _build_http_node(
call_depth=0,
)
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="user", files=[]),
user_inputs={},
),
start_at=time.perf_counter(),
)
return HttpRequestNode(
node_id="http-node",
config=HttpRequestNodeData.model_validate(node_data),
data=HttpRequestNodeData.model_validate(node_data),
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
http_request_config=HTTP_REQUEST_CONFIG,

View File

@@ -149,7 +149,7 @@ def _build_human_input_node(
)
return HumanInputNode(
node_id=node_id,
config=typed_node_data,
data=typed_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
runtime=runtime,
@@ -241,16 +241,16 @@ class TestUserAction:
def test_user_action_length_boundaries(self):
"""Test user action id and title length boundaries."""
action = UserAction(id="a" * 20, title="b" * 20)
action = UserAction(id="a" * 20, title="b" * 100)
assert action.id == "a" * 20
assert action.title == "b" * 20
assert action.title == "b" * 100
@pytest.mark.parametrize(
("field_name", "value"),
[
("id", "a" * 21),
("title", "b" * 21),
("title", "b" * 101),
],
)
def test_user_action_length_limits(self, field_name: str, value: str):
@@ -427,7 +427,7 @@ class TestHumanInputNodeVariableResolution:
"""Tests for resolving variable-based defaults in HumanInputNode."""
def test_resolves_variable_defaults(self):
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
user_id="user",
app_id="app",
@@ -504,7 +504,7 @@ class TestHumanInputNodeVariableResolution:
assert params.resolved_default_values == expected_values
def test_debugger_falls_back_to_recipient_token_when_webapp_disabled(self):
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
user_id="user",
app_id="app",
@@ -565,7 +565,7 @@ class TestHumanInputNodeVariableResolution:
assert not hasattr(pause_event.reason, "form_token")
def test_webapp_runtime_keeps_form_visible_in_ui_when_webapp_delivery_is_enabled(self):
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
user_id="user",
app_id="app",
@@ -631,7 +631,7 @@ class TestHumanInputNodeVariableResolution:
assert params.display_in_ui is True
def test_debugger_debug_mode_overrides_email_recipients(self):
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
user_id="user-123",
app_id="app",
@@ -748,7 +748,7 @@ class TestHumanInputNodeRenderedContent:
"""Tests for rendering submitted content."""
def test_replaces_outputs_placeholders_after_submission(self):
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
user_id="user",
app_id="app",

View File

@@ -40,7 +40,7 @@ def _create_human_input_node(
)
return HumanInputNode(
node_id=config["id"],
config=node_data,
data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
form_repository=repo,
@@ -51,7 +51,11 @@ def _create_human_input_node(
def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#}}") -> HumanInputNode:
system_variables = default_system_variables()
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]),
variable_pool=VariablePool.from_bootstrap(
system_variables=system_variables,
user_inputs={},
environment_variables=[],
),
start_at=0.0,
)
graph_init_params = GraphInitParams(
@@ -114,7 +118,11 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#
def _build_timeout_node() -> HumanInputNode:
system_variables = default_system_variables()
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]),
variable_pool=VariablePool.from_bootstrap(
system_variables=system_variables,
user_inputs={},
environment_variables=[],
),
start_at=0.0,
)
graph_init_params = GraphInitParams(

View File

@@ -32,7 +32,7 @@ class _MissingGraphBuilder:
def _build_runtime_state() -> GraphRuntimeState:
return GraphRuntimeState(
variable_pool=VariablePool(system_variables=default_system_variables(), user_inputs={}),
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables(), user_inputs={}),
start_at=0.0,
)
@@ -46,7 +46,7 @@ def _build_iteration_node(
init_params = build_test_graph_init_params(graph_config=graph_config)
return IterationNode(
node_id="iteration-node",
config=IterationNodeData(
data=IterationNodeData(
type="iteration",
title="Iteration",
iterator_selector=["start", "items"],

View File

@@ -41,7 +41,7 @@ def mock_graph_init_params():
@pytest.fixture
def mock_graph_runtime_state():
"""Create mock GraphRuntimeState."""
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id=str(uuid.uuid4()), files=[]),
user_inputs={},
environment_variables=[],
@@ -103,7 +103,7 @@ def _build_node(
) -> KnowledgeIndexNode:
return KnowledgeIndexNode(
node_id=node_id,
config=(
data=(
node_data
if isinstance(node_data, KnowledgeIndexNodeData)
else KnowledgeIndexNodeData.model_validate(node_data)

View File

@@ -47,7 +47,7 @@ def mock_graph_init_params():
@pytest.fixture
def mock_graph_runtime_state():
"""Create mock GraphRuntimeState."""
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id=str(uuid.uuid4()), files=[]),
user_inputs={},
environment_variables=[],
@@ -118,7 +118,7 @@ class TestKnowledgeRetrievalNode:
# Act
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -147,7 +147,7 @@ class TestKnowledgeRetrievalNode:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -206,7 +206,7 @@ class TestKnowledgeRetrievalNode:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -250,7 +250,7 @@ class TestKnowledgeRetrievalNode:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -286,7 +286,7 @@ class TestKnowledgeRetrievalNode:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -321,7 +321,7 @@ class TestKnowledgeRetrievalNode:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -362,7 +362,7 @@ class TestKnowledgeRetrievalNode:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -401,7 +401,7 @@ class TestKnowledgeRetrievalNode:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -482,7 +482,7 @@ class TestFetchDatasetRetriever:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -519,7 +519,7 @@ class TestFetchDatasetRetriever:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -574,7 +574,7 @@ class TestFetchDatasetRetriever:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -622,7 +622,7 @@ class TestFetchDatasetRetriever:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -683,7 +683,7 @@ class TestFetchDatasetRetriever:
config = {"id": node_id, "data": node_data.model_dump()}
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)

View File

@@ -16,10 +16,10 @@ class TestListOperatorNode:
"""Comprehensive tests for ListOperatorNode."""
@staticmethod
def _build_node(*, config, graph_init_params, graph_runtime_state):
def _build_node(*, data, graph_init_params, graph_runtime_state):
return ListOperatorNode(
node_id="test",
config=config if isinstance(config, ListOperatorNodeData) else ListOperatorNodeData.model_validate(config),
data=data if isinstance(data, ListOperatorNodeData) else ListOperatorNodeData.model_validate(data),
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
@@ -65,7 +65,7 @@ class TestListOperatorNode:
def _create_node(config, mock_variable):
mock_graph_runtime_state.variable_pool.get.return_value = mock_variable
return self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -83,7 +83,7 @@ class TestListOperatorNode:
}
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -127,7 +127,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -153,7 +153,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -177,7 +177,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -201,7 +201,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -228,7 +228,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -255,7 +255,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -282,7 +282,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -312,7 +312,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -335,7 +335,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = None
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -359,7 +359,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -384,7 +384,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -408,7 +408,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -432,7 +432,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -456,7 +456,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -483,7 +483,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)

View File

@@ -15,7 +15,7 @@ from core.app.llm.model_access import (
)
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration
from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.workflow.system_variables import default_system_variables
from graphon.entities import GraphInitParams
@@ -187,7 +187,7 @@ def graph_init_params() -> GraphInitParams:
@pytest.fixture
def graph_runtime_state() -> GraphRuntimeState:
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@@ -208,7 +208,7 @@ def llm_node(
http_client = mock.MagicMock()
node = LLMNode(
node_id="1",
config=llm_node_data,
data=llm_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=mock_credentials_provider,
@@ -241,9 +241,10 @@ def model_config(monkeypatch: pytest.MonkeyPatch):
)
# Create actual provider and model type instances
model_provider_factory = ModelProviderFactory(model_runtime=create_plugin_model_runtime(tenant_id="test"))
model_assembly = create_plugin_model_assembly(tenant_id="test")
model_provider_factory = model_assembly.model_provider_factory
provider_instance = model_provider_factory.get_model_provider("openai")
model_type_instance = model_provider_factory.get_model_type_instance("openai", ModelType.LLM)
model_type_instance = model_assembly.create_model_type_instance(provider="openai", model_type=ModelType.LLM)
# Create a ProviderModelBundle
provider_model_bundle = ProviderModelBundle(
@@ -1173,7 +1174,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
http_client = mock.MagicMock()
node = LLMNode(
node_id="1",
config=llm_node_data,
data=llm_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=mock_credentials_provider,

View File

@@ -28,7 +28,7 @@ def _build_template_transform_node(
)
return TemplateTransformNode(
node_id=node_id,
config=typed_node_data,
data=typed_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
**kwargs,

View File

@@ -39,7 +39,7 @@ def mock_graph_runtime_state():
def test_node_uses_default_max_output_length_when_not_overridden(graph_init_params, mock_graph_runtime_state):
node = TemplateTransformNode(
node_id="test_node",
config=TemplateTransformNodeData(
data=TemplateTransformNodeData(
title="Template Transform",
type="template-transform",
variables=[],

View File

@@ -35,7 +35,10 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams,
invoke_from="debugger",
)
runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="user", files=[]),
user_inputs={},
),
start_at=0.0,
)
return init_params, runtime_state
@@ -62,7 +65,7 @@ def test_node_hydrates_data_during_initialization():
node = _SampleNode(
node_id="node-1",
config=_build_node_data(),
data=_build_node_data(),
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
@@ -82,13 +85,16 @@ def test_node_accepts_invoke_from_enum():
invoke_from=InvokeFrom.DEBUGGER,
)
runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="user", files=[]),
user_inputs={},
),
start_at=0.0,
)
node = _SampleNode(
node_id="node-1",
config=_build_node_data(),
data=_build_node_data(),
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
@@ -140,7 +146,7 @@ def test_node_hydration_preserves_compatibility_extra_fields():
node = _SampleNode(
node_id="node-1",
config=node_config["data"],
data=node_config["data"],
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)

View File

@@ -49,7 +49,7 @@ def document_extractor_node(graph_init_params):
http_client = Mock()
node = DocumentExtractorNode(
node_id="test_node_id",
config=node_data,
data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=Mock(),
http_client=http_client,
@@ -186,12 +186,13 @@ def test_run_extract_text(
monkeypatch.setattr("graphon.file.file_manager.download", mock_download)
dispatch_mock = None
if mime_type == "application/pdf":
mock_pdf_extract = Mock(return_value=expected_text[0])
monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract)
dispatch_mock = Mock(return_value=expected_text[0])
monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_by_file_extension", dispatch_mock)
elif mime_type.startswith("application/vnd.openxmlformats"):
mock_docx_extract = Mock(return_value=expected_text[0])
monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_from_docx", mock_docx_extract)
dispatch_mock = Mock(return_value=expected_text[0])
monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_by_mime_type", dispatch_mock)
result = document_extractor_node._run()
@@ -200,6 +201,19 @@ def test_run_extract_text(
assert result.outputs is not None
assert result.outputs["text"] == ArrayStringSegment(value=expected_text)
if mime_type == "application/pdf":
dispatch_mock.assert_called_once_with(
file_content=file_content,
file_extension=extension,
unstructured_api_config=document_extractor_node._unstructured_api_config,
)
elif mime_type.startswith("application/vnd.openxmlformats"):
dispatch_mock.assert_called_once_with(
file_content=file_content,
mime_type=mime_type,
unstructured_api_config=document_extractor_node._unstructured_api_config,
)
if transfer_method == FileTransferMethod.REMOTE_URL:
document_extractor_node._http_client.get.assert_called_once_with("https://example.com/file.txt")
elif transfer_method == FileTransferMethod.LOCAL_FILE:
@@ -439,24 +453,42 @@ def test_extract_text_from_file_routes_excel_inputs(document_extractor_node, ext
file.extension = extension
file.mime_type = mime_type
with (
patch(
"graphon.nodes.document_extractor.node._download_file_content",
return_value=b"excel",
),
patch(
"graphon.nodes.document_extractor.node._extract_text_from_excel",
return_value="excel text",
) as mock_extract,
with patch(
"graphon.nodes.document_extractor.node._download_file_content",
return_value=b"excel",
):
result = _extract_text_from_file(
document_extractor_node.http_client,
file,
unstructured_api_config=document_extractor_node._unstructured_api_config,
)
if extension:
with patch(
"graphon.nodes.document_extractor.node._extract_text_by_file_extension",
return_value="excel text",
) as mock_extract:
result = _extract_text_from_file(
document_extractor_node.http_client,
file,
unstructured_api_config=document_extractor_node._unstructured_api_config,
)
mock_extract.assert_called_once_with(
file_content=b"excel",
file_extension=extension,
unstructured_api_config=document_extractor_node._unstructured_api_config,
)
else:
with patch(
"graphon.nodes.document_extractor.node._extract_text_by_mime_type",
return_value="excel text",
) as mock_extract:
result = _extract_text_from_file(
document_extractor_node.http_client,
file,
unstructured_api_config=document_extractor_node._unstructured_api_config,
)
mock_extract.assert_called_once_with(
file_content=b"excel",
mime_type=mime_type,
unstructured_api_config=document_extractor_node._unstructured_api_config,
)
assert result == "excel text"
mock_extract.assert_called_once_with(b"excel")
def test_extract_text_from_file_rejects_missing_extension_and_mime_type(document_extractor_node):

View File

@@ -29,7 +29,7 @@ def _build_if_else_node(
node_id=str(uuid.uuid4()),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
config=node_data if isinstance(node_data, IfElseNodeData) else IfElseNodeData.model_validate(node_data),
data=node_data if isinstance(node_data, IfElseNodeData) else IfElseNodeData.model_validate(node_data),
)
@@ -48,7 +48,10 @@ def test_execute_if_else_result_true():
)
# construct variable pool
pool = VariablePool(system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={})
pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
)
pool.add(["start", "array_contains"], ["ab", "def"])
pool.add(["start", "array_not_contains"], ["ac", "def"])
pool.add(["start", "contains"], "cabcde")
@@ -148,7 +151,7 @@ def test_execute_if_else_result_false():
)
# construct variable pool
pool = VariablePool(
pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
@@ -305,7 +308,7 @@ def test_execute_if_else_boolean_conditions(condition: Condition):
)
# construct variable pool with boolean values
pool = VariablePool(
pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(files=[], user_id="aaa"),
)
pool.add(["start", "bool_true"], True)
@@ -359,7 +362,7 @@ def test_execute_if_else_boolean_false_conditions():
)
# construct variable pool with boolean values
pool = VariablePool(
pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(files=[], user_id="aaa"),
)
pool.add(["start", "bool_true"], True)
@@ -424,7 +427,7 @@ def test_execute_if_else_boolean_cases_structure():
)
# construct variable pool with boolean values
pool = VariablePool(
pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(files=[], user_id="aaa"),
)
pool.add(["start", "bool_true"], True)

View File

@@ -22,7 +22,7 @@ from graphon.variables import ArrayFileSegment
def _build_list_operator_node(node_data: ListOperatorNodeData, graph_init_params) -> ListOperatorNode:
return ListOperatorNode(
node_id="test_node_id",
config=node_data,
data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=MagicMock(),
)

View File

@@ -31,7 +31,7 @@ def make_start_node(user_inputs, variables):
return StartNode(
node_id="start",
config=node_data,
data=node_data,
graph_init_params=build_test_graph_init_params(
workflow_id="wf",
graph_config={},
@@ -260,7 +260,7 @@ def test_start_node_outputs_full_variable_pool_snapshot():
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
node = StartNode(
node_id="start",
config=node_data,
data=node_data,
graph_init_params=build_test_graph_init_params(
workflow_id="wf",
graph_config={},

View File

@@ -99,7 +99,7 @@ def tool_node(monkeypatch) -> ToolNode:
call_depth=0,
)
variable_pool = VariablePool(system_variables=build_system_variables(user_id="user-id"))
variable_pool = VariablePool.from_bootstrap(system_variables=build_system_variables(user_id="user-id"))
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
config = graph_config["nodes"][0]
@@ -110,7 +110,7 @@ def tool_node(monkeypatch) -> ToolNode:
node = ToolNode(
node_id="node-instance",
config=ToolNodeData.model_validate(config["data"]),
data=ToolNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
tool_file_manager_factory=tool_file_manager_factory,

View File

@@ -44,7 +44,7 @@ def test_trigger_event_node_run_populates_trigger_info_metadata() -> None:
init_params, runtime_state = _build_context(graph_config={})
node = TriggerEventNode(
node_id="node-1",
config=_build_node_data(),
data=_build_node_data(),
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)

View File

@@ -52,7 +52,7 @@ def create_webhook_node(
node = TriggerWebhookNode(
node_id="webhook-node-1",
config=webhook_data,
data=webhook_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)

View File

@@ -44,7 +44,7 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool)
)
node = TriggerWebhookNode(
node_id="1",
config=webhook_data,
data=webhook_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)

View File

@@ -1,3 +1,4 @@
from collections.abc import Mapping
from types import SimpleNamespace
from unittest.mock import MagicMock, patch, sentinel
@@ -11,19 +12,20 @@ from graphon.entities.base_node_data import BaseNodeData
from graphon.enums import BuiltinNodeTypes, NodeType
from graphon.nodes.code.entities import CodeLanguage
from graphon.nodes.llm.entities import LLMNodeData
from graphon.nodes.llm.node import LLMNode
from graphon.variables.segments import StringSegment
def _assert_typed_node_config(config, *, node_id: str, node_type: NodeType, version: str = "1") -> None:
def _assert_constructor_node_data(data, *, node_id: str, node_type: NodeType, version: str = "1") -> None:
_ = node_id
if isinstance(config, BaseNodeData):
assert config.type == node_type
assert config.version == version
if isinstance(data, BaseNodeData):
assert data.type == node_type
assert data.version == version
return
assert isinstance(config, dict)
assert config["type"] == node_type
assert config["version"] == version
assert isinstance(data, Mapping)
assert data["type"] == node_type
assert data.get("version", "1") == version
def _node_constructor(*, return_value):
@@ -470,7 +472,7 @@ class TestDifyNodeFactoryCreateNode:
matched_node_class.assert_called_once()
kwargs = matched_node_class.call_args.kwargs
assert kwargs["node_id"] == "node-id"
_assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9")
_assert_constructor_node_data(kwargs["data"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9")
assert kwargs["graph_init_params"] is sentinel.graph_init_params
assert kwargs["graph_runtime_state"] is factory.graph_runtime_state
latest_node_class.assert_not_called()
@@ -492,7 +494,7 @@ class TestDifyNodeFactoryCreateNode:
latest_node_class.assert_called_once()
kwargs = latest_node_class.call_args.kwargs
assert kwargs["node_id"] == "node-id"
_assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9")
_assert_constructor_node_data(kwargs["data"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9")
assert kwargs["graph_init_params"] is sentinel.graph_init_params
assert kwargs["graph_runtime_state"] is factory.graph_runtime_state
@@ -530,7 +532,7 @@ class TestDifyNodeFactoryCreateNode:
assert result is created_node
kwargs = constructor.call_args.kwargs
assert kwargs["node_id"] == "node-id"
_assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=node_type)
_assert_constructor_node_data(kwargs["data"], node_id="node-id", node_type=node_type)
assert kwargs["graph_init_params"] is sentinel.graph_init_params
assert kwargs["graph_runtime_state"] is factory.graph_runtime_state
@@ -599,11 +601,12 @@ class TestDifyNodeFactoryCreateNode:
prepared_llm.assert_called_once_with(sentinel.model_instance)
assert kwargs["model_instance"] is wrapped_model_instance
def test_create_node_passes_alias_preserving_llm_config_to_constructor(
self, monkeypatch: pytest.MonkeyPatch, factory
):
def test_create_node_passes_alias_preserving_llm_data_to_constructor(self, monkeypatch, factory):
created_node = object()
constructor = _node_constructor(return_value=created_node)
constructor.validate_node_data.side_effect = lambda node_data: LLMNodeData.model_validate(
node_data.model_dump(mode="python") if isinstance(node_data, BaseNodeData) else node_data
)
monkeypatch.setattr(factory, "_resolve_node_class", MagicMock(return_value=constructor))
monkeypatch.setattr(factory, "_build_llm_compatible_node_init_kwargs", MagicMock(return_value={}))
@@ -629,10 +632,56 @@ class TestDifyNodeFactoryCreateNode:
factory.create_node(node_config)
config = constructor.call_args.kwargs["config"]
assert isinstance(config, dict)
assert config["structured_output_enabled"] is True
assert "structured_output_switch_on" not in config
data = constructor.call_args.kwargs["data"]
assert isinstance(data, Mapping)
assert data["structured_output_enabled"] is True
assert "structured_output_switch_on" not in data
assert LLMNodeData.model_validate(data).structured_output_enabled is True
def test_create_node_preserves_structured_output_switch_after_graphon_constructor(self, monkeypatch, factory):
factory.graph_init_params = SimpleNamespace(
workflow_id="workflow-id",
graph_config={},
run_context={},
call_depth=0,
)
monkeypatch.setattr(factory, "_resolve_node_class", MagicMock(return_value=LLMNode))
monkeypatch.setattr(
factory,
"_build_llm_compatible_node_init_kwargs",
MagicMock(
return_value={
"model_instance": sentinel.model_instance,
"llm_file_saver": sentinel.llm_file_saver,
"prompt_message_serializer": sentinel.prompt_message_serializer,
}
),
)
node_config = {
"id": "llm-node-id",
"data": {
"type": BuiltinNodeTypes.LLM,
"title": "LLM",
"model": {"provider": "provider", "name": "model", "mode": "chat", "completion_params": {}},
"prompt_template": [{"role": "system", "text": "x"}],
"context": {"enabled": False, "variable_selector": []},
"vision": {"enabled": False},
"structured_output_enabled": True,
"structured_output": {
"schema": {
"type": "object",
"properties": {"type": {"type": "string"}},
"required": ["type"],
}
},
},
}
node = factory.create_node(node_config)
assert node.node_data.structured_output_switch_on is True
assert node.node_data.structured_output_enabled is True
@pytest.mark.parametrize(
("node_type", "constructor_name", "expected_extra_kwargs"),
@@ -711,7 +760,7 @@ class TestDifyNodeFactoryCreateNode:
constructor_kwargs = constructor.call_args.kwargs
assert constructor_kwargs["node_id"] == "node-id"
_assert_typed_node_config(constructor_kwargs["config"], node_id="node-id", node_type=node_type)
_assert_constructor_node_data(constructor_kwargs["data"], node_id="node-id", node_type=node_type)
assert constructor_kwargs["graph_init_params"] is sentinel.graph_init_params
assert constructor_kwargs["graph_runtime_state"] is factory.graph_runtime_state
assert constructor_kwargs["credentials_provider"] is sentinel.credentials_provider

View File

@@ -109,8 +109,8 @@ class TestVariablePool:
assert pool.get([ENVIRONMENT_VARIABLE_NODE_ID, "env_var_1"]) is not None
assert pool.get([CONVERSATION_VARIABLE_NODE_ID, "conv_var_1"]) is not None
def test_constructor_loads_legacy_bootstrap_kwargs(self):
pool = VariablePool(
def test_from_bootstrap_loads_legacy_bootstrap_kwargs(self):
pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="test_user_id"),
environment_variables=[StringVariable(name="env_var", value="env-value")],
conversation_variables=[StringVariable(name="conv_var", value="conv-value")],

View File

@@ -55,7 +55,7 @@ class TestWorkflowEntry:
def test_mapping_user_inputs_to_variable_pool_with_system_variables(self):
"""Test mapping system variables from user inputs to variable pool."""
# Initialize variable pool with system variables
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
user_id="test_user_id",
app_id="test_app_id",
@@ -128,7 +128,7 @@ class TestWorkflowEntry:
return NodeConfigDictAdapter.validate_python(node_config)
workflow = StubWorkflow()
variable_pool = VariablePool(system_variables=default_system_variables(), user_inputs={})
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables(), user_inputs={})
expected_limits = CodeNodeLimits(
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
max_number=dify_config.CODE_MAX_NUMBER,
@@ -157,7 +157,7 @@ class TestWorkflowEntry:
"""Test mapping environment variables from user inputs to variable pool."""
# Initialize variable pool with environment variables
env_var = StringVariable(name="API_KEY", value="existing_key")
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
environment_variables=[env_var],
user_inputs={},
@@ -198,7 +198,7 @@ class TestWorkflowEntry:
"""Test mapping conversation variables from user inputs to variable pool."""
# Initialize variable pool with conversation variables
conv_var = StringVariable(name="last_message", value="Hello")
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
conversation_variables=[conv_var],
user_inputs={},
@@ -239,7 +239,7 @@ class TestWorkflowEntry:
def test_mapping_user_inputs_to_variable_pool_with_regular_variables(self):
"""Test mapping regular node variables from user inputs to variable pool."""
# Initialize empty variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@@ -281,7 +281,7 @@ class TestWorkflowEntry:
def test_mapping_user_inputs_with_file_handling(self):
"""Test mapping file inputs from user inputs to variable pool."""
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@@ -340,7 +340,7 @@ class TestWorkflowEntry:
def test_mapping_user_inputs_missing_variable_error(self):
"""Test that mapping raises error when required variable is missing."""
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@@ -366,7 +366,7 @@ class TestWorkflowEntry:
def test_mapping_user_inputs_with_alternative_key_format(self):
"""Test mapping with alternative key format (without node prefix)."""
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@@ -396,7 +396,7 @@ class TestWorkflowEntry:
def test_mapping_user_inputs_with_complex_selectors(self):
"""Test mapping with complex node variable keys."""
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@@ -432,7 +432,7 @@ class TestWorkflowEntry:
def test_mapping_user_inputs_invalid_node_variable(self):
"""Test that mapping handles invalid node variable format."""
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@@ -463,7 +463,7 @@ class TestWorkflowEntry:
env_var = StringVariable(name="API_KEY", value="existing_key")
conv_var = StringVariable(name="session_id", value="session123")
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
user_id="test_user",
app_id="test_app",

View File

@@ -7,7 +7,6 @@ import pytest
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
from core.model_manager import ModelInstance
from core.workflow import workflow_entry
from core.workflow.system_variables import default_system_variables
from graphon.entities.base_node_data import BaseNodeData
@@ -16,10 +15,12 @@ from graphon.errors import WorkflowNodeRunFailedError
from graphon.file import File, FileTransferMethod, FileType
from graphon.graph import Graph
from graphon.graph_events import GraphRunFailedEvent
from graphon.model_runtime.entities.llm_entities import LLMUsage
from graphon.model_runtime.entities.llm_entities import LLMMode, LLMUsage
from graphon.node_events import NodeRunResult
from graphon.nodes import BuiltinNodeTypes
from graphon.nodes.base.node import Node
from graphon.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig
from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData
from graphon.runtime import ChildGraphNotFoundError, VariablePool
from graphon.variables.variables import StringVariable
from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool
@@ -29,9 +30,30 @@ def _build_typed_node_config(node_type: NodeType):
return {"id": "node-id", "data": BaseNodeData(type=node_type)}
def _build_wrapped_model_instance() -> tuple[SimpleNamespace, ModelInstance]:
raw_model_instance = ModelInstance.__new__(ModelInstance)
return SimpleNamespace(_model_instance=raw_model_instance), raw_model_instance
def _build_model_config(*, provider: str = "openai", model_name: str = "gpt-4o") -> ModelConfig:
return ModelConfig(provider=provider, name=model_name, mode=LLMMode.CHAT)
def _build_llm_node_data(*, provider: str = "openai", model_name: str = "gpt-4o") -> LLMNodeData:
return LLMNodeData(
type=BuiltinNodeTypes.LLM,
title="Child Model",
model=_build_model_config(provider=provider, model_name=model_name),
prompt_template=[],
context=ContextConfig(enabled=False),
)
def _build_question_classifier_node_data(
*, provider: str = "openai", model_name: str = "gpt-4o"
) -> QuestionClassifierNodeData:
return QuestionClassifierNodeData(
type=BuiltinNodeTypes.QUESTION_CLASSIFIER,
title="Child Model",
query_variable_selector=["sys", "query"],
model=_build_model_config(provider=provider, model_name=model_name),
classes=[],
)
class _FakeModelNodeMixin:
@@ -40,22 +62,26 @@ class _FakeModelNodeMixin:
return "1"
def post_init(self) -> None:
self.model_instance, self.raw_model_instance = _build_wrapped_model_instance()
self.model_instance = SimpleNamespace(provider="stale-provider", model_name="stale-model")
self.usage_snapshot = LLMUsage.empty_usage()
self.usage_snapshot.total_tokens = 1
def _run(self) -> NodeRunResult:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={
"model_provider": self.node_data.model.provider,
"model_name": self.node_data.model.name,
},
llm_usage=self.usage_snapshot,
)
class _FakeLLMNode(_FakeModelNodeMixin, Node[BaseNodeData]):
class _FakeLLMNode(_FakeModelNodeMixin, Node[LLMNodeData]):
node_type = BuiltinNodeTypes.LLM
class _FakeQuestionClassifierNode(_FakeModelNodeMixin, Node[BaseNodeData]):
class _FakeQuestionClassifierNode(_FakeModelNodeMixin, Node[QuestionClassifierNodeData]):
node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER
@@ -75,7 +101,7 @@ class TestWorkflowChildEngineBuilder:
assert result is expected
def test_build_child_engine_raises_when_root_node_is_missing(self):
builder = workflow_entry._WorkflowChildEngineBuilder()
builder = workflow_entry._WorkflowChildEngineBuilder(tenant_id="tenant-id")
graph_init_params = SimpleNamespace(graph_config={"nodes": []})
parent_graph_runtime_state = SimpleNamespace(
execution_context=sentinel.execution_context,
@@ -92,7 +118,7 @@ class TestWorkflowChildEngineBuilder:
)
def test_build_child_engine_constructs_graph_engine_with_quota_layer_only(self):
builder = workflow_entry._WorkflowChildEngineBuilder()
builder = workflow_entry._WorkflowChildEngineBuilder(tenant_id="tenant-id")
graph_init_params = SimpleNamespace(graph_config={"nodes": [{"id": "root"}]})
parent_graph_runtime_state = SimpleNamespace(
execution_context=sentinel.execution_context,
@@ -114,7 +140,7 @@ class TestWorkflowChildEngineBuilder:
patch.object(workflow_entry, "GraphEngine", return_value=child_engine) as graph_engine_cls,
patch.object(workflow_entry, "GraphEngineConfig", return_value=sentinel.graph_engine_config),
patch.object(workflow_entry, "InMemoryChannel", return_value=sentinel.command_channel),
patch.object(workflow_entry, "LLMQuotaLayer", return_value=sentinel.llm_quota_layer),
patch.object(workflow_entry, "LLMQuotaLayer", return_value=sentinel.llm_quota_layer) as llm_quota_layer_cls,
):
result = builder.build_child_engine(
workflow_id="workflow-id",
@@ -147,11 +173,12 @@ class TestWorkflowChildEngineBuilder:
config=sentinel.graph_engine_config,
child_engine_builder=builder,
)
llm_quota_layer_cls.assert_called_once_with(tenant_id="tenant-id")
assert child_engine.layer.call_args_list == [((sentinel.llm_quota_layer,), {})]
@pytest.mark.parametrize("node_cls", [_FakeLLMNode, _FakeQuestionClassifierNode])
def test_build_child_engine_runs_llm_quota_layer_for_child_model_nodes(self, node_cls):
builder = workflow_entry._WorkflowChildEngineBuilder()
builder = workflow_entry._WorkflowChildEngineBuilder(tenant_id="tenant-id")
graph_init_params = build_test_graph_init_params(
graph_config={"nodes": [{"id": "root"}], "edges": []},
)
@@ -163,12 +190,10 @@ class TestWorkflowChildEngineBuilder:
def build_graph(*, graph_config, node_factory, root_node_id):
_ = graph_config
node_data = _build_llm_node_data() if node_cls is _FakeLLMNode else _build_question_classifier_node_data()
node = node_cls(
node_id=root_node_id,
config=BaseNodeData(
type=node_cls.node_type,
title="Child Model",
),
data=node_data,
graph_init_params=node_factory.graph_init_params,
graph_runtime_state=node_factory.graph_runtime_state,
)
@@ -191,8 +216,8 @@ class TestWorkflowChildEngineBuilder:
),
),
patch.object(workflow_entry.Graph, "init", side_effect=build_graph),
patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available") as ensure_quota,
patch("core.app.workflow.layers.llm_quota.deduct_llm_quota") as deduct_quota,
patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model") as ensure_quota,
patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model") as deduct_quota,
):
child_engine = builder.build_child_engine(
workflow_id="workflow-id",
@@ -203,10 +228,15 @@ class TestWorkflowChildEngineBuilder:
list(child_engine.run())
node = created_node["node"]
ensure_quota.assert_called_once_with(model_instance=node.raw_model_instance)
ensure_quota.assert_called_once_with(
tenant_id="tenant-id",
provider=node.node_data.model.provider,
model=node.node_data.model.name,
)
deduct_quota.assert_called_once_with(
tenant_id="tenant",
model_instance=node.raw_model_instance,
tenant_id="tenant-id",
provider=node.node_data.model.provider,
model=node.node_data.model.name,
usage=node.usage_snapshot,
)
@@ -252,7 +282,7 @@ class TestWorkflowEntryInit:
"ExecutionLimitsLayer",
return_value=execution_limits_layer,
) as execution_limits_layer_cls,
patch.object(workflow_entry, "LLMQuotaLayer", return_value=llm_quota_layer),
patch.object(workflow_entry, "LLMQuotaLayer", return_value=llm_quota_layer) as llm_quota_layer_cls,
patch.object(workflow_entry, "ObservabilityLayer", return_value=observability_layer),
):
entry = workflow_entry.WorkflowEntry(
@@ -291,6 +321,7 @@ class TestWorkflowEntryInit:
max_steps=workflow_entry.dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_time=workflow_entry.dify_config.WORKFLOW_MAX_EXECUTION_TIME,
)
llm_quota_layer_cls.assert_called_once_with(tenant_id="tenant-id")
assert graph_engine.layer.call_args_list == [
((debug_layer,), {}),
((execution_limits_layer,), {}),
@@ -334,7 +365,7 @@ class TestWorkflowEntrySingleStepRun:
def extract_variable_selector_to_variable_mapping(**_kwargs):
return {}
variable_pool = VariablePool(system_variables=default_system_variables(), user_inputs={})
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables(), user_inputs={})
variable_loader = MagicMock()
variable_loader.load_variables.return_value = [
StringVariable(

View File

@@ -0,0 +1,130 @@
from types import SimpleNamespace
from unittest.mock import patch
from uuid import uuid4
from sqlalchemy import create_engine, select
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from events.event_handlers import update_provider_when_message_created
from models import TenantCreditPool
from models.provider import ProviderType
def test_message_created_trial_credit_accounting_does_not_raise_when_balance_is_insufficient() -> None:
engine = create_engine("sqlite:///:memory:")
TenantCreditPool.__table__.create(engine)
tenant_id = str(uuid4())
pool_id = str(uuid4())
with engine.begin() as connection:
connection.execute(
TenantCreditPool.__table__.insert(),
{
"id": pool_id,
"tenant_id": tenant_id,
"pool_type": ProviderQuotaType.TRIAL,
"quota_limit": 10,
"quota_used": 9,
},
)
system_configuration = SimpleNamespace(
current_quota_type=ProviderQuotaType.TRIAL,
quota_configurations=[
SimpleNamespace(
quota_type=ProviderQuotaType.TRIAL,
quota_unit=QuotaUnit.TOKENS,
quota_limit=10,
)
],
)
application_generate_entity = ChatAppGenerateEntity.model_construct(
app_config=SimpleNamespace(tenant_id=tenant_id),
model_conf=SimpleNamespace(
provider="openai",
model="gpt-4o",
provider_model_bundle=SimpleNamespace(
configuration=SimpleNamespace(
using_provider_type=ProviderType.SYSTEM,
system_configuration=system_configuration,
)
),
),
)
message = SimpleNamespace(message_tokens=2, answer_tokens=1)
with (
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
patch.object(update_provider_when_message_created, "_execute_provider_updates"),
):
update_provider_when_message_created.handle(
sender=message,
application_generate_entity=application_generate_entity,
)
with engine.connect() as connection:
quota_used = connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == pool_id))
assert quota_used == 10
def test_message_created_paid_credit_accounting_uses_paid_pool() -> None:
tenant_id = str(uuid4())
system_configuration = SimpleNamespace(
current_quota_type=ProviderQuotaType.PAID,
quota_configurations=[
SimpleNamespace(
quota_type=ProviderQuotaType.PAID,
quota_unit=QuotaUnit.TOKENS,
quota_limit=10,
)
],
)
application_generate_entity = ChatAppGenerateEntity.model_construct(
app_config=SimpleNamespace(tenant_id=tenant_id),
model_conf=SimpleNamespace(
provider="openai",
model="gpt-4o",
provider_model_bundle=SimpleNamespace(
configuration=SimpleNamespace(
using_provider_type=ProviderType.SYSTEM,
system_configuration=system_configuration,
)
),
),
)
message = SimpleNamespace(message_tokens=2, answer_tokens=1)
with (
patch.object(update_provider_when_message_created, "_deduct_credit_pool_quota_capped") as mock_deduct,
patch.object(update_provider_when_message_created, "_execute_provider_updates"),
):
update_provider_when_message_created.handle(
sender=message,
application_generate_entity=application_generate_entity,
)
mock_deduct.assert_called_once_with(
tenant_id=tenant_id,
credits_required=3,
pool_type="paid",
)
def test_capped_credit_pool_accounting_skips_exhaustion_warning_when_full_amount_is_deducted(caplog) -> None:
with patch(
"services.credit_pool_service.CreditPoolService.deduct_credits_capped",
return_value=3,
) as mock_deduct:
update_provider_when_message_created._deduct_credit_pool_quota_capped(
tenant_id="tenant-id",
credits_required=3,
pool_type="trial",
)
mock_deduct.assert_called_once_with(
tenant_id="tenant-id",
credits_required=3,
pool_type="trial",
)
assert "Credit pool exhausted during message-created accounting" not in caplog.text

View File

@@ -0,0 +1,158 @@
from types import SimpleNamespace
from unittest.mock import patch
from uuid import uuid4
import pytest
from sqlalchemy import create_engine, select
from sqlalchemy.engine import Engine
from core.errors.error import QuotaExceededError
from models import TenantCreditPool
from models.enums import ProviderQuotaType
from services.credit_pool_service import CreditPoolService
def _create_engine_with_pool(*, quota_limit: int, quota_used: int) -> tuple[Engine, str, str]:
engine = create_engine("sqlite:///:memory:")
TenantCreditPool.__table__.create(engine)
tenant_id = str(uuid4())
pool_id = str(uuid4())
with engine.begin() as connection:
connection.execute(
TenantCreditPool.__table__.insert(),
{
"id": pool_id,
"tenant_id": tenant_id,
"pool_type": ProviderQuotaType.TRIAL,
"quota_limit": quota_limit,
"quota_used": quota_used,
},
)
return engine, tenant_id, pool_id
def _get_quota_used(*, engine: Engine, pool_id: str) -> int | None:
with engine.connect() as connection:
return connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == pool_id))
def test_check_and_deduct_credits_deducts_exact_amount_when_sufficient() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)):
deducted_credits = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=3)
assert deducted_credits == 3
assert _get_quota_used(engine=engine, pool_id=pool_id) == 5
def test_check_and_deduct_credits_returns_zero_for_non_positive_request() -> None:
assert CreditPoolService.check_and_deduct_credits(tenant_id=str(uuid4()), credits_required=0) == 0
def test_check_and_deduct_credits_raises_when_pool_is_missing() -> None:
engine = create_engine("sqlite:///:memory:")
TenantCreditPool.__table__.create(engine)
with (
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
pytest.raises(QuotaExceededError, match="Credit pool not found"),
):
CreditPoolService.check_and_deduct_credits(tenant_id=str(uuid4()), credits_required=1)
def test_check_and_deduct_credits_raises_when_pool_is_empty() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=10)
with (
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
pytest.raises(QuotaExceededError, match="No credits remaining"),
):
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=1)
assert _get_quota_used(engine=engine, pool_id=pool_id) == 10
def test_check_and_deduct_credits_raises_without_partial_deduction_when_insufficient() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=9)
with (
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
pytest.raises(QuotaExceededError, match="Insufficient credits remaining"),
):
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=3)
assert _get_quota_used(engine=engine, pool_id=pool_id) == 9
def test_check_and_deduct_credits_wraps_unexpected_deduction_errors() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
with (
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
patch.object(CreditPoolService, "_get_locked_pool", side_effect=RuntimeError("database unavailable")),
pytest.raises(QuotaExceededError, match="Failed to deduct credits"),
):
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=1)
assert _get_quota_used(engine=engine, pool_id=pool_id) == 2
def test_deduct_credits_capped_returns_zero_for_non_positive_request() -> None:
assert CreditPoolService.deduct_credits_capped(tenant_id=str(uuid4()), credits_required=0) == 0
def test_deduct_credits_capped_returns_zero_when_pool_is_missing() -> None:
engine = create_engine("sqlite:///:memory:")
TenantCreditPool.__table__.create(engine)
with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)):
deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=str(uuid4()), credits_required=1)
assert deducted_credits == 0
def test_deduct_credits_capped_returns_zero_when_pool_is_empty() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=10)
with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)):
deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1)
assert deducted_credits == 0
assert _get_quota_used(engine=engine, pool_id=pool_id) == 10
def test_deduct_credits_capped_deducts_only_remaining_balance_when_insufficient() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=9)
with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)):
deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=3)
assert deducted_credits == 1
assert _get_quota_used(engine=engine, pool_id=pool_id) == 10
def test_deduct_credits_capped_wraps_unexpected_deduction_errors() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
with (
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
patch.object(CreditPoolService, "_get_locked_pool", side_effect=RuntimeError("database unavailable")),
pytest.raises(QuotaExceededError, match="Failed to deduct credits"),
):
CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1)
assert _get_quota_used(engine=engine, pool_id=pool_id) == 2
def test_deduct_credits_capped_reraises_quota_exceeded_errors() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
with (
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
patch.object(CreditPoolService, "_get_locked_pool", side_effect=QuotaExceededError("quota unavailable")),
pytest.raises(QuotaExceededError, match="quota unavailable"),
):
CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1)
assert _get_quota_used(engine=engine, pool_id=pool_id) == 2

View File

@@ -2845,7 +2845,7 @@ class TestWorkflowServiceFreeNodeExecution:
mock_node_cls.validate_node_data.assert_called_once_with(sentinel.adapted_node_data)
mock_node_cls.assert_called_once_with(
node_id="n-1",
config=sentinel.node_data,
data=sentinel.node_data,
graph_init_params=mock_graph_init_context_cls.return_value.to_graph_init_params.return_value,
graph_runtime_state=ANY,
runtime=mock_runtime_cls.return_value,