mirror of
https://github.com/langgenius/dify.git
synced 2026-03-01 19:22:10 -05:00
feat: Human Input Node (#32060)
The frontend and backend implementation for the human input node. Co-authored-by: twwu <twwu@dify.ai> Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com> Co-authored-by: zhsama <torvalds@linux.do>
This commit is contained in:
3
api/tasks/app_generate/__init__.py
Normal file
3
api/tasks/app_generate/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .workflow_execute_task import AppExecutionParams, resume_app_execution, workflow_based_app_execution_task
|
||||
|
||||
__all__ = ["AppExecutionParams", "resume_app_execution", "workflow_based_app_execution_task"]
|
||||
491
api/tasks/app_generate/workflow_execute_task.py
Normal file
491
api/tasks/app_generate/workflow_execute_task.py
Normal file
@@ -0,0 +1,491 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Any, TypeAlias, Union
|
||||
|
||||
from celery import shared_task
|
||||
from flask import current_app, json
|
||||
from pydantic import BaseModel, Discriminator, Field, Tag
|
||||
from sqlalchemy import Engine, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from extensions.ext_database import db
|
||||
from libs.flask_utils import set_login_user
|
||||
from models.account import Account
|
||||
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
|
||||
from models.model import App, AppMode, Conversation, EndUser, Message
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom, WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
WORKFLOW_BASED_APP_EXECUTION_QUEUE = "workflow_based_app_execution"
|
||||
|
||||
|
||||
class _UserType(StrEnum):
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end_user"
|
||||
|
||||
|
||||
class _Account(BaseModel):
|
||||
TYPE: _UserType = _UserType.ACCOUNT
|
||||
|
||||
user_id: str
|
||||
|
||||
|
||||
class _EndUser(BaseModel):
|
||||
TYPE: _UserType = _UserType.END_USER
|
||||
end_user_id: str
|
||||
|
||||
|
||||
def _get_user_type_descriminator(value: Any):
|
||||
if isinstance(value, (_Account, _EndUser)):
|
||||
return value.TYPE
|
||||
elif isinstance(value, dict):
|
||||
user_type_str = value.get("TYPE")
|
||||
if user_type_str is None:
|
||||
return None
|
||||
try:
|
||||
user_type = _UserType(user_type_str)
|
||||
except ValueError:
|
||||
return None
|
||||
return user_type
|
||||
else:
|
||||
# return None if the discriminator value isn't found
|
||||
return None
|
||||
|
||||
|
||||
User: TypeAlias = Annotated[
|
||||
(Annotated[_Account, Tag(_UserType.ACCOUNT)] | Annotated[_EndUser, Tag(_UserType.END_USER)]),
|
||||
Discriminator(_get_user_type_descriminator),
|
||||
]
|
||||
|
||||
|
||||
class AppExecutionParams(BaseModel):
|
||||
app_id: str
|
||||
workflow_id: str
|
||||
tenant_id: str
|
||||
app_mode: AppMode = AppMode.ADVANCED_CHAT
|
||||
user: User
|
||||
args: Mapping[str, Any]
|
||||
|
||||
invoke_from: InvokeFrom
|
||||
streaming: bool = True
|
||||
call_depth: int = 0
|
||||
root_node_id: str | None = None
|
||||
workflow_run_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
|
||||
@classmethod
|
||||
def new(
|
||||
cls,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
root_node_id: str | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
):
|
||||
user_params: _Account | _EndUser
|
||||
if isinstance(user, Account):
|
||||
user_params = _Account(user_id=user.id)
|
||||
elif isinstance(user, EndUser):
|
||||
user_params = _EndUser(end_user_id=user.id)
|
||||
else:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
return cls(
|
||||
app_id=app_model.id,
|
||||
workflow_id=workflow.id,
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
user=user_params,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=streaming,
|
||||
call_depth=call_depth,
|
||||
root_node_id=root_node_id,
|
||||
workflow_run_id=workflow_run_id or str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
|
||||
class _AppRunner:
|
||||
def __init__(self, session_factory: sessionmaker | Engine, exec_params: AppExecutionParams):
|
||||
if isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(bind=session_factory)
|
||||
self._session_factory = session_factory
|
||||
self._exec_params = exec_params
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _session(self):
|
||||
with self._session_factory(expire_on_commit=False) as session, session.begin():
|
||||
yield session
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _setup_flask_context(self, user: Account | EndUser):
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
with flask_app.app_context():
|
||||
set_login_user(user)
|
||||
yield
|
||||
|
||||
def run(self):
|
||||
exec_params = self._exec_params
|
||||
with self._session() as session:
|
||||
workflow = session.get(Workflow, exec_params.workflow_id)
|
||||
if workflow is None:
|
||||
logger.warning("Workflow %s not found for execution", exec_params.workflow_id)
|
||||
return None
|
||||
app = session.get(App, workflow.app_id)
|
||||
if app is None:
|
||||
logger.warning("App %s not found for workflow %s", workflow.app_id, exec_params.workflow_id)
|
||||
return None
|
||||
|
||||
pause_config = PauseStateLayerConfig(
|
||||
session_factory=self._session_factory,
|
||||
state_owner_user_id=workflow.created_by,
|
||||
)
|
||||
|
||||
user = self._resolve_user()
|
||||
|
||||
with self._setup_flask_context(user):
|
||||
response = self._run_app(
|
||||
app=app,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
pause_state_config=pause_config,
|
||||
)
|
||||
if not exec_params.streaming:
|
||||
return response
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
_publish_streaming_response(response, exec_params.workflow_run_id, exec_params.app_mode)
|
||||
|
||||
def _run_app(
|
||||
self,
|
||||
*,
|
||||
app: App,
|
||||
workflow: Workflow,
|
||||
user: Account | EndUser,
|
||||
pause_state_config: PauseStateLayerConfig,
|
||||
):
|
||||
exec_params = self._exec_params
|
||||
if exec_params.app_mode == AppMode.ADVANCED_CHAT:
|
||||
return AdvancedChatAppGenerator().generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=exec_params.args,
|
||||
invoke_from=exec_params.invoke_from,
|
||||
streaming=exec_params.streaming,
|
||||
workflow_run_id=exec_params.workflow_run_id,
|
||||
pause_state_config=pause_state_config,
|
||||
)
|
||||
if exec_params.app_mode == AppMode.WORKFLOW:
|
||||
return WorkflowAppGenerator().generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=exec_params.args,
|
||||
invoke_from=exec_params.invoke_from,
|
||||
streaming=exec_params.streaming,
|
||||
call_depth=exec_params.call_depth,
|
||||
root_node_id=exec_params.root_node_id,
|
||||
workflow_run_id=exec_params.workflow_run_id,
|
||||
pause_state_config=pause_state_config,
|
||||
)
|
||||
|
||||
logger.error("Unsupported app mode for execution: %s", exec_params.app_mode)
|
||||
return None
|
||||
|
||||
def _resolve_user(self) -> Account | EndUser:
|
||||
user_params = self._exec_params.user
|
||||
|
||||
if isinstance(user_params, _EndUser):
|
||||
with self._session() as session:
|
||||
return session.get(EndUser, user_params.end_user_id)
|
||||
elif not isinstance(user_params, _Account):
|
||||
raise AssertionError(f"user should only be _Account or _EndUser, got {type(user_params)}")
|
||||
|
||||
with self._session() as session:
|
||||
user: Account = session.get(Account, user_params.user_id)
|
||||
user.set_tenant_id(self._exec_params.tenant_id)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def _resolve_user_for_run(session: Session, workflow_run: WorkflowRun) -> Account | EndUser | None:
|
||||
role = CreatorUserRole(workflow_run.created_by_role)
|
||||
if role == CreatorUserRole.ACCOUNT:
|
||||
user = session.get(Account, workflow_run.created_by)
|
||||
if user:
|
||||
user.set_tenant_id(workflow_run.tenant_id)
|
||||
return user
|
||||
|
||||
return session.get(EndUser, workflow_run.created_by)
|
||||
|
||||
|
||||
def _publish_streaming_response(
|
||||
response_stream: Generator[str | Mapping[str, Any], None, None], workflow_run_id: str, app_mode: AppMode
|
||||
) -> None:
|
||||
topic = MessageBasedAppGenerator.get_response_topic(app_mode, workflow_run_id)
|
||||
for event in response_stream:
|
||||
try:
|
||||
payload = json.dumps(event)
|
||||
except TypeError:
|
||||
logger.exception("error while encoding event")
|
||||
continue
|
||||
|
||||
topic.publish(payload.encode())
|
||||
|
||||
|
||||
@shared_task(queue=WORKFLOW_BASED_APP_EXECUTION_QUEUE)
|
||||
def workflow_based_app_execution_task(
|
||||
payload: str,
|
||||
) -> Generator[Mapping[str, Any] | str, None, None] | Mapping[str, Any] | None:
|
||||
exec_params = AppExecutionParams.model_validate_json(payload)
|
||||
|
||||
logger.info("workflow_based_app_execution_task run with params: %s", exec_params)
|
||||
|
||||
runner = _AppRunner(db.engine, exec_params=exec_params)
|
||||
return runner.run()
|
||||
|
||||
|
||||
def _resume_app_execution(payload: dict[str, Any]) -> None:
|
||||
workflow_run_id = payload["workflow_run_id"]
|
||||
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker=session_factory)
|
||||
|
||||
pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id)
|
||||
if pause_entity is None:
|
||||
logger.warning("No pause entity found for workflow run %s", workflow_run_id)
|
||||
return
|
||||
|
||||
try:
|
||||
resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode())
|
||||
except Exception:
|
||||
logger.exception("Failed to load resumption context for workflow run %s", workflow_run_id)
|
||||
return
|
||||
|
||||
generate_entity = resumption_context.get_generate_entity()
|
||||
|
||||
graph_runtime_state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state)
|
||||
|
||||
conversation = None
|
||||
message = None
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = session.get(WorkflowRun, workflow_run_id)
|
||||
if workflow_run is None:
|
||||
logger.warning("Workflow run %s not found during resume", workflow_run_id)
|
||||
return
|
||||
|
||||
workflow = session.get(Workflow, workflow_run.workflow_id)
|
||||
if workflow is None:
|
||||
logger.warning("Workflow %s not found during resume", workflow_run.workflow_id)
|
||||
return
|
||||
|
||||
app_model = session.get(App, workflow_run.app_id)
|
||||
if app_model is None:
|
||||
logger.warning("App %s not found during resume", workflow_run.app_id)
|
||||
return
|
||||
|
||||
user = _resolve_user_for_run(session, workflow_run)
|
||||
if user is None:
|
||||
logger.warning("User %s not found for workflow run %s", workflow_run.created_by, workflow_run_id)
|
||||
return
|
||||
|
||||
if isinstance(generate_entity, AdvancedChatAppGenerateEntity):
|
||||
if generate_entity.conversation_id is None:
|
||||
logger.warning("Conversation id missing in resumption context for workflow run %s", workflow_run_id)
|
||||
return
|
||||
|
||||
conversation = session.get(Conversation, generate_entity.conversation_id)
|
||||
if conversation is None:
|
||||
logger.warning(
|
||||
"Conversation %s not found for workflow run %s", generate_entity.conversation_id, workflow_run_id
|
||||
)
|
||||
return
|
||||
|
||||
message = session.scalar(
|
||||
select(Message).where(Message.workflow_run_id == workflow_run_id).order_by(Message.created_at.desc())
|
||||
)
|
||||
if message is None:
|
||||
logger.warning("Message not found for workflow run %s", workflow_run_id)
|
||||
return
|
||||
|
||||
if not isinstance(generate_entity, (AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity)):
|
||||
logger.error(
|
||||
"Unsupported resumption entity for workflow run %s (found %s)",
|
||||
workflow_run_id,
|
||||
type(generate_entity),
|
||||
)
|
||||
return
|
||||
|
||||
workflow_run_repo.resume_workflow_pause(workflow_run_id, pause_entity)
|
||||
|
||||
pause_config = PauseStateLayerConfig(
|
||||
session_factory=session_factory,
|
||||
state_owner_user_id=workflow.created_by,
|
||||
)
|
||||
|
||||
if isinstance(generate_entity, AdvancedChatAppGenerateEntity):
|
||||
assert conversation is not None
|
||||
assert message is not None
|
||||
_resume_advanced_chat(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
generate_entity=generate_entity,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
session_factory=session_factory,
|
||||
pause_state_config=pause_config,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_run=workflow_run,
|
||||
)
|
||||
elif isinstance(generate_entity, WorkflowAppGenerateEntity):
|
||||
_resume_workflow(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
generate_entity=generate_entity,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
session_factory=session_factory,
|
||||
pause_state_config=pause_config,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_run=workflow_run,
|
||||
workflow_run_repo=workflow_run_repo,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
|
||||
def _resume_advanced_chat(
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Account | EndUser,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
generate_entity: AdvancedChatAppGenerateEntity,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
session_factory: sessionmaker,
|
||||
pause_state_config: PauseStateLayerConfig,
|
||||
workflow_run_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
) -> None:
|
||||
try:
|
||||
triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from)
|
||||
except ValueError:
|
||||
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=app_model.id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=app_model.id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
generator = AdvancedChatAppGenerator()
|
||||
|
||||
try:
|
||||
response = generator.resume(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
application_generate_entity=generate_entity,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
pause_state_config=pause_state_config,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to resume chatflow execution for workflow run %s", workflow_run_id)
|
||||
raise
|
||||
|
||||
if generate_entity.stream:
|
||||
assert isinstance(response, Generator)
|
||||
_publish_streaming_response(response, workflow_run_id, AppMode.ADVANCED_CHAT)
|
||||
|
||||
|
||||
def _resume_workflow(
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Account | EndUser,
|
||||
generate_entity: WorkflowAppGenerateEntity,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
session_factory: sessionmaker,
|
||||
pause_state_config: PauseStateLayerConfig,
|
||||
workflow_run_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
workflow_run_repo,
|
||||
pause_entity,
|
||||
) -> None:
|
||||
try:
|
||||
triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from)
|
||||
except ValueError:
|
||||
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=app_model.id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=app_model.id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
generator = WorkflowAppGenerator()
|
||||
|
||||
try:
|
||||
response = generator.resume(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
application_generate_entity=generate_entity,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
pause_state_config=pause_state_config,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to resume workflow execution for workflow run %s", workflow_run_id)
|
||||
raise
|
||||
|
||||
if generate_entity.stream:
|
||||
assert isinstance(response, Generator)
|
||||
_publish_streaming_response(response, workflow_run_id, AppMode.WORKFLOW)
|
||||
|
||||
workflow_run_repo.delete_workflow_pause(pause_entity)
|
||||
|
||||
|
||||
@shared_task(queue=WORKFLOW_BASED_APP_EXECUTION_QUEUE, name="resume_app_execution")
|
||||
def resume_app_execution(payload: dict[str, Any]) -> None:
|
||||
_resume_app_execution(payload)
|
||||
@@ -5,32 +5,42 @@ These tasks handle workflow execution for different subscription tiers
|
||||
with appropriate retry policies and error handling.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext
|
||||
from core.app.layers.timeslice_layer import TimeSliceLayer
|
||||
from core.app.layers.trigger_post_layer import TriggerPostLayer
|
||||
from core.db.session_factory import session_factory
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.enums import CreatorUserRole, WorkflowTriggerStatus
|
||||
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus
|
||||
from models.model import App, EndUser, Tenant
|
||||
from models.trigger import WorkflowTriggerLog
|
||||
from models.workflow import Workflow
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom, WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
from services.errors.app import WorkflowNotFoundError
|
||||
from services.workflow.entities import (
|
||||
TriggerData,
|
||||
WorkflowResumeTaskData,
|
||||
WorkflowTaskData,
|
||||
)
|
||||
from tasks.workflow_cfs_scheduler.cfs_scheduler import AsyncWorkflowCFSPlanEntity, AsyncWorkflowCFSPlanScheduler
|
||||
from tasks.workflow_cfs_scheduler.entities import AsyncWorkflowQueue, AsyncWorkflowSystemStrategy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE)
|
||||
def execute_workflow_professional(task_data_dict: dict[str, Any]):
|
||||
@@ -141,6 +151,11 @@ def _execute_workflow_common(
|
||||
if trigger_data.workflow_id:
|
||||
args["workflow_id"] = str(trigger_data.workflow_id)
|
||||
|
||||
pause_config = PauseStateLayerConfig(
|
||||
session_factory=session_factory.get_session_maker(),
|
||||
state_owner_user_id=workflow.created_by,
|
||||
)
|
||||
|
||||
# Execute the workflow with the trigger type
|
||||
generator.generate(
|
||||
app_model=app_model,
|
||||
@@ -156,6 +171,7 @@ def _execute_workflow_common(
|
||||
# TODO: Re-enable TimeSliceLayer after the HITL release.
|
||||
TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id),
|
||||
],
|
||||
pause_state_config=pause_config,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -173,21 +189,153 @@ def _execute_workflow_common(
|
||||
session.commit()
|
||||
|
||||
|
||||
def _get_user(session: Session, trigger_log: WorkflowTriggerLog) -> Account | EndUser:
|
||||
@shared_task(name="resume_workflow_execution")
|
||||
def resume_workflow_execution(task_data_dict: dict[str, Any]) -> None:
|
||||
"""Resume a paused workflow run via Celery."""
|
||||
task_data = WorkflowResumeTaskData.model_validate(task_data_dict)
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_factory)
|
||||
|
||||
pause_entity = workflow_run_repo.get_workflow_pause(task_data.workflow_run_id)
|
||||
if pause_entity is None:
|
||||
logger.warning("No pause state for workflow run %s", task_data.workflow_run_id)
|
||||
return
|
||||
workflow_run = workflow_run_repo.get_workflow_run_by_id_without_tenant(pause_entity.workflow_execution_id)
|
||||
if workflow_run is None:
|
||||
logger.warning("Workflow run not found for pause entity: pause_entity_id=%s", pause_entity.id)
|
||||
return
|
||||
|
||||
try:
|
||||
resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode())
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to load resumption context for workflow run %s", task_data.workflow_run_id)
|
||||
raise exc
|
||||
|
||||
generate_entity = resumption_context.get_generate_entity()
|
||||
if not isinstance(generate_entity, WorkflowAppGenerateEntity):
|
||||
logger.error(
|
||||
"Unsupported resumption entity for workflow run %s: %s",
|
||||
task_data.workflow_run_id,
|
||||
type(generate_entity),
|
||||
)
|
||||
return
|
||||
|
||||
graph_runtime_state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state)
|
||||
|
||||
with session_factory() as session:
|
||||
workflow = session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
|
||||
if workflow is None:
|
||||
raise WorkflowNotFoundError(
|
||||
"Workflow not found: workflow_run_id=%s, workflow_id=%s", workflow_run.id, workflow_run.workflow_id
|
||||
)
|
||||
user = _get_user(session, workflow_run)
|
||||
app_model = session.scalar(select(App).where(App.id == workflow_run.app_id))
|
||||
if app_model is None:
|
||||
raise _AppNotFoundError(
|
||||
"App not found: app_id=%s, workflow_run_id=%s", workflow_run.app_id, workflow_run.id
|
||||
)
|
||||
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom(workflow_run.triggered_from),
|
||||
)
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
pause_config = PauseStateLayerConfig(
|
||||
session_factory=session_factory,
|
||||
state_owner_user_id=workflow.created_by,
|
||||
)
|
||||
|
||||
generator = WorkflowAppGenerator()
|
||||
start_time = datetime.now(UTC)
|
||||
graph_engine_layers = []
|
||||
trigger_log = _query_trigger_log_info(session_factory, task_data.workflow_run_id)
|
||||
|
||||
if trigger_log:
|
||||
cfs_plan_scheduler_entity = AsyncWorkflowCFSPlanEntity(
|
||||
queue=AsyncWorkflowQueue(trigger_log.queue_name),
|
||||
schedule_strategy=AsyncWorkflowSystemStrategy,
|
||||
granularity=dify_config.ASYNC_WORKFLOW_SCHEDULER_GRANULARITY,
|
||||
)
|
||||
cfs_plan_scheduler = AsyncWorkflowCFSPlanScheduler(plan=cfs_plan_scheduler_entity)
|
||||
|
||||
graph_engine_layers.extend(
|
||||
[
|
||||
TimeSliceLayer(cfs_plan_scheduler),
|
||||
TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id),
|
||||
]
|
||||
)
|
||||
|
||||
workflow_run_repo.resume_workflow_pause(task_data.workflow_run_id, pause_entity)
|
||||
|
||||
generator.resume(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
application_generate_entity=generate_entity,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
pause_state_config=pause_config,
|
||||
)
|
||||
workflow_run_repo.delete_workflow_pause(pause_entity)
|
||||
|
||||
|
||||
def _get_user(session: Session, workflow_run: WorkflowRun | WorkflowTriggerLog) -> Account | EndUser:
|
||||
"""Compose user from trigger log"""
|
||||
tenant = session.scalar(select(Tenant).where(Tenant.id == trigger_log.tenant_id))
|
||||
tenant = session.scalar(select(Tenant).where(Tenant.id == workflow_run.tenant_id))
|
||||
if not tenant:
|
||||
raise ValueError(f"Tenant not found: {trigger_log.tenant_id}")
|
||||
raise _TenantNotFoundError(
|
||||
"Tenant not found for WorkflowRun: tenant_id=%s, workflow_run_id=%s",
|
||||
workflow_run.tenant_id,
|
||||
workflow_run.id,
|
||||
)
|
||||
|
||||
# Get user from trigger log
|
||||
if trigger_log.created_by_role == CreatorUserRole.ACCOUNT:
|
||||
user = session.scalar(select(Account).where(Account.id == trigger_log.created_by))
|
||||
if workflow_run.created_by_role == CreatorUserRole.ACCOUNT:
|
||||
user = session.scalar(select(Account).where(Account.id == workflow_run.created_by))
|
||||
if user:
|
||||
user.current_tenant = tenant
|
||||
else: # CreatorUserRole.END_USER
|
||||
user = session.scalar(select(EndUser).where(EndUser.id == trigger_log.created_by))
|
||||
user = session.scalar(select(EndUser).where(EndUser.id == workflow_run.created_by))
|
||||
|
||||
if not user:
|
||||
raise ValueError(f"User not found: {trigger_log.created_by} (role: {trigger_log.created_by_role})")
|
||||
raise _UserNotFoundError(
|
||||
"User not found: user_id=%s, created_by_role=%s, workflow_run_id=%s",
|
||||
workflow_run.created_by,
|
||||
workflow_run.created_by_role,
|
||||
workflow_run.id,
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def _query_trigger_log_info(session_factory: sessionmaker[Session], workflow_run_id) -> WorkflowTriggerLog | None:
|
||||
with session_factory() as session, session.begin():
|
||||
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
trigger_log = trigger_log_repo.get_by_workflow_run_id(workflow_run_id)
|
||||
if not trigger_log:
|
||||
logger.debug("Trigger log not found for workflow_run: workflow_run_id=%s", workflow_run_id)
|
||||
return None
|
||||
|
||||
return trigger_log
|
||||
|
||||
|
||||
class _TenantNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class _UserNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class _AppNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
113
api/tasks/human_input_timeout_tasks.py
Normal file
113
api/tasks/human_input_timeout_tasks.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
|
||||
from celery import shared_task
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.repositories.human_input_repository import HumanInputFormSubmissionRepository
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import ensure_naive_utc, naive_utc_now
|
||||
from models.human_input import HumanInputForm
|
||||
from models.workflow import WorkflowPause, WorkflowRun
|
||||
from services.human_input_service import HumanInputService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_global_timeout(form_model: HumanInputForm, global_timeout_seconds: int, *, now) -> bool:
|
||||
if global_timeout_seconds <= 0:
|
||||
return False
|
||||
if form_model.workflow_run_id is None:
|
||||
return False
|
||||
created_at = ensure_naive_utc(form_model.created_at)
|
||||
global_deadline = created_at + timedelta(seconds=global_timeout_seconds)
|
||||
return global_deadline <= now
|
||||
|
||||
|
||||
def _handle_global_timeout(*, form_id: str, workflow_run_id: str, node_id: str, session_factory: sessionmaker) -> None:
|
||||
now = naive_utc_now()
|
||||
with session_factory() as session, session.begin():
|
||||
workflow_run = session.get(WorkflowRun, workflow_run_id)
|
||||
if workflow_run is not None:
|
||||
workflow_run.status = WorkflowExecutionStatus.STOPPED
|
||||
workflow_run.error = f"Human input global timeout at node {node_id}"
|
||||
workflow_run.finished_at = now
|
||||
session.add(workflow_run)
|
||||
|
||||
pause_model = session.scalar(select(WorkflowPause).where(WorkflowPause.workflow_run_id == workflow_run_id))
|
||||
if pause_model is not None:
|
||||
try:
|
||||
storage.delete(pause_model.state_object_key)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to delete pause state object for workflow_run_id=%s, pause_id=%s",
|
||||
workflow_run_id,
|
||||
pause_model.id,
|
||||
)
|
||||
pause_model.resumed_at = now
|
||||
session.add(pause_model)
|
||||
|
||||
|
||||
@shared_task(name="human_input_form_timeout.check_and_resume", queue="schedule_executor")
|
||||
def check_and_handle_human_input_timeouts(limit: int = 100) -> None:
|
||||
"""Scan for expired human input forms and resume or end workflows."""
|
||||
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
form_repo = HumanInputFormSubmissionRepository(session_factory)
|
||||
service = HumanInputService(session_factory, form_repository=form_repo)
|
||||
now = naive_utc_now()
|
||||
global_timeout_seconds = dify_config.HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS
|
||||
|
||||
with session_factory() as session:
|
||||
global_deadline = now - timedelta(seconds=global_timeout_seconds) if global_timeout_seconds > 0 else None
|
||||
timeout_filter = HumanInputForm.expiration_time <= now
|
||||
if global_deadline is not None:
|
||||
timeout_filter = or_(timeout_filter, HumanInputForm.created_at <= global_deadline)
|
||||
stmt = (
|
||||
select(HumanInputForm)
|
||||
.where(
|
||||
HumanInputForm.status == HumanInputFormStatus.WAITING,
|
||||
timeout_filter,
|
||||
)
|
||||
.order_by(HumanInputForm.id.asc())
|
||||
.limit(limit)
|
||||
)
|
||||
expired_forms = session.scalars(stmt).all()
|
||||
|
||||
for form_model in expired_forms:
|
||||
try:
|
||||
if form_model.form_kind == HumanInputFormKind.DELIVERY_TEST:
|
||||
form_repo.mark_timeout(
|
||||
form_id=form_model.id,
|
||||
timeout_status=HumanInputFormStatus.TIMEOUT,
|
||||
reason="delivery_test_timeout",
|
||||
)
|
||||
continue
|
||||
|
||||
is_global = _is_global_timeout(form_model, global_timeout_seconds, now=now)
|
||||
record = form_repo.mark_timeout(
|
||||
form_id=form_model.id,
|
||||
timeout_status=HumanInputFormStatus.EXPIRED if is_global else HumanInputFormStatus.TIMEOUT,
|
||||
reason="global_timeout" if is_global else "node_timeout",
|
||||
)
|
||||
assert record.workflow_run_id is not None, "workflow_run_id should not be None for non-test form"
|
||||
if is_global:
|
||||
_handle_global_timeout(
|
||||
form_id=record.form_id,
|
||||
workflow_run_id=record.workflow_run_id,
|
||||
node_id=record.node_id,
|
||||
session_factory=session_factory,
|
||||
)
|
||||
else:
|
||||
service.enqueue_resume(record.workflow_run_id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to handle timeout for form_id=%s workflow_run_id=%s",
|
||||
form_model.id,
|
||||
form_model.workflow_run_id,
|
||||
)
|
||||
190
api/tasks/mail_human_input_delivery_task.py
Normal file
190
api/tasks/mail_human_input_delivery_task.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
|
||||
from core.workflow.nodes.human_input.entities import EmailDeliveryConfig, EmailDeliveryMethod
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_mail import mail
|
||||
from models.human_input import (
|
||||
DeliveryMethodType,
|
||||
HumanInputDelivery,
|
||||
HumanInputForm,
|
||||
HumanInputFormRecipient,
|
||||
RecipientType,
|
||||
)
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _EmailRecipient:
|
||||
email: str
|
||||
token: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _EmailDeliveryJob:
|
||||
form_id: str
|
||||
subject: str
|
||||
body: str
|
||||
form_content: str
|
||||
recipients: list[_EmailRecipient]
|
||||
|
||||
|
||||
def _build_form_link(token: str) -> str:
|
||||
base_url = dify_config.APP_WEB_URL
|
||||
return f"{base_url.rstrip('/')}/form/{token}"
|
||||
|
||||
|
||||
def _parse_recipient_payload(payload: str) -> tuple[str | None, RecipientType | None]:
|
||||
try:
|
||||
payload_dict: dict[str, Any] = json.loads(payload)
|
||||
except Exception:
|
||||
logger.exception("Failed to parse recipient payload")
|
||||
return None, None
|
||||
|
||||
return payload_dict.get("email"), payload_dict.get("TYPE")
|
||||
|
||||
|
||||
def _load_email_jobs(session: Session, form: HumanInputForm) -> list[_EmailDeliveryJob]:
|
||||
deliveries = session.scalars(
|
||||
select(HumanInputDelivery).where(
|
||||
HumanInputDelivery.form_id == form.id,
|
||||
HumanInputDelivery.delivery_method_type == DeliveryMethodType.EMAIL,
|
||||
)
|
||||
).all()
|
||||
jobs: list[_EmailDeliveryJob] = []
|
||||
for delivery in deliveries:
|
||||
delivery_config = EmailDeliveryMethod.model_validate_json(delivery.channel_payload)
|
||||
|
||||
recipients = session.scalars(
|
||||
select(HumanInputFormRecipient).where(HumanInputFormRecipient.delivery_id == delivery.id)
|
||||
).all()
|
||||
|
||||
recipient_entities: list[_EmailRecipient] = []
|
||||
for recipient in recipients:
|
||||
email, recipient_type = _parse_recipient_payload(recipient.recipient_payload)
|
||||
if recipient_type not in {RecipientType.EMAIL_MEMBER, RecipientType.EMAIL_EXTERNAL}:
|
||||
continue
|
||||
if not email:
|
||||
continue
|
||||
token = recipient.access_token
|
||||
if not token:
|
||||
continue
|
||||
recipient_entities.append(_EmailRecipient(email=email, token=token))
|
||||
|
||||
if not recipient_entities:
|
||||
continue
|
||||
|
||||
jobs.append(
|
||||
_EmailDeliveryJob(
|
||||
form_id=form.id,
|
||||
subject=delivery_config.config.subject,
|
||||
body=delivery_config.config.body,
|
||||
form_content=form.rendered_content,
|
||||
recipients=recipient_entities,
|
||||
)
|
||||
)
|
||||
return jobs
|
||||
|
||||
|
||||
def _render_body(
|
||||
body_template: str,
|
||||
form_link: str,
|
||||
*,
|
||||
variable_pool: VariablePool | None,
|
||||
) -> str:
|
||||
body = EmailDeliveryConfig.render_body_template(
|
||||
body=body_template,
|
||||
url=form_link,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
return body
|
||||
|
||||
|
||||
def _load_variable_pool(workflow_run_id: str | None) -> VariablePool | None:
|
||||
if not workflow_run_id:
|
||||
return None
|
||||
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_factory)
|
||||
pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id)
|
||||
if pause_entity is None:
|
||||
logger.info("No pause state found for workflow run %s", workflow_run_id)
|
||||
return None
|
||||
|
||||
try:
|
||||
resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode())
|
||||
except Exception:
|
||||
logger.exception("Failed to load resumption context for workflow run %s", workflow_run_id)
|
||||
return None
|
||||
|
||||
graph_runtime_state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state)
|
||||
return graph_runtime_state.variable_pool
|
||||
|
||||
|
||||
def _open_session(session_factory: sessionmaker | Session | None):
|
||||
if session_factory is None:
|
||||
return Session(db.engine)
|
||||
if isinstance(session_factory, Session):
|
||||
return session_factory
|
||||
return session_factory()
|
||||
|
||||
|
||||
@shared_task(queue="mail")
|
||||
def dispatch_human_input_email_task(form_id: str, node_title: str | None = None, session_factory=None):
|
||||
if not mail.is_inited():
|
||||
return
|
||||
|
||||
logger.info(click.style(f"Start human input email delivery for form {form_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
with _open_session(session_factory) as session:
|
||||
form = session.get(HumanInputForm, form_id)
|
||||
if form is None:
|
||||
logger.warning("Human input form not found, form_id=%s", form_id)
|
||||
return
|
||||
features = FeatureService.get_features(form.tenant_id)
|
||||
if not features.human_input_email_delivery_enabled:
|
||||
logger.info(
|
||||
"Human input email delivery is not available for tenant=%s, form_id=%s",
|
||||
form.tenant_id,
|
||||
form_id,
|
||||
)
|
||||
return
|
||||
jobs = _load_email_jobs(session, form)
|
||||
|
||||
variable_pool = _load_variable_pool(form.workflow_run_id)
|
||||
|
||||
for job in jobs:
|
||||
for recipient in job.recipients:
|
||||
form_link = _build_form_link(recipient.token)
|
||||
body = _render_body(job.body, form_link, variable_pool=variable_pool)
|
||||
|
||||
mail.send(
|
||||
to=recipient.email,
|
||||
subject=job.subject,
|
||||
html=body,
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Human input email delivery succeeded for form {form_id}: latency: {end_at - start_at}", fg="green"
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Send human input email failed, form_id=%s", form_id)
|
||||
Reference in New Issue
Block a user