mirror of
https://github.com/langgenius/dify.git
synced 2025-12-19 17:27:16 -05:00
Feat/add status filter to workflow runs (#26850)
Co-authored-by: Jacky Su <jacky_su@trendmicro.com>
This commit is contained in:
@@ -8,15 +8,81 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.workflow_run_fields import (
|
||||
advanced_chat_workflow_run_pagination_fields,
|
||||
workflow_run_count_fields,
|
||||
workflow_run_detail_fields,
|
||||
workflow_run_node_execution_list_fields,
|
||||
workflow_run_pagination_fields,
|
||||
)
|
||||
from libs.custom_inputs import time_duration
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account, App, AppMode, EndUser
|
||||
from models import Account, App, AppMode, EndUser, WorkflowRunTriggeredFrom
|
||||
from services.workflow_run_service import WorkflowRunService
|
||||
|
||||
# Workflow run status choices for filtering
|
||||
WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"]
|
||||
|
||||
|
||||
def _parse_workflow_run_list_args():
|
||||
"""
|
||||
Parse common arguments for workflow run list endpoints.
|
||||
|
||||
Returns:
|
||||
Parsed arguments containing last_id, limit, status, and triggered_from filters
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("last_id", type=uuid_value, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
parser.add_argument(
|
||||
"status",
|
||||
type=str,
|
||||
choices=WORKFLOW_RUN_STATUS_CHOICES,
|
||||
location="args",
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"triggered_from",
|
||||
type=str,
|
||||
choices=["debugging", "app-run"],
|
||||
location="args",
|
||||
required=False,
|
||||
help="Filter by trigger source: debugging or app-run",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _parse_workflow_run_count_args():
|
||||
"""
|
||||
Parse common arguments for workflow run count endpoints.
|
||||
|
||||
Returns:
|
||||
Parsed arguments containing status, time_range, and triggered_from filters
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"status",
|
||||
type=str,
|
||||
choices=WORKFLOW_RUN_STATUS_CHOICES,
|
||||
location="args",
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"time_range",
|
||||
type=time_duration,
|
||||
location="args",
|
||||
required=False,
|
||||
help="Time range filter (e.g., 7d, 4h, 30m, 30s)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"triggered_from",
|
||||
type=str,
|
||||
choices=["debugging", "app-run"],
|
||||
location="args",
|
||||
required=False,
|
||||
help="Filter by trigger source: debugging or app-run",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs")
|
||||
class AdvancedChatAppWorkflowRunListApi(Resource):
|
||||
@@ -24,6 +90,8 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
||||
@api.doc(description="Get advanced chat workflow run list")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
|
||||
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
|
||||
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
|
||||
@api.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_fields)
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -34,13 +102,64 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
||||
"""
|
||||
Get advanced chat app workflow run list
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("last_id", type=uuid_value, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
args = _parse_workflow_run_list_args()
|
||||
|
||||
# Default to DEBUGGING if not specified
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
||||
if args.get("triggered_from")
|
||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||
)
|
||||
|
||||
workflow_run_service = WorkflowRunService()
|
||||
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args=args)
|
||||
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(
|
||||
app_model=app_model, args=args, triggered_from=triggered_from
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs/count")
|
||||
class AdvancedChatAppWorkflowRunCountApi(Resource):
|
||||
@api.doc("get_advanced_chat_workflow_runs_count")
|
||||
@api.doc(description="Get advanced chat workflow runs count statistics")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
|
||||
@api.doc(
|
||||
params={
|
||||
"time_range": (
|
||||
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
|
||||
"30m (30 minutes), 30s (30 seconds). Filters by created_at field."
|
||||
)
|
||||
}
|
||||
)
|
||||
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
|
||||
@api.response(200, "Workflow runs count retrieved successfully", workflow_run_count_fields)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(workflow_run_count_fields)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get advanced chat workflow runs count statistics
|
||||
"""
|
||||
args = _parse_workflow_run_count_args()
|
||||
|
||||
# Default to DEBUGGING if not specified
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
||||
if args.get("triggered_from")
|
||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||
)
|
||||
|
||||
workflow_run_service = WorkflowRunService()
|
||||
result = workflow_run_service.get_workflow_runs_count(
|
||||
app_model=app_model,
|
||||
status=args.get("status"),
|
||||
time_range=args.get("time_range"),
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -51,6 +170,8 @@ class WorkflowRunListApi(Resource):
|
||||
@api.doc(description="Get workflow run list")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
|
||||
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
|
||||
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
|
||||
@api.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_fields)
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -61,13 +182,64 @@ class WorkflowRunListApi(Resource):
|
||||
"""
|
||||
Get workflow run list
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("last_id", type=uuid_value, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
args = _parse_workflow_run_list_args()
|
||||
|
||||
# Default to DEBUGGING for workflow if not specified (backward compatibility)
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
||||
if args.get("triggered_from")
|
||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||
)
|
||||
|
||||
workflow_run_service = WorkflowRunService()
|
||||
result = workflow_run_service.get_paginate_workflow_runs(app_model=app_model, args=args)
|
||||
result = workflow_run_service.get_paginate_workflow_runs(
|
||||
app_model=app_model, args=args, triggered_from=triggered_from
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/count")
|
||||
class WorkflowRunCountApi(Resource):
|
||||
@api.doc("get_workflow_runs_count")
|
||||
@api.doc(description="Get workflow runs count statistics")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
|
||||
@api.doc(
|
||||
params={
|
||||
"time_range": (
|
||||
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
|
||||
"30m (30 minutes), 30s (30 seconds). Filters by created_at field."
|
||||
)
|
||||
}
|
||||
)
|
||||
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
|
||||
@api.response(200, "Workflow runs count retrieved successfully", workflow_run_count_fields)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_count_fields)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get workflow runs count statistics
|
||||
"""
|
||||
args = _parse_workflow_run_count_args()
|
||||
|
||||
# Default to DEBUGGING for workflow if not specified (backward compatibility)
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
||||
if args.get("triggered_from")
|
||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||
)
|
||||
|
||||
workflow_run_service = WorkflowRunService()
|
||||
result = workflow_run_service.get_workflow_runs_count(
|
||||
app_model=app_model,
|
||||
status=args.get("status"),
|
||||
time_range=args.get("time_range"),
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -64,6 +64,15 @@ workflow_run_pagination_fields = {
|
||||
"data": fields.List(fields.Nested(workflow_run_for_list_fields), attribute="data"),
|
||||
}
|
||||
|
||||
workflow_run_count_fields = {
|
||||
"total": fields.Integer,
|
||||
"running": fields.Integer,
|
||||
"succeeded": fields.Integer,
|
||||
"failed": fields.Integer,
|
||||
"stopped": fields.Integer,
|
||||
"partial_succeeded": fields.Integer(attribute="partial-succeeded"),
|
||||
}
|
||||
|
||||
workflow_run_detail_fields = {
|
||||
"id": fields.String,
|
||||
"version": fields.String,
|
||||
|
||||
32
api/libs/custom_inputs.py
Normal file
32
api/libs/custom_inputs.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Custom input types for Flask-RESTX request parsing."""
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def time_duration(value: str) -> str:
|
||||
"""
|
||||
Validate and return time duration string.
|
||||
|
||||
Accepts formats: <number>d (days), <number>h (hours), <number>m (minutes), <number>s (seconds)
|
||||
Examples: 7d, 4h, 30m, 30s
|
||||
|
||||
Args:
|
||||
value: The time duration string
|
||||
|
||||
Returns:
|
||||
The validated time duration string
|
||||
|
||||
Raises:
|
||||
ValueError: If the format is invalid
|
||||
"""
|
||||
if not value:
|
||||
raise ValueError("Time duration cannot be empty")
|
||||
|
||||
pattern = r"^(\d+)([dhms])$"
|
||||
if not re.match(pattern, value.lower()):
|
||||
raise ValueError(
|
||||
"Invalid time duration format. Use: <number>d (days), <number>h (hours), "
|
||||
"<number>m (minutes), or <number>s (seconds). Examples: 7d, 4h, 30m, 30s"
|
||||
)
|
||||
|
||||
return value.lower()
|
||||
67
api/libs/time_parser.py
Normal file
67
api/libs/time_parser.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Time duration parser utility."""
|
||||
|
||||
import re
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
|
||||
def parse_time_duration(duration_str: str) -> timedelta | None:
|
||||
"""
|
||||
Parse time duration string to timedelta.
|
||||
|
||||
Supported formats:
|
||||
- 7d: 7 days
|
||||
- 4h: 4 hours
|
||||
- 30m: 30 minutes
|
||||
- 30s: 30 seconds
|
||||
|
||||
Args:
|
||||
duration_str: Duration string (e.g., "7d", "4h", "30m", "30s")
|
||||
|
||||
Returns:
|
||||
timedelta object or None if invalid format
|
||||
"""
|
||||
if not duration_str:
|
||||
return None
|
||||
|
||||
# Pattern: number followed by unit (d, h, m, s)
|
||||
pattern = r"^(\d+)([dhms])$"
|
||||
match = re.match(pattern, duration_str.lower())
|
||||
|
||||
if not match:
|
||||
return None
|
||||
|
||||
value = int(match.group(1))
|
||||
unit = match.group(2)
|
||||
|
||||
if unit == "d":
|
||||
return timedelta(days=value)
|
||||
elif unit == "h":
|
||||
return timedelta(hours=value)
|
||||
elif unit == "m":
|
||||
return timedelta(minutes=value)
|
||||
elif unit == "s":
|
||||
return timedelta(seconds=value)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_time_threshold(duration_str: str | None) -> datetime | None:
|
||||
"""
|
||||
Get datetime threshold from duration string.
|
||||
|
||||
Calculates the datetime that is duration_str ago from now.
|
||||
|
||||
Args:
|
||||
duration_str: Duration string (e.g., "7d", "4h", "30m", "30s")
|
||||
|
||||
Returns:
|
||||
datetime object representing the threshold time, or None if no duration
|
||||
"""
|
||||
if not duration_str:
|
||||
return None
|
||||
|
||||
duration = parse_time_duration(duration_str)
|
||||
if duration is None:
|
||||
return None
|
||||
|
||||
return datetime.now(UTC) - duration
|
||||
@@ -59,6 +59,7 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
|
||||
triggered_from: str,
|
||||
limit: int = 20,
|
||||
last_id: str | None = None,
|
||||
status: str | None = None,
|
||||
) -> InfiniteScrollPagination:
|
||||
"""
|
||||
Get paginated workflow runs with filtering.
|
||||
@@ -73,6 +74,7 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
|
||||
triggered_from: Filter by trigger source (e.g., "debugging", "app-run")
|
||||
limit: Maximum number of records to return (default: 20)
|
||||
last_id: Cursor for pagination - ID of the last record from previous page
|
||||
status: Optional filter by status (e.g., "running", "succeeded", "failed")
|
||||
|
||||
Returns:
|
||||
InfiniteScrollPagination object containing:
|
||||
@@ -107,6 +109,43 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def get_workflow_runs_count(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
triggered_from: str,
|
||||
status: str | None = None,
|
||||
time_range: str | None = None,
|
||||
) -> dict[str, int]:
|
||||
"""
|
||||
Get workflow runs count statistics.
|
||||
|
||||
Retrieves total count and count by status for workflow runs
|
||||
matching the specified filters.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier for multi-tenant isolation
|
||||
app_id: Application identifier
|
||||
triggered_from: Filter by trigger source (e.g., "debugging", "app-run")
|
||||
status: Optional filter by specific status
|
||||
time_range: Optional time range filter (e.g., "7d", "4h", "30m", "30s")
|
||||
Filters records based on created_at field
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- total: Total count of all workflow runs (or filtered by status)
|
||||
- running: Count of workflow runs with status "running"
|
||||
- succeeded: Count of workflow runs with status "succeeded"
|
||||
- failed: Count of workflow runs with status "failed"
|
||||
- stopped: Count of workflow runs with status "stopped"
|
||||
- partial_succeeded: Count of workflow runs with status "partial-succeeded"
|
||||
|
||||
Note: If a status is provided, 'total' will be the count for that status,
|
||||
and the specific status count will also be set to this value, with all
|
||||
other status counts being 0.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_expired_runs_batch(
|
||||
self,
|
||||
tenant_id: str,
|
||||
|
||||
@@ -24,11 +24,12 @@ from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.time_parser import get_time_threshold
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
|
||||
@@ -63,6 +64,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
triggered_from: str,
|
||||
limit: int = 20,
|
||||
last_id: str | None = None,
|
||||
status: str | None = None,
|
||||
) -> InfiniteScrollPagination:
|
||||
"""
|
||||
Get paginated workflow runs with filtering.
|
||||
@@ -79,6 +81,10 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
WorkflowRun.triggered_from == triggered_from,
|
||||
)
|
||||
|
||||
# Add optional status filter
|
||||
if status:
|
||||
base_stmt = base_stmt.where(WorkflowRun.status == status)
|
||||
|
||||
if last_id:
|
||||
# Get the last workflow run for cursor-based pagination
|
||||
last_run_stmt = base_stmt.where(WorkflowRun.id == last_id)
|
||||
@@ -120,6 +126,73 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
)
|
||||
return session.scalar(stmt)
|
||||
|
||||
def get_workflow_runs_count(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
triggered_from: str,
|
||||
status: str | None = None,
|
||||
time_range: str | None = None,
|
||||
) -> dict[str, int]:
|
||||
"""
|
||||
Get workflow runs count statistics grouped by status.
|
||||
"""
|
||||
_initial_status_counts = {
|
||||
"running": 0,
|
||||
"succeeded": 0,
|
||||
"failed": 0,
|
||||
"stopped": 0,
|
||||
"partial-succeeded": 0,
|
||||
}
|
||||
|
||||
with self._session_maker() as session:
|
||||
# Build base where conditions
|
||||
base_conditions = [
|
||||
WorkflowRun.tenant_id == tenant_id,
|
||||
WorkflowRun.app_id == app_id,
|
||||
WorkflowRun.triggered_from == triggered_from,
|
||||
]
|
||||
|
||||
# Add time range filter if provided
|
||||
if time_range:
|
||||
time_threshold = get_time_threshold(time_range)
|
||||
if time_threshold:
|
||||
base_conditions.append(WorkflowRun.created_at >= time_threshold)
|
||||
|
||||
# If status filter is provided, return simple count
|
||||
if status:
|
||||
count_stmt = select(func.count(WorkflowRun.id)).where(*base_conditions, WorkflowRun.status == status)
|
||||
total = session.scalar(count_stmt) or 0
|
||||
|
||||
result = {"total": total} | _initial_status_counts
|
||||
|
||||
# Set the count for the filtered status
|
||||
if status in result:
|
||||
result[status] = total
|
||||
|
||||
return result
|
||||
|
||||
# No status filter - get counts grouped by status
|
||||
base_stmt = (
|
||||
select(WorkflowRun.status, func.count(WorkflowRun.id).label("count"))
|
||||
.where(*base_conditions)
|
||||
.group_by(WorkflowRun.status)
|
||||
)
|
||||
|
||||
# Execute query
|
||||
results = session.execute(base_stmt).all()
|
||||
|
||||
# Build response dictionary
|
||||
status_counts = _initial_status_counts.copy()
|
||||
|
||||
total = 0
|
||||
for status_val, count in results:
|
||||
total += count
|
||||
if status_val in status_counts:
|
||||
status_counts[status_val] = count
|
||||
|
||||
return {"total": total} | status_counts
|
||||
|
||||
def get_expired_runs_batch(
|
||||
self,
|
||||
tenant_id: str,
|
||||
|
||||
@@ -26,13 +26,15 @@ class WorkflowRunService:
|
||||
)
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
def get_paginate_advanced_chat_workflow_runs(self, app_model: App, args: dict) -> InfiniteScrollPagination:
|
||||
def get_paginate_advanced_chat_workflow_runs(
|
||||
self, app_model: App, args: dict, triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING
|
||||
) -> InfiniteScrollPagination:
|
||||
"""
|
||||
Get advanced chat app workflow run list
|
||||
Only return triggered_from == advanced_chat
|
||||
|
||||
:param app_model: app model
|
||||
:param args: request args
|
||||
:param triggered_from: workflow run triggered from (default: DEBUGGING for preview runs)
|
||||
"""
|
||||
|
||||
class WorkflowWithMessage:
|
||||
@@ -45,7 +47,7 @@ class WorkflowRunService:
|
||||
def __getattr__(self, item):
|
||||
return getattr(self._workflow_run, item)
|
||||
|
||||
pagination = self.get_paginate_workflow_runs(app_model, args)
|
||||
pagination = self.get_paginate_workflow_runs(app_model, args, triggered_from)
|
||||
|
||||
with_message_workflow_runs = []
|
||||
for workflow_run in pagination.data:
|
||||
@@ -60,23 +62,27 @@ class WorkflowRunService:
|
||||
pagination.data = with_message_workflow_runs
|
||||
return pagination
|
||||
|
||||
def get_paginate_workflow_runs(self, app_model: App, args: dict) -> InfiniteScrollPagination:
|
||||
def get_paginate_workflow_runs(
|
||||
self, app_model: App, args: dict, triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING
|
||||
) -> InfiniteScrollPagination:
|
||||
"""
|
||||
Get debug workflow run list
|
||||
Only return triggered_from == debugging
|
||||
Get workflow run list
|
||||
|
||||
:param app_model: app model
|
||||
:param args: request args
|
||||
:param triggered_from: workflow run triggered from (default: DEBUGGING)
|
||||
"""
|
||||
limit = int(args.get("limit", 20))
|
||||
last_id = args.get("last_id")
|
||||
status = args.get("status")
|
||||
|
||||
return self._workflow_run_repo.get_paginated_workflow_runs(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
triggered_from=triggered_from,
|
||||
limit=limit,
|
||||
last_id=last_id,
|
||||
status=status,
|
||||
)
|
||||
|
||||
def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun | None:
|
||||
@@ -92,6 +98,30 @@ class WorkflowRunService:
|
||||
run_id=run_id,
|
||||
)
|
||||
|
||||
def get_workflow_runs_count(
|
||||
self,
|
||||
app_model: App,
|
||||
status: str | None = None,
|
||||
time_range: str | None = None,
|
||||
triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
) -> dict[str, int]:
|
||||
"""
|
||||
Get workflow runs count statistics
|
||||
|
||||
:param app_model: app model
|
||||
:param status: optional status filter
|
||||
:param time_range: optional time range filter (e.g., "7d", "4h", "30m", "30s")
|
||||
:param triggered_from: workflow run triggered from (default: DEBUGGING)
|
||||
:return: dict with total and status counts
|
||||
"""
|
||||
return self._workflow_run_repo.get_workflow_runs_count(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
triggered_from=triggered_from,
|
||||
status=status,
|
||||
time_range=time_range,
|
||||
)
|
||||
|
||||
def get_workflow_run_node_executions(
|
||||
self,
|
||||
app_model: App,
|
||||
|
||||
68
api/tests/unit_tests/libs/test_custom_inputs.py
Normal file
68
api/tests/unit_tests/libs/test_custom_inputs.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Unit tests for custom input types."""
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.custom_inputs import time_duration
|
||||
|
||||
|
||||
class TestTimeDuration:
|
||||
"""Test time_duration input validator."""
|
||||
|
||||
def test_valid_days(self):
|
||||
"""Test valid days format."""
|
||||
result = time_duration("7d")
|
||||
assert result == "7d"
|
||||
|
||||
def test_valid_hours(self):
|
||||
"""Test valid hours format."""
|
||||
result = time_duration("4h")
|
||||
assert result == "4h"
|
||||
|
||||
def test_valid_minutes(self):
|
||||
"""Test valid minutes format."""
|
||||
result = time_duration("30m")
|
||||
assert result == "30m"
|
||||
|
||||
def test_valid_seconds(self):
|
||||
"""Test valid seconds format."""
|
||||
result = time_duration("30s")
|
||||
assert result == "30s"
|
||||
|
||||
def test_uppercase_conversion(self):
|
||||
"""Test uppercase units are converted to lowercase."""
|
||||
result = time_duration("7D")
|
||||
assert result == "7d"
|
||||
|
||||
result = time_duration("4H")
|
||||
assert result == "4h"
|
||||
|
||||
def test_invalid_format_no_unit(self):
|
||||
"""Test invalid format without unit."""
|
||||
with pytest.raises(ValueError, match="Invalid time duration format"):
|
||||
time_duration("7")
|
||||
|
||||
def test_invalid_format_wrong_unit(self):
|
||||
"""Test invalid format with wrong unit."""
|
||||
with pytest.raises(ValueError, match="Invalid time duration format"):
|
||||
time_duration("7days")
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid time duration format"):
|
||||
time_duration("7x")
|
||||
|
||||
def test_invalid_format_no_number(self):
|
||||
"""Test invalid format without number."""
|
||||
with pytest.raises(ValueError, match="Invalid time duration format"):
|
||||
time_duration("d")
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid time duration format"):
|
||||
time_duration("abc")
|
||||
|
||||
def test_empty_string(self):
|
||||
"""Test empty string."""
|
||||
with pytest.raises(ValueError, match="Time duration cannot be empty"):
|
||||
time_duration("")
|
||||
|
||||
def test_none(self):
|
||||
"""Test None value."""
|
||||
with pytest.raises(ValueError, match="Time duration cannot be empty"):
|
||||
time_duration(None)
|
||||
91
api/tests/unit_tests/libs/test_time_parser.py
Normal file
91
api/tests/unit_tests/libs/test_time_parser.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Unit tests for time parser utility."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from libs.time_parser import get_time_threshold, parse_time_duration
|
||||
|
||||
|
||||
class TestParseTimeDuration:
|
||||
"""Test parse_time_duration function."""
|
||||
|
||||
def test_parse_days(self):
|
||||
"""Test parsing days."""
|
||||
result = parse_time_duration("7d")
|
||||
assert result == timedelta(days=7)
|
||||
|
||||
def test_parse_hours(self):
|
||||
"""Test parsing hours."""
|
||||
result = parse_time_duration("4h")
|
||||
assert result == timedelta(hours=4)
|
||||
|
||||
def test_parse_minutes(self):
|
||||
"""Test parsing minutes."""
|
||||
result = parse_time_duration("30m")
|
||||
assert result == timedelta(minutes=30)
|
||||
|
||||
def test_parse_seconds(self):
|
||||
"""Test parsing seconds."""
|
||||
result = parse_time_duration("30s")
|
||||
assert result == timedelta(seconds=30)
|
||||
|
||||
def test_parse_uppercase(self):
|
||||
"""Test parsing uppercase units."""
|
||||
result = parse_time_duration("7D")
|
||||
assert result == timedelta(days=7)
|
||||
|
||||
def test_parse_invalid_format(self):
|
||||
"""Test parsing invalid format."""
|
||||
result = parse_time_duration("7days")
|
||||
assert result is None
|
||||
|
||||
result = parse_time_duration("abc")
|
||||
assert result is None
|
||||
|
||||
result = parse_time_duration("7")
|
||||
assert result is None
|
||||
|
||||
def test_parse_empty_string(self):
|
||||
"""Test parsing empty string."""
|
||||
result = parse_time_duration("")
|
||||
assert result is None
|
||||
|
||||
def test_parse_none(self):
|
||||
"""Test parsing None."""
|
||||
result = parse_time_duration(None)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetTimeThreshold:
|
||||
"""Test get_time_threshold function."""
|
||||
|
||||
def test_get_threshold_days(self):
|
||||
"""Test getting threshold for days."""
|
||||
before = datetime.now(UTC)
|
||||
result = get_time_threshold("7d")
|
||||
after = datetime.now(UTC)
|
||||
|
||||
assert result is not None
|
||||
# Result should be approximately 7 days ago
|
||||
expected = before - timedelta(days=7)
|
||||
# Allow 1 second tolerance for test execution time
|
||||
assert abs((result - expected).total_seconds()) < 1
|
||||
|
||||
def test_get_threshold_hours(self):
|
||||
"""Test getting threshold for hours."""
|
||||
before = datetime.now(UTC)
|
||||
result = get_time_threshold("4h")
|
||||
after = datetime.now(UTC)
|
||||
|
||||
assert result is not None
|
||||
expected = before - timedelta(hours=4)
|
||||
assert abs((result - expected).total_seconds()) < 1
|
||||
|
||||
def test_get_threshold_invalid(self):
|
||||
"""Test getting threshold with invalid duration."""
|
||||
result = get_time_threshold("invalid")
|
||||
assert result is None
|
||||
|
||||
def test_get_threshold_none(self):
|
||||
"""Test getting threshold with None."""
|
||||
result = get_time_threshold(None)
|
||||
assert result is None
|
||||
@@ -0,0 +1,251 @@
|
||||
"""Unit tests for workflow run repository with status filter."""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from models import WorkflowRun, WorkflowRunTriggeredFrom
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import DifyAPISQLAlchemyWorkflowRunRepository
|
||||
|
||||
|
||||
class TestDifyAPISQLAlchemyWorkflowRunRepository:
|
||||
"""Test workflow run repository with status filtering."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_maker(self):
|
||||
"""Create a mock session maker."""
|
||||
return MagicMock(spec=sessionmaker)
|
||||
|
||||
@pytest.fixture
|
||||
def repository(self, mock_session_maker):
|
||||
"""Create repository instance with mock session."""
|
||||
return DifyAPISQLAlchemyWorkflowRunRepository(mock_session_maker)
|
||||
|
||||
def test_get_paginated_workflow_runs_without_status(self, repository, mock_session_maker):
|
||||
"""Test getting paginated workflow runs without status filter."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid.uuid4())
|
||||
app_id = str(uuid.uuid4())
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_runs = [MagicMock(spec=WorkflowRun) for _ in range(3)]
|
||||
mock_session.scalars.return_value.all.return_value = mock_runs
|
||||
|
||||
# Act
|
||||
result = repository.get_paginated_workflow_runs(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
limit=20,
|
||||
last_id=None,
|
||||
status=None,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result.data) == 3
|
||||
assert result.limit == 20
|
||||
assert result.has_more is False
|
||||
|
||||
def test_get_paginated_workflow_runs_with_status_filter(self, repository, mock_session_maker):
|
||||
"""Test getting paginated workflow runs with status filter."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid.uuid4())
|
||||
app_id = str(uuid.uuid4())
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_runs = [MagicMock(spec=WorkflowRun, status="succeeded") for _ in range(2)]
|
||||
mock_session.scalars.return_value.all.return_value = mock_runs
|
||||
|
||||
# Act
|
||||
result = repository.get_paginated_workflow_runs(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
limit=20,
|
||||
last_id=None,
|
||||
status="succeeded",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result.data) == 2
|
||||
assert all(run.status == "succeeded" for run in result.data)
|
||||
|
||||
def test_get_workflow_runs_count_without_status(self, repository, mock_session_maker):
|
||||
"""Test getting workflow runs count without status filter."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid.uuid4())
|
||||
app_id = str(uuid.uuid4())
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock the GROUP BY query results
|
||||
mock_results = [
|
||||
("succeeded", 5),
|
||||
("failed", 2),
|
||||
("running", 1),
|
||||
]
|
||||
mock_session.execute.return_value.all.return_value = mock_results
|
||||
|
||||
# Act
|
||||
result = repository.get_workflow_runs_count(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
status=None,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["total"] == 8
|
||||
assert result["succeeded"] == 5
|
||||
assert result["failed"] == 2
|
||||
assert result["running"] == 1
|
||||
assert result["stopped"] == 0
|
||||
assert result["partial-succeeded"] == 0
|
||||
|
||||
def test_get_workflow_runs_count_with_status_filter(self, repository, mock_session_maker):
|
||||
"""Test getting workflow runs count with status filter."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid.uuid4())
|
||||
app_id = str(uuid.uuid4())
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock the count query for succeeded status
|
||||
mock_session.scalar.return_value = 5
|
||||
|
||||
# Act
|
||||
result = repository.get_workflow_runs_count(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
status="succeeded",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["total"] == 5
|
||||
assert result["succeeded"] == 5
|
||||
assert result["running"] == 0
|
||||
assert result["failed"] == 0
|
||||
assert result["stopped"] == 0
|
||||
assert result["partial-succeeded"] == 0
|
||||
|
||||
def test_get_workflow_runs_count_with_invalid_status(self, repository, mock_session_maker):
|
||||
"""Test that invalid status is still counted in total but not in any specific status."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid.uuid4())
|
||||
app_id = str(uuid.uuid4())
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock count query returning 0 for invalid status
|
||||
mock_session.scalar.return_value = 0
|
||||
|
||||
# Act
|
||||
result = repository.get_workflow_runs_count(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
status="invalid_status",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["total"] == 0
|
||||
assert all(result[status] == 0 for status in ["running", "succeeded", "failed", "stopped", "partial-succeeded"])
|
||||
|
||||
def test_get_workflow_runs_count_with_time_range(self, repository, mock_session_maker):
|
||||
"""Test getting workflow runs count with time range filter verifies SQL query construction."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid.uuid4())
|
||||
app_id = str(uuid.uuid4())
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock the GROUP BY query results
|
||||
mock_results = [
|
||||
("succeeded", 3),
|
||||
("running", 2),
|
||||
]
|
||||
mock_session.execute.return_value.all.return_value = mock_results
|
||||
|
||||
# Act
|
||||
result = repository.get_workflow_runs_count(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
status=None,
|
||||
time_range="1d",
|
||||
)
|
||||
|
||||
# Assert results
|
||||
assert result["total"] == 5
|
||||
assert result["succeeded"] == 3
|
||||
assert result["running"] == 2
|
||||
assert result["failed"] == 0
|
||||
|
||||
# Verify that execute was called (which means GROUP BY query was used)
|
||||
assert mock_session.execute.called, "execute should have been called for GROUP BY query"
|
||||
|
||||
# Verify SQL query includes time filter by checking the statement
|
||||
call_args = mock_session.execute.call_args
|
||||
assert call_args is not None, "execute should have been called with a statement"
|
||||
|
||||
# The first argument should be the SQL statement
|
||||
stmt = call_args[0][0]
|
||||
# Convert to string to inspect the query
|
||||
query_str = str(stmt.compile(compile_kwargs={"literal_binds": True}))
|
||||
|
||||
# Verify the query includes created_at filter
|
||||
# The query should have a WHERE clause with created_at comparison
|
||||
assert "created_at" in query_str.lower() or "workflow_runs.created_at" in query_str.lower(), (
|
||||
"Query should include created_at filter for time range"
|
||||
)
|
||||
|
||||
def test_get_workflow_runs_count_with_status_and_time_range(self, repository, mock_session_maker):
|
||||
"""Test getting workflow runs count with both status and time range filters verifies SQL query."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid.uuid4())
|
||||
app_id = str(uuid.uuid4())
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock the count query for running status within time range
|
||||
mock_session.scalar.return_value = 2
|
||||
|
||||
# Act
|
||||
result = repository.get_workflow_runs_count(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
status="running",
|
||||
time_range="1d",
|
||||
)
|
||||
|
||||
# Assert results
|
||||
assert result["total"] == 2
|
||||
assert result["running"] == 2
|
||||
assert result["succeeded"] == 0
|
||||
assert result["failed"] == 0
|
||||
|
||||
# Verify that scalar was called (which means COUNT query was used)
|
||||
assert mock_session.scalar.called, "scalar should have been called for count query"
|
||||
|
||||
# Verify SQL query includes both status and time filter
|
||||
call_args = mock_session.scalar.call_args
|
||||
assert call_args is not None, "scalar should have been called with a statement"
|
||||
|
||||
# The first argument should be the SQL statement
|
||||
stmt = call_args[0][0]
|
||||
# Convert to string to inspect the query
|
||||
query_str = str(stmt.compile(compile_kwargs={"literal_binds": True}))
|
||||
|
||||
# Verify the query includes both filters
|
||||
assert "created_at" in query_str.lower() or "workflow_runs.created_at" in query_str.lower(), (
|
||||
"Query should include created_at filter for time range"
|
||||
)
|
||||
assert "status" in query_str.lower() or "workflow_runs.status" in query_str.lower(), (
|
||||
"Query should include status filter"
|
||||
)
|
||||
Reference in New Issue
Block a user