feat: Human Input Node (#32060)

The frontend and backend implementation for the human input node.

Co-authored-by: twwu <twwu@dify.ai>
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
Co-authored-by: zhsama <torvalds@linux.do>
This commit is contained in:
QuantumGhost
2026-02-09 14:57:23 +08:00
committed by GitHub
parent 56e3a55023
commit a1fc280102
474 changed files with 32667 additions and 2050 deletions

View File

@@ -0,0 +1,3 @@
from .workflow_execute_task import AppExecutionParams, resume_app_execution, workflow_based_app_execution_task
__all__ = ["AppExecutionParams", "resume_app_execution", "workflow_based_app_execution_task"]

View File

@@ -0,0 +1,491 @@
import contextlib
import logging
import uuid
from collections.abc import Generator, Mapping
from enum import StrEnum
from typing import Annotated, Any, TypeAlias, Union
from celery import shared_task
from flask import current_app, json
from pydantic import BaseModel, Discriminator, Field, Tag
from sqlalchemy import Engine, select
from sqlalchemy.orm import Session, sessionmaker
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
WorkflowAppGenerateEntity,
)
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.runtime import GraphRuntimeState
from extensions.ext_database import db
from libs.flask_utils import set_login_user
from models.account import Account
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
from models.model import App, AppMode, Conversation, EndUser, Message
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom, WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
logger = logging.getLogger(__name__)
WORKFLOW_BASED_APP_EXECUTION_QUEUE = "workflow_based_app_execution"
class _UserType(StrEnum):
ACCOUNT = "account"
END_USER = "end_user"
class _Account(BaseModel):
TYPE: _UserType = _UserType.ACCOUNT
user_id: str
class _EndUser(BaseModel):
TYPE: _UserType = _UserType.END_USER
end_user_id: str
def _get_user_type_descriminator(value: Any):
if isinstance(value, (_Account, _EndUser)):
return value.TYPE
elif isinstance(value, dict):
user_type_str = value.get("TYPE")
if user_type_str is None:
return None
try:
user_type = _UserType(user_type_str)
except ValueError:
return None
return user_type
else:
# return None if the discriminator value isn't found
return None
User: TypeAlias = Annotated[
(Annotated[_Account, Tag(_UserType.ACCOUNT)] | Annotated[_EndUser, Tag(_UserType.END_USER)]),
Discriminator(_get_user_type_descriminator),
]
class AppExecutionParams(BaseModel):
app_id: str
workflow_id: str
tenant_id: str
app_mode: AppMode = AppMode.ADVANCED_CHAT
user: User
args: Mapping[str, Any]
invoke_from: InvokeFrom
streaming: bool = True
call_depth: int = 0
root_node_id: str | None = None
workflow_run_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
@classmethod
def new(
cls,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
root_node_id: str | None = None,
workflow_run_id: str | None = None,
):
user_params: _Account | _EndUser
if isinstance(user, Account):
user_params = _Account(user_id=user.id)
elif isinstance(user, EndUser):
user_params = _EndUser(end_user_id=user.id)
else:
raise AssertionError("this statement should be unreachable.")
return cls(
app_id=app_model.id,
workflow_id=workflow.id,
tenant_id=app_model.tenant_id,
app_mode=AppMode.value_of(app_model.mode),
user=user_params,
args=args,
invoke_from=invoke_from,
streaming=streaming,
call_depth=call_depth,
root_node_id=root_node_id,
workflow_run_id=workflow_run_id or str(uuid.uuid4()),
)
class _AppRunner:
def __init__(self, session_factory: sessionmaker | Engine, exec_params: AppExecutionParams):
if isinstance(session_factory, Engine):
session_factory = sessionmaker(bind=session_factory)
self._session_factory = session_factory
self._exec_params = exec_params
@contextlib.contextmanager
def _session(self):
with self._session_factory(expire_on_commit=False) as session, session.begin():
yield session
@contextlib.contextmanager
def _setup_flask_context(self, user: Account | EndUser):
flask_app = current_app._get_current_object() # type: ignore
with flask_app.app_context():
set_login_user(user)
yield
def run(self):
exec_params = self._exec_params
with self._session() as session:
workflow = session.get(Workflow, exec_params.workflow_id)
if workflow is None:
logger.warning("Workflow %s not found for execution", exec_params.workflow_id)
return None
app = session.get(App, workflow.app_id)
if app is None:
logger.warning("App %s not found for workflow %s", workflow.app_id, exec_params.workflow_id)
return None
pause_config = PauseStateLayerConfig(
session_factory=self._session_factory,
state_owner_user_id=workflow.created_by,
)
user = self._resolve_user()
with self._setup_flask_context(user):
response = self._run_app(
app=app,
workflow=workflow,
user=user,
pause_state_config=pause_config,
)
if not exec_params.streaming:
return response
assert isinstance(response, Generator)
_publish_streaming_response(response, exec_params.workflow_run_id, exec_params.app_mode)
def _run_app(
self,
*,
app: App,
workflow: Workflow,
user: Account | EndUser,
pause_state_config: PauseStateLayerConfig,
):
exec_params = self._exec_params
if exec_params.app_mode == AppMode.ADVANCED_CHAT:
return AdvancedChatAppGenerator().generate(
app_model=app,
workflow=workflow,
user=user,
args=exec_params.args,
invoke_from=exec_params.invoke_from,
streaming=exec_params.streaming,
workflow_run_id=exec_params.workflow_run_id,
pause_state_config=pause_state_config,
)
if exec_params.app_mode == AppMode.WORKFLOW:
return WorkflowAppGenerator().generate(
app_model=app,
workflow=workflow,
user=user,
args=exec_params.args,
invoke_from=exec_params.invoke_from,
streaming=exec_params.streaming,
call_depth=exec_params.call_depth,
root_node_id=exec_params.root_node_id,
workflow_run_id=exec_params.workflow_run_id,
pause_state_config=pause_state_config,
)
logger.error("Unsupported app mode for execution: %s", exec_params.app_mode)
return None
def _resolve_user(self) -> Account | EndUser:
user_params = self._exec_params.user
if isinstance(user_params, _EndUser):
with self._session() as session:
return session.get(EndUser, user_params.end_user_id)
elif not isinstance(user_params, _Account):
raise AssertionError(f"user should only be _Account or _EndUser, got {type(user_params)}")
with self._session() as session:
user: Account = session.get(Account, user_params.user_id)
user.set_tenant_id(self._exec_params.tenant_id)
return user
def _resolve_user_for_run(session: Session, workflow_run: WorkflowRun) -> Account | EndUser | None:
role = CreatorUserRole(workflow_run.created_by_role)
if role == CreatorUserRole.ACCOUNT:
user = session.get(Account, workflow_run.created_by)
if user:
user.set_tenant_id(workflow_run.tenant_id)
return user
return session.get(EndUser, workflow_run.created_by)
def _publish_streaming_response(
response_stream: Generator[str | Mapping[str, Any], None, None], workflow_run_id: str, app_mode: AppMode
) -> None:
topic = MessageBasedAppGenerator.get_response_topic(app_mode, workflow_run_id)
for event in response_stream:
try:
payload = json.dumps(event)
except TypeError:
logger.exception("error while encoding event")
continue
topic.publish(payload.encode())
@shared_task(queue=WORKFLOW_BASED_APP_EXECUTION_QUEUE)
def workflow_based_app_execution_task(
payload: str,
) -> Generator[Mapping[str, Any] | str, None, None] | Mapping[str, Any] | None:
exec_params = AppExecutionParams.model_validate_json(payload)
logger.info("workflow_based_app_execution_task run with params: %s", exec_params)
runner = _AppRunner(db.engine, exec_params=exec_params)
return runner.run()
def _resume_app_execution(payload: dict[str, Any]) -> None:
workflow_run_id = payload["workflow_run_id"]
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker=session_factory)
pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id)
if pause_entity is None:
logger.warning("No pause entity found for workflow run %s", workflow_run_id)
return
try:
resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode())
except Exception:
logger.exception("Failed to load resumption context for workflow run %s", workflow_run_id)
return
generate_entity = resumption_context.get_generate_entity()
graph_runtime_state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state)
conversation = None
message = None
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = session.get(WorkflowRun, workflow_run_id)
if workflow_run is None:
logger.warning("Workflow run %s not found during resume", workflow_run_id)
return
workflow = session.get(Workflow, workflow_run.workflow_id)
if workflow is None:
logger.warning("Workflow %s not found during resume", workflow_run.workflow_id)
return
app_model = session.get(App, workflow_run.app_id)
if app_model is None:
logger.warning("App %s not found during resume", workflow_run.app_id)
return
user = _resolve_user_for_run(session, workflow_run)
if user is None:
logger.warning("User %s not found for workflow run %s", workflow_run.created_by, workflow_run_id)
return
if isinstance(generate_entity, AdvancedChatAppGenerateEntity):
if generate_entity.conversation_id is None:
logger.warning("Conversation id missing in resumption context for workflow run %s", workflow_run_id)
return
conversation = session.get(Conversation, generate_entity.conversation_id)
if conversation is None:
logger.warning(
"Conversation %s not found for workflow run %s", generate_entity.conversation_id, workflow_run_id
)
return
message = session.scalar(
select(Message).where(Message.workflow_run_id == workflow_run_id).order_by(Message.created_at.desc())
)
if message is None:
logger.warning("Message not found for workflow run %s", workflow_run_id)
return
if not isinstance(generate_entity, (AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity)):
logger.error(
"Unsupported resumption entity for workflow run %s (found %s)",
workflow_run_id,
type(generate_entity),
)
return
workflow_run_repo.resume_workflow_pause(workflow_run_id, pause_entity)
pause_config = PauseStateLayerConfig(
session_factory=session_factory,
state_owner_user_id=workflow.created_by,
)
if isinstance(generate_entity, AdvancedChatAppGenerateEntity):
assert conversation is not None
assert message is not None
_resume_advanced_chat(
app_model=app_model,
workflow=workflow,
user=user,
conversation=conversation,
message=message,
generate_entity=generate_entity,
graph_runtime_state=graph_runtime_state,
session_factory=session_factory,
pause_state_config=pause_config,
workflow_run_id=workflow_run_id,
workflow_run=workflow_run,
)
elif isinstance(generate_entity, WorkflowAppGenerateEntity):
_resume_workflow(
app_model=app_model,
workflow=workflow,
user=user,
generate_entity=generate_entity,
graph_runtime_state=graph_runtime_state,
session_factory=session_factory,
pause_state_config=pause_config,
workflow_run_id=workflow_run_id,
workflow_run=workflow_run,
workflow_run_repo=workflow_run_repo,
pause_entity=pause_entity,
)
def _resume_advanced_chat(
*,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
conversation: Conversation,
message: Message,
generate_entity: AdvancedChatAppGenerateEntity,
graph_runtime_state: GraphRuntimeState,
session_factory: sessionmaker,
pause_state_config: PauseStateLayerConfig,
workflow_run_id: str,
workflow_run: WorkflowRun,
) -> None:
try:
triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from)
except ValueError:
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=app_model.id,
triggered_from=triggered_from,
)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=app_model.id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
generator = AdvancedChatAppGenerator()
try:
response = generator.resume(
app_model=app_model,
workflow=workflow,
user=user,
conversation=conversation,
message=message,
application_generate_entity=generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
graph_runtime_state=graph_runtime_state,
pause_state_config=pause_state_config,
)
except Exception:
logger.exception("Failed to resume chatflow execution for workflow run %s", workflow_run_id)
raise
if generate_entity.stream:
assert isinstance(response, Generator)
_publish_streaming_response(response, workflow_run_id, AppMode.ADVANCED_CHAT)
def _resume_workflow(
*,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
generate_entity: WorkflowAppGenerateEntity,
graph_runtime_state: GraphRuntimeState,
session_factory: sessionmaker,
pause_state_config: PauseStateLayerConfig,
workflow_run_id: str,
workflow_run: WorkflowRun,
workflow_run_repo,
pause_entity,
) -> None:
try:
triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from)
except ValueError:
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=app_model.id,
triggered_from=triggered_from,
)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=app_model.id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
generator = WorkflowAppGenerator()
try:
response = generator.resume(
app_model=app_model,
workflow=workflow,
user=user,
application_generate_entity=generate_entity,
graph_runtime_state=graph_runtime_state,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
pause_state_config=pause_state_config,
)
except Exception:
logger.exception("Failed to resume workflow execution for workflow run %s", workflow_run_id)
raise
if generate_entity.stream:
assert isinstance(response, Generator)
_publish_streaming_response(response, workflow_run_id, AppMode.WORKFLOW)
workflow_run_repo.delete_workflow_pause(pause_entity)
@shared_task(queue=WORKFLOW_BASED_APP_EXECUTION_QUEUE, name="resume_app_execution")
def resume_app_execution(payload: dict[str, Any]) -> None:
_resume_app_execution(payload)

View File

@@ -5,32 +5,42 @@ These tasks handle workflow execution for different subscription tiers
with appropriate retry policies and error handling.
"""
import logging
from datetime import UTC, datetime
from typing import Any
from celery import shared_task
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext
from core.app.layers.timeslice_layer import TimeSliceLayer
from core.app.layers.trigger_post_layer import TriggerPostLayer
from core.db.session_factory import session_factory
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.runtime import GraphRuntimeState
from extensions.ext_database import db
from models.account import Account
from models.enums import CreatorUserRole, WorkflowTriggerStatus
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus
from models.model import App, EndUser, Tenant
from models.trigger import WorkflowTriggerLog
from models.workflow import Workflow
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom, WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from services.errors.app import WorkflowNotFoundError
from services.workflow.entities import (
TriggerData,
WorkflowResumeTaskData,
WorkflowTaskData,
)
from tasks.workflow_cfs_scheduler.cfs_scheduler import AsyncWorkflowCFSPlanEntity, AsyncWorkflowCFSPlanScheduler
from tasks.workflow_cfs_scheduler.entities import AsyncWorkflowQueue, AsyncWorkflowSystemStrategy
logger = logging.getLogger(__name__)
@shared_task(queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE)
def execute_workflow_professional(task_data_dict: dict[str, Any]):
@@ -141,6 +151,11 @@ def _execute_workflow_common(
if trigger_data.workflow_id:
args["workflow_id"] = str(trigger_data.workflow_id)
pause_config = PauseStateLayerConfig(
session_factory=session_factory.get_session_maker(),
state_owner_user_id=workflow.created_by,
)
# Execute the workflow with the trigger type
generator.generate(
app_model=app_model,
@@ -156,6 +171,7 @@ def _execute_workflow_common(
# TODO: Re-enable TimeSliceLayer after the HITL release.
TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id),
],
pause_state_config=pause_config,
)
except Exception as e:
@@ -173,21 +189,153 @@ def _execute_workflow_common(
session.commit()
def _get_user(session: Session, trigger_log: WorkflowTriggerLog) -> Account | EndUser:
@shared_task(name="resume_workflow_execution")
def resume_workflow_execution(task_data_dict: dict[str, Any]) -> None:
"""Resume a paused workflow run via Celery."""
task_data = WorkflowResumeTaskData.model_validate(task_data_dict)
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_factory)
pause_entity = workflow_run_repo.get_workflow_pause(task_data.workflow_run_id)
if pause_entity is None:
logger.warning("No pause state for workflow run %s", task_data.workflow_run_id)
return
workflow_run = workflow_run_repo.get_workflow_run_by_id_without_tenant(pause_entity.workflow_execution_id)
if workflow_run is None:
logger.warning("Workflow run not found for pause entity: pause_entity_id=%s", pause_entity.id)
return
try:
resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode())
except Exception as exc:
logger.exception("Failed to load resumption context for workflow run %s", task_data.workflow_run_id)
raise exc
generate_entity = resumption_context.get_generate_entity()
if not isinstance(generate_entity, WorkflowAppGenerateEntity):
logger.error(
"Unsupported resumption entity for workflow run %s: %s",
task_data.workflow_run_id,
type(generate_entity),
)
return
graph_runtime_state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state)
with session_factory() as session:
workflow = session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
if workflow is None:
raise WorkflowNotFoundError(
"Workflow not found: workflow_run_id=%s, workflow_id=%s", workflow_run.id, workflow_run.workflow_id
)
user = _get_user(session, workflow_run)
app_model = session.scalar(select(App).where(App.id == workflow_run.app_id))
if app_model is None:
raise _AppNotFoundError(
"App not found: app_id=%s, workflow_run_id=%s", workflow_run.app_id, workflow_run.id
)
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom(workflow_run.triggered_from),
)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
pause_config = PauseStateLayerConfig(
session_factory=session_factory,
state_owner_user_id=workflow.created_by,
)
generator = WorkflowAppGenerator()
start_time = datetime.now(UTC)
graph_engine_layers = []
trigger_log = _query_trigger_log_info(session_factory, task_data.workflow_run_id)
if trigger_log:
cfs_plan_scheduler_entity = AsyncWorkflowCFSPlanEntity(
queue=AsyncWorkflowQueue(trigger_log.queue_name),
schedule_strategy=AsyncWorkflowSystemStrategy,
granularity=dify_config.ASYNC_WORKFLOW_SCHEDULER_GRANULARITY,
)
cfs_plan_scheduler = AsyncWorkflowCFSPlanScheduler(plan=cfs_plan_scheduler_entity)
graph_engine_layers.extend(
[
TimeSliceLayer(cfs_plan_scheduler),
TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id),
]
)
workflow_run_repo.resume_workflow_pause(task_data.workflow_run_id, pause_entity)
generator.resume(
app_model=app_model,
workflow=workflow,
user=user,
application_generate_entity=generate_entity,
graph_runtime_state=graph_runtime_state,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
graph_engine_layers=graph_engine_layers,
pause_state_config=pause_config,
)
workflow_run_repo.delete_workflow_pause(pause_entity)
def _get_user(session: Session, workflow_run: WorkflowRun | WorkflowTriggerLog) -> Account | EndUser:
"""Compose user from trigger log"""
tenant = session.scalar(select(Tenant).where(Tenant.id == trigger_log.tenant_id))
tenant = session.scalar(select(Tenant).where(Tenant.id == workflow_run.tenant_id))
if not tenant:
raise ValueError(f"Tenant not found: {trigger_log.tenant_id}")
raise _TenantNotFoundError(
"Tenant not found for WorkflowRun: tenant_id=%s, workflow_run_id=%s",
workflow_run.tenant_id,
workflow_run.id,
)
# Get user from trigger log
if trigger_log.created_by_role == CreatorUserRole.ACCOUNT:
user = session.scalar(select(Account).where(Account.id == trigger_log.created_by))
if workflow_run.created_by_role == CreatorUserRole.ACCOUNT:
user = session.scalar(select(Account).where(Account.id == workflow_run.created_by))
if user:
user.current_tenant = tenant
else: # CreatorUserRole.END_USER
user = session.scalar(select(EndUser).where(EndUser.id == trigger_log.created_by))
user = session.scalar(select(EndUser).where(EndUser.id == workflow_run.created_by))
if not user:
raise ValueError(f"User not found: {trigger_log.created_by} (role: {trigger_log.created_by_role})")
raise _UserNotFoundError(
"User not found: user_id=%s, created_by_role=%s, workflow_run_id=%s",
workflow_run.created_by,
workflow_run.created_by_role,
workflow_run.id,
)
return user
def _query_trigger_log_info(session_factory: sessionmaker[Session], workflow_run_id) -> WorkflowTriggerLog | None:
with session_factory() as session, session.begin():
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
trigger_log = trigger_log_repo.get_by_workflow_run_id(workflow_run_id)
if not trigger_log:
logger.debug("Trigger log not found for workflow_run: workflow_run_id=%s", workflow_run_id)
return None
return trigger_log
class _TenantNotFoundError(Exception):
pass
class _UserNotFoundError(Exception):
pass
class _AppNotFoundError(Exception):
pass

View File

@@ -0,0 +1,113 @@
import logging
from datetime import timedelta
from celery import shared_task
from sqlalchemy import or_, select
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from core.repositories.human_input_repository import HumanInputFormSubmissionRepository
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
from extensions.ext_database import db
from extensions.ext_storage import storage
from libs.datetime_utils import ensure_naive_utc, naive_utc_now
from models.human_input import HumanInputForm
from models.workflow import WorkflowPause, WorkflowRun
from services.human_input_service import HumanInputService
logger = logging.getLogger(__name__)
def _is_global_timeout(form_model: HumanInputForm, global_timeout_seconds: int, *, now) -> bool:
if global_timeout_seconds <= 0:
return False
if form_model.workflow_run_id is None:
return False
created_at = ensure_naive_utc(form_model.created_at)
global_deadline = created_at + timedelta(seconds=global_timeout_seconds)
return global_deadline <= now
def _handle_global_timeout(*, form_id: str, workflow_run_id: str, node_id: str, session_factory: sessionmaker) -> None:
now = naive_utc_now()
with session_factory() as session, session.begin():
workflow_run = session.get(WorkflowRun, workflow_run_id)
if workflow_run is not None:
workflow_run.status = WorkflowExecutionStatus.STOPPED
workflow_run.error = f"Human input global timeout at node {node_id}"
workflow_run.finished_at = now
session.add(workflow_run)
pause_model = session.scalar(select(WorkflowPause).where(WorkflowPause.workflow_run_id == workflow_run_id))
if pause_model is not None:
try:
storage.delete(pause_model.state_object_key)
except Exception:
logger.exception(
"Failed to delete pause state object for workflow_run_id=%s, pause_id=%s",
workflow_run_id,
pause_model.id,
)
pause_model.resumed_at = now
session.add(pause_model)
@shared_task(name="human_input_form_timeout.check_and_resume", queue="schedule_executor")
def check_and_handle_human_input_timeouts(limit: int = 100) -> None:
"""Scan for expired human input forms and resume or end workflows."""
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
form_repo = HumanInputFormSubmissionRepository(session_factory)
service = HumanInputService(session_factory, form_repository=form_repo)
now = naive_utc_now()
global_timeout_seconds = dify_config.HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS
with session_factory() as session:
global_deadline = now - timedelta(seconds=global_timeout_seconds) if global_timeout_seconds > 0 else None
timeout_filter = HumanInputForm.expiration_time <= now
if global_deadline is not None:
timeout_filter = or_(timeout_filter, HumanInputForm.created_at <= global_deadline)
stmt = (
select(HumanInputForm)
.where(
HumanInputForm.status == HumanInputFormStatus.WAITING,
timeout_filter,
)
.order_by(HumanInputForm.id.asc())
.limit(limit)
)
expired_forms = session.scalars(stmt).all()
for form_model in expired_forms:
try:
if form_model.form_kind == HumanInputFormKind.DELIVERY_TEST:
form_repo.mark_timeout(
form_id=form_model.id,
timeout_status=HumanInputFormStatus.TIMEOUT,
reason="delivery_test_timeout",
)
continue
is_global = _is_global_timeout(form_model, global_timeout_seconds, now=now)
record = form_repo.mark_timeout(
form_id=form_model.id,
timeout_status=HumanInputFormStatus.EXPIRED if is_global else HumanInputFormStatus.TIMEOUT,
reason="global_timeout" if is_global else "node_timeout",
)
assert record.workflow_run_id is not None, "workflow_run_id should not be None for non-test form"
if is_global:
_handle_global_timeout(
form_id=record.form_id,
workflow_run_id=record.workflow_run_id,
node_id=record.node_id,
session_factory=session_factory,
)
else:
service.enqueue_resume(record.workflow_run_id)
except Exception:
logger.exception(
"Failed to handle timeout for form_id=%s workflow_run_id=%s",
form_model.id,
form_model.workflow_run_id,
)

View File

@@ -0,0 +1,190 @@
import json
import logging
import time
from dataclasses import dataclass
from typing import Any
import click
from celery import shared_task
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
from core.workflow.nodes.human_input.entities import EmailDeliveryConfig, EmailDeliveryMethod
from core.workflow.runtime import GraphRuntimeState, VariablePool
from extensions.ext_database import db
from extensions.ext_mail import mail
from models.human_input import (
DeliveryMethodType,
HumanInputDelivery,
HumanInputForm,
HumanInputFormRecipient,
RecipientType,
)
from repositories.factory import DifyAPIRepositoryFactory
from services.feature_service import FeatureService
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class _EmailRecipient:
email: str
token: str
@dataclass(frozen=True)
class _EmailDeliveryJob:
form_id: str
subject: str
body: str
form_content: str
recipients: list[_EmailRecipient]
def _build_form_link(token: str) -> str:
base_url = dify_config.APP_WEB_URL
return f"{base_url.rstrip('/')}/form/{token}"
def _parse_recipient_payload(payload: str) -> tuple[str | None, RecipientType | None]:
try:
payload_dict: dict[str, Any] = json.loads(payload)
except Exception:
logger.exception("Failed to parse recipient payload")
return None, None
return payload_dict.get("email"), payload_dict.get("TYPE")
def _load_email_jobs(session: Session, form: HumanInputForm) -> list[_EmailDeliveryJob]:
deliveries = session.scalars(
select(HumanInputDelivery).where(
HumanInputDelivery.form_id == form.id,
HumanInputDelivery.delivery_method_type == DeliveryMethodType.EMAIL,
)
).all()
jobs: list[_EmailDeliveryJob] = []
for delivery in deliveries:
delivery_config = EmailDeliveryMethod.model_validate_json(delivery.channel_payload)
recipients = session.scalars(
select(HumanInputFormRecipient).where(HumanInputFormRecipient.delivery_id == delivery.id)
).all()
recipient_entities: list[_EmailRecipient] = []
for recipient in recipients:
email, recipient_type = _parse_recipient_payload(recipient.recipient_payload)
if recipient_type not in {RecipientType.EMAIL_MEMBER, RecipientType.EMAIL_EXTERNAL}:
continue
if not email:
continue
token = recipient.access_token
if not token:
continue
recipient_entities.append(_EmailRecipient(email=email, token=token))
if not recipient_entities:
continue
jobs.append(
_EmailDeliveryJob(
form_id=form.id,
subject=delivery_config.config.subject,
body=delivery_config.config.body,
form_content=form.rendered_content,
recipients=recipient_entities,
)
)
return jobs
def _render_body(
body_template: str,
form_link: str,
*,
variable_pool: VariablePool | None,
) -> str:
body = EmailDeliveryConfig.render_body_template(
body=body_template,
url=form_link,
variable_pool=variable_pool,
)
return body
def _load_variable_pool(workflow_run_id: str | None) -> VariablePool | None:
if not workflow_run_id:
return None
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_factory)
pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id)
if pause_entity is None:
logger.info("No pause state found for workflow run %s", workflow_run_id)
return None
try:
resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode())
except Exception:
logger.exception("Failed to load resumption context for workflow run %s", workflow_run_id)
return None
graph_runtime_state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state)
return graph_runtime_state.variable_pool
def _open_session(session_factory: sessionmaker | Session | None):
if session_factory is None:
return Session(db.engine)
if isinstance(session_factory, Session):
return session_factory
return session_factory()
@shared_task(queue="mail")
def dispatch_human_input_email_task(form_id: str, node_title: str | None = None, session_factory=None):
if not mail.is_inited():
return
logger.info(click.style(f"Start human input email delivery for form {form_id}", fg="green"))
start_at = time.perf_counter()
try:
with _open_session(session_factory) as session:
form = session.get(HumanInputForm, form_id)
if form is None:
logger.warning("Human input form not found, form_id=%s", form_id)
return
features = FeatureService.get_features(form.tenant_id)
if not features.human_input_email_delivery_enabled:
logger.info(
"Human input email delivery is not available for tenant=%s, form_id=%s",
form.tenant_id,
form_id,
)
return
jobs = _load_email_jobs(session, form)
variable_pool = _load_variable_pool(form.workflow_run_id)
for job in jobs:
for recipient in job.recipients:
form_link = _build_form_link(recipient.token)
body = _render_body(job.body, form_link, variable_pool=variable_pool)
mail.send(
to=recipient.email,
subject=job.subject,
html=body,
)
end_at = time.perf_counter()
logger.info(
click.style(
f"Human input email delivery succeeded for form {form_id}: latency: {end_at - start_at}", fg="green"
)
)
except Exception:
logger.exception("Send human input email failed, form_id=%s", form_id)