Feat/add status filter to workflow runs (#26850)

Co-authored-by: Jacky Su <jacky_su@trendmicro.com>
This commit is contained in:
Jacky Su
2025-10-18 12:15:29 +08:00
committed by GitHub
parent 1a37989769
commit ac79691d69
10 changed files with 851 additions and 19 deletions

View File

@@ -8,15 +8,81 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from fields.workflow_run_fields import (
advanced_chat_workflow_run_pagination_fields,
workflow_run_count_fields,
workflow_run_detail_fields,
workflow_run_node_execution_list_fields,
workflow_run_pagination_fields,
)
from libs.custom_inputs import time_duration
from libs.helper import uuid_value
from libs.login import current_user, login_required
from models import Account, App, AppMode, EndUser
from models import Account, App, AppMode, EndUser, WorkflowRunTriggeredFrom
from services.workflow_run_service import WorkflowRunService
# Workflow run status choices for filtering
WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"]
def _parse_workflow_run_list_args():
"""
Parse common arguments for workflow run list endpoints.
Returns:
Parsed arguments containing last_id, limit, status, and triggered_from filters
"""
parser = reqparse.RequestParser()
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
parser.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
parser.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
)
return parser.parse_args()
def _parse_workflow_run_count_args():
"""
Parse common arguments for workflow run count endpoints.
Returns:
Parsed arguments containing status, time_range, and triggered_from filters
"""
parser = reqparse.RequestParser()
parser.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
parser.add_argument(
"time_range",
type=time_duration,
location="args",
required=False,
help="Time range filter (e.g., 7d, 4h, 30m, 30s)",
)
parser.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
)
return parser.parse_args()
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs")
class AdvancedChatAppWorkflowRunListApi(Resource):
@@ -24,6 +90,8 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
@api.doc(description="Get advanced chat workflow run list")
@api.doc(params={"app_id": "Application ID"})
@api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
@api.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_fields)
@setup_required
@login_required
@@ -34,13 +102,64 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
"""
Get advanced chat app workflow run list
"""
parser = reqparse.RequestParser()
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args()
args = _parse_workflow_run_list_args()
# Default to DEBUGGING if not specified
triggered_from = (
WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args.get("triggered_from")
else WorkflowRunTriggeredFrom.DEBUGGING
)
workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args=args)
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(
app_model=app_model, args=args, triggered_from=triggered_from
)
return result
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs/count")
class AdvancedChatAppWorkflowRunCountApi(Resource):
@api.doc("get_advanced_chat_workflow_runs_count")
@api.doc(description="Get advanced chat workflow runs count statistics")
@api.doc(params={"app_id": "Application ID"})
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
@api.doc(
params={
"time_range": (
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
"30m (30 minutes), 30s (30 seconds). Filters by created_at field."
)
}
)
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
@api.response(200, "Workflow runs count retrieved successfully", workflow_run_count_fields)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@marshal_with(workflow_run_count_fields)
def get(self, app_model: App):
"""
Get advanced chat workflow runs count statistics
"""
args = _parse_workflow_run_count_args()
# Default to DEBUGGING if not specified
triggered_from = (
WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args.get("triggered_from")
else WorkflowRunTriggeredFrom.DEBUGGING
)
workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_workflow_runs_count(
app_model=app_model,
status=args.get("status"),
time_range=args.get("time_range"),
triggered_from=triggered_from,
)
return result
@@ -51,6 +170,8 @@ class WorkflowRunListApi(Resource):
@api.doc(description="Get workflow run list")
@api.doc(params={"app_id": "Application ID"})
@api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
@api.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_fields)
@setup_required
@login_required
@@ -61,13 +182,64 @@ class WorkflowRunListApi(Resource):
"""
Get workflow run list
"""
parser = reqparse.RequestParser()
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args()
args = _parse_workflow_run_list_args()
# Default to DEBUGGING for workflow if not specified (backward compatibility)
triggered_from = (
WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args.get("triggered_from")
else WorkflowRunTriggeredFrom.DEBUGGING
)
workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_paginate_workflow_runs(app_model=app_model, args=args)
result = workflow_run_service.get_paginate_workflow_runs(
app_model=app_model, args=args, triggered_from=triggered_from
)
return result
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/count")
class WorkflowRunCountApi(Resource):
@api.doc("get_workflow_runs_count")
@api.doc(description="Get workflow runs count statistics")
@api.doc(params={"app_id": "Application ID"})
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
@api.doc(
params={
"time_range": (
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
"30m (30 minutes), 30s (30 seconds). Filters by created_at field."
)
}
)
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
@api.response(200, "Workflow runs count retrieved successfully", workflow_run_count_fields)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_count_fields)
def get(self, app_model: App):
"""
Get workflow runs count statistics
"""
args = _parse_workflow_run_count_args()
# Default to DEBUGGING for workflow if not specified (backward compatibility)
triggered_from = (
WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args.get("triggered_from")
else WorkflowRunTriggeredFrom.DEBUGGING
)
workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_workflow_runs_count(
app_model=app_model,
status=args.get("status"),
time_range=args.get("time_range"),
triggered_from=triggered_from,
)
return result

View File

@@ -64,6 +64,15 @@ workflow_run_pagination_fields = {
"data": fields.List(fields.Nested(workflow_run_for_list_fields), attribute="data"),
}
workflow_run_count_fields = {
"total": fields.Integer,
"running": fields.Integer,
"succeeded": fields.Integer,
"failed": fields.Integer,
"stopped": fields.Integer,
"partial_succeeded": fields.Integer(attribute="partial-succeeded"),
}
workflow_run_detail_fields = {
"id": fields.String,
"version": fields.String,

32
api/libs/custom_inputs.py Normal file
View File

@@ -0,0 +1,32 @@
"""Custom input types for Flask-RESTX request parsing."""
import re
def time_duration(value: str) -> str:
"""
Validate and return time duration string.
Accepts formats: <number>d (days), <number>h (hours), <number>m (minutes), <number>s (seconds)
Examples: 7d, 4h, 30m, 30s
Args:
value: The time duration string
Returns:
The validated time duration string
Raises:
ValueError: If the format is invalid
"""
if not value:
raise ValueError("Time duration cannot be empty")
pattern = r"^(\d+)([dhms])$"
if not re.match(pattern, value.lower()):
raise ValueError(
"Invalid time duration format. Use: <number>d (days), <number>h (hours), "
"<number>m (minutes), or <number>s (seconds). Examples: 7d, 4h, 30m, 30s"
)
return value.lower()

67
api/libs/time_parser.py Normal file
View File

@@ -0,0 +1,67 @@
"""Time duration parser utility."""
import re
from datetime import UTC, datetime, timedelta
def parse_time_duration(duration_str: str) -> timedelta | None:
"""
Parse time duration string to timedelta.
Supported formats:
- 7d: 7 days
- 4h: 4 hours
- 30m: 30 minutes
- 30s: 30 seconds
Args:
duration_str: Duration string (e.g., "7d", "4h", "30m", "30s")
Returns:
timedelta object or None if invalid format
"""
if not duration_str:
return None
# Pattern: number followed by unit (d, h, m, s)
pattern = r"^(\d+)([dhms])$"
match = re.match(pattern, duration_str.lower())
if not match:
return None
value = int(match.group(1))
unit = match.group(2)
if unit == "d":
return timedelta(days=value)
elif unit == "h":
return timedelta(hours=value)
elif unit == "m":
return timedelta(minutes=value)
elif unit == "s":
return timedelta(seconds=value)
return None
def get_time_threshold(duration_str: str | None) -> datetime | None:
"""
Get datetime threshold from duration string.
Calculates the datetime that is duration_str ago from now.
Args:
duration_str: Duration string (e.g., "7d", "4h", "30m", "30s")
Returns:
datetime object representing the threshold time, or None if no duration
"""
if not duration_str:
return None
duration = parse_time_duration(duration_str)
if duration is None:
return None
return datetime.now(UTC) - duration

View File

@@ -59,6 +59,7 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
triggered_from: str,
limit: int = 20,
last_id: str | None = None,
status: str | None = None,
) -> InfiniteScrollPagination:
"""
Get paginated workflow runs with filtering.
@@ -73,6 +74,7 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
triggered_from: Filter by trigger source (e.g., "debugging", "app-run")
limit: Maximum number of records to return (default: 20)
last_id: Cursor for pagination - ID of the last record from previous page
status: Optional filter by status (e.g., "running", "succeeded", "failed")
Returns:
InfiniteScrollPagination object containing:
@@ -107,6 +109,43 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
"""
...
def get_workflow_runs_count(
self,
tenant_id: str,
app_id: str,
triggered_from: str,
status: str | None = None,
time_range: str | None = None,
) -> dict[str, int]:
"""
Get workflow runs count statistics.
Retrieves total count and count by status for workflow runs
matching the specified filters.
Args:
tenant_id: Tenant identifier for multi-tenant isolation
app_id: Application identifier
triggered_from: Filter by trigger source (e.g., "debugging", "app-run")
status: Optional filter by specific status
time_range: Optional time range filter (e.g., "7d", "4h", "30m", "30s")
Filters records based on created_at field
Returns:
Dictionary containing:
- total: Total count of all workflow runs (or filtered by status)
- running: Count of workflow runs with status "running"
- succeeded: Count of workflow runs with status "succeeded"
- failed: Count of workflow runs with status "failed"
- stopped: Count of workflow runs with status "stopped"
- partial_succeeded: Count of workflow runs with status "partial-succeeded"
Note: If a status is provided, 'total' will be the count for that status,
and the specific status count will also be set to this value, with all
other status counts being 0.
"""
...
def get_expired_runs_batch(
self,
tenant_id: str,

View File

@@ -24,11 +24,12 @@ from collections.abc import Sequence
from datetime import datetime
from typing import cast
from sqlalchemy import delete, select
from sqlalchemy import delete, func, select
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session, sessionmaker
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.time_parser import get_time_threshold
from models.workflow import WorkflowRun
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
@@ -63,6 +64,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
triggered_from: str,
limit: int = 20,
last_id: str | None = None,
status: str | None = None,
) -> InfiniteScrollPagination:
"""
Get paginated workflow runs with filtering.
@@ -79,6 +81,10 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
WorkflowRun.triggered_from == triggered_from,
)
# Add optional status filter
if status:
base_stmt = base_stmt.where(WorkflowRun.status == status)
if last_id:
# Get the last workflow run for cursor-based pagination
last_run_stmt = base_stmt.where(WorkflowRun.id == last_id)
@@ -120,6 +126,73 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
)
return session.scalar(stmt)
def get_workflow_runs_count(
self,
tenant_id: str,
app_id: str,
triggered_from: str,
status: str | None = None,
time_range: str | None = None,
) -> dict[str, int]:
"""
Get workflow runs count statistics grouped by status.
"""
_initial_status_counts = {
"running": 0,
"succeeded": 0,
"failed": 0,
"stopped": 0,
"partial-succeeded": 0,
}
with self._session_maker() as session:
# Build base where conditions
base_conditions = [
WorkflowRun.tenant_id == tenant_id,
WorkflowRun.app_id == app_id,
WorkflowRun.triggered_from == triggered_from,
]
# Add time range filter if provided
if time_range:
time_threshold = get_time_threshold(time_range)
if time_threshold:
base_conditions.append(WorkflowRun.created_at >= time_threshold)
# If status filter is provided, return simple count
if status:
count_stmt = select(func.count(WorkflowRun.id)).where(*base_conditions, WorkflowRun.status == status)
total = session.scalar(count_stmt) or 0
result = {"total": total} | _initial_status_counts
# Set the count for the filtered status
if status in result:
result[status] = total
return result
# No status filter - get counts grouped by status
base_stmt = (
select(WorkflowRun.status, func.count(WorkflowRun.id).label("count"))
.where(*base_conditions)
.group_by(WorkflowRun.status)
)
# Execute query
results = session.execute(base_stmt).all()
# Build response dictionary
status_counts = _initial_status_counts.copy()
total = 0
for status_val, count in results:
total += count
if status_val in status_counts:
status_counts[status_val] = count
return {"total": total} | status_counts
def get_expired_runs_batch(
self,
tenant_id: str,

View File

@@ -26,13 +26,15 @@ class WorkflowRunService:
)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
def get_paginate_advanced_chat_workflow_runs(self, app_model: App, args: dict) -> InfiniteScrollPagination:
def get_paginate_advanced_chat_workflow_runs(
self, app_model: App, args: dict, triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING
) -> InfiniteScrollPagination:
"""
Get advanced chat app workflow run list
Only return triggered_from == advanced_chat
:param app_model: app model
:param args: request args
:param triggered_from: workflow run triggered from (default: DEBUGGING for preview runs)
"""
class WorkflowWithMessage:
@@ -45,7 +47,7 @@ class WorkflowRunService:
def __getattr__(self, item):
return getattr(self._workflow_run, item)
pagination = self.get_paginate_workflow_runs(app_model, args)
pagination = self.get_paginate_workflow_runs(app_model, args, triggered_from)
with_message_workflow_runs = []
for workflow_run in pagination.data:
@@ -60,23 +62,27 @@ class WorkflowRunService:
pagination.data = with_message_workflow_runs
return pagination
def get_paginate_workflow_runs(self, app_model: App, args: dict) -> InfiniteScrollPagination:
def get_paginate_workflow_runs(
self, app_model: App, args: dict, triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING
) -> InfiniteScrollPagination:
"""
Get debug workflow run list
Only return triggered_from == debugging
Get workflow run list
:param app_model: app model
:param args: request args
:param triggered_from: workflow run triggered from (default: DEBUGGING)
"""
limit = int(args.get("limit", 20))
last_id = args.get("last_id")
status = args.get("status")
return self._workflow_run_repo.get_paginated_workflow_runs(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
triggered_from=triggered_from,
limit=limit,
last_id=last_id,
status=status,
)
def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun | None:
@@ -92,6 +98,30 @@ class WorkflowRunService:
run_id=run_id,
)
def get_workflow_runs_count(
self,
app_model: App,
status: str | None = None,
time_range: str | None = None,
triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING,
) -> dict[str, int]:
"""
Get workflow runs count statistics
:param app_model: app model
:param status: optional status filter
:param time_range: optional time range filter (e.g., "7d", "4h", "30m", "30s")
:param triggered_from: workflow run triggered from (default: DEBUGGING)
:return: dict with total and status counts
"""
return self._workflow_run_repo.get_workflow_runs_count(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
triggered_from=triggered_from,
status=status,
time_range=time_range,
)
def get_workflow_run_node_executions(
self,
app_model: App,

View File

@@ -0,0 +1,68 @@
"""Unit tests for custom input types."""
import pytest
from libs.custom_inputs import time_duration
class TestTimeDuration:
"""Test time_duration input validator."""
def test_valid_days(self):
"""Test valid days format."""
result = time_duration("7d")
assert result == "7d"
def test_valid_hours(self):
"""Test valid hours format."""
result = time_duration("4h")
assert result == "4h"
def test_valid_minutes(self):
"""Test valid minutes format."""
result = time_duration("30m")
assert result == "30m"
def test_valid_seconds(self):
"""Test valid seconds format."""
result = time_duration("30s")
assert result == "30s"
def test_uppercase_conversion(self):
"""Test uppercase units are converted to lowercase."""
result = time_duration("7D")
assert result == "7d"
result = time_duration("4H")
assert result == "4h"
def test_invalid_format_no_unit(self):
"""Test invalid format without unit."""
with pytest.raises(ValueError, match="Invalid time duration format"):
time_duration("7")
def test_invalid_format_wrong_unit(self):
"""Test invalid format with wrong unit."""
with pytest.raises(ValueError, match="Invalid time duration format"):
time_duration("7days")
with pytest.raises(ValueError, match="Invalid time duration format"):
time_duration("7x")
def test_invalid_format_no_number(self):
"""Test invalid format without number."""
with pytest.raises(ValueError, match="Invalid time duration format"):
time_duration("d")
with pytest.raises(ValueError, match="Invalid time duration format"):
time_duration("abc")
def test_empty_string(self):
"""Test empty string."""
with pytest.raises(ValueError, match="Time duration cannot be empty"):
time_duration("")
def test_none(self):
"""Test None value."""
with pytest.raises(ValueError, match="Time duration cannot be empty"):
time_duration(None)

View File

@@ -0,0 +1,91 @@
"""Unit tests for time parser utility."""
from datetime import UTC, datetime, timedelta
from libs.time_parser import get_time_threshold, parse_time_duration
class TestParseTimeDuration:
"""Test parse_time_duration function."""
def test_parse_days(self):
"""Test parsing days."""
result = parse_time_duration("7d")
assert result == timedelta(days=7)
def test_parse_hours(self):
"""Test parsing hours."""
result = parse_time_duration("4h")
assert result == timedelta(hours=4)
def test_parse_minutes(self):
"""Test parsing minutes."""
result = parse_time_duration("30m")
assert result == timedelta(minutes=30)
def test_parse_seconds(self):
"""Test parsing seconds."""
result = parse_time_duration("30s")
assert result == timedelta(seconds=30)
def test_parse_uppercase(self):
"""Test parsing uppercase units."""
result = parse_time_duration("7D")
assert result == timedelta(days=7)
def test_parse_invalid_format(self):
"""Test parsing invalid format."""
result = parse_time_duration("7days")
assert result is None
result = parse_time_duration("abc")
assert result is None
result = parse_time_duration("7")
assert result is None
def test_parse_empty_string(self):
"""Test parsing empty string."""
result = parse_time_duration("")
assert result is None
def test_parse_none(self):
"""Test parsing None."""
result = parse_time_duration(None)
assert result is None
class TestGetTimeThreshold:
"""Test get_time_threshold function."""
def test_get_threshold_days(self):
"""Test getting threshold for days."""
before = datetime.now(UTC)
result = get_time_threshold("7d")
after = datetime.now(UTC)
assert result is not None
# Result should be approximately 7 days ago
expected = before - timedelta(days=7)
# Allow 1 second tolerance for test execution time
assert abs((result - expected).total_seconds()) < 1
def test_get_threshold_hours(self):
"""Test getting threshold for hours."""
before = datetime.now(UTC)
result = get_time_threshold("4h")
after = datetime.now(UTC)
assert result is not None
expected = before - timedelta(hours=4)
assert abs((result - expected).total_seconds()) < 1
def test_get_threshold_invalid(self):
"""Test getting threshold with invalid duration."""
result = get_time_threshold("invalid")
assert result is None
def test_get_threshold_none(self):
"""Test getting threshold with None."""
result = get_time_threshold(None)
assert result is None

View File

@@ -0,0 +1,251 @@
"""Unit tests for workflow run repository with status filter."""
import uuid
from unittest.mock import MagicMock
import pytest
from sqlalchemy.orm import sessionmaker
from models import WorkflowRun, WorkflowRunTriggeredFrom
from repositories.sqlalchemy_api_workflow_run_repository import DifyAPISQLAlchemyWorkflowRunRepository
class TestDifyAPISQLAlchemyWorkflowRunRepository:
"""Test workflow run repository with status filtering."""
@pytest.fixture
def mock_session_maker(self):
"""Create a mock session maker."""
return MagicMock(spec=sessionmaker)
@pytest.fixture
def repository(self, mock_session_maker):
"""Create repository instance with mock session."""
return DifyAPISQLAlchemyWorkflowRunRepository(mock_session_maker)
def test_get_paginated_workflow_runs_without_status(self, repository, mock_session_maker):
"""Test getting paginated workflow runs without status filter."""
# Arrange
tenant_id = str(uuid.uuid4())
app_id = str(uuid.uuid4())
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_runs = [MagicMock(spec=WorkflowRun) for _ in range(3)]
mock_session.scalars.return_value.all.return_value = mock_runs
# Act
result = repository.get_paginated_workflow_runs(
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
limit=20,
last_id=None,
status=None,
)
# Assert
assert len(result.data) == 3
assert result.limit == 20
assert result.has_more is False
def test_get_paginated_workflow_runs_with_status_filter(self, repository, mock_session_maker):
"""Test getting paginated workflow runs with status filter."""
# Arrange
tenant_id = str(uuid.uuid4())
app_id = str(uuid.uuid4())
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_runs = [MagicMock(spec=WorkflowRun, status="succeeded") for _ in range(2)]
mock_session.scalars.return_value.all.return_value = mock_runs
# Act
result = repository.get_paginated_workflow_runs(
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
limit=20,
last_id=None,
status="succeeded",
)
# Assert
assert len(result.data) == 2
assert all(run.status == "succeeded" for run in result.data)
def test_get_workflow_runs_count_without_status(self, repository, mock_session_maker):
"""Test getting workflow runs count without status filter."""
# Arrange
tenant_id = str(uuid.uuid4())
app_id = str(uuid.uuid4())
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
# Mock the GROUP BY query results
mock_results = [
("succeeded", 5),
("failed", 2),
("running", 1),
]
mock_session.execute.return_value.all.return_value = mock_results
# Act
result = repository.get_workflow_runs_count(
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
status=None,
)
# Assert
assert result["total"] == 8
assert result["succeeded"] == 5
assert result["failed"] == 2
assert result["running"] == 1
assert result["stopped"] == 0
assert result["partial-succeeded"] == 0
def test_get_workflow_runs_count_with_status_filter(self, repository, mock_session_maker):
"""Test getting workflow runs count with status filter."""
# Arrange
tenant_id = str(uuid.uuid4())
app_id = str(uuid.uuid4())
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
# Mock the count query for succeeded status
mock_session.scalar.return_value = 5
# Act
result = repository.get_workflow_runs_count(
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
status="succeeded",
)
# Assert
assert result["total"] == 5
assert result["succeeded"] == 5
assert result["running"] == 0
assert result["failed"] == 0
assert result["stopped"] == 0
assert result["partial-succeeded"] == 0
def test_get_workflow_runs_count_with_invalid_status(self, repository, mock_session_maker):
"""Test that invalid status is still counted in total but not in any specific status."""
# Arrange
tenant_id = str(uuid.uuid4())
app_id = str(uuid.uuid4())
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
# Mock count query returning 0 for invalid status
mock_session.scalar.return_value = 0
# Act
result = repository.get_workflow_runs_count(
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
status="invalid_status",
)
# Assert
assert result["total"] == 0
assert all(result[status] == 0 for status in ["running", "succeeded", "failed", "stopped", "partial-succeeded"])
def test_get_workflow_runs_count_with_time_range(self, repository, mock_session_maker):
"""Test getting workflow runs count with time range filter verifies SQL query construction."""
# Arrange
tenant_id = str(uuid.uuid4())
app_id = str(uuid.uuid4())
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
# Mock the GROUP BY query results
mock_results = [
("succeeded", 3),
("running", 2),
]
mock_session.execute.return_value.all.return_value = mock_results
# Act
result = repository.get_workflow_runs_count(
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
status=None,
time_range="1d",
)
# Assert results
assert result["total"] == 5
assert result["succeeded"] == 3
assert result["running"] == 2
assert result["failed"] == 0
# Verify that execute was called (which means GROUP BY query was used)
assert mock_session.execute.called, "execute should have been called for GROUP BY query"
# Verify SQL query includes time filter by checking the statement
call_args = mock_session.execute.call_args
assert call_args is not None, "execute should have been called with a statement"
# The first argument should be the SQL statement
stmt = call_args[0][0]
# Convert to string to inspect the query
query_str = str(stmt.compile(compile_kwargs={"literal_binds": True}))
# Verify the query includes created_at filter
# The query should have a WHERE clause with created_at comparison
assert "created_at" in query_str.lower() or "workflow_runs.created_at" in query_str.lower(), (
"Query should include created_at filter for time range"
)
def test_get_workflow_runs_count_with_status_and_time_range(self, repository, mock_session_maker):
"""Test getting workflow runs count with both status and time range filters verifies SQL query."""
# Arrange
tenant_id = str(uuid.uuid4())
app_id = str(uuid.uuid4())
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
# Mock the count query for running status within time range
mock_session.scalar.return_value = 2
# Act
result = repository.get_workflow_runs_count(
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
status="running",
time_range="1d",
)
# Assert results
assert result["total"] == 2
assert result["running"] == 2
assert result["succeeded"] == 0
assert result["failed"] == 0
# Verify that scalar was called (which means COUNT query was used)
assert mock_session.scalar.called, "scalar should have been called for count query"
# Verify SQL query includes both status and time filter
call_args = mock_session.scalar.call_args
assert call_args is not None, "scalar should have been called with a statement"
# The first argument should be the SQL statement
stmt = call_args[0][0]
# Convert to string to inspect the query
query_str = str(stmt.compile(compile_kwargs={"literal_binds": True}))
# Verify the query includes both filters
assert "created_at" in query_str.lower() or "workflow_runs.created_at" in query_str.lower(), (
"Query should include created_at filter for time range"
)
assert "status" in query_str.lower() or "workflow_runs.status" in query_str.lower(), (
"Query should include status filter"
)