mirror of
https://github.com/langgenius/dify.git
synced 2026-02-17 19:00:55 -05:00
feat: Human Input Node (#32060)
The frontend and backend implementation for the human input node. Co-authored-by: twwu <twwu@dify.ai> Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com> Co-authored-by: zhsama <torvalds@linux.do>
This commit is contained in:
@@ -23,6 +23,7 @@ from . import (
|
||||
feature,
|
||||
files,
|
||||
forgot_password,
|
||||
human_input_form,
|
||||
login,
|
||||
message,
|
||||
passport,
|
||||
@@ -30,6 +31,7 @@ from . import (
|
||||
saved_message,
|
||||
site,
|
||||
workflow,
|
||||
workflow_events,
|
||||
)
|
||||
|
||||
api.add_namespace(web_ns)
|
||||
@@ -44,6 +46,7 @@ __all__ = [
|
||||
"feature",
|
||||
"files",
|
||||
"forgot_password",
|
||||
"human_input_form",
|
||||
"login",
|
||||
"message",
|
||||
"passport",
|
||||
@@ -52,4 +55,5 @@ __all__ = [
|
||||
"site",
|
||||
"web_ns",
|
||||
"workflow",
|
||||
"workflow_events",
|
||||
]
|
||||
|
||||
@@ -117,6 +117,12 @@ class InvokeRateLimitError(BaseHTTPException):
|
||||
code = 429
|
||||
|
||||
|
||||
class WebFormRateLimitExceededError(BaseHTTPException):
|
||||
error_code = "web_form_rate_limit_exceeded"
|
||||
description = "Too many form requests. Please try again later."
|
||||
code = 429
|
||||
|
||||
|
||||
class NotFoundError(BaseHTTPException):
|
||||
error_code = "not_found"
|
||||
code = 404
|
||||
|
||||
161
api/controllers/web/human_input_form.py
Normal file
161
api/controllers/web/human_input_form.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""
|
||||
Web App Human Input Form APIs.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotFoundError, WebFormRateLimitExceededError
|
||||
from controllers.web.site import serialize_app_site_payload
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import RateLimiter, extract_remote_ip
|
||||
from models.account import TenantStatus
|
||||
from models.model import App, Site
|
||||
from services.human_input_service import Form, FormNotFoundError, HumanInputService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
|
||||
prefix="web_form_submit_rate_limit",
|
||||
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
|
||||
time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS,
|
||||
)
|
||||
_FORM_ACCESS_RATE_LIMITER = RateLimiter(
|
||||
prefix="web_form_access_rate_limit",
|
||||
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
|
||||
time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
def _stringify_default_values(values: dict[str, object]) -> dict[str, str]:
|
||||
result: dict[str, str] = {}
|
||||
for key, value in values.items():
|
||||
if value is None:
|
||||
result[key] = ""
|
||||
elif isinstance(value, (dict, list)):
|
||||
result[key] = json.dumps(value, ensure_ascii=False)
|
||||
else:
|
||||
result[key] = str(value)
|
||||
return result
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime) -> int:
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Response:
|
||||
"""Return the form payload (optionally with site) as a JSON response."""
|
||||
definition_payload = form.get_definition().model_dump()
|
||||
payload = {
|
||||
"form_content": definition_payload["rendered_content"],
|
||||
"inputs": definition_payload["inputs"],
|
||||
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),
|
||||
"user_actions": definition_payload["user_actions"],
|
||||
"expiration_time": _to_timestamp(form.expiration_time),
|
||||
}
|
||||
if site_payload is not None:
|
||||
payload["site"] = site_payload
|
||||
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
|
||||
|
||||
|
||||
@web_ns.route("/form/human_input/<string:form_token>")
|
||||
class HumanInputFormApi(Resource):
|
||||
"""API for getting and submitting human input forms via the web app."""
|
||||
|
||||
# NOTE(QuantumGhost): this endpoint is unauthenticated on purpose for now.
|
||||
|
||||
# def get(self, _app_model: App, _end_user: EndUser, form_token: str):
|
||||
def get(self, form_token: str):
|
||||
"""
|
||||
Get human input form definition by token.
|
||||
|
||||
GET /api/form/human_input/<form_token>
|
||||
"""
|
||||
ip_address = extract_remote_ip(request)
|
||||
if _FORM_ACCESS_RATE_LIMITER.is_rate_limited(ip_address):
|
||||
raise WebFormRateLimitExceededError()
|
||||
_FORM_ACCESS_RATE_LIMITER.increment_rate_limit(ip_address)
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
# TODO(QuantumGhost): forbid submision for form tokens
|
||||
# that are only for console.
|
||||
form = service.get_form_by_token(form_token)
|
||||
|
||||
if form is None:
|
||||
raise NotFoundError("Form not found")
|
||||
|
||||
service.ensure_form_active(form)
|
||||
app_model, site = _get_app_site_from_form(form)
|
||||
|
||||
return _jsonify_form_definition(form, site_payload=serialize_app_site_payload(app_model, site, None))
|
||||
|
||||
# def post(self, _app_model: App, _end_user: EndUser, form_token: str):
|
||||
def post(self, form_token: str):
|
||||
"""
|
||||
Submit human input form by token.
|
||||
|
||||
POST /api/form/human_input/<form_token>
|
||||
|
||||
Request body:
|
||||
{
|
||||
"inputs": {
|
||||
"content": "User input content"
|
||||
},
|
||||
"action": "Approve"
|
||||
}
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("action", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if _FORM_SUBMIT_RATE_LIMITER.is_rate_limited(ip_address):
|
||||
raise WebFormRateLimitExceededError()
|
||||
_FORM_SUBMIT_RATE_LIMITER.increment_rate_limit(ip_address)
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
raise NotFoundError("Form not found")
|
||||
|
||||
if (recipient_type := form.recipient_type) is None:
|
||||
logger.warning("Recipient type is None for form, form_id=%", form.id)
|
||||
raise AssertionError("Recipient type is None")
|
||||
|
||||
try:
|
||||
service.submit_form_by_token(
|
||||
recipient_type=recipient_type,
|
||||
form_token=form_token,
|
||||
selected_action_id=args["action"],
|
||||
form_data=args["inputs"],
|
||||
submission_end_user_id=None,
|
||||
# submission_end_user_id=_end_user.id,
|
||||
)
|
||||
except FormNotFoundError:
|
||||
raise NotFoundError("Form not found")
|
||||
|
||||
return {}, 200
|
||||
|
||||
|
||||
def _get_app_site_from_form(form: Form) -> tuple[App, Site]:
|
||||
"""Resolve App/Site for the form's app and validate tenant status."""
|
||||
app_model = db.session.query(App).where(App.id == form.app_id).first()
|
||||
if app_model is None or app_model.tenant_id != form.tenant_id:
|
||||
raise NotFoundError("Form not found")
|
||||
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
if site is None:
|
||||
raise Forbidden()
|
||||
|
||||
if app_model.tenant and app_model.tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden()
|
||||
|
||||
return app_model, site
|
||||
@@ -1,4 +1,6 @@
|
||||
from flask_restx import fields, marshal_with
|
||||
from typing import cast
|
||||
|
||||
from flask_restx import fields, marshal, marshal_with
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
@@ -7,7 +9,7 @@ from controllers.web.wraps import WebApiResource
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import AppIconUrlField
|
||||
from models.account import TenantStatus
|
||||
from models.model import Site
|
||||
from models.model import App, Site
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
@@ -108,3 +110,14 @@ class AppSiteInfo:
|
||||
"remove_webapp_brand": remove_webapp_brand,
|
||||
"replace_webapp_logo": replace_webapp_logo,
|
||||
}
|
||||
|
||||
|
||||
def serialize_site(site: Site) -> dict:
|
||||
"""Serialize Site model using the same schema as AppSiteApi."""
|
||||
return cast(dict, marshal(site, AppSiteApi.site_fields))
|
||||
|
||||
|
||||
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict:
|
||||
can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
|
||||
app_site_info = AppSiteInfo(app_model.tenant, app_model, site, end_user_id, can_replace_logo)
|
||||
return cast(dict, marshal(app_site_info, AppSiteApi.app_fields))
|
||||
|
||||
112
api/controllers/web/workflow_events.py
Normal file
112
api/controllers/web/workflow_events.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""
|
||||
Web App Workflow Resume APIs.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from flask import Response, request
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, AppMode, EndUser
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
|
||||
|
||||
class WorkflowEventsApi(WebApiResource):
|
||||
"""API for getting workflow execution events after resume."""
|
||||
|
||||
def get(self, app_model: App, end_user: EndUser, task_id: str):
|
||||
"""
|
||||
Get workflow execution events stream after resume.
|
||||
|
||||
GET /api/workflow/<task_id>/events
|
||||
|
||||
Returns Server-Sent Events stream.
|
||||
"""
|
||||
workflow_run_id = task_id
|
||||
session_maker = sessionmaker(db.engine)
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
|
||||
tenant_id=app_model.tenant_id,
|
||||
run_id=workflow_run_id,
|
||||
)
|
||||
|
||||
if workflow_run is None:
|
||||
raise NotFoundError(f"WorkflowRun not found, id={workflow_run_id}")
|
||||
|
||||
if workflow_run.app_id != app_model.id:
|
||||
raise NotFoundError(f"WorkflowRun not found, id={workflow_run_id}")
|
||||
|
||||
if workflow_run.created_by_role != CreatorUserRole.END_USER:
|
||||
raise NotFoundError(f"WorkflowRun not created by end user, id={workflow_run_id}")
|
||||
|
||||
if workflow_run.created_by != end_user.id:
|
||||
raise NotFoundError(f"WorkflowRun not created by the current end user, id={workflow_run_id}")
|
||||
|
||||
if workflow_run.finished_at is not None:
|
||||
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
|
||||
task_id=workflow_run.id,
|
||||
workflow_run=workflow_run,
|
||||
creator_user=end_user,
|
||||
)
|
||||
|
||||
payload = response.model_dump(mode="json")
|
||||
payload["event"] = response.event.value
|
||||
|
||||
def _generate_finished_events() -> Generator[str, None, None]:
|
||||
yield f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
event_generator = _generate_finished_events
|
||||
else:
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
msg_generator = MessageGenerator()
|
||||
generator: BaseAppGenerator
|
||||
if app_mode == AppMode.ADVANCED_CHAT:
|
||||
generator = AdvancedChatAppGenerator()
|
||||
elif app_mode == AppMode.WORKFLOW:
|
||||
generator = WorkflowAppGenerator()
|
||||
else:
|
||||
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
|
||||
|
||||
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
|
||||
|
||||
def _generate_stream_events():
|
||||
if include_state_snapshot:
|
||||
return generator.convert_to_event_stream(
|
||||
build_workflow_event_stream(
|
||||
app_mode=app_mode,
|
||||
workflow_run=workflow_run,
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
session_maker=session_maker,
|
||||
)
|
||||
)
|
||||
return generator.convert_to_event_stream(
|
||||
msg_generator.retrieve_events(app_mode, workflow_run.id),
|
||||
)
|
||||
|
||||
event_generator = _generate_stream_events
|
||||
|
||||
return Response(
|
||||
event_generator(),
|
||||
mimetype="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Register the APIs
|
||||
api.add_resource(WorkflowEventsApi, "/workflow/<string:task_id>/events")
|
||||
Reference in New Issue
Block a user