mirror of
https://github.com/langgenius/dify.git
synced 2025-12-19 17:27:16 -05:00
Enhanced GraphEngine Pause Handling (#28196)
This commit: 1. Convert `pause_reason` to `pause_reasons` in `GraphExecution` and relevant classes. Change the field from a scalar value to a list that can contain multiple `PauseReason` objects, ensuring all pause events are properly captured. 2. Introduce a new `WorkflowPauseReason` model to record reasons associated with a specific `WorkflowPause`. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: -LAN- <laipz8200@outlook.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -16,6 +16,7 @@ layers =
|
||||
graph
|
||||
nodes
|
||||
node_events
|
||||
runtime
|
||||
entities
|
||||
containers =
|
||||
core.workflow
|
||||
|
||||
@@ -118,6 +118,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
workflow_run_id=workflow_run_id,
|
||||
state_owner_user_id=self._state_owner_user_id,
|
||||
state=state.dumps(),
|
||||
pause_reasons=event.reasons,
|
||||
)
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
|
||||
@@ -1,17 +1,11 @@
|
||||
from ..runtime.graph_runtime_state import GraphRuntimeState
|
||||
from ..runtime.variable_pool import VariablePool
|
||||
from .agent import AgentNodeStrategyInit
|
||||
from .graph_init_params import GraphInitParams
|
||||
from .workflow_execution import WorkflowExecution
|
||||
from .workflow_node_execution import WorkflowNodeExecution
|
||||
from .workflow_pause import WorkflowPauseEntity
|
||||
|
||||
__all__ = [
|
||||
"AgentNodeStrategyInit",
|
||||
"GraphInitParams",
|
||||
"GraphRuntimeState",
|
||||
"VariablePool",
|
||||
"WorkflowExecution",
|
||||
"WorkflowNodeExecution",
|
||||
"WorkflowPauseEntity",
|
||||
]
|
||||
|
||||
@@ -1,49 +1,26 @@
|
||||
from enum import StrEnum, auto
|
||||
from typing import Annotated, Any, ClassVar, TypeAlias
|
||||
from typing import Annotated, Literal, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Discriminator, Tag
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class _PauseReasonType(StrEnum):
|
||||
class PauseReasonType(StrEnum):
|
||||
HUMAN_INPUT_REQUIRED = auto()
|
||||
SCHEDULED_PAUSE = auto()
|
||||
|
||||
|
||||
class _PauseReasonBase(BaseModel):
|
||||
TYPE: ClassVar[_PauseReasonType]
|
||||
class HumanInputRequired(BaseModel):
|
||||
TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||
|
||||
form_id: str
|
||||
# The identifier of the human input node causing the pause.
|
||||
node_id: str
|
||||
|
||||
|
||||
class HumanInputRequired(_PauseReasonBase):
|
||||
TYPE = _PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||
|
||||
|
||||
class SchedulingPause(_PauseReasonBase):
|
||||
TYPE = _PauseReasonType.SCHEDULED_PAUSE
|
||||
class SchedulingPause(BaseModel):
|
||||
TYPE: Literal[PauseReasonType.SCHEDULED_PAUSE] = PauseReasonType.SCHEDULED_PAUSE
|
||||
|
||||
message: str
|
||||
|
||||
|
||||
def _get_pause_reason_discriminator(v: Any) -> _PauseReasonType | None:
|
||||
if isinstance(v, _PauseReasonBase):
|
||||
return v.TYPE
|
||||
elif isinstance(v, dict):
|
||||
reason_type_str = v.get("TYPE")
|
||||
if reason_type_str is None:
|
||||
return None
|
||||
try:
|
||||
reason_type = _PauseReasonType(reason_type_str)
|
||||
except ValueError:
|
||||
return None
|
||||
return reason_type
|
||||
else:
|
||||
# return None if the discriminator value isn't found
|
||||
return None
|
||||
|
||||
|
||||
PauseReason: TypeAlias = Annotated[
|
||||
(
|
||||
Annotated[HumanInputRequired, Tag(_PauseReasonType.HUMAN_INPUT_REQUIRED)]
|
||||
| Annotated[SchedulingPause, Tag(_PauseReasonType.SCHEDULED_PAUSE)]
|
||||
),
|
||||
Discriminator(_get_pause_reason_discriminator),
|
||||
]
|
||||
PauseReason: TypeAlias = Annotated[HumanInputRequired | SchedulingPause, Field(discriminator="TYPE")]
|
||||
|
||||
@@ -42,7 +42,7 @@ class GraphExecutionState(BaseModel):
|
||||
completed: bool = Field(default=False)
|
||||
aborted: bool = Field(default=False)
|
||||
paused: bool = Field(default=False)
|
||||
pause_reason: PauseReason | None = Field(default=None)
|
||||
pause_reasons: list[PauseReason] = Field(default_factory=list)
|
||||
error: GraphExecutionErrorState | None = Field(default=None)
|
||||
exceptions_count: int = Field(default=0)
|
||||
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
|
||||
@@ -107,7 +107,7 @@ class GraphExecution:
|
||||
completed: bool = False
|
||||
aborted: bool = False
|
||||
paused: bool = False
|
||||
pause_reason: PauseReason | None = None
|
||||
pause_reasons: list[PauseReason] = field(default_factory=list)
|
||||
error: Exception | None = None
|
||||
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
|
||||
exceptions_count: int = 0
|
||||
@@ -137,10 +137,8 @@ class GraphExecution:
|
||||
raise RuntimeError("Cannot pause execution that has completed")
|
||||
if self.aborted:
|
||||
raise RuntimeError("Cannot pause execution that has been aborted")
|
||||
if self.paused:
|
||||
return
|
||||
self.paused = True
|
||||
self.pause_reason = reason
|
||||
self.pause_reasons.append(reason)
|
||||
|
||||
def fail(self, error: Exception) -> None:
|
||||
"""Mark the graph execution as failed."""
|
||||
@@ -195,7 +193,7 @@ class GraphExecution:
|
||||
completed=self.completed,
|
||||
aborted=self.aborted,
|
||||
paused=self.paused,
|
||||
pause_reason=self.pause_reason,
|
||||
pause_reasons=self.pause_reasons,
|
||||
error=_serialize_error(self.error),
|
||||
exceptions_count=self.exceptions_count,
|
||||
node_executions=node_states,
|
||||
@@ -221,7 +219,7 @@ class GraphExecution:
|
||||
self.completed = state.completed
|
||||
self.aborted = state.aborted
|
||||
self.paused = state.paused
|
||||
self.pause_reason = state.pause_reason
|
||||
self.pause_reasons = state.pause_reasons
|
||||
self.error = _deserialize_error(state.error)
|
||||
self.exceptions_count = state.exceptions_count
|
||||
self.node_executions = {
|
||||
|
||||
@@ -110,7 +110,13 @@ class EventManager:
|
||||
"""
|
||||
with self._lock.write_lock():
|
||||
self._events.append(event)
|
||||
self._notify_layers(event)
|
||||
|
||||
# NOTE: `_notify_layers` is intentionally called outside the critical section
|
||||
# to minimize lock contention and avoid blocking other readers or writers.
|
||||
#
|
||||
# The public `notify_layers` method also does not use a write lock,
|
||||
# so protecting `_notify_layers` with a lock here is unnecessary.
|
||||
self._notify_layers(event)
|
||||
|
||||
def _get_new_events(self, start_index: int) -> list[GraphEngineEvent]:
|
||||
"""
|
||||
|
||||
@@ -232,7 +232,7 @@ class GraphEngine:
|
||||
self._graph_execution.start()
|
||||
else:
|
||||
self._graph_execution.paused = False
|
||||
self._graph_execution.pause_reason = None
|
||||
self._graph_execution.pause_reasons = []
|
||||
|
||||
start_event = GraphRunStartedEvent()
|
||||
self._event_manager.notify_layers(start_event)
|
||||
@@ -246,11 +246,11 @@ class GraphEngine:
|
||||
|
||||
# Handle completion
|
||||
if self._graph_execution.is_paused:
|
||||
pause_reason = self._graph_execution.pause_reason
|
||||
assert pause_reason is not None, "pause_reason should not be None when execution is paused."
|
||||
pause_reasons = self._graph_execution.pause_reasons
|
||||
assert pause_reasons, "pause_reasons should not be empty when execution is paused."
|
||||
# Ensure we have a valid PauseReason for the event
|
||||
paused_event = GraphRunPausedEvent(
|
||||
reason=pause_reason,
|
||||
reasons=pause_reasons,
|
||||
outputs=self._graph_runtime_state.outputs,
|
||||
)
|
||||
self._event_manager.notify_layers(paused_event)
|
||||
|
||||
@@ -45,8 +45,7 @@ class GraphRunAbortedEvent(BaseGraphEvent):
|
||||
class GraphRunPausedEvent(BaseGraphEvent):
|
||||
"""Event emitted when a graph run is paused by user command."""
|
||||
|
||||
# reason: str | None = Field(default=None, description="reason for pause")
|
||||
reason: PauseReason = Field(..., description="reason for pause")
|
||||
reasons: list[PauseReason] = Field(description="reason for pause", default_factory=list)
|
||||
outputs: dict[str, object] = Field(
|
||||
default_factory=dict,
|
||||
description="Outputs available to the client while the run is paused.",
|
||||
|
||||
@@ -65,7 +65,8 @@ class HumanInputNode(Node):
|
||||
return self._pause_generator()
|
||||
|
||||
def _pause_generator(self):
|
||||
yield PauseRequestedEvent(reason=HumanInputRequired())
|
||||
# TODO(QuantumGhost): yield a real form id.
|
||||
yield PauseRequestedEvent(reason=HumanInputRequired(form_id="test_form_id", node_id=self.id))
|
||||
|
||||
def _is_completion_ready(self) -> bool:
|
||||
"""Determine whether all required inputs are satisfied."""
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import Any, Protocol
|
||||
from pydantic.json import pydantic_encoder
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.runtime.variable_pool import VariablePool
|
||||
|
||||
|
||||
@@ -46,7 +47,11 @@ class ReadyQueueProtocol(Protocol):
|
||||
|
||||
|
||||
class GraphExecutionProtocol(Protocol):
|
||||
"""Structural interface for graph execution aggregate."""
|
||||
"""Structural interface for graph execution aggregate.
|
||||
|
||||
Defines the minimal set of attributes and methods required from a GraphExecution entity
|
||||
for runtime orchestration and state management.
|
||||
"""
|
||||
|
||||
workflow_id: str
|
||||
started: bool
|
||||
@@ -54,6 +59,7 @@ class GraphExecutionProtocol(Protocol):
|
||||
aborted: bool
|
||||
error: Exception | None
|
||||
exceptions_count: int
|
||||
pause_reasons: list[PauseReason]
|
||||
|
||||
def start(self) -> None:
|
||||
"""Transition execution into the running state."""
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
"""Add workflow_pauses_reasons table
|
||||
|
||||
Revision ID: 7bb281b7a422
|
||||
Revises: 09cfdda155d1
|
||||
Create Date: 2025-11-18 18:59:26.999572
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7bb281b7a422"
|
||||
down_revision = "09cfdda155d1"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
"workflow_pause_reasons",
|
||||
sa.Column("id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
|
||||
sa.Column("pause_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("type_", sa.String(20), nullable=False),
|
||||
sa.Column("form_id", sa.String(length=36), nullable=False),
|
||||
sa.Column("node_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("message", sa.String(length=255), nullable=False),
|
||||
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("workflow_pause_reasons_pkey")),
|
||||
)
|
||||
with op.batch_alter_table("workflow_pause_reasons", schema=None) as batch_op:
|
||||
batch_op.create_index(batch_op.f("workflow_pause_reasons_pause_id_idx"), ["pause_id"], unique=False)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_table("workflow_pause_reasons")
|
||||
@@ -29,6 +29,7 @@ from core.workflow.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
)
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause
|
||||
from core.workflow.enums import NodeType
|
||||
from extensions.ext_storage import Storage
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
||||
@@ -1728,3 +1729,68 @@ class WorkflowPause(DefaultFieldsMixin, Base):
|
||||
primaryjoin="WorkflowPause.workflow_run_id == WorkflowRun.id",
|
||||
back_populates="pause",
|
||||
)
|
||||
|
||||
|
||||
class WorkflowPauseReason(DefaultFieldsMixin, Base):
|
||||
__tablename__ = "workflow_pause_reasons"
|
||||
|
||||
# `pause_id` represents the identifier of the pause,
|
||||
# correspond to the `id` field of `WorkflowPause`.
|
||||
pause_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True)
|
||||
|
||||
type_: Mapped[PauseReasonType] = mapped_column(EnumText(PauseReasonType), nullable=False)
|
||||
|
||||
# form_id is not empty if and if only type_ == PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||
#
|
||||
form_id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
nullable=False,
|
||||
default="",
|
||||
)
|
||||
|
||||
# message records the text description of this pause reason. For example,
|
||||
# "The workflow has been paused due to scheduling."
|
||||
#
|
||||
# Empty message means that this pause reason is not speified.
|
||||
message: Mapped[str] = mapped_column(
|
||||
String(255),
|
||||
nullable=False,
|
||||
default="",
|
||||
)
|
||||
|
||||
# `node_id` is the identifier of node causing the pasue, correspond to
|
||||
# `Node.id`. Empty `node_id` means that this pause reason is not caused by any specific node
|
||||
# (E.G. time slicing pauses.)
|
||||
node_id: Mapped[str] = mapped_column(
|
||||
String(255),
|
||||
nullable=False,
|
||||
default="",
|
||||
)
|
||||
|
||||
# Relationship to WorkflowPause
|
||||
pause: Mapped[WorkflowPause] = orm.relationship(
|
||||
foreign_keys=[pause_id],
|
||||
# require explicit preloading.
|
||||
lazy="raise",
|
||||
uselist=False,
|
||||
primaryjoin="WorkflowPauseReason.pause_id == WorkflowPause.id",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_entity(cls, pause_reason: PauseReason) -> "WorkflowPauseReason":
|
||||
if isinstance(pause_reason, HumanInputRequired):
|
||||
return cls(
|
||||
type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id
|
||||
)
|
||||
elif isinstance(pause_reason, SchedulingPause):
|
||||
return cls(type_=PauseReasonType.SCHEDULED_PAUSE, message=pause_reason.message, node_id="")
|
||||
else:
|
||||
raise AssertionError(f"Unknown pause reason type: {pause_reason}")
|
||||
|
||||
def to_entity(self) -> PauseReason:
|
||||
if self.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED:
|
||||
return HumanInputRequired(form_id=self.form_id, node_id=self.node_id)
|
||||
elif self.type_ == PauseReasonType.SCHEDULED_PAUSE:
|
||||
return SchedulingPause(message=self.message)
|
||||
else:
|
||||
raise AssertionError(f"Unknown pause reason type: {self.type_}")
|
||||
|
||||
@@ -38,11 +38,12 @@ from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Protocol
|
||||
|
||||
from core.workflow.entities.workflow_pause import WorkflowPauseEntity
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from repositories.types import (
|
||||
AverageInteractionStats,
|
||||
DailyRunsStats,
|
||||
@@ -257,6 +258,7 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
|
||||
workflow_run_id: str,
|
||||
state_owner_user_id: str,
|
||||
state: str,
|
||||
pause_reasons: Sequence[PauseReason],
|
||||
) -> WorkflowPauseEntity:
|
||||
"""
|
||||
Create a new workflow pause state.
|
||||
|
||||
@@ -7,8 +7,11 @@ and don't contain implementation details like tenant_id, app_id, etc.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
|
||||
|
||||
class WorkflowPauseEntity(ABC):
|
||||
"""
|
||||
@@ -59,3 +62,15 @@ class WorkflowPauseEntity(ABC):
|
||||
the pause is not resumed yet.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_pause_reasons(self) -> Sequence[PauseReason]:
|
||||
"""
|
||||
Retrieve detailed reasons for this pause.
|
||||
|
||||
Returns a sequence of `PauseReason` objects describing the specific nodes and
|
||||
reasons for which the workflow execution was paused.
|
||||
This information is related to, but distinct from, the `PauseReason` type
|
||||
defined in `api/core/workflow/entities/pause_reason.py`.
|
||||
"""
|
||||
...
|
||||
@@ -31,7 +31,7 @@ from sqlalchemy import and_, delete, func, null, or_, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
||||
|
||||
from core.workflow.entities.workflow_pause import WorkflowPauseEntity
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, SchedulingPause
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@@ -41,8 +41,9 @@ from libs.time_parser import get_time_threshold
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowPause as WorkflowPauseModel
|
||||
from models.workflow import WorkflowRun
|
||||
from models.workflow import WorkflowPauseReason, WorkflowRun
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from repositories.types import (
|
||||
AverageInteractionStats,
|
||||
DailyRunsStats,
|
||||
@@ -318,6 +319,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
workflow_run_id: str,
|
||||
state_owner_user_id: str,
|
||||
state: str,
|
||||
pause_reasons: Sequence[PauseReason],
|
||||
) -> WorkflowPauseEntity:
|
||||
"""
|
||||
Create a new workflow pause state.
|
||||
@@ -371,6 +373,25 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
pause_model.workflow_run_id = workflow_run.id
|
||||
pause_model.state_object_key = state_obj_key
|
||||
pause_model.created_at = naive_utc_now()
|
||||
pause_reason_models = []
|
||||
for reason in pause_reasons:
|
||||
if isinstance(reason, HumanInputRequired):
|
||||
# TODO(QuantumGhost): record node_id for `WorkflowPauseReason`
|
||||
pause_reason_model = WorkflowPauseReason(
|
||||
pause_id=pause_model.id,
|
||||
type_=reason.TYPE,
|
||||
form_id=reason.form_id,
|
||||
)
|
||||
elif isinstance(reason, SchedulingPause):
|
||||
pause_reason_model = WorkflowPauseReason(
|
||||
pause_id=pause_model.id,
|
||||
type_=reason.TYPE,
|
||||
message=reason.message,
|
||||
)
|
||||
else:
|
||||
raise AssertionError(f"unkown reason type: {type(reason)}")
|
||||
|
||||
pause_reason_models.append(pause_reason_model)
|
||||
|
||||
# Update workflow run status
|
||||
workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||
@@ -378,10 +399,16 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
# Save everything in a transaction
|
||||
session.add(pause_model)
|
||||
session.add(workflow_run)
|
||||
session.add_all(pause_reason_models)
|
||||
|
||||
logger.info("Created workflow pause %s for workflow run %s", pause_model.id, workflow_run_id)
|
||||
|
||||
return _PrivateWorkflowPauseEntity.from_models(pause_model)
|
||||
return _PrivateWorkflowPauseEntity(pause_model=pause_model, reason_models=pause_reason_models)
|
||||
|
||||
def _get_reasons_by_pause_id(self, session: Session, pause_id: str):
|
||||
reason_stmt = select(WorkflowPauseReason).where(WorkflowPauseReason.pause_id == pause_id)
|
||||
pause_reason_models = session.scalars(reason_stmt).all()
|
||||
return pause_reason_models
|
||||
|
||||
def get_workflow_pause(
|
||||
self,
|
||||
@@ -413,8 +440,16 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
pause_model = workflow_run.pause
|
||||
if pause_model is None:
|
||||
return None
|
||||
pause_reason_models = self._get_reasons_by_pause_id(session, pause_model.id)
|
||||
|
||||
return _PrivateWorkflowPauseEntity.from_models(pause_model)
|
||||
human_input_form: list[Any] = []
|
||||
# TODO(QuantumGhost): query human_input_forms model and rebuild PauseReason
|
||||
|
||||
return _PrivateWorkflowPauseEntity(
|
||||
pause_model=pause_model,
|
||||
reason_models=pause_reason_models,
|
||||
human_input_form=human_input_form,
|
||||
)
|
||||
|
||||
def resume_workflow_pause(
|
||||
self,
|
||||
@@ -466,6 +501,8 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
if pause_model.resumed_at is not None:
|
||||
raise _WorkflowRunError(f"Cannot resume an already resumed pause, pause_id={pause_model.id}")
|
||||
|
||||
pause_reasons = self._get_reasons_by_pause_id(session, pause_model.id)
|
||||
|
||||
# Mark as resumed
|
||||
pause_model.resumed_at = naive_utc_now()
|
||||
workflow_run.pause_id = None # type: ignore
|
||||
@@ -476,7 +513,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
|
||||
logger.info("Resumed workflow pause %s for workflow run %s", pause_model.id, workflow_run_id)
|
||||
|
||||
return _PrivateWorkflowPauseEntity.from_models(pause_model)
|
||||
return _PrivateWorkflowPauseEntity(pause_model=pause_model, reason_models=pause_reasons)
|
||||
|
||||
def delete_workflow_pause(
|
||||
self,
|
||||
@@ -815,26 +852,13 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
|
||||
self,
|
||||
*,
|
||||
pause_model: WorkflowPauseModel,
|
||||
reason_models: Sequence[WorkflowPauseReason],
|
||||
human_input_form: Sequence = (),
|
||||
) -> None:
|
||||
self._pause_model = pause_model
|
||||
self._reason_models = reason_models
|
||||
self._cached_state: bytes | None = None
|
||||
|
||||
@classmethod
|
||||
def from_models(cls, workflow_pause_model) -> "_PrivateWorkflowPauseEntity":
|
||||
"""
|
||||
Create a _PrivateWorkflowPauseEntity from database models.
|
||||
|
||||
Args:
|
||||
workflow_pause_model: The WorkflowPause database model
|
||||
upload_file_model: The UploadFile database model
|
||||
|
||||
Returns:
|
||||
_PrivateWorkflowPauseEntity: The constructed entity
|
||||
|
||||
Raises:
|
||||
ValueError: If required model attributes are missing
|
||||
"""
|
||||
return cls(pause_model=workflow_pause_model)
|
||||
self._human_input_form = human_input_form
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
@@ -867,3 +891,6 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
|
||||
@property
|
||||
def resumed_at(self) -> datetime | None:
|
||||
return self._pause_model.resumed_at
|
||||
|
||||
def get_pause_reasons(self) -> Sequence[PauseReason]:
|
||||
return [reason.to_entity() for reason in self._reason_models]
|
||||
|
||||
@@ -15,7 +15,7 @@ from core.file import File
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.variables import Variable
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.workflow.entities import VariablePool, WorkflowNodeExecution
|
||||
from core.workflow.entities import WorkflowNodeExecution
|
||||
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent
|
||||
@@ -24,6 +24,7 @@ from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from enums.cloud_plan import CloudPlan
|
||||
|
||||
@@ -319,7 +319,7 @@ class TestPauseStatePersistenceLayerTestContainers:
|
||||
|
||||
# Create pause event
|
||||
event = GraphRunPausedEvent(
|
||||
reason=SchedulingPause(message="test pause"),
|
||||
reasons=[SchedulingPause(message="test pause")],
|
||||
outputs={"intermediate": "result"},
|
||||
)
|
||||
|
||||
@@ -381,7 +381,7 @@ class TestPauseStatePersistenceLayerTestContainers:
|
||||
command_channel = _TestCommandChannelImpl()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||
event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")])
|
||||
|
||||
# Act - Save pause state
|
||||
layer.on_event(event)
|
||||
@@ -390,6 +390,7 @@ class TestPauseStatePersistenceLayerTestContainers:
|
||||
pause_entity = self.workflow_run_service._workflow_run_repo.get_workflow_pause(self.test_workflow_run_id)
|
||||
assert pause_entity is not None
|
||||
assert pause_entity.workflow_execution_id == self.test_workflow_run_id
|
||||
assert pause_entity.get_pause_reasons() == event.reasons
|
||||
|
||||
state_bytes = pause_entity.get_state()
|
||||
resumption_context = WorkflowResumptionContext.loads(state_bytes.decode())
|
||||
@@ -414,7 +415,7 @@ class TestPauseStatePersistenceLayerTestContainers:
|
||||
command_channel = _TestCommandChannelImpl()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||
event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")])
|
||||
|
||||
# Act
|
||||
layer.on_event(event)
|
||||
@@ -448,7 +449,7 @@ class TestPauseStatePersistenceLayerTestContainers:
|
||||
command_channel = _TestCommandChannelImpl()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||
event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")])
|
||||
|
||||
# Act
|
||||
layer.on_event(event)
|
||||
@@ -514,7 +515,7 @@ class TestPauseStatePersistenceLayerTestContainers:
|
||||
command_channel = _TestCommandChannelImpl()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||
event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")])
|
||||
|
||||
# Act
|
||||
layer.on_event(event)
|
||||
@@ -570,7 +571,7 @@ class TestPauseStatePersistenceLayerTestContainers:
|
||||
layer = self._create_pause_state_persistence_layer()
|
||||
# Don't initialize - graph_runtime_state should not be set
|
||||
|
||||
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||
event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")])
|
||||
|
||||
# Act & Assert - Should raise AttributeError
|
||||
with pytest.raises(AttributeError):
|
||||
|
||||
@@ -334,12 +334,14 @@ class TestWorkflowPauseIntegration:
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
# Assert - Pause state created
|
||||
assert pause_entity is not None
|
||||
assert pause_entity.id is not None
|
||||
assert pause_entity.workflow_execution_id == workflow_run.id
|
||||
assert list(pause_entity.get_pause_reasons()) == []
|
||||
# Convert both to strings for comparison
|
||||
retrieved_state = pause_entity.get_state()
|
||||
if isinstance(retrieved_state, bytes):
|
||||
@@ -366,6 +368,7 @@ class TestWorkflowPauseIntegration:
|
||||
if isinstance(retrieved_state, bytes):
|
||||
retrieved_state = retrieved_state.decode()
|
||||
assert retrieved_state == test_state
|
||||
assert list(retrieved_entity.get_pause_reasons()) == []
|
||||
|
||||
# Act - Resume workflow
|
||||
resumed_entity = repository.resume_workflow_pause(
|
||||
@@ -402,6 +405,7 @@ class TestWorkflowPauseIntegration:
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
assert pause_entity is not None
|
||||
@@ -432,6 +436,7 @@ class TestWorkflowPauseIntegration:
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("test_case", resume_workflow_success_cases(), ids=lambda tc: tc.name)
|
||||
@@ -449,6 +454,7 @@ class TestWorkflowPauseIntegration:
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
self.session.refresh(workflow_run)
|
||||
@@ -480,6 +486,7 @@ class TestWorkflowPauseIntegration:
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
self.session.refresh(workflow_run)
|
||||
@@ -503,6 +510,7 @@ class TestWorkflowPauseIntegration:
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
pause_model.resumed_at = naive_utc_now()
|
||||
@@ -530,6 +538,7 @@ class TestWorkflowPauseIntegration:
|
||||
workflow_run_id=nonexistent_id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
def test_resume_nonexistent_workflow_run(self):
|
||||
@@ -543,6 +552,7 @@ class TestWorkflowPauseIntegration:
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
nonexistent_id = str(uuid.uuid4())
|
||||
@@ -570,6 +580,7 @@ class TestWorkflowPauseIntegration:
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
# Manually adjust timestamps for testing
|
||||
@@ -648,6 +659,7 @@ class TestWorkflowPauseIntegration:
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
pause_entities.append(pause_entity)
|
||||
|
||||
@@ -750,6 +762,7 @@ class TestWorkflowPauseIntegration:
|
||||
workflow_run_id=workflow_run1.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
# Try to access pause from tenant 2 using tenant 1's repository
|
||||
@@ -762,6 +775,7 @@ class TestWorkflowPauseIntegration:
|
||||
workflow_run_id=workflow_run2.id,
|
||||
state_owner_user_id=account2.id,
|
||||
state=test_state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
# Assert - Both pauses should exist and be separate
|
||||
@@ -782,6 +796,7 @@ class TestWorkflowPauseIntegration:
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
# Verify pause is properly scoped
|
||||
@@ -802,6 +817,7 @@ class TestWorkflowPauseIntegration:
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
# Assert - Verify file was uploaded to storage
|
||||
@@ -828,9 +844,7 @@ class TestWorkflowPauseIntegration:
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
workflow_run_id=workflow_run.id, state_owner_user_id=self.test_user_id, state=test_state, pause_reasons=[]
|
||||
)
|
||||
|
||||
# Get file info before deletion
|
||||
@@ -868,6 +882,7 @@ class TestWorkflowPauseIntegration:
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=large_state_json,
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
# Assert
|
||||
@@ -902,9 +917,7 @@ class TestWorkflowPauseIntegration:
|
||||
|
||||
# Pause
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=state,
|
||||
workflow_run_id=workflow_run.id, state_owner_user_id=self.test_user_id, state=state, pause_reasons=[]
|
||||
)
|
||||
assert pause_entity is not None
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ class TestDataFactory:
|
||||
|
||||
@staticmethod
|
||||
def create_graph_run_paused_event(outputs: dict[str, object] | None = None) -> GraphRunPausedEvent:
|
||||
return GraphRunPausedEvent(reason=SchedulingPause(message="test pause"), outputs=outputs or {})
|
||||
return GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")], outputs=outputs or {})
|
||||
|
||||
@staticmethod
|
||||
def create_graph_run_started_event() -> GraphRunStartedEvent:
|
||||
@@ -255,15 +255,17 @@ class TestPauseStatePersistenceLayer:
|
||||
layer.on_event(event)
|
||||
|
||||
mock_factory.assert_called_once_with(session_factory)
|
||||
mock_repo.create_workflow_pause.assert_called_once_with(
|
||||
workflow_run_id="run-123",
|
||||
state_owner_user_id="owner-123",
|
||||
state=mock_repo.create_workflow_pause.call_args.kwargs["state"],
|
||||
)
|
||||
serialized_state = mock_repo.create_workflow_pause.call_args.kwargs["state"]
|
||||
assert mock_repo.create_workflow_pause.call_count == 1
|
||||
call_kwargs = mock_repo.create_workflow_pause.call_args.kwargs
|
||||
assert call_kwargs["workflow_run_id"] == "run-123"
|
||||
assert call_kwargs["state_owner_user_id"] == "owner-123"
|
||||
serialized_state = call_kwargs["state"]
|
||||
resumption_context = WorkflowResumptionContext.loads(serialized_state)
|
||||
assert resumption_context.serialized_graph_runtime_state == expected_state
|
||||
assert resumption_context.get_generate_entity().model_dump() == generate_entity.model_dump()
|
||||
pause_reasons = call_kwargs["pause_reasons"]
|
||||
|
||||
assert isinstance(pause_reasons, list)
|
||||
|
||||
def test_on_event_ignores_non_paused_events(self, monkeypatch: pytest.MonkeyPatch):
|
||||
session_factory = Mock(name="session_factory")
|
||||
|
||||
@@ -19,38 +19,18 @@ class TestPrivateWorkflowPauseEntity:
|
||||
mock_pause_model.resumed_at = None
|
||||
|
||||
# Create entity
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
|
||||
|
||||
# Verify initialization
|
||||
assert entity._pause_model is mock_pause_model
|
||||
assert entity._cached_state is None
|
||||
|
||||
def test_from_models_classmethod(self):
|
||||
"""Test from_models class method."""
|
||||
# Create mock models
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.id = "pause-123"
|
||||
mock_pause_model.workflow_run_id = "execution-456"
|
||||
|
||||
# Create entity using from_models
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(
|
||||
workflow_pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# Verify entity creation
|
||||
assert isinstance(entity, _PrivateWorkflowPauseEntity)
|
||||
assert entity._pause_model is mock_pause_model
|
||||
|
||||
def test_id_property(self):
|
||||
"""Test id property returns pause model ID."""
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.id = "pause-123"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
|
||||
|
||||
assert entity.id == "pause-123"
|
||||
|
||||
@@ -59,9 +39,7 @@ class TestPrivateWorkflowPauseEntity:
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.workflow_run_id = "execution-456"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
|
||||
|
||||
assert entity.workflow_execution_id == "execution-456"
|
||||
|
||||
@@ -72,9 +50,7 @@ class TestPrivateWorkflowPauseEntity:
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.resumed_at = resumed_at
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
|
||||
|
||||
assert entity.resumed_at == resumed_at
|
||||
|
||||
@@ -83,9 +59,7 @@ class TestPrivateWorkflowPauseEntity:
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.resumed_at = None
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
|
||||
|
||||
assert entity.resumed_at is None
|
||||
|
||||
@@ -98,9 +72,7 @@ class TestPrivateWorkflowPauseEntity:
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.state_object_key = "test-state-key"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
|
||||
|
||||
# First call should load from storage
|
||||
result = entity.get_state()
|
||||
@@ -118,9 +90,7 @@ class TestPrivateWorkflowPauseEntity:
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.state_object_key = "test-state-key"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
|
||||
|
||||
# First call
|
||||
result1 = entity.get_state()
|
||||
@@ -139,9 +109,7 @@ class TestPrivateWorkflowPauseEntity:
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
|
||||
|
||||
# Pre-cache data
|
||||
entity._cached_state = state_data
|
||||
@@ -162,9 +130,7 @@ class TestPrivateWorkflowPauseEntity:
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
|
||||
|
||||
result = entity.get_state()
|
||||
|
||||
|
||||
@@ -8,12 +8,13 @@ from typing import Any
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph.validation import GraphValidationError
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
||||
@@ -178,8 +178,7 @@ def test_pause_command():
|
||||
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
|
||||
pause_events = [e for e in events if isinstance(e, GraphRunPausedEvent)]
|
||||
assert len(pause_events) == 1
|
||||
assert pause_events[0].reason == SchedulingPause(message="User requested pause")
|
||||
assert pause_events[0].reasons == [SchedulingPause(message="User requested pause")]
|
||||
|
||||
graph_execution = engine.graph_runtime_state.graph_execution
|
||||
assert graph_execution.paused
|
||||
assert graph_execution.pause_reason == SchedulingPause(message="User requested pause")
|
||||
assert graph_execution.pause_reasons == [SchedulingPause(message="User requested pause")]
|
||||
|
||||
@@ -6,10 +6,10 @@ from unittest.mock import Mock, patch
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.workflow.entities.workflow_pause import WorkflowPauseEntity
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from models.workflow import WorkflowPause as WorkflowPauseModel
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import (
|
||||
DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
_PrivateWorkflowPauseEntity,
|
||||
@@ -129,12 +129,14 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
workflow_run_id=workflow_run_id,
|
||||
state_owner_user_id=state_owner_user_id,
|
||||
state=state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, _PrivateWorkflowPauseEntity)
|
||||
assert result.id == "pause-123"
|
||||
assert result.workflow_execution_id == workflow_run_id
|
||||
assert result.get_pause_reasons() == []
|
||||
|
||||
# Verify database interactions
|
||||
mock_session.get.assert_called_once_with(WorkflowRun, workflow_run_id)
|
||||
@@ -156,6 +158,7 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
workflow_run_id="workflow-run-123",
|
||||
state_owner_user_id="user-123",
|
||||
state='{"test": "state"}',
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
mock_session.get.assert_called_once_with(WorkflowRun, "workflow-run-123")
|
||||
@@ -174,6 +177,7 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
workflow_run_id="workflow-run-123",
|
||||
state_owner_user_id="user-123",
|
||||
state='{"test": "state"}',
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
|
||||
@@ -316,19 +320,10 @@ class TestDeleteWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
"""Test _PrivateWorkflowPauseEntity class."""
|
||||
|
||||
def test_from_models(self, sample_workflow_pause: Mock):
|
||||
"""Test creating _PrivateWorkflowPauseEntity from models."""
|
||||
# Act
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
|
||||
|
||||
# Assert
|
||||
assert isinstance(entity, _PrivateWorkflowPauseEntity)
|
||||
assert entity._pause_model == sample_workflow_pause
|
||||
|
||||
def test_properties(self, sample_workflow_pause: Mock):
|
||||
"""Test entity properties."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
|
||||
|
||||
# Act & Assert
|
||||
assert entity.id == sample_workflow_pause.id
|
||||
@@ -338,7 +333,7 @@ class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository)
|
||||
def test_get_state(self, sample_workflow_pause: Mock):
|
||||
"""Test getting state from storage."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
|
||||
expected_state = b'{"test": "state"}'
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
@@ -354,7 +349,7 @@ class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository)
|
||||
def test_get_state_caching(self, sample_workflow_pause: Mock):
|
||||
"""Test state caching in get_state method."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
|
||||
expected_state = b'{"test": "state"}'
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
|
||||
@@ -17,6 +17,7 @@ from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from models.workflow import WorkflowPause
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity
|
||||
from services.workflow_run_service import (
|
||||
@@ -63,7 +64,7 @@ class TestDataFactory:
|
||||
**kwargs,
|
||||
) -> MagicMock:
|
||||
"""Create a mock WorkflowPauseModel object."""
|
||||
mock_pause = MagicMock()
|
||||
mock_pause = MagicMock(spec=WorkflowPause)
|
||||
mock_pause.id = id
|
||||
mock_pause.tenant_id = tenant_id
|
||||
mock_pause.app_id = app_id
|
||||
@@ -77,38 +78,15 @@ class TestDataFactory:
|
||||
|
||||
return mock_pause
|
||||
|
||||
@staticmethod
|
||||
def create_upload_file_mock(
|
||||
id: str = "file-456",
|
||||
key: str = "upload_files/test/state.json",
|
||||
name: str = "state.json",
|
||||
tenant_id: str = "tenant-456",
|
||||
**kwargs,
|
||||
) -> MagicMock:
|
||||
"""Create a mock UploadFile object."""
|
||||
mock_file = MagicMock()
|
||||
mock_file.id = id
|
||||
mock_file.key = key
|
||||
mock_file.name = name
|
||||
mock_file.tenant_id = tenant_id
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(mock_file, key, value)
|
||||
|
||||
return mock_file
|
||||
|
||||
@staticmethod
|
||||
def create_pause_entity_mock(
|
||||
pause_model: MagicMock | None = None,
|
||||
upload_file: MagicMock | None = None,
|
||||
) -> _PrivateWorkflowPauseEntity:
|
||||
"""Create a mock _PrivateWorkflowPauseEntity object."""
|
||||
if pause_model is None:
|
||||
pause_model = TestDataFactory.create_workflow_pause_mock()
|
||||
if upload_file is None:
|
||||
upload_file = TestDataFactory.create_upload_file_mock()
|
||||
|
||||
return _PrivateWorkflowPauseEntity.from_models(pause_model, upload_file)
|
||||
return _PrivateWorkflowPauseEntity(pause_model=pause_model, reason_models=[], human_input_form=[])
|
||||
|
||||
|
||||
class TestWorkflowRunService:
|
||||
|
||||
Reference in New Issue
Block a user