mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 15:01:36 -04:00
chore(api): upgrade graphon to v0.3.0 (#35469)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: WH-2099 <wh2099@pm.me>
This commit is contained in:
@@ -73,7 +73,7 @@ def test_node_integration_minimal_stream(mocker: MockerFixture):
|
||||
|
||||
node = DatasourceNode(
|
||||
node_id="n",
|
||||
config=DatasourceNodeData(
|
||||
data=DatasourceNodeData(
|
||||
type="datasource",
|
||||
version="1",
|
||||
title="Datasource",
|
||||
|
||||
@@ -4,7 +4,7 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration
|
||||
from core.model_manager import ModelInstance
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from models.provider import ProviderType
|
||||
|
||||
@@ -15,8 +15,9 @@ def get_mocked_fetch_model_config(
|
||||
mode: str,
|
||||
credentials: dict,
|
||||
):
|
||||
model_provider_factory = create_plugin_model_provider_factory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b")
|
||||
model_type_instance = model_provider_factory.get_model_type_instance(provider, ModelType.LLM)
|
||||
model_assembly = create_plugin_model_assembly(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b")
|
||||
model_provider_factory = model_assembly.model_provider_factory
|
||||
model_type_instance = model_assembly.create_model_type_instance(provider=provider, model_type=ModelType.LLM)
|
||||
provider_model_bundle = ProviderModelBundle(
|
||||
configuration=ProviderConfiguration(
|
||||
tenant_id="1",
|
||||
|
||||
@@ -45,7 +45,7 @@ def init_code_node(code_config: dict):
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
@@ -66,7 +66,7 @@ def init_code_node(code_config: dict):
|
||||
|
||||
node = CodeNode(
|
||||
node_id=str(uuid.uuid4()),
|
||||
config=CodeNodeData.model_validate(code_config["data"]),
|
||||
data=CodeNodeData.model_validate(code_config["data"]),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
code_executor=node_factory._code_executor,
|
||||
|
||||
@@ -55,7 +55,7 @@ def init_http_node(config: dict):
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
@@ -76,7 +76,7 @@ def init_http_node(config: dict):
|
||||
|
||||
node = HttpRequestNode(
|
||||
node_id=str(uuid.uuid4()),
|
||||
config=HttpRequestNodeData.model_validate(config["data"]),
|
||||
data=HttpRequestNodeData.model_validate(config["data"]),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
http_request_config=HTTP_REQUEST_CONFIG,
|
||||
@@ -204,7 +204,7 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock):
|
||||
from graphon.runtime import VariablePool
|
||||
|
||||
# Create variable pool
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="test", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
@@ -702,7 +702,7 @@ def test_nested_object_variable_selector(setup_http_mock):
|
||||
)
|
||||
|
||||
# Create independent variable pool for this test only
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
@@ -724,7 +724,7 @@ def test_nested_object_variable_selector(setup_http_mock):
|
||||
|
||||
node = HttpRequestNode(
|
||||
node_id=str(uuid.uuid4()),
|
||||
config=HttpRequestNodeData.model_validate(graph_config["nodes"][1]["data"]),
|
||||
data=HttpRequestNodeData.model_validate(graph_config["nodes"][1]["data"]),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
http_request_config=HTTP_REQUEST_CONFIG,
|
||||
|
||||
@@ -53,7 +53,7 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(
|
||||
user_id="aaa",
|
||||
app_id=app_id,
|
||||
@@ -77,7 +77,7 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||
|
||||
node = LLMNode(
|
||||
node_id=str(uuid.uuid4()),
|
||||
config=LLMNodeData.model_validate(config["data"]),
|
||||
data=LLMNodeData.model_validate(config["data"]),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
credentials_provider=MagicMock(spec=CredentialsProvider),
|
||||
|
||||
@@ -56,7 +56,7 @@ def init_parameter_extractor_node(config: dict, memory=None):
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(
|
||||
user_id="aaa", files=[], query="what's the weather in SF", conversation_id="abababa"
|
||||
),
|
||||
@@ -71,7 +71,7 @@ def init_parameter_extractor_node(config: dict, memory=None):
|
||||
|
||||
node = ParameterExtractorNode(
|
||||
node_id=str(uuid.uuid4()),
|
||||
config=ParameterExtractorNodeData.model_validate(config["data"]),
|
||||
data=ParameterExtractorNodeData.model_validate(config["data"]),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
credentials_provider=MagicMock(spec=CredentialsProvider),
|
||||
|
||||
@@ -66,7 +66,7 @@ def test_execute_template_transform():
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
@@ -88,7 +88,7 @@ def test_execute_template_transform():
|
||||
|
||||
node = TemplateTransformNode(
|
||||
node_id=str(uuid.uuid4()),
|
||||
config=TemplateTransformNodeData.model_validate(config["data"]),
|
||||
data=TemplateTransformNodeData.model_validate(config["data"]),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
jinja2_template_renderer=_SimpleJinja2Renderer(),
|
||||
|
||||
@@ -43,7 +43,7 @@ def init_tool_node(config: dict):
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
@@ -64,7 +64,7 @@ def init_tool_node(config: dict):
|
||||
|
||||
node = ToolNode(
|
||||
node_id=str(uuid.uuid4()),
|
||||
config=ToolNodeData.model_validate(config["data"]),
|
||||
data=ToolNodeData.model_validate(config["data"]),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tool_file_manager_factory=tool_file_manager_factory,
|
||||
|
||||
@@ -210,7 +210,9 @@ class TestPauseStatePersistenceLayerTestContainers:
|
||||
execution_id = workflow_run_id or getattr(self, "test_workflow_run_id", None) or str(uuid.uuid4())
|
||||
|
||||
# Create variable pool
|
||||
variable_pool = VariablePool(system_variables=build_system_variables(workflow_execution_id=execution_id))
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(workflow_execution_id=execution_id)
|
||||
)
|
||||
if variables:
|
||||
for (node_id, var_key), value in variables.items():
|
||||
variable_pool.add([node_id, var_key], value)
|
||||
|
||||
@@ -66,7 +66,7 @@ def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepos
|
||||
|
||||
|
||||
def _build_runtime_state(workflow_execution_id: str, app_id: str, workflow_id: str, user_id: str) -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(
|
||||
workflow_execution_id=workflow_execution_id,
|
||||
app_id=app_id,
|
||||
@@ -102,7 +102,7 @@ def _build_graph(
|
||||
start_data = StartNodeData(title="start", variables=[])
|
||||
start_node = StartNode(
|
||||
node_id="start",
|
||||
config=start_data,
|
||||
data=start_data,
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
@@ -117,7 +117,7 @@ def _build_graph(
|
||||
)
|
||||
human_node = HumanInputNode(
|
||||
node_id="human",
|
||||
config=human_data,
|
||||
data=human_data,
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=form_repository,
|
||||
@@ -131,7 +131,7 @@ def _build_graph(
|
||||
)
|
||||
end_node = EndNode(
|
||||
node_id="end",
|
||||
config=end_data,
|
||||
data=end_data,
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
@@ -90,16 +90,34 @@ class TestCreditPoolService:
|
||||
pool = CreditPoolService.get_pool(tenant_id=tenant_id)
|
||||
assert pool.quota_used == credits_required
|
||||
|
||||
def test_check_and_deduct_credits_caps_at_remaining(self, db_session_with_containers: Session):
|
||||
def test_check_and_deduct_credits_raises_without_deducting_when_insufficient(
|
||||
self, db_session_with_containers: Session
|
||||
):
|
||||
tenant_id = self._create_tenant_id()
|
||||
pool = CreditPoolService.create_default_pool(tenant_id)
|
||||
remaining = 5
|
||||
pool.quota_used = pool.quota_limit - remaining
|
||||
quota_used = pool.quota_used
|
||||
db_session_with_containers.commit()
|
||||
|
||||
result = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=200)
|
||||
with pytest.raises(QuotaExceededError, match="Insufficient credits remaining"):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=200)
|
||||
|
||||
db_session_with_containers.expire_all()
|
||||
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
|
||||
assert updated_pool.quota_used == quota_used
|
||||
|
||||
def test_deduct_credits_capped_depletes_available_balance(self, db_session_with_containers: Session):
|
||||
tenant_id = self._create_tenant_id()
|
||||
pool = CreditPoolService.create_default_pool(tenant_id)
|
||||
remaining = 5
|
||||
pool.quota_used = pool.quota_limit - remaining
|
||||
quota_limit = pool.quota_limit
|
||||
db_session_with_containers.commit()
|
||||
|
||||
result = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=200)
|
||||
|
||||
assert result == remaining
|
||||
db_session_with_containers.expire_all()
|
||||
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
|
||||
assert updated_pool.quota_used == pool.quota_limit
|
||||
assert updated_pool.quota_used == quota_limit
|
||||
|
||||
@@ -132,7 +132,9 @@ class TestAdvancedChatGenerateTaskPipeline:
|
||||
pipeline._task_state.answer = "partial answer"
|
||||
pipeline._workflow_run_id = "run-id"
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
|
||||
variable_pool=build_test_variable_pool(
|
||||
variables=build_system_variables(workflow_execution_id="run-id"),
|
||||
),
|
||||
start_at=0.0,
|
||||
total_tokens=7,
|
||||
node_run_steps=3,
|
||||
@@ -372,7 +374,9 @@ class TestAdvancedChatGenerateTaskPipeline:
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_run_id = "run-id"
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
|
||||
variable_pool=VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(workflow_execution_id="run-id")
|
||||
),
|
||||
start_at=0.0,
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish"
|
||||
@@ -583,7 +587,9 @@ class TestAdvancedChatGenerateTaskPipeline:
|
||||
self.items = items
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
|
||||
variable_pool=VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(workflow_execution_id="run-id")
|
||||
),
|
||||
start_at=0.0,
|
||||
)
|
||||
|
||||
@@ -617,7 +623,9 @@ class TestAdvancedChatGenerateTaskPipeline:
|
||||
def test_handle_message_end_event_applies_output_moderation(self, monkeypatch: pytest.MonkeyPatch):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
|
||||
variable_pool=VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(workflow_execution_id="run-id")
|
||||
),
|
||||
start_at=0.0,
|
||||
)
|
||||
pipeline._base_task_pipeline.handle_output_moderation_when_task_finished = lambda answer: "safe"
|
||||
|
||||
@@ -60,7 +60,7 @@ class _StubToolNode(Node[_StubToolNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
config: _StubToolNodeData,
|
||||
data: _StubToolNodeData,
|
||||
*,
|
||||
graph_init_params,
|
||||
graph_runtime_state,
|
||||
@@ -68,7 +68,7 @@ class _StubToolNode(Node[_StubToolNodeData]):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
data=data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
@@ -169,7 +169,7 @@ def _build_graph(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> G
|
||||
|
||||
|
||||
def _build_runtime_state(run_id: str) -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="user", app_id="app", workflow_id="workflow"),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
|
||||
@@ -54,7 +54,7 @@ class TestWorkflowBasedAppRunner:
|
||||
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
|
||||
|
||||
runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=default_system_variables()),
|
||||
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
|
||||
start_at=0.0,
|
||||
)
|
||||
|
||||
@@ -93,7 +93,7 @@ class TestWorkflowBasedAppRunner:
|
||||
def test_get_graph_and_variable_pool_for_single_node_run(self, monkeypatch: pytest.MonkeyPatch):
|
||||
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=default_system_variables()),
|
||||
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
|
||||
start_at=0.0,
|
||||
)
|
||||
|
||||
@@ -164,7 +164,7 @@ class TestWorkflowBasedAppRunner:
|
||||
app_id="app",
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=default_system_variables()),
|
||||
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
|
||||
start_at=0.0,
|
||||
)
|
||||
|
||||
@@ -243,7 +243,7 @@ class TestWorkflowBasedAppRunner:
|
||||
|
||||
runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app")
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=default_system_variables()),
|
||||
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
|
||||
start_at=0.0,
|
||||
)
|
||||
graph_runtime_state.register_paused_node("node-1")
|
||||
@@ -286,7 +286,7 @@ class TestWorkflowBasedAppRunner:
|
||||
|
||||
runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app")
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=default_system_variables()),
|
||||
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
|
||||
start_at=0.0,
|
||||
)
|
||||
workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state))
|
||||
@@ -425,7 +425,7 @@ class TestWorkflowBasedAppRunner:
|
||||
|
||||
runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app")
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=default_system_variables()),
|
||||
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
|
||||
start_at=0.0,
|
||||
)
|
||||
workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state))
|
||||
|
||||
@@ -16,7 +16,7 @@ from models.workflow import Workflow
|
||||
|
||||
|
||||
def _make_graph_state():
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
|
||||
@@ -95,7 +95,9 @@ class TestWorkflowGenerateTaskPipeline:
|
||||
def test_to_blocking_response_falls_back_to_human_input_required_when_pause_event_missing(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
|
||||
variable_pool=build_test_variable_pool(
|
||||
variables=build_system_variables(workflow_execution_id="run-id"),
|
||||
),
|
||||
start_at=0.0,
|
||||
total_tokens=5,
|
||||
node_run_steps=2,
|
||||
@@ -283,7 +285,9 @@ class TestWorkflowGenerateTaskPipeline:
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_execution_id = "run-id"
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
|
||||
variable_pool=VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(workflow_execution_id="run-id")
|
||||
),
|
||||
start_at=0.0,
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish"
|
||||
@@ -725,7 +729,9 @@ class TestWorkflowGenerateTaskPipeline:
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_execution_id = "run-id"
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
|
||||
variable_pool=VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(workflow_execution_id="run-id")
|
||||
),
|
||||
start_at=0.0,
|
||||
)
|
||||
|
||||
@@ -753,7 +759,9 @@ class TestWorkflowGenerateTaskPipeline:
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_execution_id = "run-id"
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
|
||||
variable_pool=VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(workflow_execution_id="run-id")
|
||||
),
|
||||
start_at=0.0,
|
||||
)
|
||||
pipeline._handle_ping_event = lambda event, **kwargs: iter(["ping"])
|
||||
@@ -769,7 +777,9 @@ class TestWorkflowGenerateTaskPipeline:
|
||||
def test_process_stream_response_main_match_paths_and_cleanup(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
|
||||
variable_pool=VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(workflow_execution_id="run-id")
|
||||
),
|
||||
start_at=0.0,
|
||||
)
|
||||
pipeline._base_task_pipeline.queue_manager.listen = lambda: iter(
|
||||
|
||||
@@ -21,7 +21,9 @@ class TestTriggerPostLayer:
|
||||
)
|
||||
runtime_state = SimpleNamespace(
|
||||
outputs={"answer": "ok"},
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")),
|
||||
variable_pool=VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(workflow_execution_id="run-1")
|
||||
),
|
||||
total_tokens=12,
|
||||
)
|
||||
|
||||
@@ -60,7 +62,9 @@ class TestTriggerPostLayer:
|
||||
def test_on_event_handles_missing_trigger_log(self):
|
||||
runtime_state = SimpleNamespace(
|
||||
outputs={},
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")),
|
||||
variable_pool=VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(workflow_execution_id="run-1")
|
||||
),
|
||||
total_tokens=0,
|
||||
)
|
||||
|
||||
@@ -91,7 +95,9 @@ class TestTriggerPostLayer:
|
||||
def test_on_event_ignores_non_status_events(self):
|
||||
runtime_state = SimpleNamespace(
|
||||
outputs={},
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")),
|
||||
variable_pool=VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(workflow_execution_id="run-1")
|
||||
),
|
||||
total_tokens=0,
|
||||
)
|
||||
|
||||
|
||||
617
api/tests/unit_tests/core/app/test_llm_quota.py
Normal file
617
api/tests/unit_tests/core/app/test_llm_quota.py
Normal file
@@ -0,0 +1,617 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine, select
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.llm.quota import (
|
||||
deduct_llm_quota,
|
||||
deduct_llm_quota_for_model,
|
||||
ensure_llm_quota_available,
|
||||
ensure_llm_quota_available_for_model,
|
||||
)
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
||||
from core.errors.error import QuotaExceededError
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from models import TenantCreditPool
|
||||
from models.enums import ProviderQuotaType as ModelProviderQuotaType
|
||||
from models.provider import Provider, ProviderType
|
||||
|
||||
|
||||
def test_ensure_llm_quota_available_for_model_raises_when_system_model_is_exhausted() -> None:
|
||||
provider_configuration = SimpleNamespace(
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
get_provider_model=MagicMock(return_value=SimpleNamespace(status=ModelStatus.QUOTA_EXCEEDED)),
|
||||
)
|
||||
provider_manager = MagicMock()
|
||||
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
|
||||
|
||||
with (
|
||||
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
|
||||
pytest.raises(QuotaExceededError, match="Model provider openai quota exceeded."),
|
||||
):
|
||||
ensure_llm_quota_available_for_model(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
provider_configuration.get_provider_model.assert_called_once_with(
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
|
||||
def test_ensure_llm_quota_available_for_model_raises_when_provider_is_missing() -> None:
|
||||
provider_manager = MagicMock()
|
||||
provider_manager.get_configurations.return_value.get.return_value = None
|
||||
|
||||
with (
|
||||
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
|
||||
pytest.raises(ValueError, match="Provider openai does not exist."),
|
||||
):
|
||||
ensure_llm_quota_available_for_model(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
|
||||
def test_ensure_llm_quota_available_for_model_ignores_custom_provider_configuration() -> None:
|
||||
provider_configuration = SimpleNamespace(
|
||||
using_provider_type=ProviderType.CUSTOM,
|
||||
get_provider_model=MagicMock(),
|
||||
)
|
||||
provider_manager = MagicMock()
|
||||
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
|
||||
|
||||
with patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager):
|
||||
ensure_llm_quota_available_for_model(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
provider_configuration.get_provider_model.assert_not_called()
|
||||
|
||||
|
||||
def test_deduct_llm_quota_for_model_uses_identity_based_trial_billing() -> None:
|
||||
usage = LLMUsage.empty_usage()
|
||||
usage.total_tokens = 42
|
||||
provider_configuration = SimpleNamespace(
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=SimpleNamespace(
|
||||
current_quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_configurations=[
|
||||
SimpleNamespace(
|
||||
quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_unit=QuotaUnit.TOKENS,
|
||||
quota_limit=100,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
provider_manager = MagicMock()
|
||||
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
|
||||
|
||||
with (
|
||||
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
|
||||
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
|
||||
):
|
||||
deduct_llm_quota_for_model(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
mock_deduct_credits.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
credits_required=42,
|
||||
)
|
||||
|
||||
|
||||
def test_deduct_llm_quota_for_model_caps_trial_pool_when_usage_exceeds_remaining() -> None:
|
||||
usage = LLMUsage.empty_usage()
|
||||
usage.total_tokens = 3
|
||||
provider_configuration = SimpleNamespace(
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=SimpleNamespace(
|
||||
current_quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_configurations=[
|
||||
SimpleNamespace(
|
||||
quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_unit=QuotaUnit.TOKENS,
|
||||
quota_limit=100,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
provider_manager = MagicMock()
|
||||
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
TenantCreditPool.__table__.create(engine)
|
||||
with engine.begin() as connection:
|
||||
connection.execute(
|
||||
TenantCreditPool.__table__.insert(),
|
||||
{
|
||||
"id": "trial-pool",
|
||||
"tenant_id": "tenant-id",
|
||||
"pool_type": ModelProviderQuotaType.TRIAL,
|
||||
"quota_limit": 10,
|
||||
"quota_used": 9,
|
||||
},
|
||||
)
|
||||
|
||||
with (
|
||||
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
|
||||
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
|
||||
):
|
||||
deduct_llm_quota_for_model(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
with engine.connect() as connection:
|
||||
quota_used = connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == "trial-pool"))
|
||||
|
||||
assert quota_used == 10
|
||||
|
||||
|
||||
def test_deduct_llm_quota_for_model_returns_for_unbounded_quota() -> None:
|
||||
usage = LLMUsage.empty_usage()
|
||||
usage.total_tokens = 42
|
||||
provider_configuration = SimpleNamespace(
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=SimpleNamespace(
|
||||
current_quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_configurations=[
|
||||
SimpleNamespace(
|
||||
quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_unit=QuotaUnit.TOKENS,
|
||||
quota_limit=-1,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
provider_manager = MagicMock()
|
||||
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
|
||||
|
||||
with (
|
||||
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
|
||||
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
|
||||
):
|
||||
deduct_llm_quota_for_model(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
mock_deduct_credits.assert_not_called()
|
||||
|
||||
|
||||
def test_deduct_llm_quota_for_model_uses_credit_configuration() -> None:
|
||||
usage = LLMUsage.empty_usage()
|
||||
provider_configuration = SimpleNamespace(
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=SimpleNamespace(
|
||||
current_quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_configurations=[
|
||||
SimpleNamespace(
|
||||
quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_unit=QuotaUnit.CREDITS,
|
||||
quota_limit=100,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
provider_manager = MagicMock()
|
||||
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
|
||||
|
||||
with (
|
||||
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
|
||||
patch.object(type(dify_config), "get_model_credits", return_value=9) as mock_get_model_credits,
|
||||
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
|
||||
):
|
||||
deduct_llm_quota_for_model(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
mock_get_model_credits.assert_called_once_with("gpt-4o")
|
||||
mock_deduct_credits.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
credits_required=9,
|
||||
)
|
||||
|
||||
|
||||
def test_deduct_llm_quota_for_model_uses_single_charge_for_times_quota() -> None:
|
||||
usage = LLMUsage.empty_usage()
|
||||
provider_configuration = SimpleNamespace(
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=SimpleNamespace(
|
||||
current_quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_configurations=[
|
||||
SimpleNamespace(
|
||||
quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_unit=QuotaUnit.TIMES,
|
||||
quota_limit=100,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
provider_manager = MagicMock()
|
||||
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
|
||||
|
||||
with (
|
||||
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
|
||||
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
|
||||
):
|
||||
deduct_llm_quota_for_model(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
mock_deduct_credits.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
credits_required=1,
|
||||
)
|
||||
|
||||
|
||||
def test_deduct_llm_quota_for_model_uses_paid_billing_pool() -> None:
|
||||
usage = LLMUsage.empty_usage()
|
||||
usage.total_tokens = 5
|
||||
provider_configuration = SimpleNamespace(
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=SimpleNamespace(
|
||||
current_quota_type=ProviderQuotaType.PAID,
|
||||
quota_configurations=[
|
||||
SimpleNamespace(
|
||||
quota_type=ProviderQuotaType.PAID,
|
||||
quota_unit=QuotaUnit.TOKENS,
|
||||
quota_limit=100,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
provider_manager = MagicMock()
|
||||
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
|
||||
|
||||
with (
|
||||
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
|
||||
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
|
||||
):
|
||||
deduct_llm_quota_for_model(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
mock_deduct_credits.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
credits_required=5,
|
||||
pool_type="paid",
|
||||
)
|
||||
|
||||
|
||||
def test_deduct_llm_quota_for_model_updates_free_quota_usage() -> None:
|
||||
usage = LLMUsage.empty_usage()
|
||||
usage.total_tokens = 3
|
||||
provider_configuration = SimpleNamespace(
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=SimpleNamespace(
|
||||
current_quota_type=ProviderQuotaType.FREE,
|
||||
quota_configurations=[
|
||||
SimpleNamespace(
|
||||
quota_type=ProviderQuotaType.FREE,
|
||||
quota_unit=QuotaUnit.TOKENS,
|
||||
quota_limit=100,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
provider_manager = MagicMock()
|
||||
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
Provider.__table__.create(engine)
|
||||
with engine.begin() as connection:
|
||||
connection.execute(
|
||||
Provider.__table__.insert(),
|
||||
[
|
||||
{
|
||||
"id": "matching-provider",
|
||||
"tenant_id": "tenant-id",
|
||||
"provider_name": "openai",
|
||||
"provider_type": ProviderType.SYSTEM,
|
||||
"quota_type": ProviderQuotaType.FREE,
|
||||
"quota_limit": 100,
|
||||
"quota_used": 10,
|
||||
"is_valid": True,
|
||||
},
|
||||
{
|
||||
"id": "other-tenant",
|
||||
"tenant_id": "other-tenant-id",
|
||||
"provider_name": "openai",
|
||||
"provider_type": ProviderType.SYSTEM,
|
||||
"quota_type": ProviderQuotaType.FREE,
|
||||
"quota_limit": 100,
|
||||
"quota_used": 20,
|
||||
"is_valid": True,
|
||||
},
|
||||
{
|
||||
"id": "other-provider",
|
||||
"tenant_id": "tenant-id",
|
||||
"provider_name": "anthropic",
|
||||
"provider_type": ProviderType.SYSTEM,
|
||||
"quota_type": ProviderQuotaType.FREE,
|
||||
"quota_limit": 100,
|
||||
"quota_used": 30,
|
||||
"is_valid": True,
|
||||
},
|
||||
{
|
||||
"id": "custom-provider",
|
||||
"tenant_id": "tenant-id",
|
||||
"provider_name": "openai",
|
||||
"provider_type": ProviderType.CUSTOM,
|
||||
"quota_type": ProviderQuotaType.FREE,
|
||||
"quota_limit": 100,
|
||||
"quota_used": 40,
|
||||
"is_valid": True,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
with (
|
||||
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
|
||||
patch("core.app.llm.quota.db", SimpleNamespace(engine=engine)),
|
||||
):
|
||||
deduct_llm_quota_for_model(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
with engine.connect() as connection:
|
||||
quota_used_by_id = dict(connection.execute(select(Provider.id, Provider.quota_used)).all())
|
||||
|
||||
assert quota_used_by_id == {
|
||||
"matching-provider": 13,
|
||||
"other-tenant": 20,
|
||||
"other-provider": 30,
|
||||
"custom-provider": 40,
|
||||
}
|
||||
|
||||
with engine.begin() as connection:
|
||||
connection.execute(
|
||||
Provider.__table__.update().where(Provider.id == "matching-provider").values(quota_limit=13, quota_used=13)
|
||||
)
|
||||
|
||||
with (
|
||||
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
|
||||
patch("core.app.llm.quota.db", SimpleNamespace(engine=engine)),
|
||||
pytest.raises(QuotaExceededError, match="Model provider openai quota exceeded."),
|
||||
):
|
||||
deduct_llm_quota_for_model(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
with engine.connect() as connection:
|
||||
exhausted_quota_used = connection.scalar(select(Provider.quota_used).where(Provider.id == "matching-provider"))
|
||||
|
||||
assert exhausted_quota_used == 13
|
||||
|
||||
|
||||
def test_deduct_llm_quota_for_model_caps_free_quota_and_raises_when_usage_exceeds_remaining() -> None:
|
||||
usage = LLMUsage.empty_usage()
|
||||
usage.total_tokens = 3
|
||||
provider_configuration = SimpleNamespace(
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=SimpleNamespace(
|
||||
current_quota_type=ProviderQuotaType.FREE,
|
||||
quota_configurations=[
|
||||
SimpleNamespace(
|
||||
quota_type=ProviderQuotaType.FREE,
|
||||
quota_unit=QuotaUnit.TOKENS,
|
||||
quota_limit=100,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
provider_manager = MagicMock()
|
||||
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
Provider.__table__.create(engine)
|
||||
with engine.begin() as connection:
|
||||
connection.execute(
|
||||
Provider.__table__.insert(),
|
||||
{
|
||||
"id": "matching-provider",
|
||||
"tenant_id": "tenant-id",
|
||||
"provider_name": "openai",
|
||||
"provider_type": ProviderType.SYSTEM,
|
||||
"quota_type": ProviderQuotaType.FREE,
|
||||
"quota_limit": 15,
|
||||
"quota_used": 13,
|
||||
"is_valid": True,
|
||||
},
|
||||
)
|
||||
|
||||
with (
|
||||
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
|
||||
patch("core.app.llm.quota.db", SimpleNamespace(engine=engine)),
|
||||
pytest.raises(QuotaExceededError, match="Model provider openai quota exceeded."),
|
||||
):
|
||||
deduct_llm_quota_for_model(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
with engine.connect() as connection:
|
||||
quota_used = connection.scalar(select(Provider.quota_used).where(Provider.id == "matching-provider"))
|
||||
|
||||
assert quota_used == 15
|
||||
|
||||
|
||||
def test_deduct_llm_quota_for_model_ignores_unknown_quota_type() -> None:
|
||||
usage = LLMUsage.empty_usage()
|
||||
usage.total_tokens = 2
|
||||
provider_configuration = SimpleNamespace(
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=SimpleNamespace(
|
||||
current_quota_type="unexpected",
|
||||
quota_configurations=[
|
||||
SimpleNamespace(
|
||||
quota_type="unexpected",
|
||||
quota_unit=QuotaUnit.TOKENS,
|
||||
quota_limit=100,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
provider_manager = MagicMock()
|
||||
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
|
||||
|
||||
with (
|
||||
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
|
||||
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
|
||||
patch("core.app.llm.quota.sessionmaker") as mock_sessionmaker,
|
||||
):
|
||||
deduct_llm_quota_for_model(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
mock_deduct_credits.assert_not_called()
|
||||
mock_sessionmaker.assert_not_called()
|
||||
|
||||
|
||||
def test_deduct_llm_quota_for_model_ignores_custom_provider_configuration() -> None:
|
||||
usage = LLMUsage.empty_usage()
|
||||
usage.total_tokens = 2
|
||||
provider_configuration = SimpleNamespace(
|
||||
using_provider_type=ProviderType.CUSTOM,
|
||||
system_configuration=SimpleNamespace(
|
||||
current_quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_configurations=[],
|
||||
),
|
||||
)
|
||||
provider_manager = MagicMock()
|
||||
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
|
||||
|
||||
with (
|
||||
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
|
||||
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
|
||||
patch("core.app.llm.quota.sessionmaker") as mock_sessionmaker,
|
||||
):
|
||||
deduct_llm_quota_for_model(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
mock_deduct_credits.assert_not_called()
|
||||
mock_sessionmaker.assert_not_called()
|
||||
|
||||
|
||||
def test_ensure_llm_quota_available_wrapper_warns_and_delegates() -> None:
|
||||
model_instance = SimpleNamespace(
|
||||
provider="openai",
|
||||
model_name="gpt-4o",
|
||||
provider_model_bundle=SimpleNamespace(configuration=SimpleNamespace(tenant_id="tenant-id")),
|
||||
model_type_instance=SimpleNamespace(model_type=ModelType.LLM),
|
||||
)
|
||||
|
||||
with (
|
||||
pytest.deprecated_call(match="ensure_llm_quota_available\\(model_instance=.*deprecated"),
|
||||
patch("core.app.llm.quota.ensure_llm_quota_available_for_model") as mock_ensure,
|
||||
):
|
||||
ensure_llm_quota_available(model_instance=model_instance)
|
||||
|
||||
mock_ensure.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
|
||||
def test_ensure_llm_quota_available_wrapper_rejects_non_llm_model_instances() -> None:
|
||||
model_instance = SimpleNamespace(
|
||||
provider="openai",
|
||||
model_name="gpt-4o",
|
||||
provider_model_bundle=SimpleNamespace(configuration=SimpleNamespace(tenant_id="tenant-id")),
|
||||
model_type_instance=SimpleNamespace(model_type=ModelType.TEXT_EMBEDDING),
|
||||
)
|
||||
|
||||
with (
|
||||
pytest.deprecated_call(match="ensure_llm_quota_available\\(model_instance=.*deprecated"),
|
||||
pytest.raises(ValueError, match="only support LLM model instances"),
|
||||
):
|
||||
ensure_llm_quota_available(model_instance=model_instance)
|
||||
|
||||
|
||||
def test_deduct_llm_quota_wrapper_warns_and_delegates() -> None:
|
||||
usage = LLMUsage.empty_usage()
|
||||
usage.total_tokens = 7
|
||||
model_instance = SimpleNamespace(
|
||||
provider="openai",
|
||||
model_name="gpt-4o",
|
||||
model_type_instance=SimpleNamespace(model_type=ModelType.LLM),
|
||||
provider_model_bundle=SimpleNamespace(configuration=SimpleNamespace()),
|
||||
)
|
||||
|
||||
with (
|
||||
pytest.deprecated_call(match="deduct_llm_quota\\(tenant_id=.*deprecated"),
|
||||
patch("core.app.llm.quota.deduct_llm_quota_for_model") as mock_deduct,
|
||||
):
|
||||
deduct_llm_quota(
|
||||
tenant_id="tenant-id",
|
||||
model_instance=model_instance,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
mock_deduct.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
def test_deduct_llm_quota_wrapper_rejects_non_llm_model_instances() -> None:
|
||||
usage = LLMUsage.empty_usage()
|
||||
model_instance = SimpleNamespace(
|
||||
provider="openai",
|
||||
model_name="gpt-4o",
|
||||
model_type_instance=SimpleNamespace(model_type=ModelType.TEXT_EMBEDDING),
|
||||
provider_model_bundle=SimpleNamespace(configuration=SimpleNamespace()),
|
||||
)
|
||||
|
||||
with (
|
||||
pytest.deprecated_call(match="deduct_llm_quota\\(tenant_id=.*deprecated"),
|
||||
pytest.raises(ValueError, match="only support LLM model instances"),
|
||||
):
|
||||
deduct_llm_quota(
|
||||
tenant_id="tenant-id",
|
||||
model_instance=model_instance,
|
||||
usage=usage,
|
||||
)
|
||||
@@ -8,9 +8,9 @@ from graphon.enums import BuiltinNodeTypes
|
||||
|
||||
|
||||
class DummyNode:
|
||||
def __init__(self, *, node_id, config, graph_init_params, graph_runtime_state, **kwargs):
|
||||
def __init__(self, *, node_id, data, graph_init_params, graph_runtime_state, **kwargs):
|
||||
self.id = node_id
|
||||
self.config = config
|
||||
self.data = data
|
||||
self.graph_init_params = graph_init_params
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self.kwargs = kwargs
|
||||
|
||||
@@ -60,7 +60,10 @@ def _make_layer(
|
||||
workflow_execution_id="run-id",
|
||||
conversation_id="conv-id",
|
||||
)
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(system_variables=system_variables), start_at=0.0)
|
||||
runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.from_bootstrap(system_variables=system_variables),
|
||||
start_at=0.0,
|
||||
)
|
||||
read_only_state = ReadOnlyGraphRuntimeStateWrapper(runtime_state)
|
||||
|
||||
application_generate_entity = WorkflowAppGenerateEntity.model_construct(
|
||||
|
||||
@@ -354,7 +354,8 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None:
|
||||
|
||||
with _patched_session(mock_session):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
|
||||
"core.entities.provider_configuration.create_plugin_model_assembly",
|
||||
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
|
||||
):
|
||||
with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"):
|
||||
with patch(
|
||||
@@ -379,7 +380,10 @@ def test_validate_provider_credentials_without_credential_id() -> None:
|
||||
mock_factory = Mock()
|
||||
mock_factory.provider_credentials_validate.return_value = {"region": "us"}
|
||||
|
||||
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_assembly",
|
||||
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
|
||||
):
|
||||
validated = configuration.validate_provider_credentials(credentials={"region": "us"})
|
||||
|
||||
assert validated == {"region": "us"}
|
||||
@@ -426,23 +430,37 @@ def test_switch_preferred_provider_type_creates_record_when_missing() -> None:
|
||||
|
||||
def test_get_model_type_instance_and_schema_delegate_to_factory() -> None:
|
||||
configuration = _build_provider_configuration()
|
||||
mock_factory = Mock()
|
||||
mock_model_type_instance = Mock()
|
||||
mock_schema = _build_ai_model("gpt-4o")
|
||||
mock_factory.get_model_type_instance.return_value = mock_model_type_instance
|
||||
mock_factory = Mock()
|
||||
mock_factory.get_provider_schema.return_value = configuration.provider
|
||||
mock_factory.get_model_schema.return_value = mock_schema
|
||||
mock_assembly = Mock()
|
||||
mock_assembly.model_runtime = Mock()
|
||||
mock_assembly.model_provider_factory = mock_factory
|
||||
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_provider_factory",
|
||||
return_value=mock_factory,
|
||||
) as mock_factory_builder:
|
||||
with (
|
||||
patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_assembly",
|
||||
return_value=mock_assembly,
|
||||
) as mock_assembly_builder,
|
||||
patch(
|
||||
"core.entities.provider_configuration.create_model_type_instance",
|
||||
return_value=mock_model_type_instance,
|
||||
) as mock_model_builder,
|
||||
):
|
||||
model_type_instance = configuration.get_model_type_instance(ModelType.LLM)
|
||||
model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"})
|
||||
|
||||
assert model_type_instance is mock_model_type_instance
|
||||
assert model_schema is mock_schema
|
||||
assert mock_factory_builder.call_count == 2
|
||||
mock_factory.get_model_type_instance.assert_called_once_with(provider="openai", model_type=ModelType.LLM)
|
||||
assert mock_assembly_builder.call_count == 2
|
||||
mock_factory.get_provider_schema.assert_called_once_with(provider="openai")
|
||||
mock_model_builder.assert_called_once_with(
|
||||
runtime=mock_assembly.model_runtime,
|
||||
provider_schema=configuration.provider,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
mock_factory.get_model_schema.assert_called_once_with(
|
||||
provider="openai",
|
||||
model_type=ModelType.LLM,
|
||||
@@ -456,17 +474,21 @@ def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> Non
|
||||
bound_runtime = Mock()
|
||||
configuration.bind_model_runtime(bound_runtime)
|
||||
|
||||
mock_factory = Mock()
|
||||
mock_model_type_instance = Mock()
|
||||
mock_schema = _build_ai_model("gpt-4o")
|
||||
mock_factory.get_model_type_instance.return_value = mock_model_type_instance
|
||||
mock_factory = Mock()
|
||||
mock_factory.get_provider_schema.return_value = configuration.provider
|
||||
mock_factory.get_model_schema.return_value = mock_schema
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory
|
||||
) as mock_factory_cls,
|
||||
patch("core.entities.provider_configuration.create_plugin_model_provider_factory") as mock_factory_builder,
|
||||
patch("core.entities.provider_configuration.create_plugin_model_assembly") as mock_assembly_builder,
|
||||
patch(
|
||||
"core.entities.provider_configuration.create_model_type_instance",
|
||||
return_value=mock_model_type_instance,
|
||||
) as mock_model_builder,
|
||||
):
|
||||
model_type_instance = configuration.get_model_type_instance(ModelType.LLM)
|
||||
model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"})
|
||||
@@ -474,8 +496,14 @@ def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> Non
|
||||
assert model_type_instance is mock_model_type_instance
|
||||
assert model_schema is mock_schema
|
||||
assert mock_factory_cls.call_count == 2
|
||||
mock_factory_cls.assert_called_with(model_runtime=bound_runtime)
|
||||
mock_factory_builder.assert_not_called()
|
||||
mock_factory_cls.assert_called_with(runtime=bound_runtime)
|
||||
mock_assembly_builder.assert_not_called()
|
||||
mock_factory.get_provider_schema.assert_called_once_with(provider="openai")
|
||||
mock_model_builder.assert_called_once_with(
|
||||
runtime=bound_runtime,
|
||||
provider_schema=configuration.provider,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
|
||||
def test_get_provider_model_returns_none_when_model_not_found() -> None:
|
||||
@@ -504,7 +532,10 @@ def test_get_provider_models_system_deduplicates_sorts_and_filters_active() -> N
|
||||
mock_factory = Mock()
|
||||
mock_factory.get_provider_schema.return_value = provider_schema
|
||||
|
||||
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_assembly",
|
||||
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
|
||||
):
|
||||
all_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=False)
|
||||
active_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=True)
|
||||
|
||||
@@ -722,7 +753,8 @@ def test_validate_provider_credentials_handles_invalid_original_json() -> None:
|
||||
|
||||
with _patched_session(mock_session):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
|
||||
"core.entities.provider_configuration.create_plugin_model_assembly",
|
||||
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
|
||||
):
|
||||
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"):
|
||||
validated = configuration.validate_provider_credentials(
|
||||
@@ -1069,7 +1101,8 @@ def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless
|
||||
|
||||
with _patched_session(mock_session):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
|
||||
"core.entities.provider_configuration.create_plugin_model_assembly",
|
||||
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
|
||||
):
|
||||
with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"):
|
||||
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
|
||||
@@ -1083,7 +1116,10 @@ def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless
|
||||
|
||||
mock_factory2 = Mock()
|
||||
mock_factory2.model_credentials_validate.return_value = {"region": "us"}
|
||||
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory2):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_assembly",
|
||||
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory2),
|
||||
):
|
||||
validated = configuration.validate_custom_model_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o",
|
||||
@@ -1575,7 +1611,8 @@ def test_validate_provider_credentials_uses_empty_original_when_record_missing()
|
||||
|
||||
with _patched_session(mock_session):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
|
||||
"core.entities.provider_configuration.create_plugin_model_assembly",
|
||||
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
|
||||
):
|
||||
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
|
||||
validated = configuration.validate_provider_credentials(
|
||||
@@ -1701,7 +1738,8 @@ def test_validate_custom_model_credentials_handles_invalid_original_json() -> No
|
||||
|
||||
with _patched_session(mock_session):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
|
||||
"core.entities.provider_configuration.create_plugin_model_assembly",
|
||||
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
|
||||
):
|
||||
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
|
||||
validated = configuration.validate_custom_model_credentials(
|
||||
|
||||
@@ -68,8 +68,8 @@ def test_check_moderation_returns_true_when_model_accepts_text(mocker: MockerFix
|
||||
mocker.patch("core.helper.moderation.secrets.choice", return_value="chunk")
|
||||
|
||||
moderation_model = SimpleNamespace(invoke=lambda **invoke_kwargs: invoke_kwargs["text"] == "chunk")
|
||||
factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: moderation_model)
|
||||
mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory)
|
||||
assembly = SimpleNamespace(create_model_type_instance=lambda **_factory_kwargs: moderation_model)
|
||||
mocker.patch("core.helper.moderation.create_plugin_model_assembly", return_value=assembly)
|
||||
|
||||
assert (
|
||||
check_moderation(
|
||||
@@ -91,7 +91,7 @@ def test_check_moderation_returns_true_when_text_is_empty(mocker: MockerFixture)
|
||||
provider_map={openai_provider: hosting_openai},
|
||||
),
|
||||
)
|
||||
factory_mock = mocker.patch("core.helper.moderation.create_plugin_model_provider_factory")
|
||||
factory_mock = mocker.patch("core.helper.moderation.create_plugin_model_assembly")
|
||||
choice_mock = mocker.patch("core.helper.moderation.secrets.choice")
|
||||
|
||||
assert (
|
||||
@@ -119,8 +119,8 @@ def test_check_moderation_returns_false_when_model_rejects_text(mocker: MockerFi
|
||||
mocker.patch("core.helper.moderation.secrets.choice", return_value="chunk")
|
||||
|
||||
moderation_model = SimpleNamespace(invoke=lambda **_invoke_kwargs: False)
|
||||
factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: moderation_model)
|
||||
mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory)
|
||||
assembly = SimpleNamespace(create_model_type_instance=lambda **_factory_kwargs: moderation_model)
|
||||
mocker.patch("core.helper.moderation.create_plugin_model_assembly", return_value=assembly)
|
||||
|
||||
assert (
|
||||
check_moderation(
|
||||
@@ -147,8 +147,8 @@ def test_check_moderation_raises_bad_request_when_provider_call_fails(mocker: Mo
|
||||
failing_model = SimpleNamespace(
|
||||
invoke=lambda **_invoke_kwargs: (_ for _ in ()).throw(RuntimeError("boom")),
|
||||
)
|
||||
factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: failing_model)
|
||||
mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory)
|
||||
assembly = SimpleNamespace(create_model_type_instance=lambda **_factory_kwargs: failing_model)
|
||||
mocker.patch("core.helper.moderation.create_plugin_model_assembly", return_value=assembly)
|
||||
|
||||
with pytest.raises(InvokeBadRequestError, match="Rate limit exceeded, please try again later."):
|
||||
check_moderation(
|
||||
|
||||
@@ -2,6 +2,7 @@ from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.plugin.impl.model_runtime_factory import create_model_type_instance
|
||||
from graphon.model_runtime.entities.common_entities import I18nObject
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import (
|
||||
@@ -73,7 +74,7 @@ def test_model_provider_factory_resolves_runtime_provider_name() -> None:
|
||||
supported_model_types=[ModelType.LLM],
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
)
|
||||
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime([provider]))
|
||||
factory = ModelProviderFactory(runtime=_FakeModelRuntime([provider]))
|
||||
|
||||
provider_schema = factory.get_model_provider("openai")
|
||||
|
||||
@@ -98,7 +99,7 @@ def test_model_provider_factory_resolves_canonical_short_name_independent_of_pro
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
),
|
||||
]
|
||||
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
|
||||
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
|
||||
|
||||
provider_schema = factory.get_model_provider("openai")
|
||||
|
||||
@@ -107,8 +108,8 @@ def test_model_provider_factory_resolves_canonical_short_name_independent_of_pro
|
||||
|
||||
|
||||
def test_model_provider_factory_requires_runtime() -> None:
|
||||
with pytest.raises(ValueError, match="model_runtime is required"):
|
||||
ModelProviderFactory(model_runtime=None) # type: ignore[arg-type]
|
||||
with pytest.raises(ValueError, match="runtime is required"):
|
||||
ModelProviderFactory(runtime=None) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_model_provider_factory_get_providers_returns_runtime_providers() -> None:
|
||||
@@ -119,7 +120,7 @@ def test_model_provider_factory_get_providers_returns_runtime_providers() -> Non
|
||||
supported_model_types=[ModelType.LLM],
|
||||
)
|
||||
]
|
||||
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
|
||||
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
|
||||
|
||||
result = factory.get_providers()
|
||||
|
||||
@@ -133,7 +134,7 @@ def test_model_provider_factory_get_provider_schema_delegates_to_provider_lookup
|
||||
provider_name="openai",
|
||||
supported_model_types=[ModelType.LLM],
|
||||
)
|
||||
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime([provider]))
|
||||
factory = ModelProviderFactory(runtime=_FakeModelRuntime([provider]))
|
||||
|
||||
result = factory.get_provider_schema("openai")
|
||||
|
||||
@@ -142,7 +143,7 @@ def test_model_provider_factory_get_provider_schema_delegates_to_provider_lookup
|
||||
|
||||
def test_model_provider_factory_raises_for_unknown_provider() -> None:
|
||||
factory = ModelProviderFactory(
|
||||
model_runtime=_FakeModelRuntime(
|
||||
runtime=_FakeModelRuntime(
|
||||
[
|
||||
_build_provider(
|
||||
provider="langgenius/openai/openai",
|
||||
@@ -172,7 +173,7 @@ def test_model_provider_factory_get_models_filters_provider_and_model_type() ->
|
||||
models=[_build_model("rerank-v3", ModelType.RERANK)],
|
||||
),
|
||||
]
|
||||
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
|
||||
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
|
||||
|
||||
results = factory.get_models(provider="openai", model_type=ModelType.LLM)
|
||||
|
||||
@@ -196,7 +197,7 @@ def test_model_provider_factory_get_models_skips_providers_without_requested_mod
|
||||
models=[_build_model("eleven_multilingual_v2", ModelType.TTS)],
|
||||
),
|
||||
]
|
||||
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
|
||||
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
|
||||
|
||||
results = factory.get_models(model_type=ModelType.TTS)
|
||||
|
||||
@@ -214,7 +215,7 @@ def test_model_provider_factory_get_models_without_model_type_keeps_all_provider
|
||||
models=[_build_model("gpt-4o-mini", ModelType.LLM), _build_model("tts-1", ModelType.TTS)],
|
||||
)
|
||||
]
|
||||
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
|
||||
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
|
||||
|
||||
results = factory.get_models(provider="openai")
|
||||
|
||||
@@ -242,7 +243,7 @@ def test_model_provider_factory_validates_provider_credentials() -> None:
|
||||
)
|
||||
]
|
||||
)
|
||||
factory = ModelProviderFactory(model_runtime=runtime)
|
||||
factory = ModelProviderFactory(runtime=runtime)
|
||||
|
||||
filtered = factory.provider_credentials_validate(
|
||||
provider="openai",
|
||||
@@ -258,7 +259,7 @@ def test_model_provider_factory_validates_provider_credentials() -> None:
|
||||
|
||||
def test_model_provider_factory_provider_credentials_validate_requires_schema() -> None:
|
||||
factory = ModelProviderFactory(
|
||||
model_runtime=_FakeModelRuntime(
|
||||
runtime=_FakeModelRuntime(
|
||||
[
|
||||
_build_provider(
|
||||
provider="langgenius/openai/openai",
|
||||
@@ -294,7 +295,7 @@ def test_model_provider_factory_validates_model_credentials() -> None:
|
||||
)
|
||||
]
|
||||
)
|
||||
factory = ModelProviderFactory(model_runtime=runtime)
|
||||
factory = ModelProviderFactory(runtime=runtime)
|
||||
|
||||
filtered = factory.model_credentials_validate(
|
||||
provider="openai",
|
||||
@@ -314,7 +315,7 @@ def test_model_provider_factory_validates_model_credentials() -> None:
|
||||
|
||||
def test_model_provider_factory_model_credentials_validate_requires_schema() -> None:
|
||||
factory = ModelProviderFactory(
|
||||
model_runtime=_FakeModelRuntime(
|
||||
runtime=_FakeModelRuntime(
|
||||
[
|
||||
_build_provider(
|
||||
provider="langgenius/openai/openai",
|
||||
@@ -346,7 +347,7 @@ def test_model_provider_factory_get_model_schema_and_icon_use_canonical_provider
|
||||
)
|
||||
runtime.get_model_schema.return_value = "schema"
|
||||
runtime.get_provider_icon.return_value = (b"icon", "image/png")
|
||||
factory = ModelProviderFactory(model_runtime=runtime)
|
||||
factory = ModelProviderFactory(runtime=runtime)
|
||||
|
||||
assert (
|
||||
factory.get_model_schema(
|
||||
@@ -382,39 +383,43 @@ def test_model_provider_factory_get_model_schema_and_icon_use_canonical_provider
|
||||
(ModelType.TTS, TTSModel),
|
||||
],
|
||||
)
|
||||
def test_model_provider_factory_builds_model_type_instances(
|
||||
def test_create_model_type_instance_builds_model_wrappers(
|
||||
model_type: ModelType,
|
||||
expected_type: type[object],
|
||||
) -> None:
|
||||
factory = ModelProviderFactory(
|
||||
model_runtime=_FakeModelRuntime(
|
||||
[
|
||||
_build_provider(
|
||||
provider="langgenius/openai/openai",
|
||||
provider_name="openai",
|
||||
supported_model_types=[model_type],
|
||||
)
|
||||
]
|
||||
)
|
||||
runtime = _FakeModelRuntime(
|
||||
[
|
||||
_build_provider(
|
||||
provider="langgenius/openai/openai",
|
||||
provider_name="openai",
|
||||
supported_model_types=[model_type],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
instance = factory.get_model_type_instance("openai", model_type)
|
||||
instance = create_model_type_instance(
|
||||
runtime=runtime,
|
||||
provider_schema=runtime.fetch_model_providers()[0],
|
||||
model_type=model_type,
|
||||
)
|
||||
|
||||
assert isinstance(instance, expected_type)
|
||||
|
||||
|
||||
def test_model_provider_factory_rejects_unsupported_model_type() -> None:
|
||||
factory = ModelProviderFactory(
|
||||
model_runtime=_FakeModelRuntime(
|
||||
[
|
||||
_build_provider(
|
||||
provider="langgenius/openai/openai",
|
||||
provider_name="openai",
|
||||
supported_model_types=[ModelType.LLM],
|
||||
)
|
||||
]
|
||||
)
|
||||
def test_create_model_type_instance_rejects_unsupported_model_type() -> None:
|
||||
runtime = _FakeModelRuntime(
|
||||
[
|
||||
_build_provider(
|
||||
provider="langgenius/openai/openai",
|
||||
provider_name="openai",
|
||||
supported_model_types=[ModelType.LLM],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported model type: unsupported"):
|
||||
factory.get_model_type_instance("openai", "unsupported") # type: ignore[arg-type]
|
||||
create_model_type_instance(
|
||||
runtime=runtime,
|
||||
provider_schema=runtime.fetch_model_providers()[0],
|
||||
model_type="unsupported", # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
@@ -31,6 +31,6 @@ def test_plugin_model_assembly_reuses_single_runtime_across_views():
|
||||
assert assembly.model_manager is model_manager
|
||||
|
||||
mock_runtime_factory.assert_called_once_with(tenant_id="tenant-1", user_id="user-1")
|
||||
mock_provider_factory_cls.assert_called_once_with(model_runtime=runtime)
|
||||
mock_provider_factory_cls.assert_called_once_with(runtime=runtime)
|
||||
mock_provider_manager_cls.assert_called_once_with(model_runtime=runtime)
|
||||
mock_model_manager_cls.assert_called_once_with(provider_manager=provider_manager)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, sentinel
|
||||
from unittest.mock import Mock, patch, sentinel
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -13,6 +13,8 @@ from core.plugin.impl.model import PluginModelClient
|
||||
from core.plugin.impl.model_runtime import TENANT_SCOPE_SCHEMA_CACHE_USER_ID, PluginModelRuntime
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime
|
||||
from graphon.model_runtime.entities.common_entities import I18nObject
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from graphon.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||
|
||||
@@ -146,7 +148,31 @@ class TestPluginModelRuntime:
|
||||
|
||||
def test_invoke_llm_resolves_plugin_fields(self) -> None:
|
||||
client = Mock(spec=PluginModelClient)
|
||||
client.invoke_llm.return_value = sentinel.result
|
||||
usage = LLMUsage.empty_usage()
|
||||
client.invoke_llm.return_value = iter(
|
||||
[
|
||||
LLMResultChunk(
|
||||
model="gpt-4o-mini",
|
||||
prompt_messages=[],
|
||||
system_fingerprint="fp-plugin",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content="plugin "),
|
||||
),
|
||||
),
|
||||
LLMResultChunk(
|
||||
model="gpt-4o-mini",
|
||||
prompt_messages=[],
|
||||
system_fingerprint="fp-plugin",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=1,
|
||||
message=AssistantPromptMessage(content="response"),
|
||||
usage=usage,
|
||||
finish_reason="stop",
|
||||
),
|
||||
),
|
||||
]
|
||||
)
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
|
||||
|
||||
result = runtime.invoke_llm(
|
||||
@@ -160,7 +186,11 @@ class TestPluginModelRuntime:
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert result is sentinel.result
|
||||
assert result.model == "gpt-4o-mini"
|
||||
assert result.prompt_messages == []
|
||||
assert result.message.content == "plugin response"
|
||||
assert result.usage == usage
|
||||
assert result.system_fingerprint == "fp-plugin"
|
||||
client.invoke_llm.assert_called_once_with(
|
||||
tenant_id="tenant",
|
||||
user_id="user",
|
||||
@@ -175,6 +205,38 @@ class TestPluginModelRuntime:
|
||||
stream=False,
|
||||
)
|
||||
|
||||
def test_invoke_llm_returns_plugin_stream_directly(self) -> None:
|
||||
client = Mock(spec=PluginModelClient)
|
||||
stream_result = iter([])
|
||||
client.invoke_llm.return_value = stream_result
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
|
||||
|
||||
result = runtime.invoke_llm(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "secret"},
|
||||
model_parameters={"temperature": 0.3},
|
||||
prompt_messages=[],
|
||||
tools=None,
|
||||
stop=("END",),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
assert result is stream_result
|
||||
client.invoke_llm.assert_called_once_with(
|
||||
tenant_id="tenant",
|
||||
user_id="user",
|
||||
plugin_id="langgenius/openai",
|
||||
provider="openai",
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "secret"},
|
||||
model_parameters={"temperature": 0.3},
|
||||
prompt_messages=[],
|
||||
tools=None,
|
||||
stop=["END"],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
def test_invoke_llm_rejects_per_call_user_override(self) -> None:
|
||||
client = Mock(spec=PluginModelClient)
|
||||
client.invoke_llm.return_value = sentinel.result
|
||||
@@ -267,6 +329,129 @@ def test_get_model_schema_uses_cached_schema_without_hitting_client(monkeypatch:
|
||||
client.get_model_schema.assert_not_called()
|
||||
|
||||
|
||||
def test_structured_output_adapter_invokes_bound_runtime_streaming() -> None:
|
||||
runtime = Mock()
|
||||
runtime.invoke_llm.return_value = sentinel.stream_result
|
||||
adapter = model_runtime_module._PluginStructuredOutputModelInstance(
|
||||
runtime=runtime,
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "secret"},
|
||||
)
|
||||
tool = Mock()
|
||||
|
||||
result = adapter.invoke_llm(
|
||||
prompt_messages=[],
|
||||
model_parameters=None,
|
||||
tools=[tool],
|
||||
stop=["END"],
|
||||
stream=True,
|
||||
callbacks=sentinel.callbacks,
|
||||
)
|
||||
|
||||
assert result is sentinel.stream_result
|
||||
runtime.invoke_llm.assert_called_once_with(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "secret"},
|
||||
model_parameters={},
|
||||
prompt_messages=[],
|
||||
tools=[tool],
|
||||
stop=["END"],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
|
||||
def test_structured_output_adapter_invokes_bound_runtime_non_streaming() -> None:
|
||||
runtime = Mock()
|
||||
runtime.invoke_llm.return_value = sentinel.result
|
||||
adapter = model_runtime_module._PluginStructuredOutputModelInstance(
|
||||
runtime=runtime,
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "secret"},
|
||||
)
|
||||
|
||||
result = adapter.invoke_llm(
|
||||
prompt_messages=[],
|
||||
model_parameters={"temperature": 0},
|
||||
tools=None,
|
||||
stop=None,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert result is sentinel.result
|
||||
runtime.invoke_llm.assert_called_once_with(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "secret"},
|
||||
model_parameters={"temperature": 0},
|
||||
prompt_messages=[],
|
||||
tools=None,
|
||||
stop=None,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_llm_with_structured_output_delegates_with_bound_adapter() -> None:
|
||||
client = Mock(spec=PluginModelClient)
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
|
||||
schema = _build_model_schema()
|
||||
runtime.get_model_schema = Mock(return_value=schema) # type: ignore[method-assign]
|
||||
|
||||
with patch.object(
|
||||
model_runtime_module,
|
||||
"invoke_llm_with_structured_output_helper",
|
||||
return_value=sentinel.structured_result,
|
||||
) as mock_helper:
|
||||
result = runtime.invoke_llm_with_structured_output(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "secret"},
|
||||
json_schema={"type": "object"},
|
||||
model_parameters={"temperature": 0},
|
||||
prompt_messages=[],
|
||||
stop=("END",),
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert result is sentinel.structured_result
|
||||
runtime.get_model_schema.assert_called_once_with(
|
||||
provider="langgenius/openai/openai",
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "secret"},
|
||||
)
|
||||
helper_kwargs = mock_helper.call_args.kwargs
|
||||
assert helper_kwargs["provider"] == "langgenius/openai/openai"
|
||||
assert helper_kwargs["model_schema"] == schema
|
||||
assert helper_kwargs["json_schema"] == {"type": "object"}
|
||||
assert helper_kwargs["model_parameters"] == {"temperature": 0}
|
||||
assert helper_kwargs["prompt_messages"] == []
|
||||
assert helper_kwargs["tools"] is None
|
||||
assert helper_kwargs["stop"] == ["END"]
|
||||
assert helper_kwargs["stream"] is False
|
||||
assert isinstance(helper_kwargs["model_instance"], model_runtime_module._PluginStructuredOutputModelInstance)
|
||||
|
||||
|
||||
def test_invoke_llm_with_structured_output_raises_when_model_schema_is_missing() -> None:
|
||||
client = Mock(spec=PluginModelClient)
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
|
||||
runtime.get_model_schema = Mock(return_value=None) # type: ignore[method-assign]
|
||||
|
||||
with pytest.raises(ValueError, match="Model schema not found for gpt-4o-mini"):
|
||||
runtime.invoke_llm_with_structured_output(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "secret"},
|
||||
json_schema={"type": "object"},
|
||||
model_parameters={},
|
||||
prompt_messages=[],
|
||||
stop=None,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
|
||||
def test_get_model_schema_deletes_invalid_cache_and_refetches(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
client = Mock(spec=PluginModelClient)
|
||||
schema = _build_model_schema()
|
||||
|
||||
@@ -289,7 +289,7 @@ def test_get_default_model_uses_injected_runtime_for_existing_default_record(moc
|
||||
|
||||
result = manager.get_default_model("tenant-id", ModelType.LLM)
|
||||
|
||||
mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime)
|
||||
mock_factory_cls.assert_called_once_with(runtime=manager._model_runtime)
|
||||
assert result is not None
|
||||
assert result.model == "gpt-4"
|
||||
assert result.provider.provider == "openai"
|
||||
@@ -316,7 +316,7 @@ def test_get_configurations_uses_injected_runtime_and_adds_provider_aliases(mock
|
||||
result = manager.get_configurations("tenant-id")
|
||||
|
||||
expected_alias = str(ModelProviderID("openai"))
|
||||
mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime)
|
||||
mock_factory_cls.assert_called_once_with(runtime=manager._model_runtime)
|
||||
assert result.tenant_id == "tenant-id"
|
||||
assert expected_alias in provider_records
|
||||
assert expected_alias in provider_model_records
|
||||
@@ -402,7 +402,7 @@ def test_get_configurations_reuses_cached_result_for_same_tenant(mocker: MockerF
|
||||
|
||||
assert first is second
|
||||
mock_get_all_providers.assert_called_once_with("tenant-id")
|
||||
mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime)
|
||||
mock_factory_cls.assert_called_once_with(runtime=manager._model_runtime)
|
||||
mock_provider_configuration.assert_called_once()
|
||||
provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime)
|
||||
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import logging
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom
|
||||
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_manager import ModelInstance
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
|
||||
from graphon.graph_engine.entities.commands import CommandType
|
||||
from graphon.graph_events import NodeRunSucceededEvent
|
||||
@@ -14,17 +13,7 @@ from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
|
||||
def _build_dify_context() -> DifyRunContext:
|
||||
return DifyRunContext(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
user_id="user-id",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
|
||||
|
||||
def _build_succeeded_event() -> NodeRunSucceededEvent:
|
||||
def _build_succeeded_event(*, provider: str = "openai", model_name: str = "gpt-4o") -> NodeRunSucceededEvent:
|
||||
return NodeRunSucceededEvent(
|
||||
id="execution-id",
|
||||
node_id="llm-node-id",
|
||||
@@ -32,113 +21,162 @@ def _build_succeeded_event() -> NodeRunSucceededEvent:
|
||||
start_at=datetime.now(),
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={"question": "hello"},
|
||||
inputs={
|
||||
"question": "hello",
|
||||
"model_provider": provider,
|
||||
"model_name": model_name,
|
||||
},
|
||||
llm_usage=LLMUsage.empty_usage(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _build_wrapped_model_instance() -> tuple[SimpleNamespace, ModelInstance]:
|
||||
raw_model_instance = ModelInstance.__new__(ModelInstance)
|
||||
return SimpleNamespace(_model_instance=raw_model_instance), raw_model_instance
|
||||
def _build_public_model_identity(*, provider: str = "openai", model_name: str = "gpt-4o") -> SimpleNamespace:
|
||||
return SimpleNamespace(provider=provider, name=model_name)
|
||||
|
||||
|
||||
def _build_node_data(*, model: SimpleNamespace | None = None) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
error_strategy=None,
|
||||
retry_config=SimpleNamespace(retry_enabled=False),
|
||||
model=model,
|
||||
)
|
||||
|
||||
|
||||
def _build_node(*, node_type: BuiltinNodeTypes = BuiltinNodeTypes.LLM) -> MagicMock:
|
||||
node = MagicMock()
|
||||
node.id = "node-id"
|
||||
node.execution_id = "execution-id"
|
||||
node.node_type = node_type
|
||||
node.node_data = _build_node_data(model=_build_public_model_identity())
|
||||
node.model_instance = SimpleNamespace(provider="stale-provider", model_name="stale-model")
|
||||
return node
|
||||
|
||||
|
||||
class _RunnableQuotaNode:
|
||||
id = "node-id"
|
||||
execution_id = "execution-id"
|
||||
node_type = BuiltinNodeTypes.LLM
|
||||
title = "LLM node"
|
||||
|
||||
def __init__(self, *, stop_event: threading.Event, node_data: SimpleNamespace | None = None) -> None:
|
||||
self.node_data = node_data or _build_node_data(model=_build_public_model_identity())
|
||||
self.graph_runtime_state = SimpleNamespace(stop_event=stop_event)
|
||||
self.original_run_called = False
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
self.original_run_called = True
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED)
|
||||
|
||||
|
||||
def test_deduct_quota_called_for_successful_llm_node() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
node = MagicMock()
|
||||
node.id = "llm-node-id"
|
||||
node.execution_id = "execution-id"
|
||||
node.node_type = BuiltinNodeTypes.LLM
|
||||
node.tenant_id = "tenant-id"
|
||||
node.require_run_context_value.return_value = _build_dify_context()
|
||||
node.model_instance, raw_model_instance = _build_wrapped_model_instance()
|
||||
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
node = _build_node(node_type=BuiltinNodeTypes.LLM)
|
||||
result_event = _build_succeeded_event()
|
||||
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
|
||||
|
||||
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", autospec=True) as mock_deduct:
|
||||
layer.on_node_run_end(node=node, error=None, result_event=result_event)
|
||||
|
||||
mock_deduct.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
model_instance=raw_model_instance,
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
usage=result_event.node_run_result.llm_usage,
|
||||
)
|
||||
|
||||
|
||||
def test_deduct_quota_called_for_question_classifier_node() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
node = MagicMock()
|
||||
node.id = "question-classifier-node-id"
|
||||
node.execution_id = "execution-id"
|
||||
node.node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER
|
||||
node.tenant_id = "tenant-id"
|
||||
node.require_run_context_value.return_value = _build_dify_context()
|
||||
node.model_instance, raw_model_instance = _build_wrapped_model_instance()
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
node = _build_node(node_type=BuiltinNodeTypes.QUESTION_CLASSIFIER)
|
||||
result_event = _build_succeeded_event(provider="anthropic", model_name="claude-3-7-sonnet")
|
||||
|
||||
result_event = _build_succeeded_event()
|
||||
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
|
||||
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", autospec=True) as mock_deduct:
|
||||
layer.on_node_run_end(node=node, error=None, result_event=result_event)
|
||||
|
||||
mock_deduct.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
model_instance=raw_model_instance,
|
||||
provider="anthropic",
|
||||
model="claude-3-7-sonnet",
|
||||
usage=result_event.node_run_result.llm_usage,
|
||||
)
|
||||
|
||||
|
||||
def test_non_llm_node_is_ignored() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
node = MagicMock()
|
||||
node.id = "start-node-id"
|
||||
node.execution_id = "execution-id"
|
||||
node.node_type = BuiltinNodeTypes.START
|
||||
node.tenant_id = "tenant-id"
|
||||
node.require_run_context_value.return_value = _build_dify_context()
|
||||
node._model_instance = object()
|
||||
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
node = _build_node(node_type=BuiltinNodeTypes.START)
|
||||
result_event = _build_succeeded_event()
|
||||
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
|
||||
|
||||
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", autospec=True) as mock_deduct:
|
||||
layer.on_node_run_end(node=node, error=None, result_event=result_event)
|
||||
|
||||
mock_deduct.assert_not_called()
|
||||
|
||||
|
||||
def test_quota_error_is_handled_in_layer() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
node = MagicMock()
|
||||
node.id = "llm-node-id"
|
||||
node.execution_id = "execution-id"
|
||||
node.node_type = BuiltinNodeTypes.LLM
|
||||
node.tenant_id = "tenant-id"
|
||||
node.require_run_context_value.return_value = _build_dify_context()
|
||||
node.model_instance = object()
|
||||
def test_precheck_ignores_non_quota_node() -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
node = _build_node(node_type=BuiltinNodeTypes.START)
|
||||
|
||||
result_event = _build_succeeded_event()
|
||||
with patch(
|
||||
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
|
||||
autospec=True,
|
||||
side_effect=ValueError("quota exceeded"),
|
||||
):
|
||||
layer.on_node_run_end(node=node, error=None, result_event=result_event)
|
||||
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
|
||||
layer.on_node_run_start(node)
|
||||
|
||||
mock_check.assert_not_called()
|
||||
|
||||
|
||||
def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
def test_quota_error_is_handled_in_layer(caplog) -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
stop_event = threading.Event()
|
||||
layer.command_channel = MagicMock()
|
||||
|
||||
node = MagicMock()
|
||||
node.id = "llm-node-id"
|
||||
node.execution_id = "execution-id"
|
||||
node.node_type = BuiltinNodeTypes.LLM
|
||||
node.tenant_id = "tenant-id"
|
||||
node.require_run_context_value.return_value = _build_dify_context()
|
||||
node.model_instance, _ = _build_wrapped_model_instance()
|
||||
node = _build_node(node_type=BuiltinNodeTypes.LLM)
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.stop_event = stop_event
|
||||
result_event = _build_succeeded_event()
|
||||
|
||||
with (
|
||||
caplog.at_level(logging.ERROR, logger="core.app.workflow.layers.llm_quota"),
|
||||
patch(
|
||||
"core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model",
|
||||
autospec=True,
|
||||
side_effect=ValueError("quota exceeded"),
|
||||
) as mock_deduct,
|
||||
):
|
||||
layer.on_node_run_end(node=node, error=None, result_event=result_event)
|
||||
|
||||
mock_deduct.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
usage=result_event.node_run_result.llm_usage,
|
||||
)
|
||||
assert "LLM quota deduction failed, node_id=node-id" in caplog.text
|
||||
assert not stop_event.is_set()
|
||||
layer.command_channel.send_command.assert_not_called()
|
||||
|
||||
|
||||
def test_send_abort_command_is_noop_without_channel_or_after_abort() -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
|
||||
layer._send_abort_command(reason="no channel")
|
||||
|
||||
layer.command_channel = MagicMock()
|
||||
layer._abort_sent = True
|
||||
layer._send_abort_command(reason="already aborted")
|
||||
|
||||
layer.command_channel.send_command.assert_not_called()
|
||||
|
||||
|
||||
def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
stop_event = threading.Event()
|
||||
layer.command_channel = MagicMock()
|
||||
|
||||
node = _build_node(node_type=BuiltinNodeTypes.LLM)
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.stop_event = stop_event
|
||||
|
||||
result_event = _build_succeeded_event()
|
||||
with patch(
|
||||
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
|
||||
"core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model",
|
||||
autospec=True,
|
||||
side_effect=QuotaExceededError("No credits remaining"),
|
||||
):
|
||||
@@ -152,19 +190,16 @@ def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None:
|
||||
|
||||
|
||||
def test_quota_precheck_failure_aborts_workflow_immediately() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
stop_event = threading.Event()
|
||||
layer.command_channel = MagicMock()
|
||||
|
||||
node = MagicMock()
|
||||
node.id = "llm-node-id"
|
||||
node.node_type = BuiltinNodeTypes.LLM
|
||||
node.model_instance, _ = _build_wrapped_model_instance()
|
||||
node = _build_node(node_type=BuiltinNodeTypes.LLM)
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.stop_event = stop_event
|
||||
|
||||
with patch(
|
||||
"core.app.workflow.layers.llm_quota.ensure_llm_quota_available",
|
||||
"core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model",
|
||||
autospec=True,
|
||||
side_effect=QuotaExceededError("Model provider openai quota exceeded."),
|
||||
):
|
||||
@@ -177,21 +212,140 @@ def test_quota_precheck_failure_aborts_workflow_immediately() -> None:
|
||||
assert abort_command.reason == "Model provider openai quota exceeded."
|
||||
|
||||
|
||||
def test_quota_precheck_passes_without_abort() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
def test_quota_precheck_failure_blocks_current_node_run() -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
stop_event = threading.Event()
|
||||
layer.command_channel = MagicMock()
|
||||
|
||||
node = MagicMock()
|
||||
node.id = "llm-node-id"
|
||||
node.node_type = BuiltinNodeTypes.LLM
|
||||
node.model_instance, raw_model_instance = _build_wrapped_model_instance()
|
||||
node = _RunnableQuotaNode(stop_event=stop_event)
|
||||
|
||||
with patch(
|
||||
"core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model",
|
||||
autospec=True,
|
||||
side_effect=QuotaExceededError("Model provider openai quota exceeded."),
|
||||
):
|
||||
layer.on_node_run_start(node)
|
||||
|
||||
result = node._run()
|
||||
assert not node.original_run_called
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert result.error == "Model provider openai quota exceeded."
|
||||
assert result.error_type == QuotaExceededError.__name__
|
||||
|
||||
|
||||
def test_missing_model_identity_blocks_current_node_run() -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
stop_event = threading.Event()
|
||||
layer.command_channel = MagicMock()
|
||||
|
||||
node = _RunnableQuotaNode(stop_event=stop_event, node_data=_build_node_data())
|
||||
|
||||
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
|
||||
layer.on_node_run_start(node)
|
||||
|
||||
result = node._run()
|
||||
assert not node.original_run_called
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert result.error == "LLM quota check requires public node model identity before execution."
|
||||
assert result.error_type == "LLMQuotaIdentityError"
|
||||
mock_check.assert_not_called()
|
||||
|
||||
|
||||
def test_quota_precheck_passes_without_abort() -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
stop_event = threading.Event()
|
||||
layer.command_channel = MagicMock()
|
||||
|
||||
node = _build_node(node_type=BuiltinNodeTypes.LLM)
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.stop_event = stop_event
|
||||
|
||||
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available", autospec=True) as mock_check:
|
||||
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
|
||||
layer.on_node_run_start(node)
|
||||
|
||||
assert not stop_event.is_set()
|
||||
mock_check.assert_called_once_with(model_instance=raw_model_instance)
|
||||
mock_check.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
)
|
||||
layer.command_channel.send_command.assert_not_called()
|
||||
|
||||
|
||||
def test_precheck_reads_model_identity_from_data_when_node_data_is_absent() -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
node = SimpleNamespace(
|
||||
id="node-id",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
data=_build_node_data(model=_build_public_model_identity(provider="anthropic", model_name="claude")),
|
||||
)
|
||||
|
||||
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
|
||||
layer.on_node_run_start(node)
|
||||
|
||||
mock_check.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
provider="anthropic",
|
||||
model="claude",
|
||||
)
|
||||
|
||||
|
||||
def test_precheck_rejects_invalid_public_model_identity() -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
stop_event = threading.Event()
|
||||
layer.command_channel = MagicMock()
|
||||
|
||||
node = _build_node(node_type=BuiltinNodeTypes.LLM)
|
||||
node.node_data = _build_node_data(model=_build_public_model_identity(provider="", model_name="gpt-4o"))
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.stop_event = stop_event
|
||||
|
||||
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
|
||||
layer.on_node_run_start(node)
|
||||
|
||||
assert stop_event.is_set()
|
||||
mock_check.assert_not_called()
|
||||
layer.command_channel.send_command.assert_called_once()
|
||||
|
||||
|
||||
def test_precheck_requires_public_node_model_config() -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
stop_event = threading.Event()
|
||||
layer.command_channel = MagicMock()
|
||||
|
||||
node = _build_node(node_type=BuiltinNodeTypes.LLM)
|
||||
node.node_data = _build_node_data()
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.stop_event = stop_event
|
||||
|
||||
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
|
||||
layer.on_node_run_start(node)
|
||||
|
||||
assert stop_event.is_set()
|
||||
mock_check.assert_not_called()
|
||||
layer.command_channel.send_command.assert_called_once()
|
||||
abort_command = layer.command_channel.send_command.call_args.args[0]
|
||||
assert abort_command.command_type == CommandType.ABORT
|
||||
assert abort_command.reason == "LLM quota check requires public node model identity before execution."
|
||||
|
||||
|
||||
def test_deduction_requires_public_event_model_identity() -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
stop_event = threading.Event()
|
||||
layer.command_channel = MagicMock()
|
||||
|
||||
node = _build_node(node_type=BuiltinNodeTypes.LLM)
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.stop_event = stop_event
|
||||
result_event = _build_succeeded_event()
|
||||
result_event.node_run_result.inputs = {"question": "hello"}
|
||||
|
||||
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", autospec=True) as mock_deduct:
|
||||
layer.on_node_run_end(node=node, error=None, result_event=result_event)
|
||||
|
||||
assert stop_event.is_set()
|
||||
mock_deduct.assert_not_called()
|
||||
layer.command_channel.send_command.assert_called_once()
|
||||
abort_command = layer.command_channel.send_command.call_args.args[0]
|
||||
assert abort_command.command_type == CommandType.ABORT
|
||||
assert abort_command.reason == "LLM quota deduction requires model identity in the node result event."
|
||||
|
||||
@@ -96,7 +96,7 @@ class MockNodeFactory(DifyNodeFactory):
|
||||
if node_type == BuiltinNodeTypes.CODE:
|
||||
mock_instance = mock_class(
|
||||
node_id=node_id,
|
||||
config=resolved_node_data,
|
||||
data=resolved_node_data,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
mock_config=self.mock_config,
|
||||
@@ -106,7 +106,7 @@ class MockNodeFactory(DifyNodeFactory):
|
||||
elif node_type == BuiltinNodeTypes.HTTP_REQUEST:
|
||||
mock_instance = mock_class(
|
||||
node_id=node_id,
|
||||
config=resolved_node_data,
|
||||
data=resolved_node_data,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
mock_config=self.mock_config,
|
||||
@@ -122,7 +122,7 @@ class MockNodeFactory(DifyNodeFactory):
|
||||
}:
|
||||
mock_instance = mock_class(
|
||||
node_id=node_id,
|
||||
config=resolved_node_data,
|
||||
data=resolved_node_data,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
mock_config=self.mock_config,
|
||||
@@ -132,7 +132,7 @@ class MockNodeFactory(DifyNodeFactory):
|
||||
else:
|
||||
mock_instance = mock_class(
|
||||
node_id=node_id,
|
||||
config=resolved_node_data,
|
||||
data=resolved_node_data,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
mock_config=self.mock_config,
|
||||
|
||||
@@ -56,7 +56,7 @@ class MockNodeMixin:
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
config: Any,
|
||||
data: Any,
|
||||
*,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
@@ -98,7 +98,7 @@ class MockNodeMixin:
|
||||
|
||||
super().__init__(
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
data=data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
**kwargs,
|
||||
|
||||
@@ -111,7 +111,7 @@ class StaticRepo(HumanInputFormRepository):
|
||||
|
||||
|
||||
def _build_runtime_state() -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
@@ -140,7 +140,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
|
||||
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
|
||||
start_node = StartNode(
|
||||
node_id=start_config["id"],
|
||||
config=StartNodeData(title="Start", variables=[]),
|
||||
data=StartNodeData(title="Start", variables=[]),
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
@@ -155,7 +155,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
|
||||
human_a_config = {"id": "human_a", "data": human_data.model_dump()}
|
||||
human_a = HumanInputNode(
|
||||
node_id=human_a_config["id"],
|
||||
config=human_data,
|
||||
data=human_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=repo,
|
||||
@@ -165,7 +165,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
|
||||
human_b_config = {"id": "human_b", "data": human_data.model_dump()}
|
||||
human_b = HumanInputNode(
|
||||
node_id=human_b_config["id"],
|
||||
config=human_data,
|
||||
data=human_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=repo,
|
||||
@@ -183,7 +183,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
|
||||
end_config = {"id": "end", "data": end_data.model_dump()}
|
||||
end_node = EndNode(
|
||||
node_id=end_config["id"],
|
||||
config=end_data,
|
||||
data=end_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
@@ -1,41 +1,36 @@
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.system_variables import build_system_variables
|
||||
from extensions.ext_database import db
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from graphon.graph import Graph
|
||||
from graphon.nodes.answer.answer_node import AnswerNode
|
||||
from graphon.nodes.answer.entities import AnswerNodeData
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
|
||||
|
||||
def test_execute_answer():
|
||||
def _build_variable_pool() -> VariablePool:
|
||||
return VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
|
||||
def _build_answer_node(*, answer: str, variable_pool: VariablePool) -> AnswerNode:
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-answer-target",
|
||||
"source": "start",
|
||||
"target": "answer",
|
||||
},
|
||||
],
|
||||
"edges": [],
|
||||
"nodes": [
|
||||
{"data": {"type": "start", "title": "Start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"title": "123",
|
||||
"title": "Answer",
|
||||
"type": "answer",
|
||||
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
|
||||
"answer": answer,
|
||||
},
|
||||
"id": "answer",
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
init_params = build_test_graph_init_params(
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
@@ -46,42 +41,31 @@ def test_execute_answer():
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=build_system_variables(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=time.perf_counter(),
|
||||
)
|
||||
variable_pool.add(["start", "weather"], "sunny")
|
||||
variable_pool.add(["llm", "text"], "You are a helpful AI.")
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# create node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
|
||||
|
||||
node = AnswerNode(
|
||||
return AnswerNode(
|
||||
node_id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=AnswerNodeData(
|
||||
title="123",
|
||||
data=AnswerNodeData(
|
||||
title="Answer",
|
||||
type="answer",
|
||||
answer="Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
|
||||
answer=answer,
|
||||
),
|
||||
)
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# execute node
|
||||
def test_execute_answer_renders_variable_selectors() -> None:
|
||||
variable_pool = _build_variable_pool()
|
||||
variable_pool.add(["start", "weather"], "sunny")
|
||||
variable_pool.add(["llm", "text"], "You are a helpful AI.")
|
||||
node = _build_answer_node(
|
||||
answer="Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
@@ -89,36 +73,11 @@ def test_execute_answer():
|
||||
|
||||
|
||||
def test_execute_answer_renders_structured_output_object_as_json() -> None:
|
||||
init_params = build_test_graph_init_params(
|
||||
workflow_id="1",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=build_system_variables(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool = _build_variable_pool()
|
||||
variable_pool.add(["1777539038857", "structured_output"], {"type": "greeting"})
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
node = AnswerNode(
|
||||
node_id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=AnswerNodeData(
|
||||
title="123",
|
||||
type="answer",
|
||||
answer="{{#1777539038857.structured_output#}}",
|
||||
),
|
||||
node = _build_answer_node(
|
||||
answer="{{#1777539038857.structured_output#}}",
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
@@ -128,35 +87,9 @@ def test_execute_answer_renders_structured_output_object_as_json() -> None:
|
||||
|
||||
|
||||
def test_execute_answer_falls_back_to_plain_selector_text_when_structured_output_missing() -> None:
|
||||
init_params = build_test_graph_init_params(
|
||||
workflow_id="1",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=build_system_variables(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
node = AnswerNode(
|
||||
node_id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=AnswerNodeData(
|
||||
title="123",
|
||||
type="answer",
|
||||
answer="{{#1777539038857.structured_output#}}",
|
||||
),
|
||||
node = _build_answer_node(
|
||||
answer="{{#1777539038857.structured_output#}}",
|
||||
variable_pool=_build_variable_pool(),
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
@@ -81,7 +81,7 @@ def test_datasource_node_delegates_to_manager_stream(mocker: MockerFixture):
|
||||
|
||||
node = DatasourceNode(
|
||||
node_id="n",
|
||||
config=DatasourceNodeData(
|
||||
data=DatasourceNodeData(
|
||||
type="datasource",
|
||||
version="1",
|
||||
title="Datasource",
|
||||
|
||||
@@ -29,7 +29,7 @@ HTTP_REQUEST_CONFIG = HttpRequestNodeConfig(
|
||||
|
||||
def test_executor_with_json_body_and_number_variable():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
user_inputs={},
|
||||
)
|
||||
@@ -85,7 +85,7 @@ def test_executor_with_json_body_and_number_variable():
|
||||
|
||||
def test_executor_with_json_body_and_object_variable():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
user_inputs={},
|
||||
)
|
||||
@@ -143,7 +143,7 @@ def test_executor_with_json_body_and_object_variable():
|
||||
|
||||
def test_executor_with_json_body_and_nested_object_variable():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
user_inputs={},
|
||||
)
|
||||
@@ -201,7 +201,7 @@ def test_executor_with_json_body_and_nested_object_variable():
|
||||
|
||||
|
||||
def test_extract_selectors_from_template_with_newline():
|
||||
variable_pool = VariablePool(system_variables=default_system_variables())
|
||||
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables())
|
||||
variable_pool.add(("node_id", "custom_query"), "line1\nline2")
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Test JSON Body with Nested Object Variable",
|
||||
@@ -230,7 +230,7 @@ def test_extract_selectors_from_template_with_newline():
|
||||
|
||||
def test_executor_with_form_data():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
user_inputs={},
|
||||
)
|
||||
@@ -320,7 +320,7 @@ def test_init_headers():
|
||||
node_data=node_data,
|
||||
timeout=timeout,
|
||||
http_request_config=HTTP_REQUEST_CONFIG,
|
||||
variable_pool=VariablePool(system_variables=default_system_variables()),
|
||||
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
|
||||
http_client=ssrf_proxy,
|
||||
file_manager=file_manager,
|
||||
)
|
||||
@@ -357,7 +357,7 @@ def test_init_params():
|
||||
node_data=node_data,
|
||||
timeout=timeout,
|
||||
http_request_config=HTTP_REQUEST_CONFIG,
|
||||
variable_pool=VariablePool(system_variables=default_system_variables()),
|
||||
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
|
||||
http_client=ssrf_proxy,
|
||||
file_manager=file_manager,
|
||||
)
|
||||
@@ -390,7 +390,7 @@ def test_init_params():
|
||||
|
||||
def test_empty_api_key_raises_error_bearer():
|
||||
"""Test that empty API key raises AuthorizationConfigError for bearer auth."""
|
||||
variable_pool = VariablePool(system_variables=default_system_variables())
|
||||
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
@@ -417,7 +417,7 @@ def test_empty_api_key_raises_error_bearer():
|
||||
|
||||
def test_empty_api_key_raises_error_basic():
|
||||
"""Test that empty API key raises AuthorizationConfigError for basic auth."""
|
||||
variable_pool = VariablePool(system_variables=default_system_variables())
|
||||
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
@@ -444,7 +444,7 @@ def test_empty_api_key_raises_error_basic():
|
||||
|
||||
def test_empty_api_key_raises_error_custom():
|
||||
"""Test that empty API key raises AuthorizationConfigError for custom auth."""
|
||||
variable_pool = VariablePool(system_variables=default_system_variables())
|
||||
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
@@ -471,7 +471,7 @@ def test_empty_api_key_raises_error_custom():
|
||||
|
||||
def test_whitespace_only_api_key_raises_error():
|
||||
"""Test that whitespace-only API key raises AuthorizationConfigError."""
|
||||
variable_pool = VariablePool(system_variables=default_system_variables())
|
||||
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
@@ -498,7 +498,7 @@ def test_whitespace_only_api_key_raises_error():
|
||||
|
||||
def test_valid_api_key_works():
|
||||
"""Test that valid API key works correctly for bearer auth."""
|
||||
variable_pool = VariablePool(system_variables=default_system_variables())
|
||||
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
@@ -536,7 +536,7 @@ def test_executor_with_json_body_and_unquoted_uuid_variable():
|
||||
# UUID that triggers the json_repair truncation bug
|
||||
test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2"
|
||||
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
user_inputs={},
|
||||
)
|
||||
@@ -583,7 +583,7 @@ def test_executor_with_json_body_and_unquoted_uuid_with_newlines():
|
||||
"""
|
||||
test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2"
|
||||
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
user_inputs={},
|
||||
)
|
||||
@@ -624,7 +624,7 @@ def test_executor_with_json_body_and_unquoted_uuid_with_newlines():
|
||||
|
||||
def test_executor_with_json_body_preserves_numbers_and_strings():
|
||||
"""Test that numbers are preserved and string values are properly quoted."""
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
@@ -110,12 +110,15 @@ def _build_http_node(
|
||||
call_depth=0,
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}),
|
||||
variable_pool=VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="user", files=[]),
|
||||
user_inputs={},
|
||||
),
|
||||
start_at=time.perf_counter(),
|
||||
)
|
||||
return HttpRequestNode(
|
||||
node_id="http-node",
|
||||
config=HttpRequestNodeData.model_validate(node_data),
|
||||
data=HttpRequestNodeData.model_validate(node_data),
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
http_request_config=HTTP_REQUEST_CONFIG,
|
||||
|
||||
@@ -149,7 +149,7 @@ def _build_human_input_node(
|
||||
)
|
||||
return HumanInputNode(
|
||||
node_id=node_id,
|
||||
config=typed_node_data,
|
||||
data=typed_node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
runtime=runtime,
|
||||
@@ -241,16 +241,16 @@ class TestUserAction:
|
||||
|
||||
def test_user_action_length_boundaries(self):
|
||||
"""Test user action id and title length boundaries."""
|
||||
action = UserAction(id="a" * 20, title="b" * 20)
|
||||
action = UserAction(id="a" * 20, title="b" * 100)
|
||||
|
||||
assert action.id == "a" * 20
|
||||
assert action.title == "b" * 20
|
||||
assert action.title == "b" * 100
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("field_name", "value"),
|
||||
[
|
||||
("id", "a" * 21),
|
||||
("title", "b" * 21),
|
||||
("title", "b" * 101),
|
||||
],
|
||||
)
|
||||
def test_user_action_length_limits(self, field_name: str, value: str):
|
||||
@@ -427,7 +427,7 @@ class TestHumanInputNodeVariableResolution:
|
||||
"""Tests for resolving variable-based defaults in HumanInputNode."""
|
||||
|
||||
def test_resolves_variable_defaults(self):
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
@@ -504,7 +504,7 @@ class TestHumanInputNodeVariableResolution:
|
||||
assert params.resolved_default_values == expected_values
|
||||
|
||||
def test_debugger_falls_back_to_recipient_token_when_webapp_disabled(self):
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
@@ -565,7 +565,7 @@ class TestHumanInputNodeVariableResolution:
|
||||
assert not hasattr(pause_event.reason, "form_token")
|
||||
|
||||
def test_webapp_runtime_keeps_form_visible_in_ui_when_webapp_delivery_is_enabled(self):
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
@@ -631,7 +631,7 @@ class TestHumanInputNodeVariableResolution:
|
||||
assert params.display_in_ui is True
|
||||
|
||||
def test_debugger_debug_mode_overrides_email_recipients(self):
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(
|
||||
user_id="user-123",
|
||||
app_id="app",
|
||||
@@ -748,7 +748,7 @@ class TestHumanInputNodeRenderedContent:
|
||||
"""Tests for rendering submitted content."""
|
||||
|
||||
def test_replaces_outputs_placeholders_after_submission(self):
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
|
||||
@@ -40,7 +40,7 @@ def _create_human_input_node(
|
||||
)
|
||||
return HumanInputNode(
|
||||
node_id=config["id"],
|
||||
config=node_data,
|
||||
data=node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
form_repository=repo,
|
||||
@@ -51,7 +51,11 @@ def _create_human_input_node(
|
||||
def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#}}") -> HumanInputNode:
|
||||
system_variables = default_system_variables()
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]),
|
||||
variable_pool=VariablePool.from_bootstrap(
|
||||
system_variables=system_variables,
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
),
|
||||
start_at=0.0,
|
||||
)
|
||||
graph_init_params = GraphInitParams(
|
||||
@@ -114,7 +118,11 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#
|
||||
def _build_timeout_node() -> HumanInputNode:
|
||||
system_variables = default_system_variables()
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]),
|
||||
variable_pool=VariablePool.from_bootstrap(
|
||||
system_variables=system_variables,
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
),
|
||||
start_at=0.0,
|
||||
)
|
||||
graph_init_params = GraphInitParams(
|
||||
|
||||
@@ -32,7 +32,7 @@ class _MissingGraphBuilder:
|
||||
|
||||
def _build_runtime_state() -> GraphRuntimeState:
|
||||
return GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=default_system_variables(), user_inputs={}),
|
||||
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables(), user_inputs={}),
|
||||
start_at=0.0,
|
||||
)
|
||||
|
||||
@@ -46,7 +46,7 @@ def _build_iteration_node(
|
||||
init_params = build_test_graph_init_params(graph_config=graph_config)
|
||||
return IterationNode(
|
||||
node_id="iteration-node",
|
||||
config=IterationNodeData(
|
||||
data=IterationNodeData(
|
||||
type="iteration",
|
||||
title="Iteration",
|
||||
iterator_selector=["start", "items"],
|
||||
|
||||
@@ -41,7 +41,7 @@ def mock_graph_init_params():
|
||||
@pytest.fixture
|
||||
def mock_graph_runtime_state():
|
||||
"""Create mock GraphRuntimeState."""
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id=str(uuid.uuid4()), files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
@@ -103,7 +103,7 @@ def _build_node(
|
||||
) -> KnowledgeIndexNode:
|
||||
return KnowledgeIndexNode(
|
||||
node_id=node_id,
|
||||
config=(
|
||||
data=(
|
||||
node_data
|
||||
if isinstance(node_data, KnowledgeIndexNodeData)
|
||||
else KnowledgeIndexNodeData.model_validate(node_data)
|
||||
|
||||
@@ -47,7 +47,7 @@ def mock_graph_init_params():
|
||||
@pytest.fixture
|
||||
def mock_graph_runtime_state():
|
||||
"""Create mock GraphRuntimeState."""
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id=str(uuid.uuid4()), files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
@@ -118,7 +118,7 @@ class TestKnowledgeRetrievalNode:
|
||||
# Act
|
||||
node = KnowledgeRetrievalNode(
|
||||
node_id=node_id,
|
||||
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -147,7 +147,7 @@ class TestKnowledgeRetrievalNode:
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
node_id=node_id,
|
||||
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -206,7 +206,7 @@ class TestKnowledgeRetrievalNode:
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
node_id=node_id,
|
||||
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -250,7 +250,7 @@ class TestKnowledgeRetrievalNode:
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
node_id=node_id,
|
||||
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -286,7 +286,7 @@ class TestKnowledgeRetrievalNode:
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
node_id=node_id,
|
||||
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -321,7 +321,7 @@ class TestKnowledgeRetrievalNode:
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
node_id=node_id,
|
||||
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -362,7 +362,7 @@ class TestKnowledgeRetrievalNode:
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
node_id=node_id,
|
||||
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -401,7 +401,7 @@ class TestKnowledgeRetrievalNode:
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
node_id=node_id,
|
||||
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -482,7 +482,7 @@ class TestFetchDatasetRetriever:
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
node_id=node_id,
|
||||
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -519,7 +519,7 @@ class TestFetchDatasetRetriever:
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
node_id=node_id,
|
||||
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -574,7 +574,7 @@ class TestFetchDatasetRetriever:
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
node_id=node_id,
|
||||
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -622,7 +622,7 @@ class TestFetchDatasetRetriever:
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
node_id=node_id,
|
||||
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -683,7 +683,7 @@ class TestFetchDatasetRetriever:
|
||||
config = {"id": node_id, "data": node_data.model_dump()}
|
||||
node = KnowledgeRetrievalNode(
|
||||
node_id=node_id,
|
||||
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
@@ -16,10 +16,10 @@ class TestListOperatorNode:
|
||||
"""Comprehensive tests for ListOperatorNode."""
|
||||
|
||||
@staticmethod
|
||||
def _build_node(*, config, graph_init_params, graph_runtime_state):
|
||||
def _build_node(*, data, graph_init_params, graph_runtime_state):
|
||||
return ListOperatorNode(
|
||||
node_id="test",
|
||||
config=config if isinstance(config, ListOperatorNodeData) else ListOperatorNodeData.model_validate(config),
|
||||
data=data if isinstance(data, ListOperatorNodeData) else ListOperatorNodeData.model_validate(data),
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
@@ -65,7 +65,7 @@ class TestListOperatorNode:
|
||||
def _create_node(config, mock_variable):
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_variable
|
||||
return self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -83,7 +83,7 @@ class TestListOperatorNode:
|
||||
}
|
||||
|
||||
node = self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -127,7 +127,7 @@ class TestListOperatorNode:
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -153,7 +153,7 @@ class TestListOperatorNode:
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -177,7 +177,7 @@ class TestListOperatorNode:
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -201,7 +201,7 @@ class TestListOperatorNode:
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -228,7 +228,7 @@ class TestListOperatorNode:
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -255,7 +255,7 @@ class TestListOperatorNode:
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -282,7 +282,7 @@ class TestListOperatorNode:
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -312,7 +312,7 @@ class TestListOperatorNode:
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -335,7 +335,7 @@ class TestListOperatorNode:
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = None
|
||||
|
||||
node = self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -359,7 +359,7 @@ class TestListOperatorNode:
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -384,7 +384,7 @@ class TestListOperatorNode:
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -408,7 +408,7 @@ class TestListOperatorNode:
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -432,7 +432,7 @@ class TestListOperatorNode:
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -456,7 +456,7 @@ class TestListOperatorNode:
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -483,7 +483,7 @@ class TestListOperatorNode:
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
@@ -15,7 +15,7 @@ from core.app.llm.model_access import (
|
||||
)
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.workflow.system_variables import default_system_variables
|
||||
from graphon.entities import GraphInitParams
|
||||
@@ -187,7 +187,7 @@ def graph_init_params() -> GraphInitParams:
|
||||
|
||||
@pytest.fixture
|
||||
def graph_runtime_state() -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
user_inputs={},
|
||||
)
|
||||
@@ -208,7 +208,7 @@ def llm_node(
|
||||
http_client = mock.MagicMock()
|
||||
node = LLMNode(
|
||||
node_id="1",
|
||||
config=llm_node_data,
|
||||
data=llm_node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
credentials_provider=mock_credentials_provider,
|
||||
@@ -241,9 +241,10 @@ def model_config(monkeypatch: pytest.MonkeyPatch):
|
||||
)
|
||||
|
||||
# Create actual provider and model type instances
|
||||
model_provider_factory = ModelProviderFactory(model_runtime=create_plugin_model_runtime(tenant_id="test"))
|
||||
model_assembly = create_plugin_model_assembly(tenant_id="test")
|
||||
model_provider_factory = model_assembly.model_provider_factory
|
||||
provider_instance = model_provider_factory.get_model_provider("openai")
|
||||
model_type_instance = model_provider_factory.get_model_type_instance("openai", ModelType.LLM)
|
||||
model_type_instance = model_assembly.create_model_type_instance(provider="openai", model_type=ModelType.LLM)
|
||||
|
||||
# Create a ProviderModelBundle
|
||||
provider_model_bundle = ProviderModelBundle(
|
||||
@@ -1173,7 +1174,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
|
||||
http_client = mock.MagicMock()
|
||||
node = LLMNode(
|
||||
node_id="1",
|
||||
config=llm_node_data,
|
||||
data=llm_node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
credentials_provider=mock_credentials_provider,
|
||||
|
||||
@@ -28,7 +28,7 @@ def _build_template_transform_node(
|
||||
)
|
||||
return TemplateTransformNode(
|
||||
node_id=node_id,
|
||||
config=typed_node_data,
|
||||
data=typed_node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
**kwargs,
|
||||
|
||||
@@ -39,7 +39,7 @@ def mock_graph_runtime_state():
|
||||
def test_node_uses_default_max_output_length_when_not_overridden(graph_init_params, mock_graph_runtime_state):
|
||||
node = TemplateTransformNode(
|
||||
node_id="test_node",
|
||||
config=TemplateTransformNodeData(
|
||||
data=TemplateTransformNodeData(
|
||||
title="Template Transform",
|
||||
type="template-transform",
|
||||
variables=[],
|
||||
|
||||
@@ -35,7 +35,10 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams,
|
||||
invoke_from="debugger",
|
||||
)
|
||||
runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}),
|
||||
variable_pool=VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="user", files=[]),
|
||||
user_inputs={},
|
||||
),
|
||||
start_at=0.0,
|
||||
)
|
||||
return init_params, runtime_state
|
||||
@@ -62,7 +65,7 @@ def test_node_hydrates_data_during_initialization():
|
||||
|
||||
node = _SampleNode(
|
||||
node_id="node-1",
|
||||
config=_build_node_data(),
|
||||
data=_build_node_data(),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
@@ -82,13 +85,16 @@ def test_node_accepts_invoke_from_enum():
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}),
|
||||
variable_pool=VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="user", files=[]),
|
||||
user_inputs={},
|
||||
),
|
||||
start_at=0.0,
|
||||
)
|
||||
|
||||
node = _SampleNode(
|
||||
node_id="node-1",
|
||||
config=_build_node_data(),
|
||||
data=_build_node_data(),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
@@ -140,7 +146,7 @@ def test_node_hydration_preserves_compatibility_extra_fields():
|
||||
|
||||
node = _SampleNode(
|
||||
node_id="node-1",
|
||||
config=node_config["data"],
|
||||
data=node_config["data"],
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
@@ -49,7 +49,7 @@ def document_extractor_node(graph_init_params):
|
||||
http_client = Mock()
|
||||
node = DocumentExtractorNode(
|
||||
node_id="test_node_id",
|
||||
config=node_data,
|
||||
data=node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=Mock(),
|
||||
http_client=http_client,
|
||||
@@ -186,12 +186,13 @@ def test_run_extract_text(
|
||||
|
||||
monkeypatch.setattr("graphon.file.file_manager.download", mock_download)
|
||||
|
||||
dispatch_mock = None
|
||||
if mime_type == "application/pdf":
|
||||
mock_pdf_extract = Mock(return_value=expected_text[0])
|
||||
monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract)
|
||||
dispatch_mock = Mock(return_value=expected_text[0])
|
||||
monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_by_file_extension", dispatch_mock)
|
||||
elif mime_type.startswith("application/vnd.openxmlformats"):
|
||||
mock_docx_extract = Mock(return_value=expected_text[0])
|
||||
monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_from_docx", mock_docx_extract)
|
||||
dispatch_mock = Mock(return_value=expected_text[0])
|
||||
monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_by_mime_type", dispatch_mock)
|
||||
|
||||
result = document_extractor_node._run()
|
||||
|
||||
@@ -200,6 +201,19 @@ def test_run_extract_text(
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["text"] == ArrayStringSegment(value=expected_text)
|
||||
|
||||
if mime_type == "application/pdf":
|
||||
dispatch_mock.assert_called_once_with(
|
||||
file_content=file_content,
|
||||
file_extension=extension,
|
||||
unstructured_api_config=document_extractor_node._unstructured_api_config,
|
||||
)
|
||||
elif mime_type.startswith("application/vnd.openxmlformats"):
|
||||
dispatch_mock.assert_called_once_with(
|
||||
file_content=file_content,
|
||||
mime_type=mime_type,
|
||||
unstructured_api_config=document_extractor_node._unstructured_api_config,
|
||||
)
|
||||
|
||||
if transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
document_extractor_node._http_client.get.assert_called_once_with("https://example.com/file.txt")
|
||||
elif transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
@@ -439,24 +453,42 @@ def test_extract_text_from_file_routes_excel_inputs(document_extractor_node, ext
|
||||
file.extension = extension
|
||||
file.mime_type = mime_type
|
||||
|
||||
with (
|
||||
patch(
|
||||
"graphon.nodes.document_extractor.node._download_file_content",
|
||||
return_value=b"excel",
|
||||
),
|
||||
patch(
|
||||
"graphon.nodes.document_extractor.node._extract_text_from_excel",
|
||||
return_value="excel text",
|
||||
) as mock_extract,
|
||||
with patch(
|
||||
"graphon.nodes.document_extractor.node._download_file_content",
|
||||
return_value=b"excel",
|
||||
):
|
||||
result = _extract_text_from_file(
|
||||
document_extractor_node.http_client,
|
||||
file,
|
||||
unstructured_api_config=document_extractor_node._unstructured_api_config,
|
||||
)
|
||||
if extension:
|
||||
with patch(
|
||||
"graphon.nodes.document_extractor.node._extract_text_by_file_extension",
|
||||
return_value="excel text",
|
||||
) as mock_extract:
|
||||
result = _extract_text_from_file(
|
||||
document_extractor_node.http_client,
|
||||
file,
|
||||
unstructured_api_config=document_extractor_node._unstructured_api_config,
|
||||
)
|
||||
mock_extract.assert_called_once_with(
|
||||
file_content=b"excel",
|
||||
file_extension=extension,
|
||||
unstructured_api_config=document_extractor_node._unstructured_api_config,
|
||||
)
|
||||
else:
|
||||
with patch(
|
||||
"graphon.nodes.document_extractor.node._extract_text_by_mime_type",
|
||||
return_value="excel text",
|
||||
) as mock_extract:
|
||||
result = _extract_text_from_file(
|
||||
document_extractor_node.http_client,
|
||||
file,
|
||||
unstructured_api_config=document_extractor_node._unstructured_api_config,
|
||||
)
|
||||
mock_extract.assert_called_once_with(
|
||||
file_content=b"excel",
|
||||
mime_type=mime_type,
|
||||
unstructured_api_config=document_extractor_node._unstructured_api_config,
|
||||
)
|
||||
|
||||
assert result == "excel text"
|
||||
mock_extract.assert_called_once_with(b"excel")
|
||||
|
||||
|
||||
def test_extract_text_from_file_rejects_missing_extension_and_mime_type(document_extractor_node):
|
||||
|
||||
@@ -29,7 +29,7 @@ def _build_if_else_node(
|
||||
node_id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=node_data if isinstance(node_data, IfElseNodeData) else IfElseNodeData.model_validate(node_data),
|
||||
data=node_data if isinstance(node_data, IfElseNodeData) else IfElseNodeData.model_validate(node_data),
|
||||
)
|
||||
|
||||
|
||||
@@ -48,7 +48,10 @@ def test_execute_if_else_result_true():
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={})
|
||||
pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
)
|
||||
pool.add(["start", "array_contains"], ["ab", "def"])
|
||||
pool.add(["start", "array_not_contains"], ["ac", "def"])
|
||||
pool.add(["start", "contains"], "cabcde")
|
||||
@@ -148,7 +151,7 @@ def test_execute_if_else_result_false():
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(
|
||||
pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
@@ -305,7 +308,7 @@ def test_execute_if_else_boolean_conditions(condition: Condition):
|
||||
)
|
||||
|
||||
# construct variable pool with boolean values
|
||||
pool = VariablePool(
|
||||
pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(files=[], user_id="aaa"),
|
||||
)
|
||||
pool.add(["start", "bool_true"], True)
|
||||
@@ -359,7 +362,7 @@ def test_execute_if_else_boolean_false_conditions():
|
||||
)
|
||||
|
||||
# construct variable pool with boolean values
|
||||
pool = VariablePool(
|
||||
pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(files=[], user_id="aaa"),
|
||||
)
|
||||
pool.add(["start", "bool_true"], True)
|
||||
@@ -424,7 +427,7 @@ def test_execute_if_else_boolean_cases_structure():
|
||||
)
|
||||
|
||||
# construct variable pool with boolean values
|
||||
pool = VariablePool(
|
||||
pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(files=[], user_id="aaa"),
|
||||
)
|
||||
pool.add(["start", "bool_true"], True)
|
||||
|
||||
@@ -22,7 +22,7 @@ from graphon.variables import ArrayFileSegment
|
||||
def _build_list_operator_node(node_data: ListOperatorNodeData, graph_init_params) -> ListOperatorNode:
|
||||
return ListOperatorNode(
|
||||
node_id="test_node_id",
|
||||
config=node_data,
|
||||
data=node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=MagicMock(),
|
||||
)
|
||||
|
||||
@@ -31,7 +31,7 @@ def make_start_node(user_inputs, variables):
|
||||
|
||||
return StartNode(
|
||||
node_id="start",
|
||||
config=node_data,
|
||||
data=node_data,
|
||||
graph_init_params=build_test_graph_init_params(
|
||||
workflow_id="wf",
|
||||
graph_config={},
|
||||
@@ -260,7 +260,7 @@ def test_start_node_outputs_full_variable_pool_snapshot():
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
node = StartNode(
|
||||
node_id="start",
|
||||
config=node_data,
|
||||
data=node_data,
|
||||
graph_init_params=build_test_graph_init_params(
|
||||
workflow_id="wf",
|
||||
graph_config={},
|
||||
|
||||
@@ -99,7 +99,7 @@ def tool_node(monkeypatch) -> ToolNode:
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(system_variables=build_system_variables(user_id="user-id"))
|
||||
variable_pool = VariablePool.from_bootstrap(system_variables=build_system_variables(user_id="user-id"))
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
|
||||
|
||||
config = graph_config["nodes"][0]
|
||||
@@ -110,7 +110,7 @@ def tool_node(monkeypatch) -> ToolNode:
|
||||
|
||||
node = ToolNode(
|
||||
node_id="node-instance",
|
||||
config=ToolNodeData.model_validate(config["data"]),
|
||||
data=ToolNodeData.model_validate(config["data"]),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tool_file_manager_factory=tool_file_manager_factory,
|
||||
|
||||
@@ -44,7 +44,7 @@ def test_trigger_event_node_run_populates_trigger_info_metadata() -> None:
|
||||
init_params, runtime_state = _build_context(graph_config={})
|
||||
node = TriggerEventNode(
|
||||
node_id="node-1",
|
||||
config=_build_node_data(),
|
||||
data=_build_node_data(),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
@@ -52,7 +52,7 @@ def create_webhook_node(
|
||||
|
||||
node = TriggerWebhookNode(
|
||||
node_id="webhook-node-1",
|
||||
config=webhook_data,
|
||||
data=webhook_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
@@ -44,7 +44,7 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool)
|
||||
)
|
||||
node = TriggerWebhookNode(
|
||||
node_id="1",
|
||||
config=webhook_data,
|
||||
data=webhook_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from collections.abc import Mapping
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch, sentinel
|
||||
|
||||
@@ -11,19 +12,20 @@ from graphon.entities.base_node_data import BaseNodeData
|
||||
from graphon.enums import BuiltinNodeTypes, NodeType
|
||||
from graphon.nodes.code.entities import CodeLanguage
|
||||
from graphon.nodes.llm.entities import LLMNodeData
|
||||
from graphon.nodes.llm.node import LLMNode
|
||||
from graphon.variables.segments import StringSegment
|
||||
|
||||
|
||||
def _assert_typed_node_config(config, *, node_id: str, node_type: NodeType, version: str = "1") -> None:
|
||||
def _assert_constructor_node_data(data, *, node_id: str, node_type: NodeType, version: str = "1") -> None:
|
||||
_ = node_id
|
||||
if isinstance(config, BaseNodeData):
|
||||
assert config.type == node_type
|
||||
assert config.version == version
|
||||
if isinstance(data, BaseNodeData):
|
||||
assert data.type == node_type
|
||||
assert data.version == version
|
||||
return
|
||||
|
||||
assert isinstance(config, dict)
|
||||
assert config["type"] == node_type
|
||||
assert config["version"] == version
|
||||
assert isinstance(data, Mapping)
|
||||
assert data["type"] == node_type
|
||||
assert data.get("version", "1") == version
|
||||
|
||||
|
||||
def _node_constructor(*, return_value):
|
||||
@@ -470,7 +472,7 @@ class TestDifyNodeFactoryCreateNode:
|
||||
matched_node_class.assert_called_once()
|
||||
kwargs = matched_node_class.call_args.kwargs
|
||||
assert kwargs["node_id"] == "node-id"
|
||||
_assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9")
|
||||
_assert_constructor_node_data(kwargs["data"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9")
|
||||
assert kwargs["graph_init_params"] is sentinel.graph_init_params
|
||||
assert kwargs["graph_runtime_state"] is factory.graph_runtime_state
|
||||
latest_node_class.assert_not_called()
|
||||
@@ -492,7 +494,7 @@ class TestDifyNodeFactoryCreateNode:
|
||||
latest_node_class.assert_called_once()
|
||||
kwargs = latest_node_class.call_args.kwargs
|
||||
assert kwargs["node_id"] == "node-id"
|
||||
_assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9")
|
||||
_assert_constructor_node_data(kwargs["data"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9")
|
||||
assert kwargs["graph_init_params"] is sentinel.graph_init_params
|
||||
assert kwargs["graph_runtime_state"] is factory.graph_runtime_state
|
||||
|
||||
@@ -530,7 +532,7 @@ class TestDifyNodeFactoryCreateNode:
|
||||
assert result is created_node
|
||||
kwargs = constructor.call_args.kwargs
|
||||
assert kwargs["node_id"] == "node-id"
|
||||
_assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=node_type)
|
||||
_assert_constructor_node_data(kwargs["data"], node_id="node-id", node_type=node_type)
|
||||
assert kwargs["graph_init_params"] is sentinel.graph_init_params
|
||||
assert kwargs["graph_runtime_state"] is factory.graph_runtime_state
|
||||
|
||||
@@ -599,11 +601,12 @@ class TestDifyNodeFactoryCreateNode:
|
||||
prepared_llm.assert_called_once_with(sentinel.model_instance)
|
||||
assert kwargs["model_instance"] is wrapped_model_instance
|
||||
|
||||
def test_create_node_passes_alias_preserving_llm_config_to_constructor(
|
||||
self, monkeypatch: pytest.MonkeyPatch, factory
|
||||
):
|
||||
def test_create_node_passes_alias_preserving_llm_data_to_constructor(self, monkeypatch, factory):
|
||||
created_node = object()
|
||||
constructor = _node_constructor(return_value=created_node)
|
||||
constructor.validate_node_data.side_effect = lambda node_data: LLMNodeData.model_validate(
|
||||
node_data.model_dump(mode="python") if isinstance(node_data, BaseNodeData) else node_data
|
||||
)
|
||||
monkeypatch.setattr(factory, "_resolve_node_class", MagicMock(return_value=constructor))
|
||||
monkeypatch.setattr(factory, "_build_llm_compatible_node_init_kwargs", MagicMock(return_value={}))
|
||||
|
||||
@@ -629,10 +632,56 @@ class TestDifyNodeFactoryCreateNode:
|
||||
|
||||
factory.create_node(node_config)
|
||||
|
||||
config = constructor.call_args.kwargs["config"]
|
||||
assert isinstance(config, dict)
|
||||
assert config["structured_output_enabled"] is True
|
||||
assert "structured_output_switch_on" not in config
|
||||
data = constructor.call_args.kwargs["data"]
|
||||
assert isinstance(data, Mapping)
|
||||
assert data["structured_output_enabled"] is True
|
||||
assert "structured_output_switch_on" not in data
|
||||
assert LLMNodeData.model_validate(data).structured_output_enabled is True
|
||||
|
||||
def test_create_node_preserves_structured_output_switch_after_graphon_constructor(self, monkeypatch, factory):
|
||||
factory.graph_init_params = SimpleNamespace(
|
||||
workflow_id="workflow-id",
|
||||
graph_config={},
|
||||
run_context={},
|
||||
call_depth=0,
|
||||
)
|
||||
monkeypatch.setattr(factory, "_resolve_node_class", MagicMock(return_value=LLMNode))
|
||||
monkeypatch.setattr(
|
||||
factory,
|
||||
"_build_llm_compatible_node_init_kwargs",
|
||||
MagicMock(
|
||||
return_value={
|
||||
"model_instance": sentinel.model_instance,
|
||||
"llm_file_saver": sentinel.llm_file_saver,
|
||||
"prompt_message_serializer": sentinel.prompt_message_serializer,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
node_config = {
|
||||
"id": "llm-node-id",
|
||||
"data": {
|
||||
"type": BuiltinNodeTypes.LLM,
|
||||
"title": "LLM",
|
||||
"model": {"provider": "provider", "name": "model", "mode": "chat", "completion_params": {}},
|
||||
"prompt_template": [{"role": "system", "text": "x"}],
|
||||
"context": {"enabled": False, "variable_selector": []},
|
||||
"vision": {"enabled": False},
|
||||
"structured_output_enabled": True,
|
||||
"structured_output": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {"type": {"type": "string"}},
|
||||
"required": ["type"],
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
node = factory.create_node(node_config)
|
||||
|
||||
assert node.node_data.structured_output_switch_on is True
|
||||
assert node.node_data.structured_output_enabled is True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("node_type", "constructor_name", "expected_extra_kwargs"),
|
||||
@@ -711,7 +760,7 @@ class TestDifyNodeFactoryCreateNode:
|
||||
|
||||
constructor_kwargs = constructor.call_args.kwargs
|
||||
assert constructor_kwargs["node_id"] == "node-id"
|
||||
_assert_typed_node_config(constructor_kwargs["config"], node_id="node-id", node_type=node_type)
|
||||
_assert_constructor_node_data(constructor_kwargs["data"], node_id="node-id", node_type=node_type)
|
||||
assert constructor_kwargs["graph_init_params"] is sentinel.graph_init_params
|
||||
assert constructor_kwargs["graph_runtime_state"] is factory.graph_runtime_state
|
||||
assert constructor_kwargs["credentials_provider"] is sentinel.credentials_provider
|
||||
|
||||
@@ -109,8 +109,8 @@ class TestVariablePool:
|
||||
assert pool.get([ENVIRONMENT_VARIABLE_NODE_ID, "env_var_1"]) is not None
|
||||
assert pool.get([CONVERSATION_VARIABLE_NODE_ID, "conv_var_1"]) is not None
|
||||
|
||||
def test_constructor_loads_legacy_bootstrap_kwargs(self):
|
||||
pool = VariablePool(
|
||||
def test_from_bootstrap_loads_legacy_bootstrap_kwargs(self):
|
||||
pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="test_user_id"),
|
||||
environment_variables=[StringVariable(name="env_var", value="env-value")],
|
||||
conversation_variables=[StringVariable(name="conv_var", value="conv-value")],
|
||||
|
||||
@@ -55,7 +55,7 @@ class TestWorkflowEntry:
|
||||
def test_mapping_user_inputs_to_variable_pool_with_system_variables(self):
|
||||
"""Test mapping system variables from user inputs to variable pool."""
|
||||
# Initialize variable pool with system variables
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(
|
||||
user_id="test_user_id",
|
||||
app_id="test_app_id",
|
||||
@@ -128,7 +128,7 @@ class TestWorkflowEntry:
|
||||
return NodeConfigDictAdapter.validate_python(node_config)
|
||||
|
||||
workflow = StubWorkflow()
|
||||
variable_pool = VariablePool(system_variables=default_system_variables(), user_inputs={})
|
||||
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables(), user_inputs={})
|
||||
expected_limits = CodeNodeLimits(
|
||||
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
|
||||
max_number=dify_config.CODE_MAX_NUMBER,
|
||||
@@ -157,7 +157,7 @@ class TestWorkflowEntry:
|
||||
"""Test mapping environment variables from user inputs to variable pool."""
|
||||
# Initialize variable pool with environment variables
|
||||
env_var = StringVariable(name="API_KEY", value="existing_key")
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
environment_variables=[env_var],
|
||||
user_inputs={},
|
||||
@@ -198,7 +198,7 @@ class TestWorkflowEntry:
|
||||
"""Test mapping conversation variables from user inputs to variable pool."""
|
||||
# Initialize variable pool with conversation variables
|
||||
conv_var = StringVariable(name="last_message", value="Hello")
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
conversation_variables=[conv_var],
|
||||
user_inputs={},
|
||||
@@ -239,7 +239,7 @@ class TestWorkflowEntry:
|
||||
def test_mapping_user_inputs_to_variable_pool_with_regular_variables(self):
|
||||
"""Test mapping regular node variables from user inputs to variable pool."""
|
||||
# Initialize empty variable pool
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
user_inputs={},
|
||||
)
|
||||
@@ -281,7 +281,7 @@ class TestWorkflowEntry:
|
||||
|
||||
def test_mapping_user_inputs_with_file_handling(self):
|
||||
"""Test mapping file inputs from user inputs to variable pool."""
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
user_inputs={},
|
||||
)
|
||||
@@ -340,7 +340,7 @@ class TestWorkflowEntry:
|
||||
|
||||
def test_mapping_user_inputs_missing_variable_error(self):
|
||||
"""Test that mapping raises error when required variable is missing."""
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
user_inputs={},
|
||||
)
|
||||
@@ -366,7 +366,7 @@ class TestWorkflowEntry:
|
||||
|
||||
def test_mapping_user_inputs_with_alternative_key_format(self):
|
||||
"""Test mapping with alternative key format (without node prefix)."""
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
user_inputs={},
|
||||
)
|
||||
@@ -396,7 +396,7 @@ class TestWorkflowEntry:
|
||||
|
||||
def test_mapping_user_inputs_with_complex_selectors(self):
|
||||
"""Test mapping with complex node variable keys."""
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
user_inputs={},
|
||||
)
|
||||
@@ -432,7 +432,7 @@ class TestWorkflowEntry:
|
||||
|
||||
def test_mapping_user_inputs_invalid_node_variable(self):
|
||||
"""Test that mapping handles invalid node variable format."""
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
user_inputs={},
|
||||
)
|
||||
@@ -463,7 +463,7 @@ class TestWorkflowEntry:
|
||||
env_var = StringVariable(name="API_KEY", value="existing_key")
|
||||
conv_var = StringVariable(name="session_id", value="session123")
|
||||
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(
|
||||
user_id="test_user",
|
||||
app_id="test_app",
|
||||
|
||||
@@ -7,7 +7,6 @@ import pytest
|
||||
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.model_manager import ModelInstance
|
||||
from core.workflow import workflow_entry
|
||||
from core.workflow.system_variables import default_system_variables
|
||||
from graphon.entities.base_node_data import BaseNodeData
|
||||
@@ -16,10 +15,12 @@ from graphon.errors import WorkflowNodeRunFailedError
|
||||
from graphon.file import File, FileTransferMethod, FileType
|
||||
from graphon.graph import Graph
|
||||
from graphon.graph_events import GraphRunFailedEvent
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode, LLMUsage
|
||||
from graphon.node_events import NodeRunResult
|
||||
from graphon.nodes import BuiltinNodeTypes
|
||||
from graphon.nodes.base.node import Node
|
||||
from graphon.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig
|
||||
from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||
from graphon.runtime import ChildGraphNotFoundError, VariablePool
|
||||
from graphon.variables.variables import StringVariable
|
||||
from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool
|
||||
@@ -29,9 +30,30 @@ def _build_typed_node_config(node_type: NodeType):
|
||||
return {"id": "node-id", "data": BaseNodeData(type=node_type)}
|
||||
|
||||
|
||||
def _build_wrapped_model_instance() -> tuple[SimpleNamespace, ModelInstance]:
|
||||
raw_model_instance = ModelInstance.__new__(ModelInstance)
|
||||
return SimpleNamespace(_model_instance=raw_model_instance), raw_model_instance
|
||||
def _build_model_config(*, provider: str = "openai", model_name: str = "gpt-4o") -> ModelConfig:
|
||||
return ModelConfig(provider=provider, name=model_name, mode=LLMMode.CHAT)
|
||||
|
||||
|
||||
def _build_llm_node_data(*, provider: str = "openai", model_name: str = "gpt-4o") -> LLMNodeData:
|
||||
return LLMNodeData(
|
||||
type=BuiltinNodeTypes.LLM,
|
||||
title="Child Model",
|
||||
model=_build_model_config(provider=provider, model_name=model_name),
|
||||
prompt_template=[],
|
||||
context=ContextConfig(enabled=False),
|
||||
)
|
||||
|
||||
|
||||
def _build_question_classifier_node_data(
|
||||
*, provider: str = "openai", model_name: str = "gpt-4o"
|
||||
) -> QuestionClassifierNodeData:
|
||||
return QuestionClassifierNodeData(
|
||||
type=BuiltinNodeTypes.QUESTION_CLASSIFIER,
|
||||
title="Child Model",
|
||||
query_variable_selector=["sys", "query"],
|
||||
model=_build_model_config(provider=provider, model_name=model_name),
|
||||
classes=[],
|
||||
)
|
||||
|
||||
|
||||
class _FakeModelNodeMixin:
|
||||
@@ -40,22 +62,26 @@ class _FakeModelNodeMixin:
|
||||
return "1"
|
||||
|
||||
def post_init(self) -> None:
|
||||
self.model_instance, self.raw_model_instance = _build_wrapped_model_instance()
|
||||
self.model_instance = SimpleNamespace(provider="stale-provider", model_name="stale-model")
|
||||
self.usage_snapshot = LLMUsage.empty_usage()
|
||||
self.usage_snapshot.total_tokens = 1
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={
|
||||
"model_provider": self.node_data.model.provider,
|
||||
"model_name": self.node_data.model.name,
|
||||
},
|
||||
llm_usage=self.usage_snapshot,
|
||||
)
|
||||
|
||||
|
||||
class _FakeLLMNode(_FakeModelNodeMixin, Node[BaseNodeData]):
|
||||
class _FakeLLMNode(_FakeModelNodeMixin, Node[LLMNodeData]):
|
||||
node_type = BuiltinNodeTypes.LLM
|
||||
|
||||
|
||||
class _FakeQuestionClassifierNode(_FakeModelNodeMixin, Node[BaseNodeData]):
|
||||
class _FakeQuestionClassifierNode(_FakeModelNodeMixin, Node[QuestionClassifierNodeData]):
|
||||
node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER
|
||||
|
||||
|
||||
@@ -75,7 +101,7 @@ class TestWorkflowChildEngineBuilder:
|
||||
assert result is expected
|
||||
|
||||
def test_build_child_engine_raises_when_root_node_is_missing(self):
|
||||
builder = workflow_entry._WorkflowChildEngineBuilder()
|
||||
builder = workflow_entry._WorkflowChildEngineBuilder(tenant_id="tenant-id")
|
||||
graph_init_params = SimpleNamespace(graph_config={"nodes": []})
|
||||
parent_graph_runtime_state = SimpleNamespace(
|
||||
execution_context=sentinel.execution_context,
|
||||
@@ -92,7 +118,7 @@ class TestWorkflowChildEngineBuilder:
|
||||
)
|
||||
|
||||
def test_build_child_engine_constructs_graph_engine_with_quota_layer_only(self):
|
||||
builder = workflow_entry._WorkflowChildEngineBuilder()
|
||||
builder = workflow_entry._WorkflowChildEngineBuilder(tenant_id="tenant-id")
|
||||
graph_init_params = SimpleNamespace(graph_config={"nodes": [{"id": "root"}]})
|
||||
parent_graph_runtime_state = SimpleNamespace(
|
||||
execution_context=sentinel.execution_context,
|
||||
@@ -114,7 +140,7 @@ class TestWorkflowChildEngineBuilder:
|
||||
patch.object(workflow_entry, "GraphEngine", return_value=child_engine) as graph_engine_cls,
|
||||
patch.object(workflow_entry, "GraphEngineConfig", return_value=sentinel.graph_engine_config),
|
||||
patch.object(workflow_entry, "InMemoryChannel", return_value=sentinel.command_channel),
|
||||
patch.object(workflow_entry, "LLMQuotaLayer", return_value=sentinel.llm_quota_layer),
|
||||
patch.object(workflow_entry, "LLMQuotaLayer", return_value=sentinel.llm_quota_layer) as llm_quota_layer_cls,
|
||||
):
|
||||
result = builder.build_child_engine(
|
||||
workflow_id="workflow-id",
|
||||
@@ -147,11 +173,12 @@ class TestWorkflowChildEngineBuilder:
|
||||
config=sentinel.graph_engine_config,
|
||||
child_engine_builder=builder,
|
||||
)
|
||||
llm_quota_layer_cls.assert_called_once_with(tenant_id="tenant-id")
|
||||
assert child_engine.layer.call_args_list == [((sentinel.llm_quota_layer,), {})]
|
||||
|
||||
@pytest.mark.parametrize("node_cls", [_FakeLLMNode, _FakeQuestionClassifierNode])
|
||||
def test_build_child_engine_runs_llm_quota_layer_for_child_model_nodes(self, node_cls):
|
||||
builder = workflow_entry._WorkflowChildEngineBuilder()
|
||||
builder = workflow_entry._WorkflowChildEngineBuilder(tenant_id="tenant-id")
|
||||
graph_init_params = build_test_graph_init_params(
|
||||
graph_config={"nodes": [{"id": "root"}], "edges": []},
|
||||
)
|
||||
@@ -163,12 +190,10 @@ class TestWorkflowChildEngineBuilder:
|
||||
|
||||
def build_graph(*, graph_config, node_factory, root_node_id):
|
||||
_ = graph_config
|
||||
node_data = _build_llm_node_data() if node_cls is _FakeLLMNode else _build_question_classifier_node_data()
|
||||
node = node_cls(
|
||||
node_id=root_node_id,
|
||||
config=BaseNodeData(
|
||||
type=node_cls.node_type,
|
||||
title="Child Model",
|
||||
),
|
||||
data=node_data,
|
||||
graph_init_params=node_factory.graph_init_params,
|
||||
graph_runtime_state=node_factory.graph_runtime_state,
|
||||
)
|
||||
@@ -191,8 +216,8 @@ class TestWorkflowChildEngineBuilder:
|
||||
),
|
||||
),
|
||||
patch.object(workflow_entry.Graph, "init", side_effect=build_graph),
|
||||
patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available") as ensure_quota,
|
||||
patch("core.app.workflow.layers.llm_quota.deduct_llm_quota") as deduct_quota,
|
||||
patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model") as ensure_quota,
|
||||
patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model") as deduct_quota,
|
||||
):
|
||||
child_engine = builder.build_child_engine(
|
||||
workflow_id="workflow-id",
|
||||
@@ -203,10 +228,15 @@ class TestWorkflowChildEngineBuilder:
|
||||
list(child_engine.run())
|
||||
|
||||
node = created_node["node"]
|
||||
ensure_quota.assert_called_once_with(model_instance=node.raw_model_instance)
|
||||
ensure_quota.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
provider=node.node_data.model.provider,
|
||||
model=node.node_data.model.name,
|
||||
)
|
||||
deduct_quota.assert_called_once_with(
|
||||
tenant_id="tenant",
|
||||
model_instance=node.raw_model_instance,
|
||||
tenant_id="tenant-id",
|
||||
provider=node.node_data.model.provider,
|
||||
model=node.node_data.model.name,
|
||||
usage=node.usage_snapshot,
|
||||
)
|
||||
|
||||
@@ -252,7 +282,7 @@ class TestWorkflowEntryInit:
|
||||
"ExecutionLimitsLayer",
|
||||
return_value=execution_limits_layer,
|
||||
) as execution_limits_layer_cls,
|
||||
patch.object(workflow_entry, "LLMQuotaLayer", return_value=llm_quota_layer),
|
||||
patch.object(workflow_entry, "LLMQuotaLayer", return_value=llm_quota_layer) as llm_quota_layer_cls,
|
||||
patch.object(workflow_entry, "ObservabilityLayer", return_value=observability_layer),
|
||||
):
|
||||
entry = workflow_entry.WorkflowEntry(
|
||||
@@ -291,6 +321,7 @@ class TestWorkflowEntryInit:
|
||||
max_steps=workflow_entry.dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
|
||||
max_time=workflow_entry.dify_config.WORKFLOW_MAX_EXECUTION_TIME,
|
||||
)
|
||||
llm_quota_layer_cls.assert_called_once_with(tenant_id="tenant-id")
|
||||
assert graph_engine.layer.call_args_list == [
|
||||
((debug_layer,), {}),
|
||||
((execution_limits_layer,), {}),
|
||||
@@ -334,7 +365,7 @@ class TestWorkflowEntrySingleStepRun:
|
||||
def extract_variable_selector_to_variable_mapping(**_kwargs):
|
||||
return {}
|
||||
|
||||
variable_pool = VariablePool(system_variables=default_system_variables(), user_inputs={})
|
||||
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables(), user_inputs={})
|
||||
variable_loader = MagicMock()
|
||||
variable_loader.load_variables.return_value = [
|
||||
StringVariable(
|
||||
|
||||
@@ -0,0 +1,130 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import create_engine, select
|
||||
|
||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
||||
from events.event_handlers import update_provider_when_message_created
|
||||
from models import TenantCreditPool
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
def test_message_created_trial_credit_accounting_does_not_raise_when_balance_is_insufficient() -> None:
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
TenantCreditPool.__table__.create(engine)
|
||||
tenant_id = str(uuid4())
|
||||
pool_id = str(uuid4())
|
||||
with engine.begin() as connection:
|
||||
connection.execute(
|
||||
TenantCreditPool.__table__.insert(),
|
||||
{
|
||||
"id": pool_id,
|
||||
"tenant_id": tenant_id,
|
||||
"pool_type": ProviderQuotaType.TRIAL,
|
||||
"quota_limit": 10,
|
||||
"quota_used": 9,
|
||||
},
|
||||
)
|
||||
|
||||
system_configuration = SimpleNamespace(
|
||||
current_quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_configurations=[
|
||||
SimpleNamespace(
|
||||
quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_unit=QuotaUnit.TOKENS,
|
||||
quota_limit=10,
|
||||
)
|
||||
],
|
||||
)
|
||||
application_generate_entity = ChatAppGenerateEntity.model_construct(
|
||||
app_config=SimpleNamespace(tenant_id=tenant_id),
|
||||
model_conf=SimpleNamespace(
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
provider_model_bundle=SimpleNamespace(
|
||||
configuration=SimpleNamespace(
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=system_configuration,
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
message = SimpleNamespace(message_tokens=2, answer_tokens=1)
|
||||
|
||||
with (
|
||||
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
|
||||
patch.object(update_provider_when_message_created, "_execute_provider_updates"),
|
||||
):
|
||||
update_provider_when_message_created.handle(
|
||||
sender=message,
|
||||
application_generate_entity=application_generate_entity,
|
||||
)
|
||||
|
||||
with engine.connect() as connection:
|
||||
quota_used = connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == pool_id))
|
||||
|
||||
assert quota_used == 10
|
||||
|
||||
|
||||
def test_message_created_paid_credit_accounting_uses_paid_pool() -> None:
|
||||
tenant_id = str(uuid4())
|
||||
system_configuration = SimpleNamespace(
|
||||
current_quota_type=ProviderQuotaType.PAID,
|
||||
quota_configurations=[
|
||||
SimpleNamespace(
|
||||
quota_type=ProviderQuotaType.PAID,
|
||||
quota_unit=QuotaUnit.TOKENS,
|
||||
quota_limit=10,
|
||||
)
|
||||
],
|
||||
)
|
||||
application_generate_entity = ChatAppGenerateEntity.model_construct(
|
||||
app_config=SimpleNamespace(tenant_id=tenant_id),
|
||||
model_conf=SimpleNamespace(
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
provider_model_bundle=SimpleNamespace(
|
||||
configuration=SimpleNamespace(
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=system_configuration,
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
message = SimpleNamespace(message_tokens=2, answer_tokens=1)
|
||||
|
||||
with (
|
||||
patch.object(update_provider_when_message_created, "_deduct_credit_pool_quota_capped") as mock_deduct,
|
||||
patch.object(update_provider_when_message_created, "_execute_provider_updates"),
|
||||
):
|
||||
update_provider_when_message_created.handle(
|
||||
sender=message,
|
||||
application_generate_entity=application_generate_entity,
|
||||
)
|
||||
|
||||
mock_deduct.assert_called_once_with(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=3,
|
||||
pool_type="paid",
|
||||
)
|
||||
|
||||
|
||||
def test_capped_credit_pool_accounting_skips_exhaustion_warning_when_full_amount_is_deducted(caplog) -> None:
|
||||
with patch(
|
||||
"services.credit_pool_service.CreditPoolService.deduct_credits_capped",
|
||||
return_value=3,
|
||||
) as mock_deduct:
|
||||
update_provider_when_message_created._deduct_credit_pool_quota_capped(
|
||||
tenant_id="tenant-id",
|
||||
credits_required=3,
|
||||
pool_type="trial",
|
||||
)
|
||||
|
||||
mock_deduct.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
credits_required=3,
|
||||
pool_type="trial",
|
||||
)
|
||||
assert "Credit pool exhausted during message-created accounting" not in caplog.text
|
||||
158
api/tests/unit_tests/services/test_credit_pool_service.py
Normal file
158
api/tests/unit_tests/services/test_credit_pool_service.py
Normal file
@@ -0,0 +1,158 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine, select
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
from core.errors.error import QuotaExceededError
|
||||
from models import TenantCreditPool
|
||||
from models.enums import ProviderQuotaType
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
|
||||
def _create_engine_with_pool(*, quota_limit: int, quota_used: int) -> tuple[Engine, str, str]:
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
TenantCreditPool.__table__.create(engine)
|
||||
tenant_id = str(uuid4())
|
||||
pool_id = str(uuid4())
|
||||
with engine.begin() as connection:
|
||||
connection.execute(
|
||||
TenantCreditPool.__table__.insert(),
|
||||
{
|
||||
"id": pool_id,
|
||||
"tenant_id": tenant_id,
|
||||
"pool_type": ProviderQuotaType.TRIAL,
|
||||
"quota_limit": quota_limit,
|
||||
"quota_used": quota_used,
|
||||
},
|
||||
)
|
||||
return engine, tenant_id, pool_id
|
||||
|
||||
|
||||
def _get_quota_used(*, engine: Engine, pool_id: str) -> int | None:
|
||||
with engine.connect() as connection:
|
||||
return connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == pool_id))
|
||||
|
||||
|
||||
def test_check_and_deduct_credits_deducts_exact_amount_when_sufficient() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
|
||||
|
||||
with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)):
|
||||
deducted_credits = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=3)
|
||||
|
||||
assert deducted_credits == 3
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 5
|
||||
|
||||
|
||||
def test_check_and_deduct_credits_returns_zero_for_non_positive_request() -> None:
|
||||
assert CreditPoolService.check_and_deduct_credits(tenant_id=str(uuid4()), credits_required=0) == 0
|
||||
|
||||
|
||||
def test_check_and_deduct_credits_raises_when_pool_is_missing() -> None:
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
TenantCreditPool.__table__.create(engine)
|
||||
|
||||
with (
|
||||
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
|
||||
pytest.raises(QuotaExceededError, match="Credit pool not found"),
|
||||
):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=str(uuid4()), credits_required=1)
|
||||
|
||||
|
||||
def test_check_and_deduct_credits_raises_when_pool_is_empty() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=10)
|
||||
|
||||
with (
|
||||
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
|
||||
pytest.raises(QuotaExceededError, match="No credits remaining"),
|
||||
):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=1)
|
||||
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 10
|
||||
|
||||
|
||||
def test_check_and_deduct_credits_raises_without_partial_deduction_when_insufficient() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=9)
|
||||
|
||||
with (
|
||||
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
|
||||
pytest.raises(QuotaExceededError, match="Insufficient credits remaining"),
|
||||
):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=3)
|
||||
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 9
|
||||
|
||||
|
||||
def test_check_and_deduct_credits_wraps_unexpected_deduction_errors() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
|
||||
|
||||
with (
|
||||
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
|
||||
patch.object(CreditPoolService, "_get_locked_pool", side_effect=RuntimeError("database unavailable")),
|
||||
pytest.raises(QuotaExceededError, match="Failed to deduct credits"),
|
||||
):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=1)
|
||||
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 2
|
||||
|
||||
|
||||
def test_deduct_credits_capped_returns_zero_for_non_positive_request() -> None:
|
||||
assert CreditPoolService.deduct_credits_capped(tenant_id=str(uuid4()), credits_required=0) == 0
|
||||
|
||||
|
||||
def test_deduct_credits_capped_returns_zero_when_pool_is_missing() -> None:
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
TenantCreditPool.__table__.create(engine)
|
||||
|
||||
with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)):
|
||||
deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=str(uuid4()), credits_required=1)
|
||||
|
||||
assert deducted_credits == 0
|
||||
|
||||
|
||||
def test_deduct_credits_capped_returns_zero_when_pool_is_empty() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=10)
|
||||
|
||||
with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)):
|
||||
deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1)
|
||||
|
||||
assert deducted_credits == 0
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 10
|
||||
|
||||
|
||||
def test_deduct_credits_capped_deducts_only_remaining_balance_when_insufficient() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=9)
|
||||
|
||||
with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)):
|
||||
deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=3)
|
||||
|
||||
assert deducted_credits == 1
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 10
|
||||
|
||||
|
||||
def test_deduct_credits_capped_wraps_unexpected_deduction_errors() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
|
||||
|
||||
with (
|
||||
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
|
||||
patch.object(CreditPoolService, "_get_locked_pool", side_effect=RuntimeError("database unavailable")),
|
||||
pytest.raises(QuotaExceededError, match="Failed to deduct credits"),
|
||||
):
|
||||
CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1)
|
||||
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 2
|
||||
|
||||
|
||||
def test_deduct_credits_capped_reraises_quota_exceeded_errors() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
|
||||
|
||||
with (
|
||||
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
|
||||
patch.object(CreditPoolService, "_get_locked_pool", side_effect=QuotaExceededError("quota unavailable")),
|
||||
pytest.raises(QuotaExceededError, match="quota unavailable"),
|
||||
):
|
||||
CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1)
|
||||
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 2
|
||||
@@ -2845,7 +2845,7 @@ class TestWorkflowServiceFreeNodeExecution:
|
||||
mock_node_cls.validate_node_data.assert_called_once_with(sentinel.adapted_node_data)
|
||||
mock_node_cls.assert_called_once_with(
|
||||
node_id="n-1",
|
||||
config=sentinel.node_data,
|
||||
data=sentinel.node_data,
|
||||
graph_init_params=mock_graph_init_context_cls.return_value.to_graph_init_params.return_value,
|
||||
graph_runtime_state=ANY,
|
||||
runtime=mock_runtime_cls.return_value,
|
||||
|
||||
Reference in New Issue
Block a user