mirror of
https://github.com/langgenius/dify.git
synced 2026-05-25 19:00:43 -04:00
refactor(api): migrate console.app.workflow to BaseModel (#36216)
Co-authored-by: WH-2099 <wh2099@pm.me> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -2,8 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import TypedDict, Unpack
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -35,9 +37,54 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_workflow import (
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.account import Account
|
||||
from models.workflow import Workflow
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
DEFAULT_WORKFLOW_TENANT_ID = "00000000-0000-0000-0000-000000000001"
|
||||
DEFAULT_WORKFLOW_APP_ID = "00000000-0000-0000-0000-000000000002"
|
||||
DEFAULT_WORKFLOW_CREATED_BY = "00000000-0000-0000-0000-000000000003"
|
||||
type WorkflowVariablePayload = dict[str, object]
|
||||
|
||||
|
||||
class WorkflowFactoryPayload(TypedDict):
|
||||
id: str
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
type: str
|
||||
version: str
|
||||
marked_name: str
|
||||
marked_comment: str
|
||||
graph: str
|
||||
features: str
|
||||
created_by: str
|
||||
created_at: datetime
|
||||
updated_by: str | None
|
||||
updated_at: datetime
|
||||
environment_variables: list[WorkflowVariablePayload]
|
||||
conversation_variables: list[WorkflowVariablePayload]
|
||||
rag_pipeline_variables: list[WorkflowVariablePayload]
|
||||
|
||||
|
||||
class WorkflowFactoryOverrides(TypedDict, total=False):
|
||||
id: str
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
type: str
|
||||
version: str
|
||||
marked_name: str
|
||||
marked_comment: str
|
||||
graph: str
|
||||
features: str
|
||||
created_by: str
|
||||
created_at: datetime
|
||||
updated_by: str | None
|
||||
updated_at: datetime
|
||||
environment_variables: list[WorkflowVariablePayload]
|
||||
conversation_variables: list[WorkflowVariablePayload]
|
||||
rag_pipeline_variables: list[WorkflowVariablePayload]
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
@@ -74,17 +121,52 @@ def make_node_execution(**overrides):
|
||||
return SimpleNamespace(**payload)
|
||||
|
||||
|
||||
def default_workflow_payload() -> WorkflowFactoryPayload:
|
||||
return {
|
||||
"id": "workflow-1",
|
||||
"tenant_id": DEFAULT_WORKFLOW_TENANT_ID,
|
||||
"app_id": DEFAULT_WORKFLOW_APP_ID,
|
||||
"type": "workflow",
|
||||
"version": "1",
|
||||
"marked_name": "Release 1",
|
||||
"marked_comment": "Initial release",
|
||||
"graph": json.dumps({"nodes": [], "edges": []}),
|
||||
"features": json.dumps({"file_upload": {"enabled": False}}),
|
||||
"created_by": DEFAULT_WORKFLOW_CREATED_BY,
|
||||
"created_at": datetime(2024, 1, 1, 12, 0, 0),
|
||||
"updated_by": None,
|
||||
"updated_at": datetime(2024, 1, 1, 12, 1, 0),
|
||||
"environment_variables": [],
|
||||
"conversation_variables": [],
|
||||
"rag_pipeline_variables": [],
|
||||
}
|
||||
|
||||
|
||||
def make_workflow(**overrides: Unpack[WorkflowFactoryOverrides]) -> Workflow:
|
||||
payload = default_workflow_payload()
|
||||
payload.update(overrides)
|
||||
return Workflow(**payload)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def workflow_author(db_session_with_containers: Session) -> Account:
|
||||
account = Account(name="Alice", email=f"alice-{uuid4()}@example.com")
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
return account
|
||||
|
||||
|
||||
class TestDraftWorkflowApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_draft_success(self, app: Flask):
|
||||
def test_get_draft_success(self, app: Flask, workflow_author: Account):
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
workflow = MagicMock()
|
||||
workflow = make_workflow(created_by=workflow_author.id)
|
||||
|
||||
service = MagicMock()
|
||||
service.get_draft_workflow.return_value = workflow
|
||||
@@ -97,7 +179,17 @@ class TestDraftWorkflowApi:
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
assert result == workflow
|
||||
|
||||
assert result["id"] == "workflow-1"
|
||||
assert result["graph"] == {"nodes": [], "edges": []}
|
||||
assert result["features"] == {"file_upload": {"enabled": False}}
|
||||
assert result["hash"] == workflow.unique_hash
|
||||
assert result["created_by"] == {
|
||||
"id": workflow_author.id,
|
||||
"name": workflow_author.name,
|
||||
"email": workflow_author.email,
|
||||
}
|
||||
assert result["updated_by"] is None
|
||||
|
||||
def test_get_draft_not_exist(self, app: Flask):
|
||||
api = DraftRagPipelineApi()
|
||||
@@ -642,7 +734,7 @@ class TestPublishedAllRagPipelineApi:
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
service = MagicMock()
|
||||
service.get_all_published_workflow.return_value = ([{"id": "w1"}], False)
|
||||
service.get_all_published_workflow.return_value = ([make_workflow(id="w1")], False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
@@ -657,7 +749,8 @@ class TestPublishedAllRagPipelineApi:
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
|
||||
assert result["items"] == [{"id": "w1"}]
|
||||
assert result["items"][0]["id"] == "w1"
|
||||
assert result["items"][0]["graph"] == {"nodes": [], "edges": []}
|
||||
assert result["has_more"] is False
|
||||
|
||||
def test_get_published_workflows_forbidden(self, app: Flask):
|
||||
@@ -690,7 +783,7 @@ class TestRagPipelineByIdApi:
|
||||
pipeline = MagicMock(tenant_id="t1")
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
workflow = MagicMock()
|
||||
workflow = make_workflow(id="w1", marked_name="test")
|
||||
|
||||
service = MagicMock()
|
||||
service.update_workflow.return_value = workflow
|
||||
@@ -711,7 +804,9 @@ class TestRagPipelineByIdApi:
|
||||
):
|
||||
result = method(api, pipeline, "w1")
|
||||
|
||||
assert result == workflow
|
||||
assert result["id"] == "w1"
|
||||
assert result["marked_name"] == "test"
|
||||
assert result["hash"] == workflow.unique_hash
|
||||
|
||||
def test_patch_no_fields(self, app: Flask):
|
||||
api = RagPipelineByIdApi()
|
||||
|
||||
@@ -3,14 +3,18 @@ from __future__ import annotations
|
||||
import json
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import HTTPException, NotFound
|
||||
|
||||
from controllers.console.app import workflow as workflow_module
|
||||
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from graphon.file import File, FileTransferMethod, FileType
|
||||
from graphon.variables import SecretVariable, StringVariable
|
||||
from graphon.variables.variables import RAGPipelineVariable
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
@@ -19,11 +23,67 @@ def _unwrap(func):
|
||||
return func
|
||||
|
||||
|
||||
def _make_workflow(**overrides):
|
||||
workflow = SimpleNamespace(
|
||||
id="workflow-1",
|
||||
graph_dict={"nodes": [], "edges": []},
|
||||
features_dict={"file_upload": {"enabled": False}},
|
||||
unique_hash="hash-1",
|
||||
version="1",
|
||||
marked_name="Release 1",
|
||||
marked_comment="Initial release",
|
||||
created_by_account=SimpleNamespace(id="user-1", name="Alice", email="alice@example.com"),
|
||||
created_at=datetime(2024, 1, 1, 12, 0, 0),
|
||||
updated_by_account=None,
|
||||
updated_at=datetime(2024, 1, 1, 12, 1, 0),
|
||||
tool_published=False,
|
||||
environment_variables=[
|
||||
{
|
||||
"id": "env-1",
|
||||
"name": "API_KEY",
|
||||
"value": "[__HIDDEN__]",
|
||||
"value_type": "secret",
|
||||
"description": "API key",
|
||||
}
|
||||
],
|
||||
conversation_variables=[
|
||||
{
|
||||
"id": "conv-1",
|
||||
"name": "topic",
|
||||
"value": "hello",
|
||||
"value_type": "string",
|
||||
"description": "Topic",
|
||||
}
|
||||
],
|
||||
rag_pipeline_variables=[
|
||||
{
|
||||
"variable": "query",
|
||||
"type": "text-input",
|
||||
"label": "Query",
|
||||
"belong_to_node_id": "shared",
|
||||
"max_length": 0,
|
||||
"required": False,
|
||||
"unit": "",
|
||||
"default_value": "",
|
||||
"options": [],
|
||||
"placeholder": "",
|
||||
"tooltips": "",
|
||||
"allowed_file_types": ["custom"],
|
||||
"allowed_file_extensions": [".pdf"],
|
||||
"allowed_file_upload_methods": ["local_file"],
|
||||
}
|
||||
],
|
||||
)
|
||||
for key, value in overrides.items():
|
||||
setattr(workflow, key, value)
|
||||
return workflow
|
||||
|
||||
|
||||
def test_parse_file_no_config(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(workflow_module.FileUploadConfigManager, "convert", lambda *_args, **_kwargs: None)
|
||||
workflow = SimpleNamespace(features_dict={}, tenant_id="t1")
|
||||
|
||||
assert workflow_module._parse_file(workflow, files=[{"id": "f"}]) == []
|
||||
assert workflow_module._parse_file(cast(workflow_module.Workflow, workflow), files=[{"id": "f"}]) == []
|
||||
|
||||
|
||||
def test_parse_file_with_config(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@@ -41,7 +101,7 @@ def test_parse_file_with_config(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(workflow_module.file_factory, "build_from_mappings", build_mock)
|
||||
|
||||
workflow = SimpleNamespace(features_dict={}, tenant_id="t1")
|
||||
result = workflow_module._parse_file(workflow, files=[{"id": "f"}])
|
||||
result = workflow_module._parse_file(cast(workflow_module.Workflow, workflow), files=[{"id": "f"}])
|
||||
|
||||
assert result == file_list
|
||||
build_mock.assert_called_once()
|
||||
@@ -259,7 +319,7 @@ def test_restore_published_workflow_to_draft_returns_400_for_invalid_structure(
|
||||
assert exc.value.description == "invalid workflow graph"
|
||||
|
||||
|
||||
def test_get_published_workflows_marshals_items_before_session_closes(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_get_published_workflows_serializes_items_before_session_closes(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = workflow_module.PublishedAllWorkflowApi()
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
@@ -278,7 +338,12 @@ def test_get_published_workflows_marshals_items_before_session_closes(app, monke
|
||||
def begin(self):
|
||||
return _SessionContext()
|
||||
|
||||
base_workflow = _make_workflow()
|
||||
|
||||
class _Workflow:
|
||||
def __getattr__(self, name):
|
||||
return getattr(base_workflow, name)
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
assert session_state["open"] is True
|
||||
@@ -295,12 +360,6 @@ def test_get_published_workflows_marshals_items_before_session_closes(app, monke
|
||||
),
|
||||
)
|
||||
|
||||
def _fake_marshal(items, fields):
|
||||
assert session_state["open"] is True
|
||||
return [{"id": item.id} for item in items]
|
||||
|
||||
monkeypatch.setattr(workflow_module, "marshal", _fake_marshal)
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/app/workflows",
|
||||
method="GET",
|
||||
@@ -308,12 +367,153 @@ def test_get_published_workflows_marshals_items_before_session_closes(app, monke
|
||||
):
|
||||
response = handler(api, app_model=SimpleNamespace(id="app", workflow_id="wf-1"))
|
||||
|
||||
assert response == {
|
||||
"items": [{"id": "w1"}],
|
||||
"page": 1,
|
||||
"limit": 10,
|
||||
"has_more": False,
|
||||
}
|
||||
assert response["items"][0]["id"] == "w1"
|
||||
assert response["page"] == 1
|
||||
assert response["limit"] == 10
|
||||
assert response["has_more"] is False
|
||||
|
||||
|
||||
def test_draft_workflow_get_serializes_response_model(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow = _make_workflow()
|
||||
monkeypatch.setattr(
|
||||
workflow_module, "WorkflowService", lambda: SimpleNamespace(get_draft_workflow=lambda **_kwargs: workflow)
|
||||
)
|
||||
|
||||
api = workflow_module.DraftWorkflowApi()
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
response = handler(api, app_model=SimpleNamespace(id="app"))
|
||||
|
||||
assert response["id"] == "workflow-1"
|
||||
assert response["graph"] == {"nodes": [], "edges": []}
|
||||
assert response["features"] == {"file_upload": {"enabled": False}}
|
||||
assert response["hash"] == "hash-1"
|
||||
assert response["created_by"] == {"id": "user-1", "name": "Alice", "email": "alice@example.com"}
|
||||
assert response["updated_by"] is None
|
||||
assert response["created_at"] == int(datetime(2024, 1, 1, 12, 0, 0).timestamp())
|
||||
assert response["updated_at"] == int(datetime(2024, 1, 1, 12, 1, 0).timestamp())
|
||||
assert response["environment_variables"] == [
|
||||
{
|
||||
"id": "env-1",
|
||||
"name": "API_KEY",
|
||||
"value": "[__HIDDEN__]",
|
||||
"value_type": "secret",
|
||||
"description": "API key",
|
||||
}
|
||||
]
|
||||
assert response["conversation_variables"] == [
|
||||
{
|
||||
"id": "conv-1",
|
||||
"name": "topic",
|
||||
"value": "hello",
|
||||
"value_type": "string",
|
||||
"description": "Topic",
|
||||
}
|
||||
]
|
||||
assert response["rag_pipeline_variables"] == [
|
||||
{
|
||||
"label": "Query",
|
||||
"variable": "query",
|
||||
"type": "text-input",
|
||||
"belong_to_node_id": "shared",
|
||||
"max_length": 0,
|
||||
"required": False,
|
||||
"unit": "",
|
||||
"default_value": "",
|
||||
"options": [],
|
||||
"placeholder": "",
|
||||
"tooltips": "",
|
||||
"allowed_file_types": ["custom"],
|
||||
"allowed_file_extensions": [".pdf"],
|
||||
"allowed_file_upload_methods": ["local_file"],
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_pipeline_variable_response_accepts_legacy_file_field_names() -> None:
|
||||
response = workflow_module.PipelineVariableResponse.model_validate(
|
||||
{
|
||||
"label": "Query",
|
||||
"variable": "query",
|
||||
"type": "single-file",
|
||||
"belong_to_node_id": "shared",
|
||||
"max_length": 0,
|
||||
"required": False,
|
||||
"unit": "",
|
||||
"default_value": "",
|
||||
"options": [],
|
||||
"placeholder": "",
|
||||
"tooltips": "",
|
||||
"allowed_file_types": [],
|
||||
"allow_file_extension": [".txt"],
|
||||
"allow_file_upload_methods": ["remote_url"],
|
||||
}
|
||||
).model_dump(mode="json")
|
||||
|
||||
assert response["allowed_file_extensions"] == [".txt"]
|
||||
assert response["allowed_file_upload_methods"] == ["remote_url"]
|
||||
|
||||
|
||||
def test_pipeline_variable_response_accepts_explicit_null_optional_fields() -> None:
|
||||
pipeline_variable = RAGPipelineVariable.model_validate(
|
||||
{
|
||||
"label": "Query",
|
||||
"variable": "query",
|
||||
"type": "text-input",
|
||||
"belong_to_node_id": "shared",
|
||||
"max_length": None,
|
||||
"unit": None,
|
||||
"default_value": None,
|
||||
"options": None,
|
||||
"placeholder": None,
|
||||
"tooltips": None,
|
||||
"allowed_file_types": None,
|
||||
"allowed_file_extensions": None,
|
||||
"allowed_file_upload_methods": None,
|
||||
}
|
||||
).model_dump(mode="json")
|
||||
|
||||
response = workflow_module.PipelineVariableResponse.model_validate(pipeline_variable).model_dump(mode="json")
|
||||
|
||||
assert response["max_length"] is None
|
||||
assert response["allowed_file_types"] is None
|
||||
assert response["allowed_file_extensions"] is None
|
||||
assert response["allowed_file_upload_methods"] is None
|
||||
|
||||
|
||||
def test_workflow_response_masks_secret_environment_variables() -> None:
|
||||
workflow = _make_workflow(
|
||||
environment_variables=[
|
||||
SecretVariable(id="env-secret", name="API_KEY", value="plain-token", selector=["env", "API_KEY"]),
|
||||
StringVariable(id="env-string", name="REGION", value="us-east-1", selector=["env", "REGION"]),
|
||||
]
|
||||
)
|
||||
|
||||
response = workflow_module.WorkflowResponse.model_validate(workflow, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
assert response["environment_variables"] == [
|
||||
{
|
||||
"id": "env-secret",
|
||||
"name": "API_KEY",
|
||||
"value": workflow_module.encrypter.full_mask_token(),
|
||||
"value_type": "secret",
|
||||
"description": "",
|
||||
},
|
||||
{
|
||||
"id": "env-string",
|
||||
"name": "REGION",
|
||||
"value": "us-east-1",
|
||||
"value_type": "string",
|
||||
"description": "",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_workflow_response_rejects_invalid_environment_variable_dict() -> None:
|
||||
workflow = _make_workflow(environment_variables=[{"value_type": "not-a-segment-type"}])
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
workflow_module.WorkflowResponse.model_validate(workflow, from_attributes=True)
|
||||
|
||||
|
||||
def test_draft_workflow_get_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@@ -373,10 +573,33 @@ def test_workflow_online_users_filters_inaccessible_workflow(app, monkeypatch: p
|
||||
"avatar": "avatar-file-id",
|
||||
"sid": "sid-1",
|
||||
}
|
||||
)
|
||||
),
|
||||
b"sid-malformed": json.dumps({"avatar": "avatar-file-id", "sid": "sid-malformed"}),
|
||||
b"sid-invalid-avatar": json.dumps(
|
||||
{
|
||||
"user_id": "u-2",
|
||||
"username": "Bob",
|
||||
"avatar": {"file_id": "avatar-file-id"},
|
||||
}
|
||||
),
|
||||
b"sid-invalid-user-id": json.dumps(
|
||||
{
|
||||
"user_id": 42,
|
||||
"username": "Carol",
|
||||
"avatar": "avatar-file-id",
|
||||
}
|
||||
),
|
||||
b"sid-invalid-username": json.dumps(
|
||||
{
|
||||
"user_id": "u-4",
|
||||
"username": ["Dave"],
|
||||
"avatar": "avatar-file-id",
|
||||
}
|
||||
),
|
||||
}
|
||||
]
|
||||
workflow_module.redis_client.pipeline.return_value = redis_pipeline
|
||||
redis_pipeline_factory = Mock(return_value=redis_pipeline)
|
||||
monkeypatch.setattr(workflow_module.redis_client, "pipeline", redis_pipeline_factory)
|
||||
|
||||
api = workflow_module.WorkflowOnlineUsersApi()
|
||||
handler = _unwrap(api.post)
|
||||
@@ -397,13 +620,17 @@ def test_workflow_online_users_filters_inaccessible_workflow(app, monkeypatch: p
|
||||
"user_id": "u-1",
|
||||
"username": "Alice",
|
||||
"avatar": signed_avatar_url,
|
||||
"sid": "sid-1",
|
||||
}
|
||||
},
|
||||
{
|
||||
"user_id": "u-2",
|
||||
"username": "Bob",
|
||||
"avatar": None,
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
workflow_module.redis_client.pipeline.assert_called_once_with(transaction=False)
|
||||
redis_pipeline_factory.assert_called_once_with(transaction=False)
|
||||
redis_pipeline.hgetall.assert_called_once_with(f"{workflow_module.WORKFLOW_ONLINE_USERS_PREFIX}{app_id_1}")
|
||||
redis_pipeline.execute.assert_called_once_with()
|
||||
sign_avatar.assert_called_once_with("avatar-file-id")
|
||||
@@ -422,7 +649,8 @@ def test_workflow_online_users_batches_redis_reads(app, monkeypatch: pytest.Monk
|
||||
first_pipeline.execute.return_value = [{} for _ in range(workflow_module.WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE)]
|
||||
second_pipeline = Mock()
|
||||
second_pipeline.execute.return_value = [{}]
|
||||
workflow_module.redis_client.pipeline.side_effect = [first_pipeline, second_pipeline]
|
||||
redis_pipeline_factory = Mock(side_effect=[first_pipeline, second_pipeline])
|
||||
monkeypatch.setattr(workflow_module.redis_client, "pipeline", redis_pipeline_factory)
|
||||
|
||||
api = workflow_module.WorkflowOnlineUsersApi()
|
||||
handler = _unwrap(api.post)
|
||||
@@ -435,7 +663,7 @@ def test_workflow_online_users_batches_redis_reads(app, monkeypatch: pytest.Monk
|
||||
response = handler(api)
|
||||
|
||||
assert len(response["data"]) == len(app_ids)
|
||||
assert workflow_module.redis_client.pipeline.call_count == 2
|
||||
assert redis_pipeline_factory.call_count == 2
|
||||
assert first_pipeline.hgetall.call_count == workflow_module.WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE
|
||||
assert second_pipeline.hgetall.call_count == 1
|
||||
|
||||
@@ -463,5 +691,6 @@ def test_workflow_online_users_rejects_excessive_workflow_ids(app, monkeypatch:
|
||||
handler(api)
|
||||
|
||||
assert exc.value.code == 400
|
||||
assert exc.value.description is not None
|
||||
assert "Maximum" in exc.value.description
|
||||
accessible_app_ids.assert_not_called()
|
||||
|
||||
@@ -0,0 +1,150 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.datasets.rag_pipeline import rag_pipeline_workflow as module
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def _make_workflow(**overrides):
|
||||
workflow = SimpleNamespace(
|
||||
id="workflow-1",
|
||||
graph_dict={"nodes": [], "edges": []},
|
||||
features_dict={"file_upload": {"enabled": False}},
|
||||
unique_hash="hash-1",
|
||||
version="1",
|
||||
marked_name="Release 1",
|
||||
marked_comment="Initial release",
|
||||
created_by_account=SimpleNamespace(id="user-1", name="Alice", email="alice@example.com"),
|
||||
created_at=datetime(2024, 1, 1, 12, 0, 0),
|
||||
updated_by_account=None,
|
||||
updated_at=datetime(2024, 1, 1, 12, 1, 0),
|
||||
tool_published=False,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
rag_pipeline_variables=[],
|
||||
)
|
||||
for key, value in overrides.items():
|
||||
setattr(workflow, key, value)
|
||||
return workflow
|
||||
|
||||
|
||||
def test_draft_rag_pipeline_workflow_get_serializes_response_model(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow = _make_workflow()
|
||||
monkeypatch.setattr(
|
||||
module, "RagPipelineService", lambda: SimpleNamespace(get_draft_workflow=lambda **_kwargs: workflow)
|
||||
)
|
||||
|
||||
api = module.DraftRagPipelineApi()
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
response = handler(api, pipeline=SimpleNamespace(id="pipeline-1"))
|
||||
|
||||
assert response["id"] == "workflow-1"
|
||||
assert response["graph"] == {"nodes": [], "edges": []}
|
||||
assert response["features"] == {"file_upload": {"enabled": False}}
|
||||
assert response["hash"] == "hash-1"
|
||||
assert response["created_by"] == {"id": "user-1", "name": "Alice", "email": "alice@example.com"}
|
||||
assert response["updated_by"] is None
|
||||
assert response["created_at"] == int(datetime(2024, 1, 1, 12, 0, 0).timestamp())
|
||||
assert response["updated_at"] == int(datetime(2024, 1, 1, 12, 1, 0).timestamp())
|
||||
|
||||
|
||||
def test_published_rag_pipeline_workflows_serialize_items_before_session_closes(
|
||||
app, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
api = module.PublishedAllRagPipelineApi()
|
||||
handler = _unwrap(api.get)
|
||||
session_state = {"open": False}
|
||||
|
||||
class _SessionContext:
|
||||
def __enter__(self):
|
||||
session_state["open"] = True
|
||||
return object()
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
session_state["open"] = False
|
||||
return False
|
||||
|
||||
class _SessionMaker:
|
||||
def begin(self):
|
||||
return _SessionContext()
|
||||
|
||||
base_workflow = _make_workflow()
|
||||
|
||||
class _Workflow:
|
||||
def __getattr__(self, name: str):
|
||||
assert session_state["open"] is True
|
||||
return getattr(base_workflow, name)
|
||||
|
||||
monkeypatch.setattr(module, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(module, "sessionmaker", lambda *_args, **_kwargs: _SessionMaker())
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (SimpleNamespace(id="user-1"), "tenant-1"))
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"RagPipelineService",
|
||||
lambda: SimpleNamespace(get_all_published_workflow=lambda **_kwargs: ([_Workflow()], False)),
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/rag/pipelines/pipeline-1/workflows",
|
||||
method="GET",
|
||||
query_string={"page": 1, "limit": 10, "user_id": "", "named_only": "false"},
|
||||
):
|
||||
response = handler(api, pipeline=SimpleNamespace(id="pipeline-1"))
|
||||
|
||||
assert response["items"][0]["id"] == "workflow-1"
|
||||
assert response["page"] == 1
|
||||
assert response["limit"] == 10
|
||||
assert response["has_more"] is False
|
||||
|
||||
|
||||
def test_rag_pipeline_workflow_patch_serializes_response_model(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow = _make_workflow(marked_name="Updated release")
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (SimpleNamespace(id="user-1"), "tenant-1"))
|
||||
|
||||
class _SessionContext:
|
||||
def __enter__(self):
|
||||
return object()
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
class _SessionMaker:
|
||||
def begin(self):
|
||||
return _SessionContext()
|
||||
|
||||
monkeypatch.setattr(module, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(module, "sessionmaker", lambda *_args, **_kwargs: _SessionMaker())
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"RagPipelineService",
|
||||
lambda: SimpleNamespace(update_workflow=lambda **_kwargs: workflow),
|
||||
)
|
||||
payload: dict[str, object] = {"marked_name": "Updated release"}
|
||||
|
||||
api = module.RagPipelineByIdApi()
|
||||
handler = _unwrap(api.patch)
|
||||
|
||||
with (
|
||||
app.test_request_context("/rag/pipelines/pipeline-1/workflows/workflow-1", method="PATCH", json=payload),
|
||||
patch.object(type(module.console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
):
|
||||
response = handler(
|
||||
api,
|
||||
pipeline=SimpleNamespace(id="pipeline-1", tenant_id="tenant-1"),
|
||||
workflow_id="workflow-1",
|
||||
)
|
||||
|
||||
assert response["id"] == "workflow-1"
|
||||
assert response["marked_name"] == "Updated release"
|
||||
assert response["hash"] == "hash-1"
|
||||
Reference in New Issue
Block a user