mirror of
https://github.com/langgenius/dify.git
synced 2026-03-01 19:22:10 -05:00
feat: Human Input Node (#32060)
The frontend and backend implementation for the human input node. Co-authored-by: twwu <twwu@dify.ai> Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com> Co-authored-by: zhsama <torvalds@linux.do>
This commit is contained in:
@@ -16,11 +16,9 @@ if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _load_app_module():
|
||||
@pytest.fixture(scope="module")
|
||||
def app_module():
|
||||
module_name = "controllers.console.app.app"
|
||||
if module_name in sys.modules:
|
||||
return sys.modules[module_name]
|
||||
|
||||
root = Path(__file__).resolve().parents[5]
|
||||
module_path = root / "controllers" / "console" / "app" / "app.py"
|
||||
|
||||
@@ -59,8 +57,12 @@ def _load_app_module():
|
||||
|
||||
stub_namespace = _StubNamespace()
|
||||
|
||||
original_console = sys.modules.get("controllers.console")
|
||||
original_app_pkg = sys.modules.get("controllers.console.app")
|
||||
original_modules: dict[str, ModuleType | None] = {
|
||||
"controllers.console": sys.modules.get("controllers.console"),
|
||||
"controllers.console.app": sys.modules.get("controllers.console.app"),
|
||||
"controllers.common.schema": sys.modules.get("controllers.common.schema"),
|
||||
module_name: sys.modules.get(module_name),
|
||||
}
|
||||
stubbed_modules: list[tuple[str, ModuleType | None]] = []
|
||||
|
||||
console_module = ModuleType("controllers.console")
|
||||
@@ -105,35 +107,35 @@ def _load_app_module():
|
||||
module = util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
try:
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
yield module
|
||||
finally:
|
||||
for name, original in reversed(stubbed_modules):
|
||||
if original is not None:
|
||||
sys.modules[name] = original
|
||||
else:
|
||||
sys.modules.pop(name, None)
|
||||
if original_console is not None:
|
||||
sys.modules["controllers.console"] = original_console
|
||||
else:
|
||||
sys.modules.pop("controllers.console", None)
|
||||
if original_app_pkg is not None:
|
||||
sys.modules["controllers.console.app"] = original_app_pkg
|
||||
else:
|
||||
sys.modules.pop("controllers.console.app", None)
|
||||
|
||||
return module
|
||||
for name, original in original_modules.items():
|
||||
if original is not None:
|
||||
sys.modules[name] = original
|
||||
else:
|
||||
sys.modules.pop(name, None)
|
||||
|
||||
|
||||
_app_module = _load_app_module()
|
||||
AppDetailWithSite = _app_module.AppDetailWithSite
|
||||
AppPagination = _app_module.AppPagination
|
||||
AppPartial = _app_module.AppPartial
|
||||
@pytest.fixture(scope="module")
|
||||
def app_models(app_module):
|
||||
return SimpleNamespace(
|
||||
AppDetailWithSite=app_module.AppDetailWithSite,
|
||||
AppPagination=app_module.AppPagination,
|
||||
AppPartial=app_module.AppPartial,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_signed_url(monkeypatch):
|
||||
def patch_signed_url(monkeypatch, app_module):
|
||||
"""Ensure icon URL generation uses a deterministic helper for tests."""
|
||||
|
||||
def _fake_signed_url(key: str | None) -> str | None:
|
||||
@@ -141,7 +143,7 @@ def patch_signed_url(monkeypatch):
|
||||
return None
|
||||
return f"signed:{key}"
|
||||
|
||||
monkeypatch.setattr(_app_module.file_helpers, "get_signed_file_url", _fake_signed_url)
|
||||
monkeypatch.setattr(app_module.file_helpers, "get_signed_file_url", _fake_signed_url)
|
||||
|
||||
|
||||
def _ts(hour: int = 12) -> datetime:
|
||||
@@ -169,7 +171,8 @@ def _dummy_workflow():
|
||||
)
|
||||
|
||||
|
||||
def test_app_partial_serialization_uses_aliases():
|
||||
def test_app_partial_serialization_uses_aliases(app_models):
|
||||
AppPartial = app_models.AppPartial
|
||||
created_at = _ts()
|
||||
app_obj = SimpleNamespace(
|
||||
id="app-1",
|
||||
@@ -204,7 +207,8 @@ def test_app_partial_serialization_uses_aliases():
|
||||
assert serialized["tags"][0]["name"] == "Utilities"
|
||||
|
||||
|
||||
def test_app_detail_with_site_includes_nested_serialization():
|
||||
def test_app_detail_with_site_includes_nested_serialization(app_models):
|
||||
AppDetailWithSite = app_models.AppDetailWithSite
|
||||
timestamp = _ts(14)
|
||||
site = SimpleNamespace(
|
||||
code="site-code",
|
||||
@@ -253,7 +257,8 @@ def test_app_detail_with_site_includes_nested_serialization():
|
||||
assert serialized["site"]["created_at"] == int(timestamp.timestamp())
|
||||
|
||||
|
||||
def test_app_pagination_aliases_per_page_and_has_next():
|
||||
def test_app_pagination_aliases_per_page_and_has_next(app_models):
|
||||
AppPagination = app_models.AppPagination
|
||||
item_one = SimpleNamespace(
|
||||
id="app-10",
|
||||
name="Paginated One",
|
||||
|
||||
@@ -0,0 +1,229 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.console import wraps as console_wraps
|
||||
from controllers.console.app import workflow as workflow_module
|
||||
from controllers.console.app import wraps as app_wraps
|
||||
from libs import login as login_lib
|
||||
from models.account import Account, AccountStatus, TenantAccountRole
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def _make_account() -> Account:
|
||||
account = Account(name="tester", email="tester@example.com")
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.role = TenantAccountRole.OWNER
|
||||
account.id = "account-123" # type: ignore[assignment]
|
||||
account._current_tenant = SimpleNamespace(id="tenant-123") # type: ignore[attr-defined]
|
||||
account._get_current_object = lambda: account # type: ignore[attr-defined]
|
||||
return account
|
||||
|
||||
|
||||
def _make_app(mode: AppMode) -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-123", tenant_id="tenant-123", mode=mode.value)
|
||||
|
||||
|
||||
def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account, app_model: SimpleNamespace) -> None:
|
||||
# Skip setup and auth guardrails
|
||||
monkeypatch.setattr("configs.dify_config.EDITION", "CLOUD")
|
||||
monkeypatch.setattr(login_lib.dify_config, "LOGIN_DISABLED", True)
|
||||
monkeypatch.setattr(login_lib, "current_user", account)
|
||||
monkeypatch.setattr(login_lib, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
|
||||
monkeypatch.setattr(login_lib, "check_csrf_token", lambda *_, **__: None)
|
||||
monkeypatch.setattr(console_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
|
||||
monkeypatch.setattr(app_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
|
||||
monkeypatch.setattr(console_wraps.dify_config, "EDITION", "CLOUD")
|
||||
monkeypatch.delenv("INIT_PASSWORD", raising=False)
|
||||
|
||||
# Avoid hitting the database when resolving the app model
|
||||
monkeypatch.setattr(app_wraps, "_load_app_model", lambda _app_id: app_model)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreviewCase:
|
||||
resource_cls: type
|
||||
path: str
|
||||
mode: AppMode
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
[
|
||||
PreviewCase(
|
||||
resource_cls=workflow_module.AdvancedChatDraftHumanInputFormPreviewApi,
|
||||
path="/console/api/apps/app-123/advanced-chat/workflows/draft/human-input/nodes/node-42/form/preview",
|
||||
mode=AppMode.ADVANCED_CHAT,
|
||||
),
|
||||
PreviewCase(
|
||||
resource_cls=workflow_module.WorkflowDraftHumanInputFormPreviewApi,
|
||||
path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-42/form/preview",
|
||||
mode=AppMode.WORKFLOW,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_human_input_preview_delegates_to_service(
|
||||
app: Flask, monkeypatch: pytest.MonkeyPatch, case: PreviewCase
|
||||
) -> None:
|
||||
account = _make_account()
|
||||
app_model = _make_app(case.mode)
|
||||
_patch_console_guards(monkeypatch, account, app_model)
|
||||
|
||||
preview_payload = {
|
||||
"form_id": "node-42",
|
||||
"form_content": "<div>example</div>",
|
||||
"inputs": [{"name": "topic"}],
|
||||
"actions": [{"id": "continue"}],
|
||||
}
|
||||
service_instance = MagicMock()
|
||||
service_instance.get_human_input_form_preview.return_value = preview_payload
|
||||
monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance))
|
||||
|
||||
with app.test_request_context(case.path, method="POST", json={"inputs": {"topic": "tech"}}):
|
||||
response = case.resource_cls().post(app_id=app_model.id, node_id="node-42")
|
||||
|
||||
assert response == preview_payload
|
||||
service_instance.get_human_input_form_preview.assert_called_once_with(
|
||||
app_model=app_model,
|
||||
account=account,
|
||||
node_id="node-42",
|
||||
inputs={"topic": "tech"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubmitCase:
|
||||
resource_cls: type
|
||||
path: str
|
||||
mode: AppMode
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
[
|
||||
SubmitCase(
|
||||
resource_cls=workflow_module.AdvancedChatDraftHumanInputFormRunApi,
|
||||
path="/console/api/apps/app-123/advanced-chat/workflows/draft/human-input/nodes/node-99/form/run",
|
||||
mode=AppMode.ADVANCED_CHAT,
|
||||
),
|
||||
SubmitCase(
|
||||
resource_cls=workflow_module.WorkflowDraftHumanInputFormRunApi,
|
||||
path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-99/form/run",
|
||||
mode=AppMode.WORKFLOW,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_human_input_submit_forwards_payload(app: Flask, monkeypatch: pytest.MonkeyPatch, case: SubmitCase) -> None:
|
||||
account = _make_account()
|
||||
app_model = _make_app(case.mode)
|
||||
_patch_console_guards(monkeypatch, account, app_model)
|
||||
|
||||
result_payload = {"node_id": "node-99", "outputs": {"__rendered_content": "<p>done</p>"}, "action": "approve"}
|
||||
service_instance = MagicMock()
|
||||
service_instance.submit_human_input_form_preview.return_value = result_payload
|
||||
monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance))
|
||||
|
||||
with app.test_request_context(
|
||||
case.path,
|
||||
method="POST",
|
||||
json={"form_inputs": {"answer": "42"}, "inputs": {"#node-1.result#": "LLM output"}, "action": "approve"},
|
||||
):
|
||||
response = case.resource_cls().post(app_id=app_model.id, node_id="node-99")
|
||||
|
||||
assert response == result_payload
|
||||
service_instance.submit_human_input_form_preview.assert_called_once_with(
|
||||
app_model=app_model,
|
||||
account=account,
|
||||
node_id="node-99",
|
||||
form_inputs={"answer": "42"},
|
||||
inputs={"#node-1.result#": "LLM output"},
|
||||
action="approve",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeliveryTestCase:
|
||||
resource_cls: type
|
||||
path: str
|
||||
mode: AppMode
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
[
|
||||
DeliveryTestCase(
|
||||
resource_cls=workflow_module.WorkflowDraftHumanInputDeliveryTestApi,
|
||||
path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-7/delivery-test",
|
||||
mode=AppMode.ADVANCED_CHAT,
|
||||
),
|
||||
DeliveryTestCase(
|
||||
resource_cls=workflow_module.WorkflowDraftHumanInputDeliveryTestApi,
|
||||
path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-7/delivery-test",
|
||||
mode=AppMode.WORKFLOW,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_human_input_delivery_test_calls_service(
|
||||
app: Flask, monkeypatch: pytest.MonkeyPatch, case: DeliveryTestCase
|
||||
) -> None:
|
||||
account = _make_account()
|
||||
app_model = _make_app(case.mode)
|
||||
_patch_console_guards(monkeypatch, account, app_model)
|
||||
|
||||
service_instance = MagicMock()
|
||||
monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance))
|
||||
|
||||
with app.test_request_context(
|
||||
case.path,
|
||||
method="POST",
|
||||
json={"delivery_method_id": "delivery-123"},
|
||||
):
|
||||
response = case.resource_cls().post(app_id=app_model.id, node_id="node-7")
|
||||
|
||||
assert response == {}
|
||||
service_instance.test_human_input_delivery.assert_called_once_with(
|
||||
app_model=app_model,
|
||||
account=account,
|
||||
node_id="node-7",
|
||||
delivery_method_id="delivery-123",
|
||||
inputs={},
|
||||
)
|
||||
|
||||
|
||||
def test_human_input_delivery_test_maps_validation_error(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
account = _make_account()
|
||||
app_model = _make_app(AppMode.ADVANCED_CHAT)
|
||||
_patch_console_guards(monkeypatch, account, app_model)
|
||||
|
||||
service_instance = MagicMock()
|
||||
service_instance.test_human_input_delivery.side_effect = ValueError("bad delivery method")
|
||||
monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance))
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-123/workflows/draft/human-input/nodes/node-1/delivery-test",
|
||||
method="POST",
|
||||
json={"delivery_method_id": "bad"},
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
workflow_module.WorkflowDraftHumanInputDeliveryTestApi().post(app_id=app_model.id, node_id="node-1")
|
||||
|
||||
|
||||
def test_human_input_preview_rejects_non_mapping(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
account = _make_account()
|
||||
app_model = _make_app(AppMode.ADVANCED_CHAT)
|
||||
_patch_console_guards(monkeypatch, account, app_model)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-123/advanced-chat/workflows/draft/human-input/nodes/node-1/form/preview",
|
||||
method="POST",
|
||||
json={"inputs": ["not-a-dict"]},
|
||||
):
|
||||
with pytest.raises(ValidationError):
|
||||
workflow_module.AdvancedChatDraftHumanInputFormPreviewApi().post(app_id=app_model.id, node_id="node-1")
|
||||
@@ -0,0 +1,110 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console import wraps as console_wraps
|
||||
from controllers.console.app import workflow_run as workflow_run_module
|
||||
from controllers.web.error import NotFoundError
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.nodes.human_input.entities import FormInput, UserAction
|
||||
from core.workflow.nodes.human_input.enums import FormInputType
|
||||
from libs import login as login_lib
|
||||
from models.account import Account, AccountStatus, TenantAccountRole
|
||||
from models.workflow import WorkflowRun
|
||||
|
||||
|
||||
def _make_account() -> Account:
|
||||
account = Account(name="tester", email="tester@example.com")
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.role = TenantAccountRole.OWNER
|
||||
account.id = "account-123" # type: ignore[assignment]
|
||||
account._current_tenant = SimpleNamespace(id="tenant-123") # type: ignore[attr-defined]
|
||||
account._get_current_object = lambda: account # type: ignore[attr-defined]
|
||||
return account
|
||||
|
||||
|
||||
def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account) -> None:
|
||||
monkeypatch.setattr(login_lib.dify_config, "LOGIN_DISABLED", True)
|
||||
monkeypatch.setattr(login_lib, "current_user", account)
|
||||
monkeypatch.setattr(login_lib, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
|
||||
monkeypatch.setattr(login_lib, "check_csrf_token", lambda *_, **__: None)
|
||||
monkeypatch.setattr(console_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
|
||||
monkeypatch.setattr(workflow_run_module, "current_user", account)
|
||||
monkeypatch.setattr(console_wraps.dify_config, "EDITION", "CLOUD")
|
||||
|
||||
|
||||
class _PauseEntity:
|
||||
def __init__(self, paused_at: datetime, reasons: list[HumanInputRequired]):
|
||||
self.paused_at = paused_at
|
||||
self._reasons = reasons
|
||||
|
||||
def get_pause_reasons(self):
|
||||
return self._reasons
|
||||
|
||||
|
||||
def test_pause_details_returns_backstage_input_url(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
account = _make_account()
|
||||
_patch_console_guards(monkeypatch, account)
|
||||
monkeypatch.setattr(workflow_run_module.dify_config, "APP_WEB_URL", "https://web.example.com")
|
||||
|
||||
workflow_run = Mock(spec=WorkflowRun)
|
||||
workflow_run.tenant_id = "tenant-123"
|
||||
workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||
workflow_run.created_at = datetime(2024, 1, 1, 12, 0, 0)
|
||||
fake_db = SimpleNamespace(engine=Mock(), session=SimpleNamespace(get=lambda *_: workflow_run))
|
||||
monkeypatch.setattr(workflow_run_module, "db", fake_db)
|
||||
|
||||
reason = HumanInputRequired(
|
||||
form_id="form-1",
|
||||
form_content="content",
|
||||
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
|
||||
actions=[UserAction(id="approve", title="Approve")],
|
||||
node_id="node-1",
|
||||
node_title="Ask Name",
|
||||
form_token="backstage-token",
|
||||
)
|
||||
pause_entity = _PauseEntity(paused_at=datetime(2024, 1, 1, 12, 0, 0), reasons=[reason])
|
||||
|
||||
repo = Mock()
|
||||
repo.get_workflow_pause.return_value = pause_entity
|
||||
monkeypatch.setattr(
|
||||
workflow_run_module.DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_, **__: repo,
|
||||
)
|
||||
|
||||
with app.test_request_context("/console/api/workflow/run-1/pause-details", method="GET"):
|
||||
response, status = workflow_run_module.ConsoleWorkflowPauseDetailsApi().get(workflow_run_id="run-1")
|
||||
|
||||
assert status == 200
|
||||
assert response["paused_at"] == "2024-01-01T12:00:00Z"
|
||||
assert response["paused_nodes"][0]["node_id"] == "node-1"
|
||||
assert response["paused_nodes"][0]["pause_type"]["type"] == "human_input"
|
||||
assert (
|
||||
response["paused_nodes"][0]["pause_type"]["backstage_input_url"]
|
||||
== "https://web.example.com/form/backstage-token"
|
||||
)
|
||||
assert "pending_human_inputs" not in response
|
||||
|
||||
|
||||
def test_pause_details_tenant_isolation(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
account = _make_account()
|
||||
_patch_console_guards(monkeypatch, account)
|
||||
monkeypatch.setattr(workflow_run_module.dify_config, "APP_WEB_URL", "https://web.example.com")
|
||||
|
||||
workflow_run = Mock(spec=WorkflowRun)
|
||||
workflow_run.tenant_id = "tenant-456"
|
||||
workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||
workflow_run.created_at = datetime(2024, 1, 1, 12, 0, 0)
|
||||
fake_db = SimpleNamespace(engine=Mock(), session=SimpleNamespace(get=lambda *_: workflow_run))
|
||||
monkeypatch.setattr(workflow_run_module, "db", fake_db)
|
||||
|
||||
with pytest.raises(NotFoundError):
|
||||
with app.test_request_context("/console/api/workflow/run-1/pause-details", method="GET"):
|
||||
response, status = workflow_run_module.ConsoleWorkflowPauseDetailsApi().get(workflow_run_id="run-1")
|
||||
@@ -0,0 +1,25 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from controllers.service_api.app.workflow import WorkflowRunOutputsField, WorkflowRunStatusField
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
|
||||
|
||||
def test_workflow_run_status_field_with_enum() -> None:
|
||||
field = WorkflowRunStatusField()
|
||||
obj = SimpleNamespace(status=WorkflowExecutionStatus.PAUSED)
|
||||
|
||||
assert field.output("status", obj) == "paused"
|
||||
|
||||
|
||||
def test_workflow_run_outputs_field_paused_returns_empty() -> None:
|
||||
field = WorkflowRunOutputsField()
|
||||
obj = SimpleNamespace(status=WorkflowExecutionStatus.PAUSED, outputs_dict={"foo": "bar"})
|
||||
|
||||
assert field.output("outputs", obj) == {}
|
||||
|
||||
|
||||
def test_workflow_run_outputs_field_running_returns_outputs() -> None:
|
||||
field = WorkflowRunOutputsField()
|
||||
obj = SimpleNamespace(status=WorkflowExecutionStatus.RUNNING, outputs_dict={"foo": "bar"})
|
||||
|
||||
assert field.output("outputs", obj) == {"foo": "bar"}
|
||||
456
api/tests/unit_tests/controllers/web/test_human_input_form.py
Normal file
456
api/tests/unit_tests/controllers/web/test_human_input_form.py
Normal file
@@ -0,0 +1,456 @@
|
||||
"""Unit tests for controllers.web.human_input_form endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import controllers.web.human_input_form as human_input_module
|
||||
import controllers.web.site as site_module
|
||||
from controllers.web.error import WebFormRateLimitExceededError
|
||||
from models.human_input import RecipientType
|
||||
from services.human_input_service import FormExpiredError
|
||||
|
||||
HumanInputFormApi = human_input_module.HumanInputFormApi
|
||||
TenantStatus = human_input_module.TenantStatus
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
"""Configure a minimal Flask app for request contexts."""
|
||||
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
"""Simple stand-in for db.session that returns pre-seeded objects."""
|
||||
|
||||
def __init__(self, mapping: dict[str, Any]):
|
||||
self._mapping = mapping
|
||||
self._model_name: str | None = None
|
||||
|
||||
def query(self, model):
|
||||
self._model_name = model.__name__
|
||||
return self
|
||||
|
||||
def where(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
assert self._model_name is not None
|
||||
return self._mapping.get(self._model_name)
|
||||
|
||||
|
||||
class _FakeDB:
|
||||
"""Minimal db stub exposing engine and session."""
|
||||
|
||||
def __init__(self, session: _FakeSession):
|
||||
self.session = session
|
||||
self.engine = object()
|
||||
|
||||
|
||||
def test_get_form_includes_site(monkeypatch: pytest.MonkeyPatch, app: Flask):
|
||||
"""GET returns form definition merged with site payload."""
|
||||
|
||||
expiration_time = datetime(2099, 1, 1, tzinfo=UTC)
|
||||
|
||||
class _FakeDefinition:
|
||||
def model_dump(self):
|
||||
return {
|
||||
"form_content": "Raw content",
|
||||
"rendered_content": "Rendered {{#$output.name#}}",
|
||||
"inputs": [{"type": "text", "output_variable_name": "name", "default": None}],
|
||||
"default_values": {"name": "Alice", "age": 30, "meta": {"k": "v"}},
|
||||
"user_actions": [{"id": "approve", "title": "Approve", "button_style": "default"}],
|
||||
}
|
||||
|
||||
class _FakeForm:
|
||||
def __init__(self, expiration: datetime):
|
||||
self.workflow_run_id = "workflow-1"
|
||||
self.app_id = "app-1"
|
||||
self.tenant_id = "tenant-1"
|
||||
self.expiration_time = expiration
|
||||
self.recipient_type = RecipientType.BACKSTAGE
|
||||
|
||||
def get_definition(self):
|
||||
return _FakeDefinition()
|
||||
|
||||
form = _FakeForm(expiration_time)
|
||||
limiter_mock = MagicMock()
|
||||
limiter_mock.is_rate_limited.return_value = False
|
||||
monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock)
|
||||
monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10")
|
||||
|
||||
tenant = SimpleNamespace(
|
||||
id="tenant-1",
|
||||
status=TenantStatus.NORMAL,
|
||||
plan="basic",
|
||||
custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": False},
|
||||
)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True)
|
||||
workflow_run = SimpleNamespace(app_id="app-1")
|
||||
site_model = SimpleNamespace(
|
||||
title="My Site",
|
||||
icon_type="emoji",
|
||||
icon="robot",
|
||||
icon_background="#fff",
|
||||
description="desc",
|
||||
default_language="en",
|
||||
chat_color_theme="light",
|
||||
chat_color_theme_inverted=False,
|
||||
copyright=None,
|
||||
privacy_policy=None,
|
||||
custom_disclaimer=None,
|
||||
prompt_public=False,
|
||||
show_workflow_steps=True,
|
||||
use_icon_as_answer_icon=False,
|
||||
)
|
||||
|
||||
# Patch service to return fake form.
|
||||
service_mock = MagicMock()
|
||||
service_mock.get_form_by_token.return_value = form
|
||||
monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock)
|
||||
|
||||
# Patch db session.
|
||||
db_stub = _FakeDB(_FakeSession({"WorkflowRun": workflow_run, "App": app_model, "Site": site_model}))
|
||||
monkeypatch.setattr(human_input_module, "db", db_stub)
|
||||
|
||||
monkeypatch.setattr(
|
||||
site_module.FeatureService,
|
||||
"get_features",
|
||||
lambda tenant_id: SimpleNamespace(can_replace_logo=True),
|
||||
)
|
||||
|
||||
with app.test_request_context("/api/form/human_input/token-1", method="GET"):
|
||||
response = HumanInputFormApi().get("token-1")
|
||||
|
||||
body = json.loads(response.get_data(as_text=True))
|
||||
assert set(body.keys()) == {
|
||||
"site",
|
||||
"form_content",
|
||||
"inputs",
|
||||
"resolved_default_values",
|
||||
"user_actions",
|
||||
"expiration_time",
|
||||
}
|
||||
assert body["form_content"] == "Rendered {{#$output.name#}}"
|
||||
assert body["inputs"] == [{"type": "text", "output_variable_name": "name", "default": None}]
|
||||
assert body["resolved_default_values"] == {"name": "Alice", "age": "30", "meta": '{"k": "v"}'}
|
||||
assert body["user_actions"] == [{"id": "approve", "title": "Approve", "button_style": "default"}]
|
||||
assert body["expiration_time"] == int(expiration_time.timestamp())
|
||||
assert body["site"] == {
|
||||
"app_id": "app-1",
|
||||
"end_user_id": None,
|
||||
"enable_site": True,
|
||||
"site": {
|
||||
"title": "My Site",
|
||||
"chat_color_theme": "light",
|
||||
"chat_color_theme_inverted": False,
|
||||
"icon_type": "emoji",
|
||||
"icon": "robot",
|
||||
"icon_background": "#fff",
|
||||
"icon_url": None,
|
||||
"description": "desc",
|
||||
"copyright": None,
|
||||
"privacy_policy": None,
|
||||
"custom_disclaimer": None,
|
||||
"default_language": "en",
|
||||
"prompt_public": False,
|
||||
"show_workflow_steps": True,
|
||||
"use_icon_as_answer_icon": False,
|
||||
},
|
||||
"model_config": None,
|
||||
"plan": "basic",
|
||||
"can_replace_logo": True,
|
||||
"custom_config": {
|
||||
"remove_webapp_brand": True,
|
||||
"replace_webapp_logo": None,
|
||||
},
|
||||
}
|
||||
service_mock.get_form_by_token.assert_called_once_with("token-1")
|
||||
limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10")
|
||||
limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10")
|
||||
|
||||
|
||||
def test_get_form_allows_backstage_token(monkeypatch: pytest.MonkeyPatch, app: Flask):
|
||||
"""GET returns form payload for backstage token."""
|
||||
|
||||
expiration_time = datetime(2099, 1, 2, tzinfo=UTC)
|
||||
|
||||
class _FakeDefinition:
|
||||
def model_dump(self):
|
||||
return {
|
||||
"form_content": "Raw content",
|
||||
"rendered_content": "Rendered",
|
||||
"inputs": [],
|
||||
"default_values": {},
|
||||
"user_actions": [],
|
||||
}
|
||||
|
||||
class _FakeForm:
|
||||
def __init__(self, expiration: datetime):
|
||||
self.workflow_run_id = "workflow-1"
|
||||
self.app_id = "app-1"
|
||||
self.tenant_id = "tenant-1"
|
||||
self.expiration_time = expiration
|
||||
|
||||
def get_definition(self):
|
||||
return _FakeDefinition()
|
||||
|
||||
form = _FakeForm(expiration_time)
|
||||
limiter_mock = MagicMock()
|
||||
limiter_mock.is_rate_limited.return_value = False
|
||||
monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock)
|
||||
monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10")
|
||||
tenant = SimpleNamespace(
|
||||
id="tenant-1",
|
||||
status=TenantStatus.NORMAL,
|
||||
plan="basic",
|
||||
custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": False},
|
||||
)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True)
|
||||
workflow_run = SimpleNamespace(app_id="app-1")
|
||||
site_model = SimpleNamespace(
|
||||
title="My Site",
|
||||
icon_type="emoji",
|
||||
icon="robot",
|
||||
icon_background="#fff",
|
||||
description="desc",
|
||||
default_language="en",
|
||||
chat_color_theme="light",
|
||||
chat_color_theme_inverted=False,
|
||||
copyright=None,
|
||||
privacy_policy=None,
|
||||
custom_disclaimer=None,
|
||||
prompt_public=False,
|
||||
show_workflow_steps=True,
|
||||
use_icon_as_answer_icon=False,
|
||||
)
|
||||
|
||||
service_mock = MagicMock()
|
||||
service_mock.get_form_by_token.return_value = form
|
||||
monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock)
|
||||
|
||||
db_stub = _FakeDB(_FakeSession({"WorkflowRun": workflow_run, "App": app_model, "Site": site_model}))
|
||||
monkeypatch.setattr(human_input_module, "db", db_stub)
|
||||
|
||||
monkeypatch.setattr(
|
||||
site_module.FeatureService,
|
||||
"get_features",
|
||||
lambda tenant_id: SimpleNamespace(can_replace_logo=True),
|
||||
)
|
||||
|
||||
with app.test_request_context("/api/form/human_input/token-1", method="GET"):
|
||||
response = HumanInputFormApi().get("token-1")
|
||||
|
||||
body = json.loads(response.get_data(as_text=True))
|
||||
assert set(body.keys()) == {
|
||||
"site",
|
||||
"form_content",
|
||||
"inputs",
|
||||
"resolved_default_values",
|
||||
"user_actions",
|
||||
"expiration_time",
|
||||
}
|
||||
assert body["form_content"] == "Rendered"
|
||||
assert body["inputs"] == []
|
||||
assert body["resolved_default_values"] == {}
|
||||
assert body["user_actions"] == []
|
||||
assert body["expiration_time"] == int(expiration_time.timestamp())
|
||||
assert body["site"] == {
|
||||
"app_id": "app-1",
|
||||
"end_user_id": None,
|
||||
"enable_site": True,
|
||||
"site": {
|
||||
"title": "My Site",
|
||||
"chat_color_theme": "light",
|
||||
"chat_color_theme_inverted": False,
|
||||
"icon_type": "emoji",
|
||||
"icon": "robot",
|
||||
"icon_background": "#fff",
|
||||
"icon_url": None,
|
||||
"description": "desc",
|
||||
"copyright": None,
|
||||
"privacy_policy": None,
|
||||
"custom_disclaimer": None,
|
||||
"default_language": "en",
|
||||
"prompt_public": False,
|
||||
"show_workflow_steps": True,
|
||||
"use_icon_as_answer_icon": False,
|
||||
},
|
||||
"model_config": None,
|
||||
"plan": "basic",
|
||||
"can_replace_logo": True,
|
||||
"custom_config": {
|
||||
"remove_webapp_brand": True,
|
||||
"replace_webapp_logo": None,
|
||||
},
|
||||
}
|
||||
service_mock.get_form_by_token.assert_called_once_with("token-1")
|
||||
limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10")
|
||||
limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10")
|
||||
|
||||
|
||||
def test_get_form_raises_forbidden_when_site_missing(monkeypatch: pytest.MonkeyPatch, app: Flask):
|
||||
"""GET raises Forbidden if site cannot be resolved."""
|
||||
|
||||
expiration_time = datetime(2099, 1, 3, tzinfo=UTC)
|
||||
|
||||
class _FakeDefinition:
|
||||
def model_dump(self):
|
||||
return {
|
||||
"form_content": "Raw content",
|
||||
"rendered_content": "Rendered",
|
||||
"inputs": [],
|
||||
"default_values": {},
|
||||
"user_actions": [],
|
||||
}
|
||||
|
||||
class _FakeForm:
|
||||
def __init__(self, expiration: datetime):
|
||||
self.workflow_run_id = "workflow-1"
|
||||
self.app_id = "app-1"
|
||||
self.tenant_id = "tenant-1"
|
||||
self.expiration_time = expiration
|
||||
|
||||
def get_definition(self):
|
||||
return _FakeDefinition()
|
||||
|
||||
form = _FakeForm(expiration_time)
|
||||
limiter_mock = MagicMock()
|
||||
limiter_mock.is_rate_limited.return_value = False
|
||||
monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock)
|
||||
monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10")
|
||||
tenant = SimpleNamespace(status=TenantStatus.NORMAL)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant)
|
||||
workflow_run = SimpleNamespace(app_id="app-1")
|
||||
|
||||
service_mock = MagicMock()
|
||||
service_mock.get_form_by_token.return_value = form
|
||||
monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock)
|
||||
|
||||
db_stub = _FakeDB(_FakeSession({"WorkflowRun": workflow_run, "App": app_model, "Site": None}))
|
||||
monkeypatch.setattr(human_input_module, "db", db_stub)
|
||||
|
||||
with app.test_request_context("/api/form/human_input/token-1", method="GET"):
|
||||
with pytest.raises(Forbidden):
|
||||
HumanInputFormApi().get("token-1")
|
||||
limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10")
|
||||
limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10")
|
||||
|
||||
|
||||
def test_submit_form_accepts_backstage_token(monkeypatch: pytest.MonkeyPatch, app: Flask):
|
||||
"""POST forwards backstage submissions to the service."""
|
||||
|
||||
class _FakeForm:
|
||||
recipient_type = RecipientType.BACKSTAGE
|
||||
|
||||
form = _FakeForm()
|
||||
limiter_mock = MagicMock()
|
||||
limiter_mock.is_rate_limited.return_value = False
|
||||
monkeypatch.setattr(human_input_module, "_FORM_SUBMIT_RATE_LIMITER", limiter_mock)
|
||||
monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10")
|
||||
service_mock = MagicMock()
|
||||
service_mock.get_form_by_token.return_value = form
|
||||
monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock)
|
||||
monkeypatch.setattr(human_input_module, "db", _FakeDB(_FakeSession({})))
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/form/human_input/token-1",
|
||||
method="POST",
|
||||
json={"inputs": {"content": "ok"}, "action": "approve"},
|
||||
):
|
||||
response, status = HumanInputFormApi().post("token-1")
|
||||
|
||||
assert status == 200
|
||||
assert response == {}
|
||||
service_mock.submit_form_by_token.assert_called_once_with(
|
||||
recipient_type=RecipientType.BACKSTAGE,
|
||||
form_token="token-1",
|
||||
selected_action_id="approve",
|
||||
form_data={"content": "ok"},
|
||||
submission_end_user_id=None,
|
||||
)
|
||||
limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10")
|
||||
limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10")
|
||||
|
||||
|
||||
def test_submit_form_rate_limited(monkeypatch: pytest.MonkeyPatch, app: Flask):
|
||||
"""POST rejects submissions when rate limit is exceeded."""
|
||||
|
||||
limiter_mock = MagicMock()
|
||||
limiter_mock.is_rate_limited.return_value = True
|
||||
monkeypatch.setattr(human_input_module, "_FORM_SUBMIT_RATE_LIMITER", limiter_mock)
|
||||
monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10")
|
||||
|
||||
service_mock = MagicMock()
|
||||
service_mock.get_form_by_token.return_value = None
|
||||
monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock)
|
||||
monkeypatch.setattr(human_input_module, "db", _FakeDB(_FakeSession({})))
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/form/human_input/token-1",
|
||||
method="POST",
|
||||
json={"inputs": {"content": "ok"}, "action": "approve"},
|
||||
):
|
||||
with pytest.raises(WebFormRateLimitExceededError):
|
||||
HumanInputFormApi().post("token-1")
|
||||
|
||||
limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10")
|
||||
limiter_mock.increment_rate_limit.assert_not_called()
|
||||
service_mock.get_form_by_token.assert_not_called()
|
||||
|
||||
|
||||
def test_get_form_rate_limited(monkeypatch: pytest.MonkeyPatch, app: Flask):
|
||||
"""GET rejects requests when rate limit is exceeded."""
|
||||
|
||||
limiter_mock = MagicMock()
|
||||
limiter_mock.is_rate_limited.return_value = True
|
||||
monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock)
|
||||
monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10")
|
||||
|
||||
service_mock = MagicMock()
|
||||
service_mock.get_form_by_token.return_value = None
|
||||
monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock)
|
||||
monkeypatch.setattr(human_input_module, "db", _FakeDB(_FakeSession({})))
|
||||
|
||||
with app.test_request_context("/api/form/human_input/token-1", method="GET"):
|
||||
with pytest.raises(WebFormRateLimitExceededError):
|
||||
HumanInputFormApi().get("token-1")
|
||||
|
||||
limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10")
|
||||
limiter_mock.increment_rate_limit.assert_not_called()
|
||||
service_mock.get_form_by_token.assert_not_called()
|
||||
|
||||
|
||||
def test_get_form_raises_expired(monkeypatch: pytest.MonkeyPatch, app: Flask):
|
||||
class _FakeForm:
|
||||
pass
|
||||
|
||||
form = _FakeForm()
|
||||
limiter_mock = MagicMock()
|
||||
limiter_mock.is_rate_limited.return_value = False
|
||||
monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock)
|
||||
monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10")
|
||||
service_mock = MagicMock()
|
||||
service_mock.get_form_by_token.return_value = form
|
||||
service_mock.ensure_form_active.side_effect = FormExpiredError("form-id")
|
||||
monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock)
|
||||
monkeypatch.setattr(human_input_module, "db", _FakeDB(_FakeSession({})))
|
||||
|
||||
with app.test_request_context("/api/form/human_input/token-1", method="GET"):
|
||||
with pytest.raises(FormExpiredError):
|
||||
HumanInputFormApi().get("token-1")
|
||||
|
||||
service_mock.ensure_form_active.assert_called_once_with(form)
|
||||
limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10")
|
||||
limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10")
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
@@ -12,6 +13,8 @@ import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from core.entities.execution_extra_content import HumanInputContent
|
||||
|
||||
# Ensure flask_restx.api finds MethodView during import.
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
@@ -137,6 +140,12 @@ def test_message_list_mapping(app: Flask) -> None:
|
||||
status="success",
|
||||
error=None,
|
||||
message_metadata_dict={"meta": "value"},
|
||||
extra_contents=[
|
||||
HumanInputContent(
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
submitted=True,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
pagination = SimpleNamespace(limit=20, has_more=False, data=[message])
|
||||
@@ -169,6 +178,8 @@ def test_message_list_mapping(app: Flask) -> None:
|
||||
|
||||
assert item["agent_thoughts"][0]["chain_id"] == "chain-1"
|
||||
assert item["agent_thoughts"][0]["created_at"] == int(thought_created_at.timestamp())
|
||||
assert item["extra_contents"][0]["workflow_run_id"] == message.extra_contents[0].workflow_run_id
|
||||
assert item["extra_contents"][0]["submitted"] == message.extra_contents[0].submitted
|
||||
|
||||
assert item["message_files"][0]["id"] == "file-dict"
|
||||
assert item["message_files"][1]["id"] == "file-obj"
|
||||
|
||||
Reference in New Issue
Block a user