mirror of
https://github.com/langgenius/dify.git
synced 2026-05-16 07:00:36 -04:00
Merge branch 'main' into tp
This commit is contained in:
@@ -208,8 +208,6 @@ class TestAnnotationImportServiceValidation:
|
||||
|
||||
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
|
||||
|
||||
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
|
||||
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
|
||||
|
||||
@@ -230,8 +228,6 @@ class TestAnnotationImportServiceValidation:
|
||||
|
||||
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
|
||||
|
||||
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
|
||||
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
|
||||
|
||||
@@ -248,8 +244,6 @@ class TestAnnotationImportServiceValidation:
|
||||
csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff'
|
||||
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
|
||||
|
||||
with (
|
||||
patch("services.annotation_service.current_account_with_tenant") as mock_auth,
|
||||
patch("services.annotation_service.pd.read_csv", side_effect=ParserError("malformed CSV")),
|
||||
@@ -269,8 +263,6 @@ class TestAnnotationImportServiceValidation:
|
||||
|
||||
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
|
||||
|
||||
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
|
||||
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
|
||||
|
||||
|
||||
@@ -43,7 +43,6 @@ class TestAuthenticationSecurity:
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.")
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
mock_features.return_value.is_allow_register = True
|
||||
|
||||
# Act
|
||||
@@ -76,7 +75,6 @@ class TestAuthenticationSecurity:
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Wrong password")
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
|
||||
# Act
|
||||
with self.app.test_request_context(
|
||||
@@ -109,7 +107,6 @@ class TestAuthenticationSecurity:
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.")
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
mock_features.return_value.is_allow_register = False
|
||||
|
||||
# Act
|
||||
@@ -135,7 +132,6 @@ class TestAuthenticationSecurity:
|
||||
def test_reset_password_with_existing_account(self, mock_send_email, mock_get_user, mock_features, mock_db):
|
||||
"""Test that reset password returns success with token for existing accounts."""
|
||||
# Mock the setup check
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
|
||||
# Test with existing account
|
||||
mock_get_user.return_value = MagicMock(email="existing@example.com")
|
||||
|
||||
@@ -65,7 +65,6 @@ class TestEmailCodeLoginSendEmailApi:
|
||||
- IP rate limiting is checked
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_send_email.return_value = "email_token_123"
|
||||
@@ -98,7 +97,6 @@ class TestEmailCodeLoginSendEmailApi:
|
||||
- Registration is allowed by system features
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_user.return_value = None
|
||||
mock_get_features.return_value.is_allow_register = True
|
||||
@@ -130,7 +128,6 @@ class TestEmailCodeLoginSendEmailApi:
|
||||
- Registration is blocked by system features
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_user.return_value = None
|
||||
mock_get_features.return_value.is_allow_register = False
|
||||
@@ -152,7 +149,6 @@ class TestEmailCodeLoginSendEmailApi:
|
||||
- Prevents spam and abuse
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = True
|
||||
|
||||
# Act & Assert
|
||||
@@ -172,7 +168,6 @@ class TestEmailCodeLoginSendEmailApi:
|
||||
- AccountInFreezeError is raised for frozen accounts
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_user.side_effect = AccountRegisterError("Account frozen")
|
||||
|
||||
@@ -213,7 +208,6 @@ class TestEmailCodeLoginSendEmailApi:
|
||||
- Defaults to en-US when not specified
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_send_email.return_value = "token"
|
||||
@@ -286,7 +280,6 @@ class TestEmailCodeLoginApi:
|
||||
- User is logged in with token pair
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_get_tenants.return_value = [MagicMock()]
|
||||
@@ -335,7 +328,6 @@ class TestEmailCodeLoginApi:
|
||||
- User is logged in after account creation
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "newuser@example.com", "code": "123456"}
|
||||
mock_get_user.return_value = None
|
||||
mock_create_account.return_value = mock_account
|
||||
@@ -369,7 +361,6 @@ class TestEmailCodeLoginApi:
|
||||
- InvalidTokenError is raised for invalid/expired tokens
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
@@ -392,7 +383,6 @@ class TestEmailCodeLoginApi:
|
||||
- InvalidEmailError is raised when email doesn't match token
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "original@example.com", "code": "123456"}
|
||||
|
||||
# Act & Assert
|
||||
@@ -415,7 +405,6 @@ class TestEmailCodeLoginApi:
|
||||
- EmailCodeError is raised for wrong verification code
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
|
||||
# Act & Assert
|
||||
@@ -453,7 +442,6 @@ class TestEmailCodeLoginApi:
|
||||
- User is added as owner of new workspace
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_get_tenants.return_value = []
|
||||
@@ -496,7 +484,6 @@ class TestEmailCodeLoginApi:
|
||||
- WorkspacesLimitExceeded is raised when limit reached
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_get_tenants.return_value = []
|
||||
@@ -538,7 +525,6 @@ class TestEmailCodeLoginApi:
|
||||
- NotAllowedCreateWorkspace is raised when creation disabled
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_get_tenants.return_value = []
|
||||
|
||||
@@ -110,7 +110,6 @@ class TestLoginApi:
|
||||
- Rate limit is reset after successful login
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.return_value = mock_account
|
||||
@@ -162,7 +161,6 @@ class TestLoginApi:
|
||||
- Authentication proceeds with invitation token
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = {"data": {"email": "test@example.com"}}
|
||||
mock_authenticate.return_value = mock_account
|
||||
@@ -199,7 +197,6 @@ class TestLoginApi:
|
||||
- EmailPasswordLoginLimitError is raised when limit exceeded
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = True
|
||||
mock_get_invitation.return_value = None
|
||||
|
||||
@@ -228,7 +225,6 @@ class TestLoginApi:
|
||||
- AccountInFreezeError is raised for frozen accounts
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_frozen.return_value = True
|
||||
|
||||
# Act & Assert
|
||||
@@ -268,7 +264,6 @@ class TestLoginApi:
|
||||
- Generic error message prevents user enumeration
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = AccountPasswordError("Invalid password")
|
||||
@@ -305,7 +300,6 @@ class TestLoginApi:
|
||||
- Login is prevented even with valid credentials
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = AccountLoginError("Account is banned")
|
||||
@@ -351,7 +345,6 @@ class TestLoginApi:
|
||||
- User cannot login without an assigned workspace
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.return_value = mock_account
|
||||
@@ -383,7 +376,6 @@ class TestLoginApi:
|
||||
- Security check prevents invitation token abuse
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = {"data": {"email": "invited@example.com"}}
|
||||
|
||||
@@ -425,7 +417,6 @@ class TestLoginApi:
|
||||
mock_token_pair,
|
||||
):
|
||||
"""Test that login retries with lowercase email when uppercase lookup fails."""
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = [AccountPasswordError("Invalid"), mock_account]
|
||||
@@ -459,7 +450,6 @@ class TestLoginApi:
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"}
|
||||
mock_get_account.side_effect = Unauthorized("Account is banned.")
|
||||
|
||||
@@ -513,7 +503,6 @@ class TestLogoutApi:
|
||||
- Success response is returned
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_current_account.return_value = (mock_account, MagicMock())
|
||||
|
||||
# Act
|
||||
@@ -539,7 +528,6 @@ class TestLogoutApi:
|
||||
- Success response is returned
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
# Create a mock anonymous user that will pass isinstance check
|
||||
anonymous_user = MagicMock()
|
||||
mock_flask_login.AnonymousUserMixin = type("AnonymousUserMixin", (), {})
|
||||
|
||||
@@ -46,7 +46,6 @@ class TestPartnerTenants:
|
||||
patch("libs.login.dify_config.LOGIN_DISABLED", False),
|
||||
patch("libs.login.check_csrf_token") as mock_csrf,
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
mock_csrf.return_value = None
|
||||
yield {"db": mock_db, "csrf": mock_csrf}
|
||||
|
||||
|
||||
@@ -8,8 +8,10 @@ from werkzeug.exceptions import Forbidden
|
||||
import controllers.console.tag.tags as module
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.tag.tags import (
|
||||
TagBindingCreateApi,
|
||||
TagBindingDeleteApi,
|
||||
DeprecatedTagBindingCreateApi,
|
||||
DeprecatedTagBindingRemoveApi,
|
||||
TagBindingCollectionApi,
|
||||
TagBindingItemApi,
|
||||
TagListApi,
|
||||
TagUpdateDeleteApi,
|
||||
)
|
||||
@@ -205,9 +207,9 @@ class TestTagUpdateDeleteApi:
|
||||
assert status == 204
|
||||
|
||||
|
||||
class TestTagBindingCreateApi:
|
||||
class TestTagBindingCollectionApi:
|
||||
def test_create_success(self, app, admin_user, payload_patch):
|
||||
api = TagBindingCreateApi()
|
||||
api = TagBindingCollectionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
@@ -232,7 +234,7 @@ class TestTagBindingCreateApi:
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_create_forbidden(self, app, readonly_user, payload_patch):
|
||||
api = TagBindingCreateApi()
|
||||
api = TagBindingCollectionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with app.test_request_context("/", json={}):
|
||||
@@ -247,9 +249,78 @@ class TestTagBindingCreateApi:
|
||||
method(api)
|
||||
|
||||
|
||||
class TestTagBindingDeleteApi:
|
||||
class TestDeprecatedTagBindingCreateApi:
|
||||
def test_create_success(self, app, admin_user, payload_patch):
|
||||
api = DeprecatedTagBindingCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"tag_ids": ["tag-1"],
|
||||
"target_id": "target-1",
|
||||
"type": "knowledge",
|
||||
}
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(admin_user, None),
|
||||
),
|
||||
payload_patch(payload),
|
||||
patch("controllers.console.tag.tags.TagService.save_tag_binding") as save_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
save_mock.assert_called_once()
|
||||
assert status == 200
|
||||
assert result["result"] == "success"
|
||||
|
||||
|
||||
class TestTagBindingItemApi:
|
||||
def test_delete_success(self, app, admin_user, payload_patch):
|
||||
api = TagBindingItemApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
payload = {
|
||||
"target_id": "target-1",
|
||||
"type": "knowledge",
|
||||
}
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(admin_user, None),
|
||||
),
|
||||
payload_patch(payload),
|
||||
patch("controllers.console.tag.tags.TagService.delete_tag_binding") as delete_mock,
|
||||
):
|
||||
result, status = method(api, "tag-1")
|
||||
|
||||
delete_mock.assert_called_once()
|
||||
delete_payload = delete_mock.call_args.args[0]
|
||||
assert delete_payload.tag_id == "tag-1"
|
||||
assert delete_payload.target_id == "target-1"
|
||||
assert delete_payload.type == TagType.KNOWLEDGE
|
||||
assert status == 200
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_delete_forbidden(self, app, readonly_user):
|
||||
api = TagBindingItemApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
with patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(readonly_user, None),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "tag-1")
|
||||
|
||||
|
||||
class TestDeprecatedTagBindingRemoveApi:
|
||||
def test_remove_success(self, app, admin_user, payload_patch):
|
||||
api = TagBindingDeleteApi()
|
||||
api = DeprecatedTagBindingRemoveApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
@@ -274,7 +345,7 @@ class TestTagBindingDeleteApi:
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_remove_forbidden(self, app, readonly_user, payload_patch):
|
||||
api = TagBindingDeleteApi()
|
||||
api = DeprecatedTagBindingRemoveApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with app.test_request_context("/", json={}):
|
||||
@@ -297,3 +368,35 @@ class TestTagResponseModel:
|
||||
|
||||
assert payload["type"] == "knowledge"
|
||||
assert payload["binding_count"] == "1"
|
||||
|
||||
|
||||
class TestTagBindingRouteMetadata:
|
||||
def test_legacy_write_routes_are_marked_deprecated(self):
|
||||
assert DeprecatedTagBindingCreateApi.post.__apidoc__["deprecated"] is True
|
||||
assert DeprecatedTagBindingRemoveApi.post.__apidoc__["deprecated"] is True
|
||||
assert TagBindingCollectionApi.post.__apidoc__.get("deprecated") is not True
|
||||
assert TagBindingItemApi.delete.__apidoc__.get("deprecated") is not True
|
||||
|
||||
def test_write_routes_have_stable_operation_ids(self):
|
||||
assert TagBindingCollectionApi.post.__apidoc__["id"] == "create_tag_binding"
|
||||
assert TagBindingItemApi.delete.__apidoc__["id"] == "delete_tag_binding"
|
||||
assert DeprecatedTagBindingCreateApi.post.__apidoc__["id"] == "create_tag_binding_deprecated"
|
||||
assert DeprecatedTagBindingRemoveApi.post.__apidoc__["id"] == "delete_tag_binding_deprecated"
|
||||
|
||||
def test_canonical_and_legacy_write_routes_are_registered(self):
|
||||
route_map = {
|
||||
resource.__name__: urls
|
||||
for resource, urls, _route_doc, _kwargs in console_ns.resources
|
||||
if resource.__name__
|
||||
in {
|
||||
"TagBindingCollectionApi",
|
||||
"TagBindingItemApi",
|
||||
"DeprecatedTagBindingCreateApi",
|
||||
"DeprecatedTagBindingRemoveApi",
|
||||
}
|
||||
}
|
||||
|
||||
assert route_map["TagBindingCollectionApi"] == ("/tag-bindings",)
|
||||
assert route_map["TagBindingItemApi"] == ("/tag-bindings/<uuid:id>",)
|
||||
assert route_map["DeprecatedTagBindingCreateApi"] == ("/tag-bindings/create",)
|
||||
assert route_map["DeprecatedTagBindingRemoveApi"] == ("/tag-bindings/remove",)
|
||||
|
||||
@@ -122,6 +122,35 @@ def test_post_form_invalid_recipient_type(app, monkeypatch: pytest.MonkeyPatch)
|
||||
handler(api, form_token="token")
|
||||
|
||||
|
||||
def test_post_form_rejects_webapp_recipient_type(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.STANDALONE_WEB_APP)
|
||||
|
||||
class _ServiceStub:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def get_form_by_token(self, _token):
|
||||
return form
|
||||
|
||||
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form.current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="user-1"), "tenant-1"),
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleHumanInputFormApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/form/human_input/token",
|
||||
method="POST",
|
||||
json={"inputs": {"content": "ok"}, "action": "approve"},
|
||||
):
|
||||
with pytest.raises(NotFoundError):
|
||||
handler(api, form_token="token")
|
||||
|
||||
|
||||
def test_post_form_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
submit_mock = Mock()
|
||||
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.CONSOLE)
|
||||
|
||||
@@ -24,10 +24,6 @@ def app():
|
||||
return app
|
||||
|
||||
|
||||
def _mock_wraps_db(mock_db):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
|
||||
|
||||
def _build_account(email: str, account_id: str = "acc", tenant: object | None = None) -> Account:
|
||||
tenant_obj = tenant if tenant is not None else SimpleNamespace(id="tenant-id")
|
||||
account = Account(name=account_id, email=email)
|
||||
@@ -64,7 +60,6 @@ class TestChangeEmailSend:
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("current@example.com", "acc1")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
@@ -117,7 +112,6 @@ class TestChangeEmailSend:
|
||||
"""GHSA-4q3w-q5mc-45rq: a phase-1 token must not unlock the new-email send step."""
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("current@example.com", "acc1")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
@@ -163,7 +157,6 @@ class TestChangeEmailValidity:
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("user@example.com", "acc2")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
@@ -223,7 +216,6 @@ class TestChangeEmailValidity:
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
|
||||
mock_is_rate_limit.return_value = False
|
||||
@@ -280,7 +272,6 @@ class TestChangeEmailValidity:
|
||||
"""A token whose phase marker is a string but not a known transition must be rejected."""
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
|
||||
mock_is_rate_limit.return_value = False
|
||||
@@ -330,7 +321,6 @@ class TestChangeEmailValidity:
|
||||
"""A token minted without a phase marker (e.g. a hand-crafted token) must not validate."""
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
|
||||
mock_is_rate_limit.return_value = False
|
||||
@@ -378,7 +368,6 @@ class TestChangeEmailReset:
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
current_user = _build_account("old@example.com", "acc3")
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
@@ -434,7 +423,6 @@ class TestChangeEmailReset:
|
||||
"""GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset."""
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
current_user = _build_account("old@example.com", "acc3")
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
@@ -488,7 +476,6 @@ class TestChangeEmailReset:
|
||||
"""A verified token for address A must not be replayed to change to address B."""
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
current_user = _build_account("old@example.com", "acc3")
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
@@ -561,7 +548,6 @@ class TestAccountDeletionFeedback:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.BillingService.update_account_deletion_feedback")
|
||||
def test_should_normalize_feedback_email(self, mock_update, mock_db, app):
|
||||
_mock_wraps_db(mock_db)
|
||||
with app.test_request_context(
|
||||
"/account/delete/feedback",
|
||||
method="POST",
|
||||
@@ -578,7 +564,6 @@ class TestCheckEmailUnique:
|
||||
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
|
||||
def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
@@ -16,10 +16,6 @@ def app():
|
||||
return flask_app
|
||||
|
||||
|
||||
def _mock_wraps_db(mock_db):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
|
||||
|
||||
def _build_feature_flags():
|
||||
placeholder_quota = SimpleNamespace(limit=0, size=0)
|
||||
workspace_members = SimpleNamespace(is_available=lambda count: True)
|
||||
@@ -49,7 +45,6 @@ class TestMemberInviteEmailApi:
|
||||
mock_get_features,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_get_features.return_value = _build_feature_flags()
|
||||
mock_invite_member.return_value = "token-abc"
|
||||
|
||||
|
||||
@@ -310,7 +310,6 @@ class TestSystemSetup:
|
||||
def test_should_allow_when_setup_complete(self, mock_db):
|
||||
"""Test that requests are allowed when setup is complete"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Setup exists
|
||||
|
||||
@setup_required
|
||||
def admin_view():
|
||||
|
||||
@@ -22,7 +22,7 @@ _WRAPS_MODULE: ModuleType | None = None
|
||||
|
||||
@contextmanager
|
||||
def _mock_db():
|
||||
mock_session = SimpleNamespace(query=lambda *args, **kwargs: SimpleNamespace(first=lambda: True))
|
||||
mock_session = SimpleNamespace(scalar=lambda *args, **kwargs: True)
|
||||
with patch("extensions.ext_database.db.session", mock_session):
|
||||
yield
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from controllers.service_api.app.app import AppInfoApi, AppMetaApi, AppParameter
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from models.account import TenantStatus
|
||||
from models.model import App, AppMode
|
||||
from tests.unit_tests.conftest import setup_mock_tenant_account_query
|
||||
from tests.unit_tests.conftest import setup_mock_tenant_owner_execute_result
|
||||
|
||||
|
||||
class TestAppParameterApi:
|
||||
@@ -74,7 +74,7 @@ class TestAppParameterApi:
|
||||
# Mock tenant owner info for login
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@@ -120,7 +120,7 @@ class TestAppParameterApi:
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@@ -161,7 +161,7 @@ class TestAppParameterApi:
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@@ -200,7 +200,7 @@ class TestAppParameterApi:
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@@ -263,7 +263,7 @@ class TestAppMetaApi:
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/meta", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@@ -331,7 +331,7 @@ class TestAppInfoApi:
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@@ -388,7 +388,7 @@ class TestAppInfoApi:
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@@ -434,7 +434,7 @@ class TestAppInfoApi:
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@@ -486,7 +486,7 @@ class TestAppInfoApi:
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
|
||||
@@ -0,0 +1,707 @@
|
||||
"""Dedicated tests for HITL behavior exposed through the Service API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import ANY, MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
import services.app_generate_service as ags_module
|
||||
from controllers.service_api.app.workflow_events import WorkflowEventsApi
|
||||
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
|
||||
from core.app.apps.common import workflow_response_converter
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import QueueWorkflowPausedEvent
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatPausedBlockingResponse,
|
||||
HumanInputRequiredResponse,
|
||||
WorkflowAppPausedBlockingResponse,
|
||||
WorkflowPauseStreamResponse,
|
||||
)
|
||||
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper
|
||||
from core.workflow.human_input_policy import HumanInputSurface
|
||||
from core.workflow.system_variables import build_system_variables
|
||||
from graphon.entities import WorkflowStartReason
|
||||
from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType
|
||||
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
|
||||
from graphon.nodes.human_input.entities import FormInput, UserAction
|
||||
from graphon.nodes.human_input.enums import FormInputType
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from models.account import Account
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.workflow_event_snapshot_service import _build_snapshot_events
|
||||
from tests.unit_tests.controllers.service_api.conftest import _unwrap
|
||||
|
||||
|
||||
class _DummyRateLimit:
|
||||
@staticmethod
|
||||
def gen_request_key() -> str:
|
||||
return "dummy-request-id"
|
||||
|
||||
def __init__(self, client_id: str, max_active_requests: int) -> None:
|
||||
self.client_id = client_id
|
||||
self.max_active_requests = max_active_requests
|
||||
|
||||
def enter(self, request_id: str | None = None) -> str:
|
||||
return request_id or "dummy-request-id"
|
||||
|
||||
def exit(self, request_id: str) -> None:
|
||||
return None
|
||||
|
||||
def generate(self, generator, request_id: str):
|
||||
return generator
|
||||
|
||||
|
||||
def _mock_repo_for_run(monkeypatch: pytest.MonkeyPatch, workflow_run):
|
||||
workflow_events_module = sys.modules["controllers.service_api.app.workflow_events"]
|
||||
repo = SimpleNamespace(get_workflow_run_by_id_and_tenant_id=lambda **_kwargs: workflow_run)
|
||||
monkeypatch.setattr(
|
||||
workflow_events_module.DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_args, **_kwargs: repo,
|
||||
)
|
||||
monkeypatch.setattr(workflow_events_module, "db", SimpleNamespace(engine=object()))
|
||||
return workflow_events_module
|
||||
|
||||
|
||||
def _build_service_api_pause_converter() -> WorkflowResponseConverter:
|
||||
application_generate_entity = SimpleNamespace(
|
||||
inputs={},
|
||||
files=[],
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
app_config=SimpleNamespace(app_id="app-id", tenant_id="tenant-id"),
|
||||
)
|
||||
system_variables = build_system_variables(
|
||||
user_id="user",
|
||||
app_id="app-id",
|
||||
workflow_id="workflow-id",
|
||||
workflow_execution_id="run-id",
|
||||
)
|
||||
user = MagicMock(spec=Account)
|
||||
user.id = "account-id"
|
||||
user.name = "Tester"
|
||||
user.email = "tester@example.com"
|
||||
return WorkflowResponseConverter(
|
||||
application_generate_entity=application_generate_entity,
|
||||
user=user,
|
||||
system_variables=system_variables,
|
||||
)
|
||||
|
||||
|
||||
def _build_advanced_chat_paused_blocking_response() -> AdvancedChatPausedBlockingResponse:
|
||||
data = AdvancedChatPausedBlockingResponse.Data(
|
||||
id="msg-1",
|
||||
mode="chat",
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
workflow_run_id="run-1",
|
||||
answer="partial",
|
||||
metadata={"usage": {"total_tokens": 1}},
|
||||
created_at=1,
|
||||
paused_nodes=["node-1"],
|
||||
reasons=[
|
||||
{
|
||||
"type": PauseReasonType.HUMAN_INPUT_REQUIRED,
|
||||
"form_id": "form-1",
|
||||
"expiration_time": 100,
|
||||
}
|
||||
],
|
||||
status=WorkflowExecutionStatus.PAUSED,
|
||||
elapsed_time=0.1,
|
||||
total_tokens=0,
|
||||
total_steps=0,
|
||||
)
|
||||
return AdvancedChatPausedBlockingResponse(task_id="t1", data=data)
|
||||
|
||||
|
||||
def _build_workflow_paused_blocking_response() -> WorkflowAppPausedBlockingResponse:
|
||||
return WorkflowAppPausedBlockingResponse(
|
||||
task_id="t1",
|
||||
workflow_run_id="r1",
|
||||
data=WorkflowAppPausedBlockingResponse.Data(
|
||||
id="r1",
|
||||
workflow_id="wf-1",
|
||||
status=WorkflowExecutionStatus.PAUSED,
|
||||
outputs={},
|
||||
error=None,
|
||||
elapsed_time=0.5,
|
||||
total_tokens=0,
|
||||
total_steps=2,
|
||||
created_at=1,
|
||||
finished_at=None,
|
||||
paused_nodes=["node-1"],
|
||||
reasons=[{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 100}],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _FakePauseEntity(WorkflowPauseEntity):
|
||||
pause_id: str
|
||||
workflow_run_id: str
|
||||
paused_at_value: datetime
|
||||
pause_reasons: Sequence[HumanInputRequired]
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self.pause_id
|
||||
|
||||
@property
|
||||
def workflow_execution_id(self) -> str:
|
||||
return self.workflow_run_id
|
||||
|
||||
def get_state(self) -> bytes:
|
||||
raise AssertionError("state is not required for snapshot tests")
|
||||
|
||||
@property
|
||||
def resumed_at(self) -> datetime | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def paused_at(self) -> datetime:
|
||||
return self.paused_at_value
|
||||
|
||||
def get_pause_reasons(self) -> Sequence[HumanInputRequired]:
|
||||
return self.pause_reasons
|
||||
|
||||
|
||||
def _build_workflow_run(status: WorkflowExecutionStatus) -> WorkflowRun:
|
||||
return WorkflowRun(
|
||||
id="run-1",
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
workflow_id="workflow-1",
|
||||
type="workflow",
|
||||
triggered_from="app-run",
|
||||
version="v1",
|
||||
graph=None,
|
||||
inputs=json.dumps({"input": "value"}),
|
||||
status=status,
|
||||
outputs=json.dumps({}),
|
||||
error=None,
|
||||
elapsed_time=0.0,
|
||||
total_tokens=0,
|
||||
total_steps=0,
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="user-1",
|
||||
created_at=datetime(2024, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
|
||||
|
||||
def _build_snapshot(status: WorkflowNodeExecutionStatus) -> WorkflowNodeExecutionSnapshot:
|
||||
created_at = datetime(2024, 1, 1, tzinfo=UTC)
|
||||
finished_at = datetime(2024, 1, 1, 0, 0, 5, tzinfo=UTC)
|
||||
return WorkflowNodeExecutionSnapshot(
|
||||
execution_id="exec-1",
|
||||
node_id="node-1",
|
||||
node_type="human-input",
|
||||
title="Human Input",
|
||||
index=1,
|
||||
status=status.value,
|
||||
elapsed_time=0.5,
|
||||
created_at=created_at,
|
||||
finished_at=finished_at,
|
||||
iteration_id=None,
|
||||
loop_id=None,
|
||||
)
|
||||
|
||||
|
||||
def _build_resumption_context(task_id: str) -> WorkflowResumptionContext:
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_id="workflow-1",
|
||||
)
|
||||
generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=task_id,
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id="user-1",
|
||||
stream=True,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
call_depth=0,
|
||||
workflow_execution_id="run-1",
|
||||
)
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
|
||||
runtime_state.register_paused_node("node-1")
|
||||
runtime_state.outputs = {"result": "value"}
|
||||
wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity)
|
||||
return WorkflowResumptionContext(
|
||||
generate_entity=wrapper,
|
||||
serialized_graph_runtime_state=runtime_state.dumps(),
|
||||
)
|
||||
|
||||
|
||||
class TestHitlServiceApi:
|
||||
# Service API event-stream continuation
|
||||
def test_workflow_events_continue_on_pause_keeps_stream_open(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow_run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="app-1",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="end-user-1",
|
||||
finished_at=None,
|
||||
)
|
||||
workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
|
||||
msg_generator = Mock()
|
||||
msg_generator.retrieve_events.return_value = ["raw-event"]
|
||||
workflow_generator = Mock()
|
||||
workflow_generator.convert_to_event_stream.return_value = iter(["data: streamed\n\n"])
|
||||
monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator)
|
||||
monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator)
|
||||
|
||||
api = WorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events?user=u1&continue_on_pause=true", method="GET"):
|
||||
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
|
||||
|
||||
assert response.get_data(as_text=True) == "data: streamed\n\n"
|
||||
msg_generator.retrieve_events.assert_called_once_with(
|
||||
AppMode.WORKFLOW,
|
||||
"run-1",
|
||||
terminal_events=[],
|
||||
)
|
||||
workflow_generator.convert_to_event_stream.assert_called_once_with(["raw-event"])
|
||||
|
||||
def test_workflow_events_snapshot_continue_on_pause_keeps_pause_open(
|
||||
self, app, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
workflow_run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="app-1",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="end-user-1",
|
||||
finished_at=None,
|
||||
)
|
||||
workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
|
||||
msg_generator = Mock()
|
||||
workflow_generator = Mock()
|
||||
workflow_generator.convert_to_event_stream.return_value = iter(["data: snapshot\n\n"])
|
||||
snapshot_builder = Mock(return_value=["snapshot-events"])
|
||||
monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator)
|
||||
monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator)
|
||||
monkeypatch.setattr(workflow_events_module, "build_workflow_event_stream", snapshot_builder)
|
||||
|
||||
api = WorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/workflow/run-1/events?user=u1&include_state_snapshot=true&continue_on_pause=true",
|
||||
method="GET",
|
||||
):
|
||||
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
|
||||
|
||||
assert response.get_data(as_text=True) == "data: snapshot\n\n"
|
||||
msg_generator.retrieve_events.assert_not_called()
|
||||
snapshot_builder.assert_called_once_with(
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_run=workflow_run,
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
session_maker=ANY,
|
||||
human_input_surface=HumanInputSurface.SERVICE_API,
|
||||
close_on_pause=False,
|
||||
)
|
||||
workflow_generator.convert_to_event_stream.assert_called_once_with(["snapshot-events"])
|
||||
|
||||
def test_advanced_chat_blocking_injects_pause_state_config(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", False)
|
||||
monkeypatch.setattr(ags_module, "RateLimit", _DummyRateLimit)
|
||||
|
||||
workflow = MagicMock()
|
||||
workflow.created_by = "owner-id"
|
||||
monkeypatch.setattr(AppGenerateService, "_get_workflow", lambda *args, **kwargs: workflow)
|
||||
monkeypatch.setattr(ags_module.session_factory, "get_session_maker", lambda: "session-maker")
|
||||
|
||||
generator_instance = MagicMock()
|
||||
generator_instance.generate.return_value = {"result": "advanced-blocking"}
|
||||
generator_instance.convert_to_event_stream.side_effect = lambda payload: payload
|
||||
monkeypatch.setattr(ags_module, "AdvancedChatAppGenerator", lambda: generator_instance)
|
||||
|
||||
app_model = MagicMock()
|
||||
app_model.mode = AppMode.ADVANCED_CHAT
|
||||
app_model.id = "app-id"
|
||||
app_model.tenant_id = "tenant-id"
|
||||
app_model.max_active_requests = 0
|
||||
app_model.is_agent = False
|
||||
|
||||
user = MagicMock()
|
||||
user.id = "user-id"
|
||||
|
||||
result = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args={"workflow_id": None, "query": "hi", "inputs": {}},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
assert result == {"result": "advanced-blocking"}
|
||||
call_kwargs = generator_instance.generate.call_args.kwargs
|
||||
assert call_kwargs["streaming"] is False
|
||||
assert call_kwargs["pause_state_config"] is not None
|
||||
assert call_kwargs["pause_state_config"].session_factory == "session-maker"
|
||||
assert call_kwargs["pause_state_config"].state_owner_user_id == "owner-id"
|
||||
|
||||
# Blocking payload contract
|
||||
def test_advanced_chat_blocking_pause_payload_contract(self) -> None:
|
||||
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
|
||||
|
||||
response = AdvancedChatAppGenerateResponseConverter.convert_blocking_full_response(
|
||||
_build_advanced_chat_paused_blocking_response()
|
||||
)
|
||||
|
||||
assert response["event"] == "workflow_paused"
|
||||
assert response["workflow_run_id"] == "run-1"
|
||||
assert response["answer"] == "partial"
|
||||
assert response["data"]["reasons"][0]["type"] == PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||
assert response["data"]["reasons"][0]["expiration_time"] == 100
|
||||
assert "human_input_forms" not in response["data"]
|
||||
|
||||
def test_workflow_blocking_pause_payload_contract(self) -> None:
|
||||
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
|
||||
|
||||
response = WorkflowAppGenerateResponseConverter.convert_blocking_full_response(
|
||||
_build_workflow_paused_blocking_response()
|
||||
)
|
||||
|
||||
assert response["workflow_run_id"] == "r1"
|
||||
assert response["data"]["status"] == WorkflowExecutionStatus.PAUSED
|
||||
assert response["data"]["paused_nodes"] == ["node-1"]
|
||||
assert response["data"]["reasons"] == [
|
||||
{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 100}
|
||||
]
|
||||
assert "human_input_forms" not in response["data"]
|
||||
|
||||
def test_advanced_chat_blocking_pipeline_pause_payload_contract(self) -> None:
|
||||
from core.app.app_config.entities import AppAdditionalFeatures
|
||||
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
|
||||
from models.enums import MessageStatus
|
||||
from models.model import EndUser
|
||||
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
app_mode=AppMode.ADVANCED_CHAT,
|
||||
additional_features=AppAdditionalFeatures(),
|
||||
variables=[],
|
||||
workflow_id="workflow-id",
|
||||
)
|
||||
application_generate_entity = AdvancedChatAppGenerateEntity.model_construct(
|
||||
task_id="task",
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="hello",
|
||||
files=[],
|
||||
user_id="user",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
extras={},
|
||||
trace_manager=None,
|
||||
workflow_run_id="run-id",
|
||||
)
|
||||
pipeline = AdvancedChatAppGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}),
|
||||
queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None),
|
||||
conversation=SimpleNamespace(id="conv-id", mode=AppMode.ADVANCED_CHAT),
|
||||
message=SimpleNamespace(
|
||||
id="message-id",
|
||||
query="hello",
|
||||
created_at=datetime.utcnow(),
|
||||
status=MessageStatus.NORMAL,
|
||||
answer="",
|
||||
),
|
||||
user=EndUser(tenant_id="tenant", type="session", name="tester", session_id="session"),
|
||||
stream=False,
|
||||
dialogue_count=1,
|
||||
draft_var_saver_factory=lambda **kwargs: None,
|
||||
)
|
||||
pipeline._task_state.answer = "partial answer"
|
||||
pipeline._workflow_run_id = "run-id"
|
||||
|
||||
def _gen():
|
||||
yield HumanInputRequiredResponse(
|
||||
task_id="task",
|
||||
workflow_run_id="run-id",
|
||||
data=HumanInputRequiredResponse.Data(
|
||||
form_id="form-1",
|
||||
node_id="node-1",
|
||||
node_title="Approval",
|
||||
form_content="Need approval",
|
||||
inputs=[],
|
||||
actions=[UserAction(id="approve", title="Approve")],
|
||||
display_in_ui=True,
|
||||
form_token="token-1",
|
||||
resolved_default_values={},
|
||||
expiration_time=123,
|
||||
),
|
||||
)
|
||||
yield WorkflowPauseStreamResponse(
|
||||
task_id="task",
|
||||
workflow_run_id="run-id",
|
||||
data=WorkflowPauseStreamResponse.Data(
|
||||
workflow_run_id="run-id",
|
||||
paused_nodes=["node-1"],
|
||||
outputs={},
|
||||
reasons=[
|
||||
{
|
||||
"type": PauseReasonType.HUMAN_INPUT_REQUIRED,
|
||||
"form_id": "form-1",
|
||||
"node_id": "node-1",
|
||||
"expiration_time": 123,
|
||||
},
|
||||
],
|
||||
status="paused",
|
||||
created_at=1,
|
||||
elapsed_time=0.1,
|
||||
total_tokens=0,
|
||||
total_steps=0,
|
||||
),
|
||||
)
|
||||
|
||||
response = pipeline._to_blocking_response(_gen())
|
||||
|
||||
assert isinstance(response, AdvancedChatPausedBlockingResponse)
|
||||
assert response.data.answer == "partial answer"
|
||||
assert response.data.workflow_run_id == "run-id"
|
||||
assert response.data.reasons[0]["form_id"] == "form-1"
|
||||
assert response.data.reasons[0]["expiration_time"] == 123
|
||||
|
||||
def test_workflow_blocking_pipeline_pause_payload_contract(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from core.app.apps.workflow import generate_task_pipeline as workflow_pipeline_module
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
additional_features=AppAdditionalFeatures(),
|
||||
variables=[],
|
||||
workflow_id="workflow-id",
|
||||
)
|
||||
application_generate_entity = WorkflowAppGenerateEntity.model_construct(
|
||||
task_id="task",
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id="user",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
trace_manager=None,
|
||||
workflow_execution_id="run-id",
|
||||
extras={},
|
||||
call_depth=0,
|
||||
)
|
||||
pipeline = WorkflowAppGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}),
|
||||
queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None),
|
||||
user=SimpleNamespace(id="user", session_id="session"),
|
||||
stream=False,
|
||||
draft_var_saver_factory=lambda **kwargs: None,
|
||||
)
|
||||
monkeypatch.setattr(workflow_pipeline_module.time, "time", lambda: 1700000000)
|
||||
|
||||
def _gen():
|
||||
yield HumanInputRequiredResponse(
|
||||
task_id="task",
|
||||
workflow_run_id="run",
|
||||
data=HumanInputRequiredResponse.Data(
|
||||
form_id="form-1",
|
||||
node_id="node-1",
|
||||
node_title="Human Input",
|
||||
form_content="content",
|
||||
expiration_time=1,
|
||||
),
|
||||
)
|
||||
yield WorkflowPauseStreamResponse(
|
||||
task_id="task",
|
||||
workflow_run_id="run",
|
||||
data=WorkflowPauseStreamResponse.Data(
|
||||
workflow_run_id="run",
|
||||
status=WorkflowExecutionStatus.PAUSED,
|
||||
outputs={},
|
||||
paused_nodes=["node-1"],
|
||||
reasons=[{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 1}],
|
||||
created_at=1,
|
||||
elapsed_time=0.1,
|
||||
total_tokens=0,
|
||||
total_steps=0,
|
||||
),
|
||||
)
|
||||
|
||||
response = pipeline._to_blocking_response(_gen())
|
||||
|
||||
assert isinstance(response, WorkflowAppPausedBlockingResponse)
|
||||
assert response.data.status == WorkflowExecutionStatus.PAUSED
|
||||
assert response.data.paused_nodes == ["node-1"]
|
||||
assert response.data.reasons == [{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 1}]
|
||||
|
||||
def test_service_api_pause_event_serializes_hitl_reason(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
converter = _build_service_api_pause_converter()
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="task",
|
||||
workflow_run_id="run-id",
|
||||
workflow_id="workflow-id",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
|
||||
expiration_time = datetime(2024, 1, 1, tzinfo=UTC)
|
||||
|
||||
class _FakeSession:
|
||||
def execute(self, _stmt):
|
||||
return [("form-1", expiration_time, '{"display_in_ui": true}')]
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(workflow_response_converter, "Session", lambda **_: _FakeSession())
|
||||
monkeypatch.setattr(workflow_response_converter, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(
|
||||
workflow_response_converter,
|
||||
"load_form_tokens_by_form_id",
|
||||
lambda form_ids, session=None, surface=None: {"form-1": "token"},
|
||||
)
|
||||
|
||||
reason = HumanInputRequired(
|
||||
form_id="form-1",
|
||||
form_content="Rendered",
|
||||
inputs=[
|
||||
FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="field", default=None),
|
||||
],
|
||||
actions=[UserAction(id="approve", title="Approve")],
|
||||
display_in_ui=True,
|
||||
node_id="node-id",
|
||||
node_title="Human Step",
|
||||
form_token="token",
|
||||
)
|
||||
queue_event = QueueWorkflowPausedEvent(
|
||||
reasons=[reason],
|
||||
outputs={"answer": "value"},
|
||||
paused_nodes=["node-id"],
|
||||
)
|
||||
|
||||
runtime_state = SimpleNamespace(total_tokens=0, node_run_steps=0)
|
||||
responses = converter.workflow_pause_to_stream_response(
|
||||
event=queue_event,
|
||||
task_id="task",
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
assert isinstance(responses[-1], WorkflowPauseStreamResponse)
|
||||
pause_resp = responses[-1]
|
||||
assert pause_resp.workflow_run_id == "run-id"
|
||||
assert pause_resp.data.paused_nodes == ["node-id"]
|
||||
assert pause_resp.data.outputs == {}
|
||||
assert pause_resp.data.reasons[0]["TYPE"] == "human_input_required"
|
||||
assert pause_resp.data.reasons[0]["form_id"] == "form-1"
|
||||
assert pause_resp.data.reasons[0]["form_token"] == "token"
|
||||
assert pause_resp.data.reasons[0]["expiration_time"] == int(expiration_time.timestamp())
|
||||
|
||||
assert isinstance(responses[0], HumanInputRequiredResponse)
|
||||
hi_resp = responses[0]
|
||||
assert hi_resp.data.form_id == "form-1"
|
||||
assert hi_resp.data.node_id == "node-id"
|
||||
assert hi_resp.data.node_title == "Human Step"
|
||||
assert hi_resp.data.inputs[0].output_variable_name == "field"
|
||||
assert hi_resp.data.actions[0].id == "approve"
|
||||
assert hi_resp.data.display_in_ui is True
|
||||
assert hi_resp.data.form_token == "token"
|
||||
assert hi_resp.data.expiration_time == int(expiration_time.timestamp())
|
||||
|
||||
# Snapshot payload contract
|
||||
def test_snapshot_events_include_pause_payload_contract(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow_run = _build_workflow_run(WorkflowExecutionStatus.PAUSED)
|
||||
snapshot = _build_snapshot(WorkflowNodeExecutionStatus.PAUSED)
|
||||
resumption_context = _build_resumption_context("task-ctx")
|
||||
monkeypatch.setattr(
|
||||
"services.workflow_event_snapshot_service.load_form_tokens_by_form_id",
|
||||
lambda form_ids, session=None, surface=None: {"form-1": "wtok"},
|
||||
)
|
||||
|
||||
class _SessionContext:
|
||||
def __init__(self, session):
|
||||
self._session = session
|
||||
|
||||
def __enter__(self):
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def session_maker() -> _SessionContext:
|
||||
return _SessionContext(
|
||||
SimpleNamespace(
|
||||
execute=lambda _stmt: [("form-1", datetime(2024, 1, 1, tzinfo=UTC), '{"display_in_ui": true}')],
|
||||
)
|
||||
)
|
||||
|
||||
pause_entity = _FakePauseEntity(
|
||||
pause_id="pause-1",
|
||||
workflow_run_id="run-1",
|
||||
paused_at_value=datetime(2024, 1, 1, tzinfo=UTC),
|
||||
pause_reasons=[
|
||||
HumanInputRequired(
|
||||
form_id="form-1",
|
||||
form_content="content",
|
||||
node_id="node-1",
|
||||
node_title="Human Input",
|
||||
form_token="wtok",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
events = _build_snapshot_events(
|
||||
workflow_run=workflow_run,
|
||||
node_snapshots=[snapshot],
|
||||
task_id="task-ctx",
|
||||
message_context=None,
|
||||
pause_entity=pause_entity,
|
||||
resumption_context=resumption_context,
|
||||
session_maker=session_maker,
|
||||
)
|
||||
|
||||
assert [event["event"] for event in events] == [
|
||||
"workflow_started",
|
||||
"node_started",
|
||||
"node_finished",
|
||||
"human_input_required",
|
||||
"workflow_paused",
|
||||
]
|
||||
assert events[2]["data"]["status"] == WorkflowNodeExecutionStatus.PAUSED.value
|
||||
assert events[3]["data"]["form_token"] == "wtok"
|
||||
assert events[3]["data"]["expiration_time"] == int(datetime(2024, 1, 1, tzinfo=UTC).timestamp())
|
||||
pause_data = events[-1]["data"]
|
||||
assert pause_data["paused_nodes"] == ["node-1"]
|
||||
assert pause_data["outputs"] == {"result": "value"}
|
||||
assert pause_data["reasons"][0]["TYPE"] == "human_input_required"
|
||||
assert pause_data["reasons"][0]["form_token"] == "wtok"
|
||||
assert pause_data["reasons"][0]["expiration_time"] == int(datetime(2024, 1, 1, tzinfo=UTC).timestamp())
|
||||
assert pause_data["status"] == WorkflowExecutionStatus.PAUSED.value
|
||||
assert pause_data["created_at"] == int(workflow_run.created_at.timestamp())
|
||||
assert pause_data["elapsed_time"] == workflow_run.elapsed_time
|
||||
assert pause_data["total_tokens"] == workflow_run.total_tokens
|
||||
assert pause_data["total_steps"] == workflow_run.total_steps
|
||||
@@ -0,0 +1,184 @@
|
||||
"""Unit tests for Service API human input form endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.service_api.app.human_input_form import WorkflowHumanInputFormApi
|
||||
from models.human_input import RecipientType
|
||||
from tests.unit_tests.controllers.service_api.conftest import _unwrap
|
||||
|
||||
|
||||
class TestWorkflowHumanInputFormApi:
|
||||
def test_get_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
definition = SimpleNamespace(
|
||||
model_dump=lambda: {
|
||||
"rendered_content": "Rendered form content",
|
||||
"inputs": [{"output_variable_name": "name"}],
|
||||
"default_values": {"name": "Alice", "age": 30, "meta": {"k": "v"}},
|
||||
"user_actions": [{"id": "approve", "title": "Approve"}],
|
||||
}
|
||||
)
|
||||
form = SimpleNamespace(
|
||||
app_id="app-1",
|
||||
tenant_id="tenant-1",
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
expiration_time=datetime(2099, 1, 1, tzinfo=UTC),
|
||||
get_definition=lambda: definition,
|
||||
)
|
||||
service_mock = Mock()
|
||||
service_mock.get_form_by_token.return_value = form
|
||||
workflow_module = sys.modules["controllers.service_api.app.human_input_form"]
|
||||
monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock)
|
||||
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = WorkflowHumanInputFormApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context("/form/human_input/token-1", method="GET"):
|
||||
response = handler(api, app_model=app_model, form_token="token-1")
|
||||
|
||||
payload = json.loads(response.get_data(as_text=True))
|
||||
assert payload == {
|
||||
"form_content": "Rendered form content",
|
||||
"inputs": [{"output_variable_name": "name"}],
|
||||
"resolved_default_values": {"name": "Alice", "age": "30", "meta": '{"k": "v"}'},
|
||||
"user_actions": [{"id": "approve", "title": "Approve"}],
|
||||
"expiration_time": int(form.expiration_time.timestamp()),
|
||||
}
|
||||
service_mock.get_form_by_token.assert_called_once_with("token-1")
|
||||
service_mock.ensure_form_active.assert_called_once_with(form)
|
||||
|
||||
def test_get_form_not_in_app(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = SimpleNamespace(
|
||||
app_id="another-app",
|
||||
tenant_id="tenant-1",
|
||||
expiration_time=datetime(2099, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
service_mock = Mock()
|
||||
service_mock.get_form_by_token.return_value = form
|
||||
workflow_module = sys.modules["controllers.service_api.app.human_input_form"]
|
||||
monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock)
|
||||
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = WorkflowHumanInputFormApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context("/form/human_input/token-1", method="GET"):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, form_token="token-1")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"recipient_type",
|
||||
[
|
||||
RecipientType.CONSOLE,
|
||||
RecipientType.BACKSTAGE,
|
||||
RecipientType.EMAIL_MEMBER,
|
||||
RecipientType.EMAIL_EXTERNAL,
|
||||
],
|
||||
)
|
||||
def test_get_rejects_non_service_api_recipient_types(
|
||||
self, app, monkeypatch: pytest.MonkeyPatch, recipient_type: RecipientType
|
||||
) -> None:
|
||||
form = SimpleNamespace(
|
||||
app_id="app-1",
|
||||
tenant_id="tenant-1",
|
||||
recipient_type=recipient_type,
|
||||
expiration_time=datetime(2099, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
service_mock = Mock()
|
||||
service_mock.get_form_by_token.return_value = form
|
||||
workflow_module = sys.modules["controllers.service_api.app.human_input_form"]
|
||||
monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock)
|
||||
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = WorkflowHumanInputFormApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context("/form/human_input/token-1", method="GET"):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, form_token="token-1")
|
||||
|
||||
service_mock.ensure_form_active.assert_not_called()
|
||||
|
||||
def test_post_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = SimpleNamespace(
|
||||
app_id="app-1",
|
||||
tenant_id="tenant-1",
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
)
|
||||
service_mock = Mock()
|
||||
service_mock.get_form_by_token.return_value = form
|
||||
workflow_module = sys.modules["controllers.service_api.app.human_input_form"]
|
||||
monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock)
|
||||
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = WorkflowHumanInputFormApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/form/human_input/token-1",
|
||||
method="POST",
|
||||
json={"inputs": {"name": "Alice"}, "action": "approve", "user": "external-1"},
|
||||
):
|
||||
response, status = handler(api, app_model=app_model, end_user=end_user, form_token="token-1")
|
||||
|
||||
assert response == {}
|
||||
assert status == 200
|
||||
service_mock.submit_form_by_token.assert_called_once_with(
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
form_token="token-1",
|
||||
selected_action_id="approve",
|
||||
form_data={"name": "Alice"},
|
||||
submission_end_user_id="end-user-1",
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"recipient_type",
|
||||
[
|
||||
RecipientType.CONSOLE,
|
||||
RecipientType.BACKSTAGE,
|
||||
RecipientType.EMAIL_MEMBER,
|
||||
RecipientType.EMAIL_EXTERNAL,
|
||||
],
|
||||
)
|
||||
def test_post_rejects_non_service_api_recipient_types(
|
||||
self, app, monkeypatch: pytest.MonkeyPatch, recipient_type: RecipientType
|
||||
) -> None:
|
||||
form = SimpleNamespace(
|
||||
app_id="app-1",
|
||||
tenant_id="tenant-1",
|
||||
recipient_type=recipient_type,
|
||||
)
|
||||
service_mock = Mock()
|
||||
service_mock.get_form_by_token.return_value = form
|
||||
workflow_module = sys.modules["controllers.service_api.app.human_input_form"]
|
||||
monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock)
|
||||
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = WorkflowHumanInputFormApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/form/human_input/token-1",
|
||||
method="POST",
|
||||
json={"inputs": {"name": "Alice"}, "action": "approve", "user": "external-1"},
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, end_user=end_user, form_token="token-1")
|
||||
|
||||
service_mock.submit_form_by_token.assert_not_called()
|
||||
@@ -0,0 +1,166 @@
|
||||
"""Unit tests for Service API workflow event stream endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.service_api.app.error import NotWorkflowAppError
|
||||
from controllers.service_api.app.workflow_events import WorkflowEventsApi
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import AppMode
|
||||
from tests.unit_tests.controllers.service_api.conftest import _unwrap
|
||||
|
||||
|
||||
def _mock_repo_for_run(monkeypatch: pytest.MonkeyPatch, workflow_run):
|
||||
workflow_events_module = sys.modules["controllers.service_api.app.workflow_events"]
|
||||
repo = SimpleNamespace(get_workflow_run_by_id_and_tenant_id=lambda **_kwargs: workflow_run)
|
||||
monkeypatch.setattr(
|
||||
workflow_events_module.DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_args, **_kwargs: repo,
|
||||
)
|
||||
monkeypatch.setattr(workflow_events_module, "db", SimpleNamespace(engine=object()))
|
||||
return workflow_events_module
|
||||
|
||||
|
||||
class TestWorkflowEventsApi:
|
||||
def test_wrong_app_mode(self, app) -> None:
|
||||
api = WorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
|
||||
with pytest.raises(NotWorkflowAppError):
|
||||
handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
|
||||
|
||||
def test_workflow_run_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_mock_repo_for_run(monkeypatch, workflow_run=None)
|
||||
api = WorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
|
||||
|
||||
def test_workflow_run_permission_denied(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow_run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="app-1",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by="another-user",
|
||||
finished_at=None,
|
||||
)
|
||||
_mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
|
||||
api = WorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
|
||||
|
||||
def test_finished_run_returns_sse(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow_run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="app-1",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="end-user-1",
|
||||
finished_at=datetime(2099, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
|
||||
monkeypatch.setattr(
|
||||
workflow_events_module.WorkflowResponseConverter,
|
||||
"workflow_run_result_to_finish_response",
|
||||
lambda **_kwargs: SimpleNamespace(
|
||||
model_dump=lambda mode="json": {"task_id": "run-1", "status": "succeeded"},
|
||||
event=SimpleNamespace(value="workflow_finished"),
|
||||
),
|
||||
)
|
||||
|
||||
api = WorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
|
||||
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
|
||||
|
||||
assert response.mimetype == "text/event-stream"
|
||||
body = response.get_data(as_text=True).strip()
|
||||
assert body.startswith("data: ")
|
||||
payload = json.loads(body[len("data: ") :])
|
||||
assert payload["task_id"] == "run-1"
|
||||
assert payload["event"] == "workflow_finished"
|
||||
|
||||
def test_running_run_streams_events(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow_run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="app-1",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="end-user-1",
|
||||
finished_at=None,
|
||||
)
|
||||
workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
|
||||
msg_generator = Mock()
|
||||
msg_generator.retrieve_events.return_value = ["raw-event"]
|
||||
workflow_generator = Mock()
|
||||
workflow_generator.convert_to_event_stream.return_value = iter(["data: streamed\n\n"])
|
||||
monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator)
|
||||
monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator)
|
||||
|
||||
api = WorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
|
||||
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
|
||||
|
||||
assert response.get_data(as_text=True) == "data: streamed\n\n"
|
||||
msg_generator.retrieve_events.assert_called_once_with(
|
||||
AppMode.WORKFLOW,
|
||||
"run-1",
|
||||
terminal_events=None,
|
||||
)
|
||||
workflow_generator.convert_to_event_stream.assert_called_once_with(["raw-event"])
|
||||
|
||||
def test_running_run_with_snapshot(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow_run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="app-1",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="end-user-1",
|
||||
finished_at=None,
|
||||
)
|
||||
workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
|
||||
msg_generator = Mock()
|
||||
workflow_generator = Mock()
|
||||
workflow_generator.convert_to_event_stream.return_value = iter(["data: snapshot\n\n"])
|
||||
snapshot_builder = Mock(return_value=["snapshot-events"])
|
||||
monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator)
|
||||
monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator)
|
||||
monkeypatch.setattr(workflow_events_module, "build_workflow_event_stream", snapshot_builder)
|
||||
|
||||
api = WorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events?user=u1&include_state_snapshot=true", method="GET"):
|
||||
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
|
||||
|
||||
assert response.get_data(as_text=True) == "data: snapshot\n\n"
|
||||
msg_generator.retrieve_events.assert_not_called()
|
||||
snapshot_builder.assert_called_once()
|
||||
workflow_generator.convert_to_event_stream.assert_called_once_with(["snapshot-events"])
|
||||
@@ -15,7 +15,10 @@ from flask import Flask
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models.account import TenantStatus
|
||||
from models.model import App, AppMode, EndUser
|
||||
from tests.unit_tests.conftest import setup_mock_tenant_account_query
|
||||
from tests.unit_tests.conftest import (
|
||||
setup_mock_dataset_owner_execute_result,
|
||||
setup_mock_tenant_owner_execute_result,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -123,9 +126,7 @@ class AuthenticationMocker:
|
||||
mock_db.session.get.side_effect = [mock_app, mock_tenant]
|
||||
|
||||
if mock_account:
|
||||
mock_ta = Mock()
|
||||
mock_ta.account_id = mock_account.id
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
@staticmethod
|
||||
def setup_dataset_auth(mock_db, mock_tenant, mock_account):
|
||||
@@ -133,8 +134,7 @@ class AuthenticationMocker:
|
||||
mock_ta = Mock()
|
||||
mock_ta.account_id = mock_account.id
|
||||
|
||||
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta)
|
||||
|
||||
setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_ta)
|
||||
mock_db.session.get.return_value = mock_account
|
||||
|
||||
|
||||
|
||||
@@ -701,8 +701,8 @@ class TestDocumentApiDelete:
|
||||
``delete`` is wrapped by ``@cloud_edition_billing_rate_limit_check`` which
|
||||
internally calls ``validate_and_get_api_token``. To bypass the decorator
|
||||
we call the original function via ``__wrapped__`` (preserved by
|
||||
``functools.wraps``). ``delete`` queries the dataset via
|
||||
``db.session.query(Dataset)`` directly, so we patch ``db`` at the
|
||||
``functools.wraps``). ``delete`` loads the dataset via
|
||||
``db.session.scalar(select(Dataset)...)``, so we patch ``db`` at the
|
||||
controller module.
|
||||
"""
|
||||
|
||||
|
||||
@@ -24,8 +24,8 @@ from enums.cloud_plan import CloudPlan
|
||||
from models.account import TenantStatus
|
||||
from models.model import ApiToken
|
||||
from tests.unit_tests.conftest import (
|
||||
setup_mock_dataset_tenant_query,
|
||||
setup_mock_tenant_account_query,
|
||||
setup_mock_dataset_owner_execute_result,
|
||||
setup_mock_tenant_owner_execute_result,
|
||||
)
|
||||
|
||||
|
||||
@@ -141,14 +141,11 @@ class TestValidateAppToken:
|
||||
mock_account = Mock()
|
||||
mock_account.id = str(uuid.uuid4())
|
||||
|
||||
mock_ta = Mock()
|
||||
mock_ta.account_id = mock_account.id
|
||||
|
||||
# Use side_effect to return app first, then tenant via session.get()
|
||||
mock_db.session.get.side_effect = [mock_app, mock_tenant]
|
||||
|
||||
# Mock the tenant owner query (execute(select(...)).one_or_none())
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta)
|
||||
# Mock the tenant owner execute result (execute(select(...)).one_or_none())
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
@validate_app_token
|
||||
def protected_view(app_model):
|
||||
@@ -471,7 +468,7 @@ class TestValidateDatasetToken:
|
||||
mock_account.current_tenant = mock_tenant
|
||||
|
||||
# Mock the tenant account join query (execute(select(...)).one_or_none())
|
||||
setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta)
|
||||
setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_ta)
|
||||
|
||||
# Mock the account lookup via session.get()
|
||||
mock_db.session.get.return_value = mock_account
|
||||
|
||||
@@ -22,18 +22,16 @@ class FakeSession:
|
||||
|
||||
def __init__(self, mapping: dict[str, Any] | None = None):
|
||||
self._mapping: dict[str, Any] = mapping or {}
|
||||
self._model_name: str | None = None
|
||||
|
||||
def query(self, model: type) -> FakeSession:
|
||||
self._model_name = model.__name__
|
||||
return self
|
||||
def get(self, model: type, _ident: object) -> Any:
|
||||
return self._mapping.get(model.__name__)
|
||||
|
||||
def where(self, *_args: object, **_kwargs: object) -> FakeSession:
|
||||
return self
|
||||
|
||||
def first(self) -> Any:
|
||||
assert self._model_name is not None
|
||||
return self._mapping.get(self._model_name)
|
||||
def scalar(self, stmt: Any) -> Any:
|
||||
try:
|
||||
model = stmt.column_descriptions[0]["entity"]
|
||||
except (AttributeError, IndexError, KeyError, TypeError):
|
||||
return None
|
||||
return self._mapping.get(model.__name__)
|
||||
|
||||
|
||||
class FakeDB:
|
||||
|
||||
@@ -36,18 +36,6 @@ class _FakeSession:
|
||||
|
||||
def __init__(self, mapping: dict[str, Any]):
|
||||
self._mapping = mapping
|
||||
self._model_name: str | None = None
|
||||
|
||||
def query(self, model):
|
||||
self._model_name = model.__name__
|
||||
return self
|
||||
|
||||
def where(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
assert self._model_name is not None
|
||||
return self._mapping.get(self._model_name)
|
||||
|
||||
def get(self, model, ident):
|
||||
return self._mapping.get(model.__name__)
|
||||
|
||||
@@ -34,7 +34,6 @@ def _patch_wraps():
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
|
||||
patch("controllers.web.login.dify_config", web_dify),
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
yield
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user