diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index e0a64ffe26..d5544ff473 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -1,13 +1,15 @@ import logging +from collections.abc import Mapping +from datetime import datetime from typing import Literal from dateutil.parser import isoparse from flask import request -from flask_restx import Namespace, Resource, fields +from flask_restx import Resource, fields from graphon.enums import WorkflowExecutionStatus from graphon.graph_engine.manager import GraphEngineManager from graphon.model_runtime.errors.invoke import InvokeError -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, InternalServerError, NotFound @@ -33,9 +35,10 @@ from core.errors.error import ( from core.helper.trace_id_helper import get_external_trace_id from extensions.ext_database import db from extensions.ext_redis import redis_client -from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model +from fields.base import ResponseModel +from fields.end_user_fields import SimpleEndUser +from fields.member_fields import SimpleAccount from libs import helper -from libs.helper import OptionalTimestampField, TimestampField from models.model import App, AppMode, EndUser from models.workflow import WorkflowRun from repositories.factory import DifyAPIRepositoryFactory @@ -65,38 +68,142 @@ class WorkflowLogQuery(BaseModel): register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +def _enum_value(value): + return getattr(value, "value", value) + + class WorkflowRunStatusField(fields.Raw): def output(self, key, obj: WorkflowRun, **kwargs): - return obj.status.value + return _enum_value(obj.status) class WorkflowRunOutputsField(fields.Raw): def output(self, key, obj: WorkflowRun, **kwargs): - if obj.status == WorkflowExecutionStatus.PAUSED: + status = _enum_value(obj.status) + if status == WorkflowExecutionStatus.PAUSED.value: return {} outputs = obj.outputs_dict return outputs or {} -workflow_run_fields = { - "id": fields.String, - "workflow_id": fields.String, - "status": WorkflowRunStatusField, - "inputs": fields.Raw, - "outputs": WorkflowRunOutputsField, - "error": fields.String, - "total_steps": fields.Integer, - "total_tokens": fields.Integer, - "created_at": TimestampField, - "finished_at": OptionalTimestampField, - "elapsed_time": fields.Float, -} +class WorkflowRunResponse(ResponseModel): + id: str + workflow_id: str + status: str + inputs: dict | list | str | int | float | bool | None = None + outputs: dict = Field(default_factory=dict) + error: str | None = None + total_steps: int | None = None + total_tokens: int | None = None + created_at: int | None = None + finished_at: int | None = None + elapsed_time: float | int | None = None + + @field_validator("created_at", "finished_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) -def build_workflow_run_model(api_or_ns: Namespace): - """Build the workflow run model for the API or Namespace.""" - return api_or_ns.model("WorkflowRun", workflow_run_fields) +class WorkflowRunForLogResponse(ResponseModel): + id: str + version: str | None = None + status: str | None = None + triggered_from: str | None = None + error: str | None = None + elapsed_time: float | int | None = None + total_tokens: int | None = None + total_steps: int | None = None + created_at: int | None = None + finished_at: int | None = None + exceptions_count: int | None = None + + @field_validator("status", "triggered_from", mode="before") + @classmethod + def _normalize_enum(cls, value): + return _enum_value(value) + + @field_validator("created_at", "finished_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class WorkflowAppLogPartialResponse(ResponseModel): + id: str + workflow_run: WorkflowRunForLogResponse | None = None + details: dict | list | str | int | float | bool | None = None + created_from: str | None = None + created_by_role: str | None = None + created_by_account: SimpleAccount | None = None + created_by_end_user: SimpleEndUser | None = None + created_at: int | None = None + + @field_validator("created_from", "created_by_role", mode="before") + @classmethod + def _normalize_enum(cls, value): + return _enum_value(value) + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class WorkflowAppLogPaginationResponse(ResponseModel): + page: int + limit: int + total: int + has_more: bool + data: list[WorkflowAppLogPartialResponse] + + +register_schema_models( + service_api_ns, + WorkflowRunResponse, + WorkflowRunForLogResponse, + WorkflowAppLogPartialResponse, + WorkflowAppLogPaginationResponse, +) + + +def _serialize_workflow_run(workflow_run: WorkflowRun) -> dict: + status = _enum_value(workflow_run.status) + raw_outputs = workflow_run.outputs_dict + if status == WorkflowExecutionStatus.PAUSED.value or raw_outputs is None: + outputs: dict = {} + elif isinstance(raw_outputs, dict): + outputs = raw_outputs + elif isinstance(raw_outputs, Mapping): + outputs = dict(raw_outputs) + else: + outputs = {} + return WorkflowRunResponse.model_validate( + { + "id": workflow_run.id, + "workflow_id": workflow_run.workflow_id, + "status": status, + "inputs": workflow_run.inputs, + "outputs": outputs, + "error": workflow_run.error, + "total_steps": workflow_run.total_steps, + "total_tokens": workflow_run.total_tokens, + "created_at": workflow_run.created_at, + "finished_at": workflow_run.finished_at, + "elapsed_time": workflow_run.elapsed_time, + } + ).model_dump(mode="json") + + +def _serialize_workflow_log_pagination(pagination) -> dict: + return WorkflowAppLogPaginationResponse.model_validate(pagination, from_attributes=True).model_dump(mode="json") @service_api_ns.route("/workflows/run/") @@ -112,7 +219,11 @@ class WorkflowRunDetailApi(Resource): } ) @validate_app_token - @service_api_ns.marshal_with(build_workflow_run_model(service_api_ns)) + @service_api_ns.response( + 200, + "Workflow run details retrieved successfully", + service_api_ns.models[WorkflowRunResponse.__name__], + ) def get(self, app_model: App, workflow_run_id: str): """Get a workflow task running detail. @@ -133,7 +244,7 @@ class WorkflowRunDetailApi(Resource): ) if not workflow_run: raise NotFound("Workflow run not found.") - return workflow_run + return _serialize_workflow_run(workflow_run) @service_api_ns.route("/workflows/run") @@ -299,7 +410,11 @@ class WorkflowAppLogApi(Resource): } ) @validate_app_token - @service_api_ns.marshal_with(build_workflow_app_log_pagination_model(service_api_ns)) + @service_api_ns.response( + 200, + "Logs retrieved successfully", + service_api_ns.models[WorkflowAppLogPaginationResponse.__name__], + ) def get(self, app_model: App): """Get workflow app logs. @@ -327,4 +442,4 @@ class WorkflowAppLogApi(Resource): created_by_account=args.created_by_account, ) - return workflow_app_log_pagination + return _serialize_workflow_log_pagination(workflow_app_log_pagination) diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py index cfa21bf2dd..74a3c75839 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py @@ -15,6 +15,7 @@ Focus on: import sys import uuid +from datetime import UTC, datetime from types import SimpleNamespace from unittest.mock import Mock, patch @@ -43,6 +44,22 @@ from services.errors.llm import InvokeRateLimitError from services.workflow_app_service import WorkflowAppService +def _make_mock_workflow_run(run_id: str = "run-1"): + run = Mock() + run.id = run_id + run.workflow_id = "wf-1" + run.status = WorkflowExecutionStatus.SUCCEEDED + run.inputs = {"input": "value"} + run.outputs_dict = {"output": "value"} + run.error = None + run.total_steps = 1 + run.total_tokens = 10 + run.created_at = datetime(2026, 1, 1, tzinfo=UTC) + run.finished_at = datetime(2026, 1, 1, tzinfo=UTC) + run.elapsed_time = 0.1 + return run + + class TestWorkflowRunPayload: """Test suite for WorkflowRunPayload Pydantic model.""" @@ -359,7 +376,7 @@ class TestWorkflowRunDetailApi: handler(api, app_model=app_model, workflow_run_id="run") def test_success(self, monkeypatch: pytest.MonkeyPatch) -> None: - run = SimpleNamespace(id="run") + run = _make_mock_workflow_run(run_id="run") repo = SimpleNamespace(get_workflow_run_by_id=lambda **_kwargs: run) workflow_module = sys.modules["controllers.service_api.app.workflow"] monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) @@ -373,7 +390,10 @@ class TestWorkflowRunDetailApi: handler = _unwrap(api.get) app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value, tenant_id="t1", id="a1") - assert handler(api, app_model=app_model, workflow_run_id="run") == run + result = handler(api, app_model=app_model, workflow_run_id="run") + assert result["id"] == "run" + assert result["workflow_id"] == "wf-1" + assert result["status"] == "succeeded" class TestWorkflowRunApi: @@ -490,7 +510,7 @@ class TestWorkflowAppLogApi: monkeypatch.setattr( WorkflowAppService, "get_paginate_workflow_app_logs", - lambda *_args, **_kwargs: {"items": [], "total": 0}, + lambda *_args, **_kwargs: {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}, ) api = WorkflowAppLogApi() @@ -500,7 +520,7 @@ class TestWorkflowAppLogApi: with app.test_request_context("/workflows/logs", method="GET"): response = handler(api, app_model=app_model) - assert response == {"items": [], "total": 0} + assert response == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} # ============================================================================= @@ -527,9 +547,8 @@ def mock_workflow_app(): class TestWorkflowRunDetailApiGet: """Test suite for WorkflowRunDetailApi.get() endpoint. - ``get`` is wrapped by ``@validate_app_token`` (preserves ``__wrapped__``) - and ``@service_api_ns.marshal_with``. We call the unwrapped method - directly; ``marshal_with`` is a no-op when calling directly. + ``get`` is wrapped by ``@validate_app_token`` (preserves ``__wrapped__``), + and we call the unwrapped method directly in tests. """ @patch("controllers.service_api.app.workflow.DifyAPIRepositoryFactory") @@ -542,9 +561,7 @@ class TestWorkflowRunDetailApiGet: mock_workflow_app, ): """Test successful workflow run detail retrieval.""" - mock_run = Mock() - mock_run.id = "run-1" - mock_run.status = "succeeded" + mock_run = _make_mock_workflow_run(run_id="run-1") mock_repo = Mock() mock_repo.get_workflow_run_by_id.return_value = mock_run mock_repo_factory.create_api_workflow_run_repository.return_value = mock_repo @@ -558,7 +575,8 @@ class TestWorkflowRunDetailApiGet: api = WorkflowRunDetailApi() result = _unwrap(api.get)(api, app_model=mock_workflow_app, workflow_run_id=mock_run.id) - assert result == mock_run + assert result["id"] == mock_run.id + assert result["status"] == "succeeded" @patch("controllers.service_api.app.workflow.db") def test_get_workflow_run_wrong_app_mode(self, mock_db, app): @@ -622,8 +640,7 @@ class TestWorkflowTaskStopApiPost: class TestWorkflowAppLogApiGet: """Test suite for WorkflowAppLogApi.get() endpoint. - ``get`` is wrapped by ``@validate_app_token`` and - ``@service_api_ns.marshal_with``. + ``get`` is wrapped by ``@validate_app_token``. """ @patch("controllers.service_api.app.workflow.WorkflowAppService") @@ -637,6 +654,10 @@ class TestWorkflowAppLogApiGet: ): """Test successful workflow log retrieval.""" mock_pagination = Mock() + mock_pagination.page = 1 + mock_pagination.limit = 20 + mock_pagination.total = 0 + mock_pagination.has_more = False mock_pagination.data = [] mock_svc_instance = Mock() mock_svc_instance.get_paginate_workflow_app_logs.return_value = mock_pagination @@ -661,4 +682,4 @@ class TestWorkflowAppLogApiGet: api = WorkflowAppLogApi() result = _unwrap(api.get)(api, app_model=mock_workflow_app) - assert result == mock_pagination + assert result == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}