mirror of
https://github.com/langgenius/dify.git
synced 2026-05-26 04:00:39 -04:00
refactor(api): migrate console.app.workflow to BaseModel (#36216)
Co-authored-by: WH-2099 <wh2099@pm.me> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -2,8 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import TypedDict, Unpack
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -35,9 +37,54 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_workflow import (
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.account import Account
|
||||
from models.workflow import Workflow
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
DEFAULT_WORKFLOW_TENANT_ID = "00000000-0000-0000-0000-000000000001"
|
||||
DEFAULT_WORKFLOW_APP_ID = "00000000-0000-0000-0000-000000000002"
|
||||
DEFAULT_WORKFLOW_CREATED_BY = "00000000-0000-0000-0000-000000000003"
|
||||
type WorkflowVariablePayload = dict[str, object]
|
||||
|
||||
|
||||
class WorkflowFactoryPayload(TypedDict):
|
||||
id: str
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
type: str
|
||||
version: str
|
||||
marked_name: str
|
||||
marked_comment: str
|
||||
graph: str
|
||||
features: str
|
||||
created_by: str
|
||||
created_at: datetime
|
||||
updated_by: str | None
|
||||
updated_at: datetime
|
||||
environment_variables: list[WorkflowVariablePayload]
|
||||
conversation_variables: list[WorkflowVariablePayload]
|
||||
rag_pipeline_variables: list[WorkflowVariablePayload]
|
||||
|
||||
|
||||
class WorkflowFactoryOverrides(TypedDict, total=False):
|
||||
id: str
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
type: str
|
||||
version: str
|
||||
marked_name: str
|
||||
marked_comment: str
|
||||
graph: str
|
||||
features: str
|
||||
created_by: str
|
||||
created_at: datetime
|
||||
updated_by: str | None
|
||||
updated_at: datetime
|
||||
environment_variables: list[WorkflowVariablePayload]
|
||||
conversation_variables: list[WorkflowVariablePayload]
|
||||
rag_pipeline_variables: list[WorkflowVariablePayload]
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
@@ -74,17 +121,52 @@ def make_node_execution(**overrides):
|
||||
return SimpleNamespace(**payload)
|
||||
|
||||
|
||||
def default_workflow_payload() -> WorkflowFactoryPayload:
|
||||
return {
|
||||
"id": "workflow-1",
|
||||
"tenant_id": DEFAULT_WORKFLOW_TENANT_ID,
|
||||
"app_id": DEFAULT_WORKFLOW_APP_ID,
|
||||
"type": "workflow",
|
||||
"version": "1",
|
||||
"marked_name": "Release 1",
|
||||
"marked_comment": "Initial release",
|
||||
"graph": json.dumps({"nodes": [], "edges": []}),
|
||||
"features": json.dumps({"file_upload": {"enabled": False}}),
|
||||
"created_by": DEFAULT_WORKFLOW_CREATED_BY,
|
||||
"created_at": datetime(2024, 1, 1, 12, 0, 0),
|
||||
"updated_by": None,
|
||||
"updated_at": datetime(2024, 1, 1, 12, 1, 0),
|
||||
"environment_variables": [],
|
||||
"conversation_variables": [],
|
||||
"rag_pipeline_variables": [],
|
||||
}
|
||||
|
||||
|
||||
def make_workflow(**overrides: Unpack[WorkflowFactoryOverrides]) -> Workflow:
|
||||
payload = default_workflow_payload()
|
||||
payload.update(overrides)
|
||||
return Workflow(**payload)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def workflow_author(db_session_with_containers: Session) -> Account:
|
||||
account = Account(name="Alice", email=f"alice-{uuid4()}@example.com")
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
return account
|
||||
|
||||
|
||||
class TestDraftWorkflowApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_draft_success(self, app: Flask):
|
||||
def test_get_draft_success(self, app: Flask, workflow_author: Account):
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
workflow = MagicMock()
|
||||
workflow = make_workflow(created_by=workflow_author.id)
|
||||
|
||||
service = MagicMock()
|
||||
service.get_draft_workflow.return_value = workflow
|
||||
@@ -97,7 +179,17 @@ class TestDraftWorkflowApi:
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
assert result == workflow
|
||||
|
||||
assert result["id"] == "workflow-1"
|
||||
assert result["graph"] == {"nodes": [], "edges": []}
|
||||
assert result["features"] == {"file_upload": {"enabled": False}}
|
||||
assert result["hash"] == workflow.unique_hash
|
||||
assert result["created_by"] == {
|
||||
"id": workflow_author.id,
|
||||
"name": workflow_author.name,
|
||||
"email": workflow_author.email,
|
||||
}
|
||||
assert result["updated_by"] is None
|
||||
|
||||
def test_get_draft_not_exist(self, app: Flask):
|
||||
api = DraftRagPipelineApi()
|
||||
@@ -642,7 +734,7 @@ class TestPublishedAllRagPipelineApi:
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
service = MagicMock()
|
||||
service.get_all_published_workflow.return_value = ([{"id": "w1"}], False)
|
||||
service.get_all_published_workflow.return_value = ([make_workflow(id="w1")], False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
@@ -657,7 +749,8 @@ class TestPublishedAllRagPipelineApi:
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
|
||||
assert result["items"] == [{"id": "w1"}]
|
||||
assert result["items"][0]["id"] == "w1"
|
||||
assert result["items"][0]["graph"] == {"nodes": [], "edges": []}
|
||||
assert result["has_more"] is False
|
||||
|
||||
def test_get_published_workflows_forbidden(self, app: Flask):
|
||||
@@ -690,7 +783,7 @@ class TestRagPipelineByIdApi:
|
||||
pipeline = MagicMock(tenant_id="t1")
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
workflow = MagicMock()
|
||||
workflow = make_workflow(id="w1", marked_name="test")
|
||||
|
||||
service = MagicMock()
|
||||
service.update_workflow.return_value = workflow
|
||||
@@ -711,7 +804,9 @@ class TestRagPipelineByIdApi:
|
||||
):
|
||||
result = method(api, pipeline, "w1")
|
||||
|
||||
assert result == workflow
|
||||
assert result["id"] == "w1"
|
||||
assert result["marked_name"] == "test"
|
||||
assert result["hash"] == workflow.unique_hash
|
||||
|
||||
def test_patch_no_fields(self, app: Flask):
|
||||
api = RagPipelineByIdApi()
|
||||
|
||||
Reference in New Issue
Block a user