Merge branch 'main' into tp

This commit is contained in:
JzoNg
2026-04-27 10:20:08 +08:00
316 changed files with 12484 additions and 7806 deletions

View File

@@ -208,8 +208,6 @@ class TestAnnotationImportServiceValidation:
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
@@ -230,8 +228,6 @@ class TestAnnotationImportServiceValidation:
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
@@ -248,8 +244,6 @@ class TestAnnotationImportServiceValidation:
csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff'
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
with (
patch("services.annotation_service.current_account_with_tenant") as mock_auth,
patch("services.annotation_service.pd.read_csv", side_effect=ParserError("malformed CSV")),
@@ -269,8 +263,6 @@ class TestAnnotationImportServiceValidation:
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")

View File

@@ -43,7 +43,6 @@ class TestAuthenticationSecurity:
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.")
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
mock_features.return_value.is_allow_register = True
# Act
@@ -76,7 +75,6 @@ class TestAuthenticationSecurity:
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Wrong password")
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
# Act
with self.app.test_request_context(
@@ -109,7 +107,6 @@ class TestAuthenticationSecurity:
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.")
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
mock_features.return_value.is_allow_register = False
# Act
@@ -135,7 +132,6 @@ class TestAuthenticationSecurity:
def test_reset_password_with_existing_account(self, mock_send_email, mock_get_user, mock_features, mock_db):
"""Test that reset password returns success with token for existing accounts."""
# Mock the setup check
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
# Test with existing account
mock_get_user.return_value = MagicMock(email="existing@example.com")

View File

@@ -65,7 +65,6 @@ class TestEmailCodeLoginSendEmailApi:
- IP rate limiting is checked
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_user.return_value = mock_account
mock_send_email.return_value = "email_token_123"
@@ -98,7 +97,6 @@ class TestEmailCodeLoginSendEmailApi:
- Registration is allowed by system features
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_user.return_value = None
mock_get_features.return_value.is_allow_register = True
@@ -130,7 +128,6 @@ class TestEmailCodeLoginSendEmailApi:
- Registration is blocked by system features
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_user.return_value = None
mock_get_features.return_value.is_allow_register = False
@@ -152,7 +149,6 @@ class TestEmailCodeLoginSendEmailApi:
- Prevents spam and abuse
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = True
# Act & Assert
@@ -172,7 +168,6 @@ class TestEmailCodeLoginSendEmailApi:
- AccountInFreezeError is raised for frozen accounts
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_user.side_effect = AccountRegisterError("Account frozen")
@@ -213,7 +208,6 @@ class TestEmailCodeLoginSendEmailApi:
- Defaults to en-US when not specified
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_user.return_value = mock_account
mock_send_email.return_value = "token"
@@ -286,7 +280,6 @@ class TestEmailCodeLoginApi:
- User is logged in with token pair
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
mock_get_user.return_value = mock_account
mock_get_tenants.return_value = [MagicMock()]
@@ -335,7 +328,6 @@ class TestEmailCodeLoginApi:
- User is logged in after account creation
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "newuser@example.com", "code": "123456"}
mock_get_user.return_value = None
mock_create_account.return_value = mock_account
@@ -369,7 +361,6 @@ class TestEmailCodeLoginApi:
- InvalidTokenError is raised for invalid/expired tokens
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = None
# Act & Assert
@@ -392,7 +383,6 @@ class TestEmailCodeLoginApi:
- InvalidEmailError is raised when email doesn't match token
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "original@example.com", "code": "123456"}
# Act & Assert
@@ -415,7 +405,6 @@ class TestEmailCodeLoginApi:
- EmailCodeError is raised for wrong verification code
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
# Act & Assert
@@ -453,7 +442,6 @@ class TestEmailCodeLoginApi:
- User is added as owner of new workspace
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
mock_get_user.return_value = mock_account
mock_get_tenants.return_value = []
@@ -496,7 +484,6 @@ class TestEmailCodeLoginApi:
- WorkspacesLimitExceeded is raised when limit reached
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
mock_get_user.return_value = mock_account
mock_get_tenants.return_value = []
@@ -538,7 +525,6 @@ class TestEmailCodeLoginApi:
- NotAllowedCreateWorkspace is raised when creation disabled
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
mock_get_user.return_value = mock_account
mock_get_tenants.return_value = []

View File

@@ -110,7 +110,6 @@ class TestLoginApi:
- Rate limit is reset after successful login
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.return_value = mock_account
@@ -162,7 +161,6 @@ class TestLoginApi:
- Authentication proceeds with invitation token
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = {"data": {"email": "test@example.com"}}
mock_authenticate.return_value = mock_account
@@ -199,7 +197,6 @@ class TestLoginApi:
- EmailPasswordLoginLimitError is raised when limit exceeded
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = True
mock_get_invitation.return_value = None
@@ -228,7 +225,6 @@ class TestLoginApi:
- AccountInFreezeError is raised for frozen accounts
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_frozen.return_value = True
# Act & Assert
@@ -268,7 +264,6 @@ class TestLoginApi:
- Generic error message prevents user enumeration
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = AccountPasswordError("Invalid password")
@@ -305,7 +300,6 @@ class TestLoginApi:
- Login is prevented even with valid credentials
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = AccountLoginError("Account is banned")
@@ -351,7 +345,6 @@ class TestLoginApi:
- User cannot login without an assigned workspace
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.return_value = mock_account
@@ -383,7 +376,6 @@ class TestLoginApi:
- Security check prevents invitation token abuse
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = {"data": {"email": "invited@example.com"}}
@@ -425,7 +417,6 @@ class TestLoginApi:
mock_token_pair,
):
"""Test that login retries with lowercase email when uppercase lookup fails."""
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = [AccountPasswordError("Invalid"), mock_account]
@@ -459,7 +450,6 @@ class TestLoginApi:
mock_db,
app,
):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"}
mock_get_account.side_effect = Unauthorized("Account is banned.")
@@ -513,7 +503,6 @@ class TestLogoutApi:
- Success response is returned
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_current_account.return_value = (mock_account, MagicMock())
# Act
@@ -539,7 +528,6 @@ class TestLogoutApi:
- Success response is returned
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
# Create a mock anonymous user that will pass isinstance check
anonymous_user = MagicMock()
mock_flask_login.AnonymousUserMixin = type("AnonymousUserMixin", (), {})

View File

@@ -46,7 +46,6 @@ class TestPartnerTenants:
patch("libs.login.dify_config.LOGIN_DISABLED", False),
patch("libs.login.check_csrf_token") as mock_csrf,
):
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
mock_csrf.return_value = None
yield {"db": mock_db, "csrf": mock_csrf}

View File

@@ -8,8 +8,10 @@ from werkzeug.exceptions import Forbidden
import controllers.console.tag.tags as module
from controllers.console import console_ns
from controllers.console.tag.tags import (
TagBindingCreateApi,
TagBindingDeleteApi,
DeprecatedTagBindingCreateApi,
DeprecatedTagBindingRemoveApi,
TagBindingCollectionApi,
TagBindingItemApi,
TagListApi,
TagUpdateDeleteApi,
)
@@ -205,9 +207,9 @@ class TestTagUpdateDeleteApi:
assert status == 204
class TestTagBindingCreateApi:
class TestTagBindingCollectionApi:
def test_create_success(self, app, admin_user, payload_patch):
api = TagBindingCreateApi()
api = TagBindingCollectionApi()
method = unwrap(api.post)
payload = {
@@ -232,7 +234,7 @@ class TestTagBindingCreateApi:
assert result["result"] == "success"
def test_create_forbidden(self, app, readonly_user, payload_patch):
api = TagBindingCreateApi()
api = TagBindingCollectionApi()
method = unwrap(api.post)
with app.test_request_context("/", json={}):
@@ -247,9 +249,78 @@ class TestTagBindingCreateApi:
method(api)
class TestTagBindingDeleteApi:
class TestDeprecatedTagBindingCreateApi:
def test_create_success(self, app, admin_user, payload_patch):
api = DeprecatedTagBindingCreateApi()
method = unwrap(api.post)
payload = {
"tag_ids": ["tag-1"],
"target_id": "target-1",
"type": "knowledge",
}
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(admin_user, None),
),
payload_patch(payload),
patch("controllers.console.tag.tags.TagService.save_tag_binding") as save_mock,
):
result, status = method(api)
save_mock.assert_called_once()
assert status == 200
assert result["result"] == "success"
class TestTagBindingItemApi:
def test_delete_success(self, app, admin_user, payload_patch):
api = TagBindingItemApi()
method = unwrap(api.delete)
payload = {
"target_id": "target-1",
"type": "knowledge",
}
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(admin_user, None),
),
payload_patch(payload),
patch("controllers.console.tag.tags.TagService.delete_tag_binding") as delete_mock,
):
result, status = method(api, "tag-1")
delete_mock.assert_called_once()
delete_payload = delete_mock.call_args.args[0]
assert delete_payload.tag_id == "tag-1"
assert delete_payload.target_id == "target-1"
assert delete_payload.type == TagType.KNOWLEDGE
assert status == 200
assert result["result"] == "success"
def test_delete_forbidden(self, app, readonly_user):
api = TagBindingItemApi()
method = unwrap(api.delete)
with app.test_request_context("/"):
with patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(readonly_user, None),
):
with pytest.raises(Forbidden):
method(api, "tag-1")
class TestDeprecatedTagBindingRemoveApi:
def test_remove_success(self, app, admin_user, payload_patch):
api = TagBindingDeleteApi()
api = DeprecatedTagBindingRemoveApi()
method = unwrap(api.post)
payload = {
@@ -274,7 +345,7 @@ class TestTagBindingDeleteApi:
assert result["result"] == "success"
def test_remove_forbidden(self, app, readonly_user, payload_patch):
api = TagBindingDeleteApi()
api = DeprecatedTagBindingRemoveApi()
method = unwrap(api.post)
with app.test_request_context("/", json={}):
@@ -297,3 +368,35 @@ class TestTagResponseModel:
assert payload["type"] == "knowledge"
assert payload["binding_count"] == "1"
class TestTagBindingRouteMetadata:
def test_legacy_write_routes_are_marked_deprecated(self):
assert DeprecatedTagBindingCreateApi.post.__apidoc__["deprecated"] is True
assert DeprecatedTagBindingRemoveApi.post.__apidoc__["deprecated"] is True
assert TagBindingCollectionApi.post.__apidoc__.get("deprecated") is not True
assert TagBindingItemApi.delete.__apidoc__.get("deprecated") is not True
def test_write_routes_have_stable_operation_ids(self):
assert TagBindingCollectionApi.post.__apidoc__["id"] == "create_tag_binding"
assert TagBindingItemApi.delete.__apidoc__["id"] == "delete_tag_binding"
assert DeprecatedTagBindingCreateApi.post.__apidoc__["id"] == "create_tag_binding_deprecated"
assert DeprecatedTagBindingRemoveApi.post.__apidoc__["id"] == "delete_tag_binding_deprecated"
def test_canonical_and_legacy_write_routes_are_registered(self):
route_map = {
resource.__name__: urls
for resource, urls, _route_doc, _kwargs in console_ns.resources
if resource.__name__
in {
"TagBindingCollectionApi",
"TagBindingItemApi",
"DeprecatedTagBindingCreateApi",
"DeprecatedTagBindingRemoveApi",
}
}
assert route_map["TagBindingCollectionApi"] == ("/tag-bindings",)
assert route_map["TagBindingItemApi"] == ("/tag-bindings/<uuid:id>",)
assert route_map["DeprecatedTagBindingCreateApi"] == ("/tag-bindings/create",)
assert route_map["DeprecatedTagBindingRemoveApi"] == ("/tag-bindings/remove",)

View File

@@ -122,6 +122,35 @@ def test_post_form_invalid_recipient_type(app, monkeypatch: pytest.MonkeyPatch)
handler(api, form_token="token")
def test_post_form_rejects_webapp_recipient_type(app, monkeypatch: pytest.MonkeyPatch) -> None:
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.STANDALONE_WEB_APP)
class _ServiceStub:
def __init__(self, *_args, **_kwargs):
pass
def get_form_by_token(self, _token):
return form
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
monkeypatch.setattr(
"controllers.console.human_input_form.current_account_with_tenant",
lambda: (SimpleNamespace(id="user-1"), "tenant-1"),
)
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleHumanInputFormApi()
handler = _unwrap(api.post)
with app.test_request_context(
"/console/api/form/human_input/token",
method="POST",
json={"inputs": {"content": "ok"}, "action": "approve"},
):
with pytest.raises(NotFoundError):
handler(api, form_token="token")
def test_post_form_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
submit_mock = Mock()
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.CONSOLE)

View File

@@ -24,10 +24,6 @@ def app():
return app
def _mock_wraps_db(mock_db):
mock_db.session.query.return_value.first.return_value = MagicMock()
def _build_account(email: str, account_id: str = "acc", tenant: object | None = None) -> Account:
tenant_obj = tenant if tenant is not None else SimpleNamespace(id="tenant-id")
account = Account(name=account_id, email=email)
@@ -64,7 +60,6 @@ class TestChangeEmailSend:
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("current@example.com", "acc1")
mock_current_account.return_value = (mock_account, None)
@@ -117,7 +112,6 @@ class TestChangeEmailSend:
"""GHSA-4q3w-q5mc-45rq: a phase-1 token must not unlock the new-email send step."""
from controllers.console.auth.error import InvalidTokenError
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("current@example.com", "acc1")
mock_current_account.return_value = (mock_account, None)
@@ -163,7 +157,6 @@ class TestChangeEmailValidity:
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("user@example.com", "acc2")
mock_current_account.return_value = (mock_account, None)
@@ -223,7 +216,6 @@ class TestChangeEmailValidity:
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
mock_is_rate_limit.return_value = False
@@ -280,7 +272,6 @@ class TestChangeEmailValidity:
"""A token whose phase marker is a string but not a known transition must be rejected."""
from controllers.console.auth.error import InvalidTokenError
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
mock_is_rate_limit.return_value = False
@@ -330,7 +321,6 @@ class TestChangeEmailValidity:
"""A token minted without a phase marker (e.g. a hand-crafted token) must not validate."""
from controllers.console.auth.error import InvalidTokenError
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
mock_is_rate_limit.return_value = False
@@ -378,7 +368,6 @@ class TestChangeEmailReset:
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
current_user = _build_account("old@example.com", "acc3")
mock_current_account.return_value = (current_user, None)
@@ -434,7 +423,6 @@ class TestChangeEmailReset:
"""GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset."""
from controllers.console.auth.error import InvalidTokenError
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
current_user = _build_account("old@example.com", "acc3")
mock_current_account.return_value = (current_user, None)
@@ -488,7 +476,6 @@ class TestChangeEmailReset:
"""A verified token for address A must not be replayed to change to address B."""
from controllers.console.auth.error import InvalidTokenError
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
current_user = _build_account("old@example.com", "acc3")
mock_current_account.return_value = (current_user, None)
@@ -561,7 +548,6 @@ class TestAccountDeletionFeedback:
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.BillingService.update_account_deletion_feedback")
def test_should_normalize_feedback_email(self, mock_update, mock_db, app):
_mock_wraps_db(mock_db)
with app.test_request_context(
"/account/delete/feedback",
method="POST",
@@ -578,7 +564,6 @@ class TestCheckEmailUnique:
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app):
_mock_wraps_db(mock_db)
mock_is_freeze.return_value = False
mock_check_unique.return_value = True

View File

@@ -1,5 +1,5 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from unittest.mock import patch
import pytest
from flask import Flask, g
@@ -16,10 +16,6 @@ def app():
return flask_app
def _mock_wraps_db(mock_db):
mock_db.session.query.return_value.first.return_value = MagicMock()
def _build_feature_flags():
placeholder_quota = SimpleNamespace(limit=0, size=0)
workspace_members = SimpleNamespace(is_available=lambda count: True)
@@ -49,7 +45,6 @@ class TestMemberInviteEmailApi:
mock_get_features,
app,
):
_mock_wraps_db(mock_db)
mock_get_features.return_value = _build_feature_flags()
mock_invite_member.return_value = "token-abc"

View File

@@ -310,7 +310,6 @@ class TestSystemSetup:
def test_should_allow_when_setup_complete(self, mock_db):
"""Test that requests are allowed when setup is complete"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock() # Setup exists
@setup_required
def admin_view():

View File

@@ -22,7 +22,7 @@ _WRAPS_MODULE: ModuleType | None = None
@contextmanager
def _mock_db():
mock_session = SimpleNamespace(query=lambda *args, **kwargs: SimpleNamespace(first=lambda: True))
mock_session = SimpleNamespace(scalar=lambda *args, **kwargs: True)
with patch("extensions.ext_database.db.session", mock_session):
yield

View File

@@ -12,7 +12,7 @@ from controllers.service_api.app.app import AppInfoApi, AppMetaApi, AppParameter
from controllers.service_api.app.error import AppUnavailableError
from models.account import TenantStatus
from models.model import App, AppMode
from tests.unit_tests.conftest import setup_mock_tenant_account_query
from tests.unit_tests.conftest import setup_mock_tenant_owner_execute_result
class TestAppParameterApi:
@@ -74,7 +74,7 @@ class TestAppParameterApi:
# Mock tenant owner info for login
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
@@ -120,7 +120,7 @@ class TestAppParameterApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
@@ -161,7 +161,7 @@ class TestAppParameterApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act & Assert
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
@@ -200,7 +200,7 @@ class TestAppParameterApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act & Assert
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
@@ -263,7 +263,7 @@ class TestAppMetaApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/meta", method="GET", headers={"Authorization": "Bearer test_token"}):
@@ -331,7 +331,7 @@ class TestAppInfoApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
@@ -388,7 +388,7 @@ class TestAppInfoApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
@@ -434,7 +434,7 @@ class TestAppInfoApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
@@ -486,7 +486,7 @@ class TestAppInfoApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):

View File

@@ -0,0 +1,707 @@
"""Dedicated tests for HITL behavior exposed through the Service API."""
from __future__ import annotations
import json
import sys
from collections.abc import Sequence
from dataclasses import dataclass
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import ANY, MagicMock, Mock
import pytest
import services.app_generate_service as ags_module
from controllers.service_api.app.workflow_events import WorkflowEventsApi
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
from core.app.apps.common import workflow_response_converter
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import QueueWorkflowPausedEvent
from core.app.entities.task_entities import (
AdvancedChatPausedBlockingResponse,
HumanInputRequiredResponse,
WorkflowAppPausedBlockingResponse,
WorkflowPauseStreamResponse,
)
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper
from core.workflow.human_input_policy import HumanInputSurface
from core.workflow.system_variables import build_system_variables
from graphon.entities import WorkflowStartReason
from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
from graphon.nodes.human_input.entities import FormInput, UserAction
from graphon.nodes.human_input.enums import FormInputType
from graphon.runtime import GraphRuntimeState, VariablePool
from models.account import Account
from models.enums import CreatorUserRole
from models.model import AppMode
from models.workflow import WorkflowRun
from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot
from repositories.entities.workflow_pause import WorkflowPauseEntity
from services.app_generate_service import AppGenerateService
from services.workflow_event_snapshot_service import _build_snapshot_events
from tests.unit_tests.controllers.service_api.conftest import _unwrap
class _DummyRateLimit:
@staticmethod
def gen_request_key() -> str:
return "dummy-request-id"
def __init__(self, client_id: str, max_active_requests: int) -> None:
self.client_id = client_id
self.max_active_requests = max_active_requests
def enter(self, request_id: str | None = None) -> str:
return request_id or "dummy-request-id"
def exit(self, request_id: str) -> None:
return None
def generate(self, generator, request_id: str):
return generator
def _mock_repo_for_run(monkeypatch: pytest.MonkeyPatch, workflow_run):
workflow_events_module = sys.modules["controllers.service_api.app.workflow_events"]
repo = SimpleNamespace(get_workflow_run_by_id_and_tenant_id=lambda **_kwargs: workflow_run)
monkeypatch.setattr(
workflow_events_module.DifyAPIRepositoryFactory,
"create_api_workflow_run_repository",
lambda *_args, **_kwargs: repo,
)
monkeypatch.setattr(workflow_events_module, "db", SimpleNamespace(engine=object()))
return workflow_events_module
def _build_service_api_pause_converter() -> WorkflowResponseConverter:
application_generate_entity = SimpleNamespace(
inputs={},
files=[],
invoke_from=InvokeFrom.SERVICE_API,
app_config=SimpleNamespace(app_id="app-id", tenant_id="tenant-id"),
)
system_variables = build_system_variables(
user_id="user",
app_id="app-id",
workflow_id="workflow-id",
workflow_execution_id="run-id",
)
user = MagicMock(spec=Account)
user.id = "account-id"
user.name = "Tester"
user.email = "tester@example.com"
return WorkflowResponseConverter(
application_generate_entity=application_generate_entity,
user=user,
system_variables=system_variables,
)
def _build_advanced_chat_paused_blocking_response() -> AdvancedChatPausedBlockingResponse:
data = AdvancedChatPausedBlockingResponse.Data(
id="msg-1",
mode="chat",
conversation_id="c1",
message_id="m1",
workflow_run_id="run-1",
answer="partial",
metadata={"usage": {"total_tokens": 1}},
created_at=1,
paused_nodes=["node-1"],
reasons=[
{
"type": PauseReasonType.HUMAN_INPUT_REQUIRED,
"form_id": "form-1",
"expiration_time": 100,
}
],
status=WorkflowExecutionStatus.PAUSED,
elapsed_time=0.1,
total_tokens=0,
total_steps=0,
)
return AdvancedChatPausedBlockingResponse(task_id="t1", data=data)
def _build_workflow_paused_blocking_response() -> WorkflowAppPausedBlockingResponse:
return WorkflowAppPausedBlockingResponse(
task_id="t1",
workflow_run_id="r1",
data=WorkflowAppPausedBlockingResponse.Data(
id="r1",
workflow_id="wf-1",
status=WorkflowExecutionStatus.PAUSED,
outputs={},
error=None,
elapsed_time=0.5,
total_tokens=0,
total_steps=2,
created_at=1,
finished_at=None,
paused_nodes=["node-1"],
reasons=[{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 100}],
),
)
@dataclass(frozen=True)
class _FakePauseEntity(WorkflowPauseEntity):
pause_id: str
workflow_run_id: str
paused_at_value: datetime
pause_reasons: Sequence[HumanInputRequired]
@property
def id(self) -> str:
return self.pause_id
@property
def workflow_execution_id(self) -> str:
return self.workflow_run_id
def get_state(self) -> bytes:
raise AssertionError("state is not required for snapshot tests")
@property
def resumed_at(self) -> datetime | None:
return None
@property
def paused_at(self) -> datetime:
return self.paused_at_value
def get_pause_reasons(self) -> Sequence[HumanInputRequired]:
return self.pause_reasons
def _build_workflow_run(status: WorkflowExecutionStatus) -> WorkflowRun:
return WorkflowRun(
id="run-1",
tenant_id="tenant-1",
app_id="app-1",
workflow_id="workflow-1",
type="workflow",
triggered_from="app-run",
version="v1",
graph=None,
inputs=json.dumps({"input": "value"}),
status=status,
outputs=json.dumps({}),
error=None,
elapsed_time=0.0,
total_tokens=0,
total_steps=0,
created_by_role=CreatorUserRole.END_USER,
created_by="user-1",
created_at=datetime(2024, 1, 1, tzinfo=UTC),
)
def _build_snapshot(status: WorkflowNodeExecutionStatus) -> WorkflowNodeExecutionSnapshot:
created_at = datetime(2024, 1, 1, tzinfo=UTC)
finished_at = datetime(2024, 1, 1, 0, 0, 5, tzinfo=UTC)
return WorkflowNodeExecutionSnapshot(
execution_id="exec-1",
node_id="node-1",
node_type="human-input",
title="Human Input",
index=1,
status=status.value,
elapsed_time=0.5,
created_at=created_at,
finished_at=finished_at,
iteration_id=None,
loop_id=None,
)
def _build_resumption_context(task_id: str) -> WorkflowResumptionContext:
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant-1",
app_id="app-1",
app_mode=AppMode.WORKFLOW,
workflow_id="workflow-1",
)
generate_entity = WorkflowAppGenerateEntity(
task_id=task_id,
app_config=app_config,
inputs={},
files=[],
user_id="user-1",
stream=True,
invoke_from=InvokeFrom.EXPLORE,
call_depth=0,
workflow_execution_id="run-1",
)
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
runtime_state.register_paused_node("node-1")
runtime_state.outputs = {"result": "value"}
wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity)
return WorkflowResumptionContext(
generate_entity=wrapper,
serialized_graph_runtime_state=runtime_state.dumps(),
)
class TestHitlServiceApi:
# Service API event-stream continuation
def test_workflow_events_continue_on_pause_keeps_stream_open(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
app_id="app-1",
created_by_role=CreatorUserRole.END_USER,
created_by="end-user-1",
finished_at=None,
)
workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
msg_generator = Mock()
msg_generator.retrieve_events.return_value = ["raw-event"]
workflow_generator = Mock()
workflow_generator.convert_to_event_stream.return_value = iter(["data: streamed\n\n"])
monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator)
monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context("/workflow/run-1/events?user=u1&continue_on_pause=true", method="GET"):
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
assert response.get_data(as_text=True) == "data: streamed\n\n"
msg_generator.retrieve_events.assert_called_once_with(
AppMode.WORKFLOW,
"run-1",
terminal_events=[],
)
workflow_generator.convert_to_event_stream.assert_called_once_with(["raw-event"])
def test_workflow_events_snapshot_continue_on_pause_keeps_pause_open(
self, app, monkeypatch: pytest.MonkeyPatch
) -> None:
workflow_run = SimpleNamespace(
id="run-1",
app_id="app-1",
created_by_role=CreatorUserRole.END_USER,
created_by="end-user-1",
finished_at=None,
)
workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
msg_generator = Mock()
workflow_generator = Mock()
workflow_generator.convert_to_event_stream.return_value = iter(["data: snapshot\n\n"])
snapshot_builder = Mock(return_value=["snapshot-events"])
monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator)
monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator)
monkeypatch.setattr(workflow_events_module, "build_workflow_event_stream", snapshot_builder)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context(
"/workflow/run-1/events?user=u1&include_state_snapshot=true&continue_on_pause=true",
method="GET",
):
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
assert response.get_data(as_text=True) == "data: snapshot\n\n"
msg_generator.retrieve_events.assert_not_called()
snapshot_builder.assert_called_once_with(
app_mode=AppMode.WORKFLOW,
workflow_run=workflow_run,
tenant_id="tenant-1",
app_id="app-1",
session_maker=ANY,
human_input_surface=HumanInputSurface.SERVICE_API,
close_on_pause=False,
)
workflow_generator.convert_to_event_stream.assert_called_once_with(["snapshot-events"])
def test_advanced_chat_blocking_injects_pause_state_config(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", False)
monkeypatch.setattr(ags_module, "RateLimit", _DummyRateLimit)
workflow = MagicMock()
workflow.created_by = "owner-id"
monkeypatch.setattr(AppGenerateService, "_get_workflow", lambda *args, **kwargs: workflow)
monkeypatch.setattr(ags_module.session_factory, "get_session_maker", lambda: "session-maker")
generator_instance = MagicMock()
generator_instance.generate.return_value = {"result": "advanced-blocking"}
generator_instance.convert_to_event_stream.side_effect = lambda payload: payload
monkeypatch.setattr(ags_module, "AdvancedChatAppGenerator", lambda: generator_instance)
app_model = MagicMock()
app_model.mode = AppMode.ADVANCED_CHAT
app_model.id = "app-id"
app_model.tenant_id = "tenant-id"
app_model.max_active_requests = 0
app_model.is_agent = False
user = MagicMock()
user.id = "user-id"
result = AppGenerateService.generate(
app_model=app_model,
user=user,
args={"workflow_id": None, "query": "hi", "inputs": {}},
invoke_from=InvokeFrom.SERVICE_API,
streaming=False,
)
assert result == {"result": "advanced-blocking"}
call_kwargs = generator_instance.generate.call_args.kwargs
assert call_kwargs["streaming"] is False
assert call_kwargs["pause_state_config"] is not None
assert call_kwargs["pause_state_config"].session_factory == "session-maker"
assert call_kwargs["pause_state_config"].state_owner_user_id == "owner-id"
# Blocking payload contract
def test_advanced_chat_blocking_pause_payload_contract(self) -> None:
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
response = AdvancedChatAppGenerateResponseConverter.convert_blocking_full_response(
_build_advanced_chat_paused_blocking_response()
)
assert response["event"] == "workflow_paused"
assert response["workflow_run_id"] == "run-1"
assert response["answer"] == "partial"
assert response["data"]["reasons"][0]["type"] == PauseReasonType.HUMAN_INPUT_REQUIRED
assert response["data"]["reasons"][0]["expiration_time"] == 100
assert "human_input_forms" not in response["data"]
def test_workflow_blocking_pause_payload_contract(self) -> None:
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
response = WorkflowAppGenerateResponseConverter.convert_blocking_full_response(
_build_workflow_paused_blocking_response()
)
assert response["workflow_run_id"] == "r1"
assert response["data"]["status"] == WorkflowExecutionStatus.PAUSED
assert response["data"]["paused_nodes"] == ["node-1"]
assert response["data"]["reasons"] == [
{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 100}
]
assert "human_input_forms" not in response["data"]
def test_advanced_chat_blocking_pipeline_pause_payload_contract(self) -> None:
from core.app.app_config.entities import AppAdditionalFeatures
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
from models.enums import MessageStatus
from models.model import EndUser
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant",
app_id="app",
app_mode=AppMode.ADVANCED_CHAT,
additional_features=AppAdditionalFeatures(),
variables=[],
workflow_id="workflow-id",
)
application_generate_entity = AdvancedChatAppGenerateEntity.model_construct(
task_id="task",
app_config=app_config,
inputs={},
query="hello",
files=[],
user_id="user",
stream=False,
invoke_from=InvokeFrom.WEB_APP,
extras={},
trace_manager=None,
workflow_run_id="run-id",
)
pipeline = AdvancedChatAppGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
workflow=SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}),
queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None),
conversation=SimpleNamespace(id="conv-id", mode=AppMode.ADVANCED_CHAT),
message=SimpleNamespace(
id="message-id",
query="hello",
created_at=datetime.utcnow(),
status=MessageStatus.NORMAL,
answer="",
),
user=EndUser(tenant_id="tenant", type="session", name="tester", session_id="session"),
stream=False,
dialogue_count=1,
draft_var_saver_factory=lambda **kwargs: None,
)
pipeline._task_state.answer = "partial answer"
pipeline._workflow_run_id = "run-id"
def _gen():
yield HumanInputRequiredResponse(
task_id="task",
workflow_run_id="run-id",
data=HumanInputRequiredResponse.Data(
form_id="form-1",
node_id="node-1",
node_title="Approval",
form_content="Need approval",
inputs=[],
actions=[UserAction(id="approve", title="Approve")],
display_in_ui=True,
form_token="token-1",
resolved_default_values={},
expiration_time=123,
),
)
yield WorkflowPauseStreamResponse(
task_id="task",
workflow_run_id="run-id",
data=WorkflowPauseStreamResponse.Data(
workflow_run_id="run-id",
paused_nodes=["node-1"],
outputs={},
reasons=[
{
"type": PauseReasonType.HUMAN_INPUT_REQUIRED,
"form_id": "form-1",
"node_id": "node-1",
"expiration_time": 123,
},
],
status="paused",
created_at=1,
elapsed_time=0.1,
total_tokens=0,
total_steps=0,
),
)
response = pipeline._to_blocking_response(_gen())
assert isinstance(response, AdvancedChatPausedBlockingResponse)
assert response.data.answer == "partial answer"
assert response.data.workflow_run_id == "run-id"
assert response.data.reasons[0]["form_id"] == "form-1"
assert response.data.reasons[0]["expiration_time"] == 123
def test_workflow_blocking_pipeline_pause_payload_contract(self, monkeypatch: pytest.MonkeyPatch) -> None:
from core.app.apps.workflow import generate_task_pipeline as workflow_pipeline_module
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant",
app_id="app",
app_mode=AppMode.WORKFLOW,
additional_features=AppAdditionalFeatures(),
variables=[],
workflow_id="workflow-id",
)
application_generate_entity = WorkflowAppGenerateEntity.model_construct(
task_id="task",
app_config=app_config,
inputs={},
files=[],
user_id="user",
stream=False,
invoke_from=InvokeFrom.WEB_APP,
trace_manager=None,
workflow_execution_id="run-id",
extras={},
call_depth=0,
)
pipeline = WorkflowAppGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
workflow=SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}),
queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None),
user=SimpleNamespace(id="user", session_id="session"),
stream=False,
draft_var_saver_factory=lambda **kwargs: None,
)
monkeypatch.setattr(workflow_pipeline_module.time, "time", lambda: 1700000000)
def _gen():
yield HumanInputRequiredResponse(
task_id="task",
workflow_run_id="run",
data=HumanInputRequiredResponse.Data(
form_id="form-1",
node_id="node-1",
node_title="Human Input",
form_content="content",
expiration_time=1,
),
)
yield WorkflowPauseStreamResponse(
task_id="task",
workflow_run_id="run",
data=WorkflowPauseStreamResponse.Data(
workflow_run_id="run",
status=WorkflowExecutionStatus.PAUSED,
outputs={},
paused_nodes=["node-1"],
reasons=[{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 1}],
created_at=1,
elapsed_time=0.1,
total_tokens=0,
total_steps=0,
),
)
response = pipeline._to_blocking_response(_gen())
assert isinstance(response, WorkflowAppPausedBlockingResponse)
assert response.data.status == WorkflowExecutionStatus.PAUSED
assert response.data.paused_nodes == ["node-1"]
assert response.data.reasons == [{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 1}]
def test_service_api_pause_event_serializes_hitl_reason(self, monkeypatch: pytest.MonkeyPatch) -> None:
converter = _build_service_api_pause_converter()
converter.workflow_start_to_stream_response(
task_id="task",
workflow_run_id="run-id",
workflow_id="workflow-id",
reason=WorkflowStartReason.INITIAL,
)
expiration_time = datetime(2024, 1, 1, tzinfo=UTC)
class _FakeSession:
def execute(self, _stmt):
return [("form-1", expiration_time, '{"display_in_ui": true}')]
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(workflow_response_converter, "Session", lambda **_: _FakeSession())
monkeypatch.setattr(workflow_response_converter, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(
workflow_response_converter,
"load_form_tokens_by_form_id",
lambda form_ids, session=None, surface=None: {"form-1": "token"},
)
reason = HumanInputRequired(
form_id="form-1",
form_content="Rendered",
inputs=[
FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="field", default=None),
],
actions=[UserAction(id="approve", title="Approve")],
display_in_ui=True,
node_id="node-id",
node_title="Human Step",
form_token="token",
)
queue_event = QueueWorkflowPausedEvent(
reasons=[reason],
outputs={"answer": "value"},
paused_nodes=["node-id"],
)
runtime_state = SimpleNamespace(total_tokens=0, node_run_steps=0)
responses = converter.workflow_pause_to_stream_response(
event=queue_event,
task_id="task",
graph_runtime_state=runtime_state,
)
assert isinstance(responses[-1], WorkflowPauseStreamResponse)
pause_resp = responses[-1]
assert pause_resp.workflow_run_id == "run-id"
assert pause_resp.data.paused_nodes == ["node-id"]
assert pause_resp.data.outputs == {}
assert pause_resp.data.reasons[0]["TYPE"] == "human_input_required"
assert pause_resp.data.reasons[0]["form_id"] == "form-1"
assert pause_resp.data.reasons[0]["form_token"] == "token"
assert pause_resp.data.reasons[0]["expiration_time"] == int(expiration_time.timestamp())
assert isinstance(responses[0], HumanInputRequiredResponse)
hi_resp = responses[0]
assert hi_resp.data.form_id == "form-1"
assert hi_resp.data.node_id == "node-id"
assert hi_resp.data.node_title == "Human Step"
assert hi_resp.data.inputs[0].output_variable_name == "field"
assert hi_resp.data.actions[0].id == "approve"
assert hi_resp.data.display_in_ui is True
assert hi_resp.data.form_token == "token"
assert hi_resp.data.expiration_time == int(expiration_time.timestamp())
# Snapshot payload contract
def test_snapshot_events_include_pause_payload_contract(self, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = _build_workflow_run(WorkflowExecutionStatus.PAUSED)
snapshot = _build_snapshot(WorkflowNodeExecutionStatus.PAUSED)
resumption_context = _build_resumption_context("task-ctx")
monkeypatch.setattr(
"services.workflow_event_snapshot_service.load_form_tokens_by_form_id",
lambda form_ids, session=None, surface=None: {"form-1": "wtok"},
)
class _SessionContext:
def __init__(self, session):
self._session = session
def __enter__(self):
return self._session
def __exit__(self, exc_type, exc, tb):
return False
def session_maker() -> _SessionContext:
return _SessionContext(
SimpleNamespace(
execute=lambda _stmt: [("form-1", datetime(2024, 1, 1, tzinfo=UTC), '{"display_in_ui": true}')],
)
)
pause_entity = _FakePauseEntity(
pause_id="pause-1",
workflow_run_id="run-1",
paused_at_value=datetime(2024, 1, 1, tzinfo=UTC),
pause_reasons=[
HumanInputRequired(
form_id="form-1",
form_content="content",
node_id="node-1",
node_title="Human Input",
form_token="wtok",
)
],
)
events = _build_snapshot_events(
workflow_run=workflow_run,
node_snapshots=[snapshot],
task_id="task-ctx",
message_context=None,
pause_entity=pause_entity,
resumption_context=resumption_context,
session_maker=session_maker,
)
assert [event["event"] for event in events] == [
"workflow_started",
"node_started",
"node_finished",
"human_input_required",
"workflow_paused",
]
assert events[2]["data"]["status"] == WorkflowNodeExecutionStatus.PAUSED.value
assert events[3]["data"]["form_token"] == "wtok"
assert events[3]["data"]["expiration_time"] == int(datetime(2024, 1, 1, tzinfo=UTC).timestamp())
pause_data = events[-1]["data"]
assert pause_data["paused_nodes"] == ["node-1"]
assert pause_data["outputs"] == {"result": "value"}
assert pause_data["reasons"][0]["TYPE"] == "human_input_required"
assert pause_data["reasons"][0]["form_token"] == "wtok"
assert pause_data["reasons"][0]["expiration_time"] == int(datetime(2024, 1, 1, tzinfo=UTC).timestamp())
assert pause_data["status"] == WorkflowExecutionStatus.PAUSED.value
assert pause_data["created_at"] == int(workflow_run.created_at.timestamp())
assert pause_data["elapsed_time"] == workflow_run.elapsed_time
assert pause_data["total_tokens"] == workflow_run.total_tokens
assert pause_data["total_steps"] == workflow_run.total_steps

View File

@@ -0,0 +1,184 @@
"""Unit tests for Service API human input form endpoints."""
from __future__ import annotations
import json
import sys
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import Mock
import pytest
from werkzeug.exceptions import NotFound
from controllers.service_api.app.human_input_form import WorkflowHumanInputFormApi
from models.human_input import RecipientType
from tests.unit_tests.controllers.service_api.conftest import _unwrap
class TestWorkflowHumanInputFormApi:
def test_get_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
definition = SimpleNamespace(
model_dump=lambda: {
"rendered_content": "Rendered form content",
"inputs": [{"output_variable_name": "name"}],
"default_values": {"name": "Alice", "age": 30, "meta": {"k": "v"}},
"user_actions": [{"id": "approve", "title": "Approve"}],
}
)
form = SimpleNamespace(
app_id="app-1",
tenant_id="tenant-1",
recipient_type=RecipientType.STANDALONE_WEB_APP,
expiration_time=datetime(2099, 1, 1, tzinfo=UTC),
get_definition=lambda: definition,
)
service_mock = Mock()
service_mock.get_form_by_token.return_value = form
workflow_module = sys.modules["controllers.service_api.app.human_input_form"]
monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock)
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
api = WorkflowHumanInputFormApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
with app.test_request_context("/form/human_input/token-1", method="GET"):
response = handler(api, app_model=app_model, form_token="token-1")
payload = json.loads(response.get_data(as_text=True))
assert payload == {
"form_content": "Rendered form content",
"inputs": [{"output_variable_name": "name"}],
"resolved_default_values": {"name": "Alice", "age": "30", "meta": '{"k": "v"}'},
"user_actions": [{"id": "approve", "title": "Approve"}],
"expiration_time": int(form.expiration_time.timestamp()),
}
service_mock.get_form_by_token.assert_called_once_with("token-1")
service_mock.ensure_form_active.assert_called_once_with(form)
def test_get_form_not_in_app(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
form = SimpleNamespace(
app_id="another-app",
tenant_id="tenant-1",
expiration_time=datetime(2099, 1, 1, tzinfo=UTC),
)
service_mock = Mock()
service_mock.get_form_by_token.return_value = form
workflow_module = sys.modules["controllers.service_api.app.human_input_form"]
monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock)
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
api = WorkflowHumanInputFormApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
with app.test_request_context("/form/human_input/token-1", method="GET"):
with pytest.raises(NotFound):
handler(api, app_model=app_model, form_token="token-1")
@pytest.mark.parametrize(
"recipient_type",
[
RecipientType.CONSOLE,
RecipientType.BACKSTAGE,
RecipientType.EMAIL_MEMBER,
RecipientType.EMAIL_EXTERNAL,
],
)
def test_get_rejects_non_service_api_recipient_types(
self, app, monkeypatch: pytest.MonkeyPatch, recipient_type: RecipientType
) -> None:
form = SimpleNamespace(
app_id="app-1",
tenant_id="tenant-1",
recipient_type=recipient_type,
expiration_time=datetime(2099, 1, 1, tzinfo=UTC),
)
service_mock = Mock()
service_mock.get_form_by_token.return_value = form
workflow_module = sys.modules["controllers.service_api.app.human_input_form"]
monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock)
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
api = WorkflowHumanInputFormApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
with app.test_request_context("/form/human_input/token-1", method="GET"):
with pytest.raises(NotFound):
handler(api, app_model=app_model, form_token="token-1")
service_mock.ensure_form_active.assert_not_called()
def test_post_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
form = SimpleNamespace(
app_id="app-1",
tenant_id="tenant-1",
recipient_type=RecipientType.STANDALONE_WEB_APP,
)
service_mock = Mock()
service_mock.get_form_by_token.return_value = form
workflow_module = sys.modules["controllers.service_api.app.human_input_form"]
monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock)
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
api = WorkflowHumanInputFormApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context(
"/form/human_input/token-1",
method="POST",
json={"inputs": {"name": "Alice"}, "action": "approve", "user": "external-1"},
):
response, status = handler(api, app_model=app_model, end_user=end_user, form_token="token-1")
assert response == {}
assert status == 200
service_mock.submit_form_by_token.assert_called_once_with(
recipient_type=RecipientType.STANDALONE_WEB_APP,
form_token="token-1",
selected_action_id="approve",
form_data={"name": "Alice"},
submission_end_user_id="end-user-1",
)
@pytest.mark.parametrize(
"recipient_type",
[
RecipientType.CONSOLE,
RecipientType.BACKSTAGE,
RecipientType.EMAIL_MEMBER,
RecipientType.EMAIL_EXTERNAL,
],
)
def test_post_rejects_non_service_api_recipient_types(
self, app, monkeypatch: pytest.MonkeyPatch, recipient_type: RecipientType
) -> None:
form = SimpleNamespace(
app_id="app-1",
tenant_id="tenant-1",
recipient_type=recipient_type,
)
service_mock = Mock()
service_mock.get_form_by_token.return_value = form
workflow_module = sys.modules["controllers.service_api.app.human_input_form"]
monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock)
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
api = WorkflowHumanInputFormApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context(
"/form/human_input/token-1",
method="POST",
json={"inputs": {"name": "Alice"}, "action": "approve", "user": "external-1"},
):
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, form_token="token-1")
service_mock.submit_form_by_token.assert_not_called()

View File

@@ -0,0 +1,166 @@
"""Unit tests for Service API workflow event stream endpoints."""
from __future__ import annotations
import json
import sys
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import Mock
import pytest
from werkzeug.exceptions import NotFound
from controllers.service_api.app.error import NotWorkflowAppError
from controllers.service_api.app.workflow_events import WorkflowEventsApi
from models.enums import CreatorUserRole
from models.model import AppMode
from tests.unit_tests.controllers.service_api.conftest import _unwrap
def _mock_repo_for_run(monkeypatch: pytest.MonkeyPatch, workflow_run):
workflow_events_module = sys.modules["controllers.service_api.app.workflow_events"]
repo = SimpleNamespace(get_workflow_run_by_id_and_tenant_id=lambda **_kwargs: workflow_run)
monkeypatch.setattr(
workflow_events_module.DifyAPIRepositoryFactory,
"create_api_workflow_run_repository",
lambda *_args, **_kwargs: repo,
)
monkeypatch.setattr(workflow_events_module, "db", SimpleNamespace(engine=object()))
return workflow_events_module
class TestWorkflowEventsApi:
def test_wrong_app_mode(self, app) -> None:
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
with pytest.raises(NotWorkflowAppError):
handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
def test_workflow_run_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
_mock_repo_for_run(monkeypatch, workflow_run=None)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
def test_workflow_run_permission_denied(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
app_id="app-1",
created_by_role=CreatorUserRole.ACCOUNT,
created_by="another-user",
finished_at=None,
)
_mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
def test_finished_run_returns_sse(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
app_id="app-1",
created_by_role=CreatorUserRole.END_USER,
created_by="end-user-1",
finished_at=datetime(2099, 1, 1, tzinfo=UTC),
)
workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
monkeypatch.setattr(
workflow_events_module.WorkflowResponseConverter,
"workflow_run_result_to_finish_response",
lambda **_kwargs: SimpleNamespace(
model_dump=lambda mode="json": {"task_id": "run-1", "status": "succeeded"},
event=SimpleNamespace(value="workflow_finished"),
),
)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
assert response.mimetype == "text/event-stream"
body = response.get_data(as_text=True).strip()
assert body.startswith("data: ")
payload = json.loads(body[len("data: ") :])
assert payload["task_id"] == "run-1"
assert payload["event"] == "workflow_finished"
def test_running_run_streams_events(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
app_id="app-1",
created_by_role=CreatorUserRole.END_USER,
created_by="end-user-1",
finished_at=None,
)
workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
msg_generator = Mock()
msg_generator.retrieve_events.return_value = ["raw-event"]
workflow_generator = Mock()
workflow_generator.convert_to_event_stream.return_value = iter(["data: streamed\n\n"])
monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator)
monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
assert response.get_data(as_text=True) == "data: streamed\n\n"
msg_generator.retrieve_events.assert_called_once_with(
AppMode.WORKFLOW,
"run-1",
terminal_events=None,
)
workflow_generator.convert_to_event_stream.assert_called_once_with(["raw-event"])
def test_running_run_with_snapshot(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
app_id="app-1",
created_by_role=CreatorUserRole.END_USER,
created_by="end-user-1",
finished_at=None,
)
workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
msg_generator = Mock()
workflow_generator = Mock()
workflow_generator.convert_to_event_stream.return_value = iter(["data: snapshot\n\n"])
snapshot_builder = Mock(return_value=["snapshot-events"])
monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator)
monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator)
monkeypatch.setattr(workflow_events_module, "build_workflow_event_stream", snapshot_builder)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context("/workflow/run-1/events?user=u1&include_state_snapshot=true", method="GET"):
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
assert response.get_data(as_text=True) == "data: snapshot\n\n"
msg_generator.retrieve_events.assert_not_called()
snapshot_builder.assert_called_once()
workflow_generator.convert_to_event_stream.assert_called_once_with(["snapshot-events"])

View File

@@ -15,7 +15,10 @@ from flask import Flask
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.account import TenantStatus
from models.model import App, AppMode, EndUser
from tests.unit_tests.conftest import setup_mock_tenant_account_query
from tests.unit_tests.conftest import (
setup_mock_dataset_owner_execute_result,
setup_mock_tenant_owner_execute_result,
)
@pytest.fixture
@@ -123,9 +126,7 @@ class AuthenticationMocker:
mock_db.session.get.side_effect = [mock_app, mock_tenant]
if mock_account:
mock_ta = Mock()
mock_ta.account_id = mock_account.id
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
@staticmethod
def setup_dataset_auth(mock_db, mock_tenant, mock_account):
@@ -133,8 +134,7 @@ class AuthenticationMocker:
mock_ta = Mock()
mock_ta.account_id = mock_account.id
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta)
setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_ta)
mock_db.session.get.return_value = mock_account

View File

@@ -701,8 +701,8 @@ class TestDocumentApiDelete:
``delete`` is wrapped by ``@cloud_edition_billing_rate_limit_check`` which
internally calls ``validate_and_get_api_token``. To bypass the decorator
we call the original function via ``__wrapped__`` (preserved by
``functools.wraps``). ``delete`` queries the dataset via
``db.session.query(Dataset)`` directly, so we patch ``db`` at the
``functools.wraps``). ``delete`` loads the dataset via
``db.session.scalar(select(Dataset)...)``, so we patch ``db`` at the
controller module.
"""

View File

@@ -24,8 +24,8 @@ from enums.cloud_plan import CloudPlan
from models.account import TenantStatus
from models.model import ApiToken
from tests.unit_tests.conftest import (
setup_mock_dataset_tenant_query,
setup_mock_tenant_account_query,
setup_mock_dataset_owner_execute_result,
setup_mock_tenant_owner_execute_result,
)
@@ -141,14 +141,11 @@ class TestValidateAppToken:
mock_account = Mock()
mock_account.id = str(uuid.uuid4())
mock_ta = Mock()
mock_ta.account_id = mock_account.id
# Use side_effect to return app first, then tenant via session.get()
mock_db.session.get.side_effect = [mock_app, mock_tenant]
# Mock the tenant owner query (execute(select(...)).one_or_none())
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta)
# Mock the tenant owner execute result (execute(select(...)).one_or_none())
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
@validate_app_token
def protected_view(app_model):
@@ -471,7 +468,7 @@ class TestValidateDatasetToken:
mock_account.current_tenant = mock_tenant
# Mock the tenant account join query (execute(select(...)).one_or_none())
setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta)
setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_ta)
# Mock the account lookup via session.get()
mock_db.session.get.return_value = mock_account

View File

@@ -22,18 +22,16 @@ class FakeSession:
def __init__(self, mapping: dict[str, Any] | None = None):
self._mapping: dict[str, Any] = mapping or {}
self._model_name: str | None = None
def query(self, model: type) -> FakeSession:
self._model_name = model.__name__
return self
def get(self, model: type, _ident: object) -> Any:
return self._mapping.get(model.__name__)
def where(self, *_args: object, **_kwargs: object) -> FakeSession:
return self
def first(self) -> Any:
assert self._model_name is not None
return self._mapping.get(self._model_name)
def scalar(self, stmt: Any) -> Any:
try:
model = stmt.column_descriptions[0]["entity"]
except (AttributeError, IndexError, KeyError, TypeError):
return None
return self._mapping.get(model.__name__)
class FakeDB:

View File

@@ -36,18 +36,6 @@ class _FakeSession:
def __init__(self, mapping: dict[str, Any]):
self._mapping = mapping
self._model_name: str | None = None
def query(self, model):
self._model_name = model.__name__
return self
def where(self, *args, **kwargs):
return self
def first(self):
assert self._model_name is not None
return self._mapping.get(self._model_name)
def get(self, model, ident):
return self._mapping.get(model.__name__)

View File

@@ -34,7 +34,6 @@ def _patch_wraps():
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
patch("controllers.web.login.dify_config", web_dify),
):
mock_db.session.query.return_value.first.return_value = MagicMock()
yield