Merge branch 'main' into tp

This commit is contained in:
JzoNg
2026-04-27 10:20:08 +08:00
316 changed files with 12484 additions and 7806 deletions

View File

@@ -659,6 +659,11 @@ INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y
MARKETPLACE_ENABLED=true
MARKETPLACE_API_URL=https://marketplace.dify.ai
# Creators Platform configuration
CREATORS_PLATFORM_FEATURES_ENABLED=true
CREATORS_PLATFORM_API_URL=https://creators.dify.ai
CREATORS_PLATFORM_OAUTH_CLIENT_ID=
# Endpoint configuration
ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}

View File

@@ -11,7 +11,7 @@ from configs import dify_config
from core.helper import encrypter
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.plugin import PluginInstaller
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
from core.tools.utils.system_encryption import encrypt_system_params
from extensions.ext_database import db
from models import Tenant
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
@@ -44,7 +44,7 @@ def setup_system_tool_oauth_client(provider, client_params):
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
oauth_client_params = encrypt_system_params(client_params_dict)
click.echo(click.style("Client params encrypted successfully.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
@@ -94,7 +94,7 @@ def setup_system_trigger_oauth_client(provider, client_params):
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
oauth_client_params = encrypt_system_params(client_params_dict)
click.echo(click.style("Client params encrypted successfully.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))

View File

@@ -287,6 +287,27 @@ class MarketplaceConfig(BaseSettings):
)
class CreatorsPlatformConfig(BaseSettings):
"""
Configuration for Creators Platform integration
"""
CREATORS_PLATFORM_FEATURES_ENABLED: bool = Field(
description="Enable or disable Creators Platform features",
default=True,
)
CREATORS_PLATFORM_API_URL: HttpUrl = Field(
description="Creators Platform API URL",
default=HttpUrl("https://creators.dify.ai"),
)
CREATORS_PLATFORM_OAUTH_CLIENT_ID: str = Field(
description="OAuth client ID for Creators Platform integration",
default="",
)
class EndpointConfig(BaseSettings):
"""
Configuration for various application endpoints and URLs
@@ -1379,6 +1400,7 @@ class FeatureConfig(
AuthConfig, # Changed from OAuthConfig to AuthConfig
BillingConfig,
CodeExecutionSandboxConfig,
CreatorsPlatformConfig,
TriggerConfig,
AsyncWorkflowConfig,
PluginConfig,

View File

@@ -0,0 +1,6 @@
from pydantic import BaseModel, JsonValue
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict[str, JsonValue]
action: str

View File

@@ -692,6 +692,32 @@ class AppExportApi(Resource):
return payload.model_dump(mode="json")
@console_ns.route("/apps/<uuid:app_id>/publish-to-creators-platform")
class AppPublishToCreatorsPlatformApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=None)
@edit_permission_required
def post(self, app_model):
"""Publish app to Creators Platform"""
from configs import dify_config
from core.helper.creators import get_redirect_url, upload_dsl
if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
return {"error": "Creators Platform features are not enabled"}, 403
current_user, _ = current_account_with_tenant()
dsl_content = AppDslService.export_dsl(app_model=app_model, include_secret=False)
dsl_bytes = dsl_content.encode("utf-8")
claim_code = upload_dsl(dsl_bytes)
redirect_url = get_redirect_url(str(current_user.id), claim_code)
return {"redirect_url": redirect_url}
@console_ns.route("/apps/<uuid:app_id>/name")
class AppNameApi(Resource):
@console_ns.doc("check_app_name")

View File

@@ -8,10 +8,10 @@ from collections.abc import Generator
from flask import Response, jsonify, request
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
@@ -20,11 +20,11 @@ from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.apps.message_generator import MessageGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from models import App
from models.enums import CreatorUserRole
from models.human_input import RecipientType
from models.model import AppMode
from models.workflow import WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
@@ -34,11 +34,6 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
logger = logging.getLogger(__name__)
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict
action: str
def _jsonify_form_definition(form: Form) -> Response:
payload = form.get_definition().model_dump()
payload["expiration_time"] = int(form.expiration_time.timestamp())
@@ -56,6 +51,11 @@ class ConsoleHumanInputFormApi(Resource):
if form.tenant_id != current_tenant_id:
raise NotFoundError("App not found")
@staticmethod
def _ensure_console_recipient_type(form: Form) -> None:
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.CONSOLE):
raise NotFoundError("form not found")
@setup_required
@login_required
@account_initialization_required
@@ -99,10 +99,8 @@ class ConsoleHumanInputFormApi(Resource):
raise NotFoundError(f"form not found, token={form_token}")
self._ensure_console_access(form)
self._ensure_console_recipient_type(form)
recipient_type = form.recipient_type
if recipient_type not in {RecipientType.CONSOLE, RecipientType.BACKSTAGE}:
raise NotFoundError(f"form not found, token={form_token}")
# The type checker is not smart enought to validate the following invariant.
# So we need to assert it manually.
assert recipient_type is not None, "recipient_type cannot be None here."

View File

@@ -37,6 +37,11 @@ class TagBindingRemovePayload(BaseModel):
type: TagType = Field(description="Tag type")
class TagBindingItemDeletePayload(BaseModel):
target_id: str = Field(description="Target ID to unbind tag from")
type: TagType = Field(description="Tag type")
class TagListQueryParam(BaseModel):
type: Literal["knowledge", "app", ""] = Field("", description="Tag type filter")
keyword: str | None = Field(None, description="Search keyword")
@@ -70,6 +75,7 @@ register_schema_models(
TagBasePayload,
TagBindingPayload,
TagBindingRemovePayload,
TagBindingItemDeletePayload,
TagListQueryParam,
TagResponse,
)
@@ -152,41 +158,107 @@ class TagUpdateDeleteApi(Resource):
return "", 204
@console_ns.route("/tag-bindings/create")
class TagBindingCreateApi(Resource):
def _require_tag_binding_edit_permission() -> None:
"""
Ensure the current account can edit tag bindings.
Tag binding operations are allowed for users who can edit resources (app/dataset) within the current tenant.
"""
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
def _create_tag_bindings() -> tuple[dict[str, str], int]:
_require_tag_binding_edit_permission()
payload = TagBindingPayload.model_validate(console_ns.payload or {})
TagService.save_tag_binding(
TagBindingCreatePayload(
tag_ids=payload.tag_ids,
target_id=payload.target_id,
type=payload.type,
)
)
return {"result": "success"}, 200
def _remove_tag_binding() -> tuple[dict[str, str], int]:
_require_tag_binding_edit_permission()
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(
tag_id=payload.tag_id,
target_id=payload.target_id,
type=payload.type,
)
)
return {"result": "success"}, 200
@console_ns.route("/tag-bindings")
class TagBindingCollectionApi(Resource):
"""Canonical collection resource for tag binding creation."""
@console_ns.doc("create_tag_binding")
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
return _create_tag_bindings()
payload = TagBindingPayload.model_validate(console_ns.payload or {})
TagService.save_tag_binding(
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=payload.type)
@console_ns.route("/tag-bindings/<uuid:id>")
class TagBindingItemApi(Resource):
"""Canonical item resource for tag binding deletion."""
@console_ns.doc("delete_tag_binding")
@console_ns.doc(params={"id": "Tag ID"})
@console_ns.expect(console_ns.models[TagBindingItemDeletePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def delete(self, id):
_require_tag_binding_edit_permission()
payload = TagBindingItemDeletePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(
tag_id=str(id),
target_id=payload.target_id,
type=payload.type,
)
)
return {"result": "success"}, 200
@console_ns.route("/tag-bindings/create")
class DeprecatedTagBindingCreateApi(Resource):
"""Deprecated verb-based alias for tag binding creation."""
@console_ns.doc("create_tag_binding_deprecated")
@console_ns.doc(deprecated=True)
@console_ns.doc(description="Deprecated legacy alias. Use POST /tag-bindings instead.")
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
return _create_tag_bindings()
@console_ns.route("/tag-bindings/remove")
class TagBindingDeleteApi(Resource):
class DeprecatedTagBindingRemoveApi(Resource):
"""Deprecated verb-based alias for tag binding deletion."""
@console_ns.doc("delete_tag_binding_deprecated")
@console_ns.doc(deprecated=True)
@console_ns.doc(description="Deprecated legacy alias. Use DELETE /tag-bindings/{id} instead.")
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=payload.type)
)
return {"result": "success"}, 200
return _remove_tag_binding()

View File

@@ -23,9 +23,11 @@ from .app import (
conversation,
file,
file_preview,
human_input_form,
message,
site,
workflow,
workflow_events,
)
from .dataset import (
dataset,
@@ -50,6 +52,7 @@ __all__ = [
"file",
"file_preview",
"hit_testing",
"human_input_form",
"index",
"message",
"metadata",
@@ -58,6 +61,7 @@ __all__ = [
"segment",
"site",
"workflow",
"workflow_events",
]
api.add_namespace(service_api_ns)

View File

@@ -0,0 +1,137 @@
"""
Service API human input form endpoints.
This module exposes app-token authenticated APIs for fetching and submitting
paused human input forms in workflow/chatflow runs.
"""
import json
import logging
from datetime import datetime
from flask import Response
from flask_restx import Resource
from werkzeug.exceptions import BadRequest, NotFound
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
from extensions.ext_database import db
from models.model import App, EndUser
from services.human_input_service import Form, FormNotFoundError, HumanInputService
logger = logging.getLogger(__name__)
register_schema_models(service_api_ns, HumanInputFormSubmitPayload)
def _stringify_default_values(values: dict[str, object]) -> dict[str, str]:
result: dict[str, str] = {}
for key, value in values.items():
if value is None:
result[key] = ""
elif isinstance(value, (dict, list)):
result[key] = json.dumps(value, ensure_ascii=False)
else:
result[key] = str(value)
return result
def _to_timestamp(value: datetime) -> int:
return int(value.timestamp())
def _jsonify_form_definition(form: Form) -> Response:
definition_payload = form.get_definition().model_dump()
payload = {
"form_content": definition_payload["rendered_content"],
"inputs": definition_payload["inputs"],
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),
"user_actions": definition_payload["user_actions"],
"expiration_time": _to_timestamp(form.expiration_time),
}
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
def _ensure_form_belongs_to_app(form: Form, app_model: App) -> None:
if form.app_id != app_model.id or form.tenant_id != app_model.tenant_id:
raise NotFound("Form not found")
def _ensure_form_is_allowed_for_service_api(form: Form) -> None:
# Keep app-token callers scoped to the public web-form surface; internal HITL
# routes must continue to flow through console-only authentication.
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.SERVICE_API):
raise NotFound("Form not found")
@service_api_ns.route("/form/human_input/<string:form_token>")
class WorkflowHumanInputFormApi(Resource):
@service_api_ns.doc("get_human_input_form")
@service_api_ns.doc(description="Get a paused human input form by token")
@service_api_ns.doc(params={"form_token": "Human input form token"})
@service_api_ns.doc(
responses={
200: "Form retrieved successfully",
401: "Unauthorized - invalid API token",
404: "Form not found",
412: "Form already submitted or expired",
}
)
@validate_app_token
def get(self, app_model: App, form_token: str):
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
raise NotFound("Form not found")
_ensure_form_belongs_to_app(form, app_model)
_ensure_form_is_allowed_for_service_api(form)
service.ensure_form_active(form)
return _jsonify_form_definition(form)
@service_api_ns.expect(service_api_ns.models[HumanInputFormSubmitPayload.__name__])
@service_api_ns.doc("submit_human_input_form")
@service_api_ns.doc(description="Submit a paused human input form by token")
@service_api_ns.doc(params={"form_token": "Human input form token"})
@service_api_ns.doc(
responses={
200: "Form submitted successfully",
400: "Bad request - invalid submission data",
401: "Unauthorized - invalid API token",
404: "Form not found",
412: "Form already submitted or expired",
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, form_token: str):
payload = HumanInputFormSubmitPayload.model_validate(service_api_ns.payload or {})
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
raise NotFound("Form not found")
_ensure_form_belongs_to_app(form, app_model)
_ensure_form_is_allowed_for_service_api(form)
recipient_type = form.recipient_type
if recipient_type is None:
logger.warning("Recipient type is None for form, form_id=%s", form.id)
raise BadRequest("Form recipient type is invalid")
try:
service.submit_form_by_token(
recipient_type=recipient_type,
form_token=form_token,
selected_action_id=payload.action,
form_data=payload.inputs,
submission_end_user_id=end_user.id,
)
except FormNotFoundError:
raise NotFound("Form not found")
return {}, 200

View File

@@ -0,0 +1,142 @@
"""
Service API workflow resume event stream endpoints.
"""
import json
from collections.abc import Generator
from flask import Response, request
from flask_restx import Resource
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotWorkflowAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.apps.message_generator import MessageGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.task_entities import StreamEvent
from core.workflow.human_input_policy import HumanInputSurface
from extensions.ext_database import db
from models.enums import CreatorUserRole
from models.model import App, AppMode, EndUser
from repositories.factory import DifyAPIRepositoryFactory
from services.workflow_event_snapshot_service import build_workflow_event_stream
@service_api_ns.route("/workflow/<string:task_id>/events")
class WorkflowEventsApi(Resource):
"""Service API for getting workflow execution events after resume."""
@service_api_ns.doc("get_workflow_events")
@service_api_ns.doc(description="Get workflow execution events stream after resume")
@service_api_ns.doc(
params={
"task_id": "Workflow run ID",
"user": "End user identifier (query param)",
"include_state_snapshot": (
"Whether to replay from persisted state snapshot, "
'specify `"true"` to include a status snapshot of executed nodes'
),
"continue_on_pause": (
"Whether to keep the stream open across workflow_paused events,"
'specify `"true"` to keep the stream open for `workflow_paused` events.'
),
}
)
@service_api_ns.doc(
responses={
200: "SSE event stream",
401: "Unauthorized - invalid API token",
404: "Workflow run not found",
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True))
def get(self, app_model: App, end_user: EndUser, task_id: str):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}:
raise NotWorkflowAppError()
session_maker = sessionmaker(db.engine)
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
tenant_id=app_model.tenant_id,
run_id=task_id,
)
if workflow_run is None:
raise NotFound("Workflow run not found")
if workflow_run.app_id != app_model.id:
raise NotFound("Workflow run not found")
if workflow_run.created_by_role != CreatorUserRole.END_USER:
raise NotFound("Workflow run not found")
if workflow_run.created_by != end_user.id:
raise NotFound("Workflow run not found")
workflow_run_entity = workflow_run
if workflow_run_entity.finished_at is not None:
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
task_id=workflow_run_entity.id,
workflow_run=workflow_run_entity,
creator_user=end_user,
)
payload = response.model_dump(mode="json")
payload["event"] = response.event.value
def _generate_finished_events() -> Generator[str, None, None]:
yield f"data: {json.dumps(payload)}\n\n"
event_generator = _generate_finished_events
else:
msg_generator = MessageGenerator()
generator: BaseAppGenerator
if app_mode == AppMode.ADVANCED_CHAT:
generator = AdvancedChatAppGenerator()
elif app_mode == AppMode.WORKFLOW:
generator = WorkflowAppGenerator()
else:
raise NotWorkflowAppError()
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
continue_on_pause = request.args.get("continue_on_pause", "false").lower() == "true"
terminal_events: list[StreamEvent] | None = [] if continue_on_pause else None
def _generate_stream_events():
if include_state_snapshot:
return generator.convert_to_event_stream(
build_workflow_event_stream(
app_mode=app_mode,
workflow_run=workflow_run_entity,
tenant_id=app_model.tenant_id,
app_id=app_model.id,
session_maker=session_maker,
human_input_surface=HumanInputSurface.SERVICE_API,
close_on_pause=not continue_on_pause,
)
)
return generator.convert_to_event_stream(
msg_generator.retrieve_events(
app_mode,
workflow_run_entity.id,
terminal_events=terminal_events,
),
)
event_generator = _generate_stream_events
return Response(
event_generator(),
mimetype="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)

View File

@@ -9,11 +9,11 @@ from typing import Any, NotRequired, TypedDict
from flask import Response, request
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.web import web_ns
from controllers.web.error import NotFoundError, WebFormRateLimitExceededError
from controllers.web.site import serialize_app_site_payload
@@ -26,11 +26,6 @@ from services.human_input_service import Form, FormNotFoundError, HumanInputServ
logger = logging.getLogger(__name__)
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict
action: str
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
prefix="web_form_submit_rate_limit",
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,

View File

@@ -34,7 +34,11 @@ from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.app.entities.task_entities import (
AdvancedChatPausedBlockingResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
)
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.ops.ops_trace_manager import TraceQueueManager
@@ -655,7 +659,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user: Account | EndUser,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> ChatbotAppBlockingResponse | Generator[ChatbotAppStreamResponse, None, None]:
) -> (
ChatbotAppBlockingResponse
| AdvancedChatPausedBlockingResponse
| Generator[ChatbotAppStreamResponse, None, None]
):
"""
Handle response.
:param application_generate_entity: application generate entity

View File

@@ -3,7 +3,7 @@ from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppBlockingResponse,
AdvancedChatPausedBlockingResponse,
AppStreamResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
@@ -12,22 +12,40 @@ from core.app.entities.task_entities import (
NodeFinishStreamResponse,
NodeStartStreamResponse,
PingStreamResponse,
StreamEvent,
)
class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
class AdvancedChatAppGenerateResponseConverter(
AppGenerateResponseConverter[ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse]
):
@classmethod
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
def convert_blocking_full_response(
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
) -> dict[str, Any]:
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
blocking_response = cast(ChatbotAppBlockingResponse, blocking_response)
if isinstance(blocking_response, AdvancedChatPausedBlockingResponse):
paused_data = blocking_response.data.model_dump(mode="json")
return {
"event": StreamEvent.WORKFLOW_PAUSED.value,
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
"message_id": blocking_response.data.message_id,
"conversation_id": blocking_response.data.conversation_id,
"mode": blocking_response.data.mode,
"answer": blocking_response.data.answer,
"metadata": blocking_response.data.metadata,
"created_at": blocking_response.data.created_at,
"workflow_run_id": blocking_response.data.workflow_run_id,
"data": paused_data,
}
response = {
"event": "message",
"event": StreamEvent.MESSAGE.value,
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
"message_id": blocking_response.data.message_id,
@@ -41,7 +59,9 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
def convert_blocking_simple_response(
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
) -> dict[str, Any]:
"""
Convert blocking simple response.
:param blocking_response: blocking response
@@ -50,7 +70,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata)
if isinstance(metadata, dict):
response["metadata"] = cls._get_simple_metadata(metadata)
return response

View File

@@ -53,14 +53,18 @@ from core.app.entities.queue_entities import (
WorkflowQueueMessage,
)
from core.app.entities.task_entities import (
AdvancedChatPausedBlockingResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ErrorStreamResponse,
HumanInputRequiredPauseReasonPayload,
HumanInputRequiredResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
MessageEndStreamResponse,
PingStreamResponse,
StreamResponse,
WorkflowPauseStreamResponse,
WorkflowTaskState,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
@@ -210,7 +214,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if message.status == MessageStatus.PAUSED and message.answer:
self._task_state.answer = message.answer
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
def process(
self,
) -> Union[
ChatbotAppBlockingResponse,
AdvancedChatPausedBlockingResponse,
Generator[ChatbotAppStreamResponse, None, None],
]:
"""
Process generate task pipeline.
:return:
@@ -226,14 +236,39 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
else:
return self._to_blocking_response(generator)
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> ChatbotAppBlockingResponse:
def _to_blocking_response(
self, generator: Generator[StreamResponse, None, None]
) -> Union[ChatbotAppBlockingResponse, AdvancedChatPausedBlockingResponse]:
"""
Process blocking response.
:return:
"""
human_input_responses: list[HumanInputRequiredResponse] = []
for stream_response in generator:
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, HumanInputRequiredResponse):
human_input_responses.append(stream_response)
elif isinstance(stream_response, WorkflowPauseStreamResponse):
return AdvancedChatPausedBlockingResponse(
task_id=stream_response.task_id,
data=AdvancedChatPausedBlockingResponse.Data(
id=self._message_id,
mode=self._conversation_mode,
conversation_id=self._conversation_id,
message_id=self._message_id,
workflow_run_id=stream_response.data.workflow_run_id,
answer=self._task_state.answer,
metadata=self._message_end_to_stream_response().metadata,
created_at=self._message_created_at,
paused_nodes=stream_response.data.paused_nodes,
reasons=stream_response.data.reasons,
status=stream_response.data.status,
elapsed_time=stream_response.data.elapsed_time,
total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps,
),
)
elif isinstance(stream_response, MessageEndStreamResponse):
extras = {}
if stream_response.metadata:
@@ -254,8 +289,41 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
else:
continue
if human_input_responses:
return self._build_paused_blocking_response_from_human_input(human_input_responses)
raise ValueError("queue listening stopped unexpectedly.")
def _build_paused_blocking_response_from_human_input(
self, human_input_responses: list[HumanInputRequiredResponse]
) -> AdvancedChatPausedBlockingResponse:
runtime_state = self._resolve_graph_runtime_state()
paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses))
reasons = [
HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json")
for response in human_input_responses
]
return AdvancedChatPausedBlockingResponse(
task_id=self._application_generate_entity.task_id,
data=AdvancedChatPausedBlockingResponse.Data(
id=self._message_id,
mode=self._conversation_mode,
conversation_id=self._conversation_id,
message_id=self._message_id,
workflow_run_id=human_input_responses[-1].workflow_run_id,
answer=self._task_state.answer,
metadata=self._message_end_to_stream_response().metadata,
created_at=self._message_created_at,
paused_nodes=paused_nodes,
reasons=reasons,
status=WorkflowExecutionStatus.PAUSED,
elapsed_time=time.perf_counter() - self._base_task_pipeline.start_at,
total_tokens=runtime_state.total_tokens,
total_steps=runtime_state.node_run_steps,
),
)
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[ChatbotAppStreamResponse, Any, None]:

View File

@@ -1,6 +1,8 @@
from collections.abc import Generator
from typing import Any, cast
from pydantic import JsonValue
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
@@ -12,11 +14,9 @@ from core.app.entities.task_entities import (
)
class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
@classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking full response.
:param blocking_response: blocking response
@@ -37,7 +37,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking simple response.
:param blocking_response: blocking response
@@ -70,7 +70,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
@@ -101,7 +101,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,

View File

@@ -1,7 +1,9 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping
from typing import Any, Union
from typing import Any, Union, cast
from pydantic import JsonValue
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
@@ -11,8 +13,10 @@ from graphon.model_runtime.errors.invoke import InvokeError
logger = logging.getLogger(__name__)
class AppGenerateResponseConverter(ABC):
_blocking_response_type: type[AppBlockingResponse]
class AppGenerateResponseConverter[TBlockingResponse: AppBlockingResponse](ABC):
@classmethod
def _cast_blocking_response(cls, response: AppBlockingResponse) -> TBlockingResponse:
return cast(TBlockingResponse, response)
@classmethod
def convert(
@@ -20,7 +24,7 @@ class AppGenerateResponseConverter(ABC):
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response)
return cls.convert_blocking_full_response(cls._cast_blocking_response(response))
else:
def _generate_full_response() -> Generator[dict[str, Any] | str, Any, None]:
@@ -29,7 +33,7 @@ class AppGenerateResponseConverter(ABC):
return _generate_full_response()
else:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_simple_response(response)
return cls.convert_blocking_simple_response(cls._cast_blocking_response(response))
else:
def _generate_simple_response() -> Generator[dict[str, Any] | str, Any, None]:
@@ -39,12 +43,12 @@ class AppGenerateResponseConverter(ABC):
@classmethod
@abstractmethod
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
def convert_blocking_full_response(cls, blocking_response: TBlockingResponse) -> dict[str, Any]:
raise NotImplementedError
@classmethod
@abstractmethod
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
def convert_blocking_simple_response(cls, blocking_response: TBlockingResponse) -> dict[str, Any]:
raise NotImplementedError
@classmethod
@@ -106,13 +110,13 @@ class AppGenerateResponseConverter(ABC):
return metadata
@classmethod
def _error_to_stream_response(cls, e: Exception) -> dict[str, Any]:
def _error_to_stream_response(cls, e: Exception) -> dict[str, JsonValue]:
"""
Error to stream response.
:param e: exception
:return:
"""
error_responses: dict[type[Exception], dict[str, Any]] = {
error_responses: dict[type[Exception], dict[str, JsonValue]] = {
ValueError: {"code": "invalid_param", "status": 400},
ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400},
QuotaExceededError: {
@@ -126,7 +130,7 @@ class AppGenerateResponseConverter(ABC):
}
# Determine the response based on the type of exception
data: dict[str, Any] | None = None
data: dict[str, JsonValue] | None = None
for k, v in error_responses.items():
if isinstance(e, k):
data = v

View File

@@ -1,6 +1,8 @@
from collections.abc import Generator
from typing import Any, cast
from pydantic import JsonValue
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
@@ -12,11 +14,9 @@ from core.app.entities.task_entities import (
)
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
@classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking full response.
:param blocking_response: blocking response
@@ -37,7 +37,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking simple response.
:param blocking_response: blocking response
@@ -70,7 +70,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
@@ -101,7 +101,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,

View File

@@ -52,6 +52,7 @@ from core.tools.tool_manager import ToolManager
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
from core.trigger.trigger_manager import TriggerManager
from core.workflow.human_input_forms import load_form_tokens_by_form_id
from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons
from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
@@ -336,7 +337,26 @@ class WorkflowResponseConverter:
except (TypeError, json.JSONDecodeError):
definition_payload = {}
display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui"))
form_token_by_form_id = load_form_tokens_by_form_id(human_input_form_ids, session=session)
form_token_by_form_id = load_form_tokens_by_form_id(
human_input_form_ids,
session=session,
surface=(
HumanInputSurface.SERVICE_API
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API
else None
),
)
# Reconnect paths must preserve the same pause-reason contract as live streams;
# otherwise clients see schema drift after resume.
pause_reasons = enrich_human_input_pause_reasons(
pause_reasons,
form_tokens_by_form_id=form_token_by_form_id,
expiration_times_by_form_id={
form_id: int(expiration_time.timestamp())
for form_id, expiration_time in expiration_times_by_form_id.items()
},
)
responses: list[StreamResponse] = []

View File

@@ -1,6 +1,8 @@
from collections.abc import Generator
from typing import Any, cast
from pydantic import JsonValue
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
@@ -12,17 +14,15 @@ from core.app.entities.task_entities import (
)
class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = CompletionAppBlockingResponse
class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[CompletionAppBlockingResponse]):
@classmethod
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse):
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
response = {
response: dict[str, Any] = {
"event": "message",
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
@@ -36,7 +36,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse):
"""
Convert blocking simple response.
:param blocking_response: blocking response
@@ -69,7 +69,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
@@ -99,7 +99,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value,
"message_id": chunk.message_id,
"created_at": chunk.created_at,

View File

@@ -1,6 +1,7 @@
from collections.abc import Callable, Generator, Mapping
from collections.abc import Callable, Generator, Iterable, Mapping
from core.app.apps.streaming_utils import stream_topic_events
from core.app.entities.task_entities import StreamEvent
from extensions.ext_redis import get_pubsub_broadcast_channel
from libs.broadcast_channel.channel import Topic
from models.model import AppMode
@@ -26,6 +27,7 @@ class MessageGenerator:
idle_timeout=300,
ping_interval: float = 10.0,
on_subscribe: Callable[[], None] | None = None,
terminal_events: Iterable[str | StreamEvent] | None = None,
) -> Generator[Mapping | str, None, None]:
topic = cls.get_response_topic(app_mode, workflow_run_id)
return stream_topic_events(
@@ -33,4 +35,5 @@ class MessageGenerator:
idle_timeout=idle_timeout,
ping_interval=ping_interval,
on_subscribe=on_subscribe,
terminal_events=terminal_events,
)

View File

@@ -13,11 +13,9 @@ from core.app.entities.task_entities import (
)
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[WorkflowAppBlockingResponse]):
@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
"""
Convert blocking full response.
:param blocking_response: blocking response
@@ -26,7 +24,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
return dict(blocking_response.model_dump())
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
"""
Convert blocking simple response.
:param blocking_response: blocking response

View File

@@ -27,7 +27,11 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.app.entities.task_entities import (
WorkflowAppBlockingResponse,
WorkflowAppPausedBlockingResponse,
WorkflowAppStreamResponse,
)
from core.datasource.entities.datasource_entities import (
DatasourceProviderType,
OnlineDriveBrowseFilesRequest,
@@ -627,7 +631,11 @@ class PipelineGenerator(BaseAppGenerator):
user: Account | EndUser,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
) -> (
WorkflowAppBlockingResponse
| WorkflowAppPausedBlockingResponse
| Generator[WorkflowAppStreamResponse, None, None]
):
"""
Handle response.
:param application_generate_entity: application generate entity

View File

@@ -59,7 +59,7 @@ def stream_topic_events(
def _normalize_terminal_events(terminal_events: Iterable[str | StreamEvent] | None) -> set[str]:
if not terminal_events:
if terminal_events is None:
return {StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value}
values: set[str] = set()
for item in terminal_events:

View File

@@ -25,7 +25,11 @@ from core.app.apps.workflow.app_runner import WorkflowAppRunner
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.app.entities.task_entities import (
WorkflowAppBlockingResponse,
WorkflowAppPausedBlockingResponse,
WorkflowAppStreamResponse,
)
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.db.session_factory import session_factory
from core.helper.trace_id_helper import extract_external_trace_id_from_args
@@ -612,7 +616,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
user: Account | EndUser,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
) -> (
WorkflowAppBlockingResponse
| WorkflowAppPausedBlockingResponse
| Generator[WorkflowAppStreamResponse, None, None]
):
"""
Handle response.
:param application_generate_entity: application generate entity

View File

@@ -9,24 +9,29 @@ from core.app.entities.task_entities import (
NodeStartStreamResponse,
PingStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppPausedBlockingResponse,
WorkflowAppStreamResponse,
)
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse
class WorkflowAppGenerateResponseConverter(
AppGenerateResponseConverter[WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse]
):
@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
def convert_blocking_full_response(
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
) -> dict[str, Any]:
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
return blocking_response.model_dump()
return dict(blocking_response.model_dump())
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
def convert_blocking_simple_response(
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
) -> dict[str, Any]:
"""
Convert blocking simple response.
:param blocking_response: blocking response

View File

@@ -42,12 +42,15 @@ from core.app.entities.queue_entities import (
)
from core.app.entities.task_entities import (
ErrorStreamResponse,
HumanInputRequiredPauseReasonPayload,
HumanInputRequiredResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
PingStreamResponse,
StreamResponse,
TextChunkStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppPausedBlockingResponse,
WorkflowAppStreamResponse,
WorkflowFinishStreamResponse,
WorkflowPauseStreamResponse,
@@ -118,7 +121,11 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
)
self._graph_runtime_state: GraphRuntimeState | None = self._base_task_pipeline.queue_manager.graph_runtime_state
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
def process(
self,
) -> Union[
WorkflowAppBlockingResponse, WorkflowAppPausedBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]
]:
"""
Process generate task pipeline.
:return:
@@ -129,19 +136,24 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
else:
return self._to_blocking_response(generator)
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse:
def _to_blocking_response(
self, generator: Generator[StreamResponse, None, None]
) -> Union[WorkflowAppBlockingResponse, WorkflowAppPausedBlockingResponse]:
"""
To blocking response.
:return:
"""
human_input_responses: list[HumanInputRequiredResponse] = []
for stream_response in generator:
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, HumanInputRequiredResponse):
human_input_responses.append(stream_response)
elif isinstance(stream_response, WorkflowPauseStreamResponse):
response = WorkflowAppBlockingResponse(
return WorkflowAppPausedBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=stream_response.data.workflow_run_id,
data=WorkflowAppBlockingResponse.Data(
data=WorkflowAppPausedBlockingResponse.Data(
id=stream_response.data.workflow_run_id,
workflow_id=self._workflow.id,
status=stream_response.data.status,
@@ -152,12 +164,13 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
total_steps=stream_response.data.total_steps,
created_at=stream_response.data.created_at,
finished_at=None,
paused_nodes=stream_response.data.paused_nodes,
reasons=stream_response.data.reasons,
),
)
return response
elif isinstance(stream_response, WorkflowFinishStreamResponse):
response = WorkflowAppBlockingResponse(
return WorkflowAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=stream_response.data.id,
data=WorkflowAppBlockingResponse.Data(
@@ -174,12 +187,44 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
),
)
return response
else:
continue
if human_input_responses:
return self._build_paused_blocking_response_from_human_input(human_input_responses)
raise ValueError("queue listening stopped unexpectedly.")
def _build_paused_blocking_response_from_human_input(
self, human_input_responses: list[HumanInputRequiredResponse]
) -> WorkflowAppPausedBlockingResponse:
runtime_state = self._resolve_graph_runtime_state()
paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses))
created_at = int(runtime_state.start_at)
reasons = [
HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json")
for response in human_input_responses
]
return WorkflowAppPausedBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=human_input_responses[-1].workflow_run_id,
data=WorkflowAppPausedBlockingResponse.Data(
id=human_input_responses[-1].workflow_run_id,
workflow_id=self._workflow.id,
status=WorkflowExecutionStatus.PAUSED,
outputs={},
error=None,
elapsed_time=time.perf_counter() - self._base_task_pipeline.start_at,
total_tokens=runtime_state.total_tokens,
total_steps=runtime_state.node_run_steps,
created_at=created_at,
finished_at=None,
paused_nodes=paused_nodes,
reasons=reasons,
),
)
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[WorkflowAppStreamResponse, None, None]:

View File

@@ -1,12 +1,13 @@
from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import Any
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, JsonValue
from core.app.entities.agent_strategy import AgentStrategyInfo
from core.rag.entities import RetrievalSourceMetadata
from graphon.entities import WorkflowStartReason
from graphon.entities.pause_reason import PauseReasonType
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from graphon.nodes.human_input.entities import FormInput, UserAction
@@ -295,6 +296,40 @@ class HumanInputRequiredResponse(StreamResponse):
data: Data
class HumanInputRequiredPauseReasonPayload(BaseModel):
"""
Public pause-reason payload used by blocking responses when only
``human_input_required`` events are available.
"""
TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED
form_id: str
node_id: str
node_title: str
form_content: str
inputs: Sequence[FormInput] = Field(default_factory=list)
actions: Sequence[UserAction] = Field(default_factory=list)
display_in_ui: bool = False
form_token: str | None = None
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
expiration_time: int
@classmethod
def from_response_data(cls, data: HumanInputRequiredResponse.Data) -> "HumanInputRequiredPauseReasonPayload":
return cls(
form_id=data.form_id,
node_id=data.node_id,
node_title=data.node_title,
form_content=data.form_content,
inputs=data.inputs,
actions=data.actions,
display_in_ui=data.display_in_ui,
form_token=data.form_token,
resolved_default_values=data.resolved_default_values,
expiration_time=data.expiration_time,
)
class HumanInputFormFilledResponse(StreamResponse):
class Data(BaseModel):
"""
@@ -355,7 +390,7 @@ class NodeStartStreamResponse(StreamResponse):
workflow_run_id: str
data: Data
def to_ignore_detail_dict(self):
def to_ignore_detail_dict(self) -> dict[str, JsonValue]:
return {
"event": self.event.value,
"task_id": self.task_id,
@@ -412,7 +447,7 @@ class NodeFinishStreamResponse(StreamResponse):
workflow_run_id: str
data: Data
def to_ignore_detail_dict(self):
def to_ignore_detail_dict(self) -> dict[str, JsonValue]:
return {
"event": self.event.value,
"task_id": self.task_id,
@@ -774,6 +809,34 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
data: Data
class AdvancedChatPausedBlockingResponse(AppBlockingResponse):
"""
ChatbotAppPausedBlockingResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
id: str
mode: str
conversation_id: str
message_id: str
workflow_run_id: str
answer: str
metadata: Mapping[str, object] = Field(default_factory=dict)
created_at: int
paused_nodes: Sequence[str] = Field(default_factory=list)
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list[Mapping[str, Any]])
status: WorkflowExecutionStatus
elapsed_time: float
total_tokens: int
total_steps: int
data: Data
class CompletionAppBlockingResponse(AppBlockingResponse):
"""
CompletionAppBlockingResponse entity
@@ -819,6 +882,33 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
data: Data
class WorkflowAppPausedBlockingResponse(AppBlockingResponse):
"""
WorkflowAppPausedBlockingResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
id: str
workflow_id: str
status: WorkflowExecutionStatus
outputs: Mapping[str, Any] | None = None
error: str | None = None
elapsed_time: float
total_tokens: int
total_steps: int
created_at: int
finished_at: int | None
paused_nodes: Sequence[str] = Field(default_factory=list)
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list)
workflow_run_id: str
data: Data
class AgentLogStreamResponse(StreamResponse):
"""
AgentLogStreamResponse entity

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
from copy import deepcopy
from typing import Any
from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity
@@ -14,8 +15,21 @@ from graphon.nodes.llm.protocols import CredentialsProvider
class DifyCredentialsProvider:
"""Resolves and returns LLM credentials for a given provider and model.
Fetched credentials are stored in :attr:`credentials_cache` and reused for
subsequent ``fetch`` calls for the same ``(provider_name, model_name)``.
Because of that cache, a single instance can return stale credentials after
the tenant or provider configuration changes (e.g. API key rotation).
Do **not** keep one instance for the lifetime of a process or across
unrelated invocations. Create a new provider per request, workflow run, or
other bounded scope where up-to-date credentials matter.
"""
tenant_id: str
provider_manager: ProviderManager
credentials_cache: dict[tuple[str, str], dict[str, Any]]
def __init__(
self,
@@ -30,8 +44,12 @@ class DifyCredentialsProvider:
user_id=run_context.user_id,
)
self.provider_manager = provider_manager
self.credentials_cache = {}
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
if (provider_name, model_name) in self.credentials_cache:
return deepcopy(self.credentials_cache[(provider_name, model_name)])
provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
provider_configuration = provider_configurations.get(provider_name)
if not provider_configuration:
@@ -46,6 +64,7 @@ class DifyCredentialsProvider:
if credentials is None:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
self.credentials_cache[(provider_name, model_name)] = deepcopy(credentials)
return credentials
@@ -65,7 +84,8 @@ class DifyModelFactory:
provider_manager=create_plugin_provider_manager(
tenant_id=run_context.tenant_id,
user_id=run_context.user_id,
)
),
enable_credentials_cache=True,
)
self.model_manager = model_manager
@@ -84,7 +104,7 @@ def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsPro
tenant_id=run_context.tenant_id,
user_id=run_context.user_id,
)
model_manager = ModelManager(provider_manager=provider_manager)
model_manager = ModelManager(provider_manager=provider_manager, enable_credentials_cache=True)
return (
DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager),

View File

@@ -0,0 +1,41 @@
"""
Helper module for Creators Platform integration.
Provides functionality to upload DSL files to the Creators Platform
and generate redirect URLs with OAuth authorization codes.
"""
import logging
from urllib.parse import urlencode
import httpx
from yarl import URL
from configs import dify_config
logger = logging.getLogger(__name__)
creators_platform_api_url = URL(str(dify_config.CREATORS_PLATFORM_API_URL))
def upload_dsl(dsl_file_bytes: bytes, filename: str = "template.yaml") -> str:
url = str(creators_platform_api_url / "api/v1/templates/anonymous-upload")
response = httpx.post(url, files={"file": (filename, dsl_file_bytes)}, timeout=30)
response.raise_for_status()
data = response.json()
claim_code = data.get("data", {}).get("claim_code")
if not claim_code:
raise ValueError("Creators Platform did not return a valid claim_code")
return claim_code
def get_redirect_url(user_account_id: str, claim_code: str) -> str:
base_url = str(dify_config.CREATORS_PLATFORM_API_URL).rstrip("/")
params: dict[str, str] = {"dsl_claim_code": claim_code}
client_id = str(dify_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID or "")
if client_id:
from services.oauth_server import OAuthServerService
oauth_code = OAuthServerService.sign_oauth_authorization_code(client_id, user_account_id)
params["oauth_code"] = oauth_code
return f"{base_url}?{urlencode(params)}"

View File

@@ -13,8 +13,6 @@ from core.llm_generator.output_parser.rule_config_generator import RuleConfigGen
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
from core.llm_generator.prompts import (
CONVERSATION_TITLE_PROMPT,
DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS,
DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE,
GENERATOR_QA_PROMPT,
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
LLM_MODIFY_CODE_SYSTEM,
@@ -217,8 +215,8 @@ class LLMGenerator:
else:
# Default-model generation keeps the built-in suggested-questions tuning.
model_parameters = {
"max_tokens": DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS,
"temperature": DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE,
"max_tokens": 2560,
"temperature": 0.0,
}
stop = []

View File

@@ -10,7 +10,14 @@ logger = logging.getLogger(__name__)
class SuggestedQuestionsAfterAnswerOutputParser:
def __init__(self, instruction_prompt: str | None = None) -> None:
self._instruction_prompt = instruction_prompt or DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
self._instruction_prompt = self._build_instruction_prompt(instruction_prompt)
@staticmethod
def _build_instruction_prompt(instruction_prompt: str | None) -> str:
if not instruction_prompt or not instruction_prompt.strip():
return DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
return f'{instruction_prompt}\nYou must output a JSON array like ["question1", "question2", "question3"].'
def get_format_instructions(self) -> str:
return self._instruction_prompt

View File

@@ -104,9 +104,6 @@ DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
'["question1","question2","question3"]\n'
)
DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS = 256
DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE = 0.0
GENERATOR_QA_PROMPT = (
"<Task> The user will send a long text. Generate a Question and Answer pairs only using the knowledge"
" in the long text. Please think step by step."

View File

@@ -1,5 +1,6 @@
import logging
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
from copy import deepcopy
from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload
from configs import dify_config
@@ -36,11 +37,13 @@ class ModelInstance:
Model instance class.
"""
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str, credentials: dict | None = None) -> None:
self.provider_model_bundle = provider_model_bundle
self.model_name = model
self.provider = provider_model_bundle.configuration.provider.provider
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
if credentials is None:
credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
self.credentials = credentials
# Runtime LLM invocation fields.
self.parameters: Mapping[str, Any] = {}
self.stop: Sequence[str] = ()
@@ -434,8 +437,30 @@ class ModelInstance:
class ModelManager:
def __init__(self, provider_manager: ProviderManager):
"""Resolves :class:`ModelInstance` objects for a tenant and provider.
When ``enable_credentials_cache`` is ``True``, resolved credentials for each
``(tenant_id, provider, model_type, model)`` are stored in
``_credentials_cache`` and reused. That can return **stale** credentials after
API keys or provider settings change, so a manager constructed with
``enable_credentials_cache=True`` should not be kept for the lifetime of a
process or shared across unrelated work. Prefer a new manager per request,
workflow run, or similar bounded scope.
The default is ``enable_credentials_cache=False``; in that mode the internal
credential cache is not populated, and each ``get_model_instance`` call
loads credentials from the current provider configuration.
"""
def __init__(
self,
provider_manager: ProviderManager,
*,
enable_credentials_cache: bool = False,
) -> None:
self._provider_manager = provider_manager
self._credentials_cache: dict[tuple[str, str, str, str], Any] = {}
self._enable_credentials_cache = enable_credentials_cache
@classmethod
def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager":
@@ -463,8 +488,19 @@ class ModelManager:
tenant_id=tenant_id, provider=provider, model_type=model_type
)
model_instance = ModelInstance(provider_model_bundle, model)
return model_instance
cred_cache_key = (tenant_id, provider, model_type.value, model)
if cred_cache_key in self._credentials_cache:
return ModelInstance(
provider_model_bundle,
model,
deepcopy(self._credentials_cache[cred_cache_key]),
)
ret = ModelInstance(provider_model_bundle, model)
if self._enable_credentials_cache:
self._credentials_cache[cred_cache_key] = deepcopy(ret.credentials)
return ret
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]:
"""

View File

@@ -156,7 +156,8 @@ class Jieba(BaseKeyword):
if dataset_keyword_table:
keyword_table_dict = dataset_keyword_table.keyword_table_dict
if keyword_table_dict:
return dict(keyword_table_dict["__data__"]["table"])
data: Any = keyword_table_dict["__data__"]
return dict(data["table"])
else:
keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE
dataset_keyword_table = DatasetKeywordTable(

View File

@@ -109,7 +109,7 @@ class JiebaKeywordTableHandler:
"""Extract keywords with JIEBA tfidf."""
keywords = self._tfidf.extract_tags(
sentence=text,
topK=max_keywords_per_chunk,
topK=max_keywords_per_chunk or 10,
)
# jieba.analyse.extract_tags returns an untyped list when withFlag is False by default.
keywords = cast(list[str], keywords)

View File

@@ -551,6 +551,7 @@ class RetrievalService:
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
for i in child_index_nodes:
assert i.index_node_id
segment_ids.append(i.segment_id)
if i.segment_id in child_chunk_map:
child_chunk_map[i.segment_id].append(i)

View File

@@ -11,6 +11,7 @@ from core.rag.models.document import AttachmentDocument, Document
from extensions.ext_database import db
from graphon.model_runtime.entities.model_entities import ModelType
from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding
from models.enums import SegmentType
class DatasetDocumentStore:
@@ -127,6 +128,7 @@ class DatasetDocumentStore:
if save_child:
if doc.children:
for position, child in enumerate(doc.children, start=1):
assert self._document_id
child_segment = ChildChunk(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
@@ -137,7 +139,7 @@ class DatasetDocumentStore:
index_node_hash=child.metadata.get("doc_hash"),
content=child.page_content,
word_count=len(child.page_content),
type="automatic",
type=SegmentType.AUTOMATIC,
created_by=self._user_id,
)
db.session.add(child_segment)
@@ -163,6 +165,7 @@ class DatasetDocumentStore:
)
# add new child chunks
for position, child in enumerate(doc.children, start=1):
assert self._document_id
child_segment = ChildChunk(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
@@ -173,7 +176,7 @@ class DatasetDocumentStore:
index_node_hash=child.metadata.get("doc_hash"),
content=child.page_content,
word_count=len(child.page_content),
type="automatic",
type=SegmentType.AUTOMATIC,
created_by=self._user_id,
)
db.session.add(child_segment)

View File

@@ -31,7 +31,7 @@ class FunctionCallMultiDatasetRouter:
result: LLMResult = model_instance.invoke_llm( # pyright: ignore[reportCallIssue, reportArgumentType]
prompt_messages=prompt_messages,
tools=dataset_tools,
stream=False,
stream=False, # pyright: ignore[reportArgumentType]
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
)
usage = result.usage or LLMUsage.empty_usage()

View File

@@ -14,23 +14,23 @@ from configs import dify_config
logger = logging.getLogger(__name__)
class OAuthEncryptionError(Exception):
"""OAuth encryption/decryption specific error"""
class EncryptionError(Exception):
"""Encryption/decryption specific error"""
pass
class SystemOAuthEncrypter:
class SystemEncrypter:
"""
A simple OAuth parameters encrypter using AES-CBC encryption.
A simple parameters encrypter using AES-CBC encryption.
This class provides methods to encrypt and decrypt OAuth parameters
This class provides methods to encrypt and decrypt parameters
using AES-CBC mode with a key derived from the application's SECRET_KEY.
"""
def __init__(self, secret_key: str | None = None):
"""
Initialize the OAuth encrypter.
Initialize the encrypter.
Args:
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
@@ -43,19 +43,19 @@ class SystemOAuthEncrypter:
# Generate a fixed 256-bit key using SHA-256
self.key = hashlib.sha256(secret_key.encode()).digest()
def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str:
def encrypt_params(self, params: Mapping[str, Any]) -> str:
"""
Encrypt OAuth parameters.
Encrypt parameters.
Args:
oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
params: Parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
Returns:
Base64-encoded encrypted string
Raises:
OAuthEncryptionError: If encryption fails
ValueError: If oauth_params is invalid
EncryptionError: If encryption fails
ValueError: If params is invalid
"""
try:
@@ -66,7 +66,7 @@ class SystemOAuthEncrypter:
cipher = AES.new(self.key, AES.MODE_CBC, iv)
# Encrypt data
padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size)
padded_data = pad(TypeAdapter(dict).dump_json(dict(params)), AES.block_size)
encrypted_data = cipher.encrypt(padded_data)
# Combine IV and encrypted data
@@ -76,20 +76,20 @@ class SystemOAuthEncrypter:
return base64.b64encode(combined).decode()
except Exception as e:
raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e
raise EncryptionError(f"Encryption failed: {str(e)}") from e
def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]:
def decrypt_params(self, encrypted_data: str) -> Mapping[str, Any]:
"""
Decrypt OAuth parameters.
Decrypt parameters.
Args:
encrypted_data: Base64-encoded encrypted string
Returns:
Decrypted OAuth parameters dictionary
Decrypted parameters dictionary
Raises:
OAuthEncryptionError: If decryption fails
EncryptionError: If decryption fails
ValueError: If encrypted_data is invalid
"""
if not isinstance(encrypted_data, str):
@@ -118,70 +118,70 @@ class SystemOAuthEncrypter:
unpadded_data = unpad(decrypted_data, AES.block_size)
# Parse JSON
oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
if not isinstance(oauth_params, dict):
if not isinstance(params, dict):
raise ValueError("Decrypted data is not a valid dictionary")
return oauth_params
return params
except Exception as e:
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
raise EncryptionError(f"Decryption failed: {str(e)}") from e
# Factory function for creating encrypter instances
def create_system_oauth_encrypter(secret_key: str | None = None) -> SystemOAuthEncrypter:
def create_system_encrypter(secret_key: str | None = None) -> SystemEncrypter:
"""
Create an OAuth encrypter instance.
Create an encrypter instance.
Args:
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
Returns:
SystemOAuthEncrypter instance
SystemEncrypter instance
"""
return SystemOAuthEncrypter(secret_key=secret_key)
return SystemEncrypter(secret_key=secret_key)
# Global encrypter instance (for backward compatibility)
_oauth_encrypter: SystemOAuthEncrypter | None = None
_encrypter: SystemEncrypter | None = None
def get_system_oauth_encrypter() -> SystemOAuthEncrypter:
def get_system_encrypter() -> SystemEncrypter:
"""
Get the global OAuth encrypter instance.
Get the global encrypter instance.
Returns:
SystemOAuthEncrypter instance
SystemEncrypter instance
"""
global _oauth_encrypter
if _oauth_encrypter is None:
_oauth_encrypter = SystemOAuthEncrypter()
return _oauth_encrypter
global _encrypter
if _encrypter is None:
_encrypter = SystemEncrypter()
return _encrypter
# Convenience functions for backward compatibility
def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str:
def encrypt_system_params(params: Mapping[str, Any]) -> str:
"""
Encrypt OAuth parameters using the global encrypter.
Encrypt parameters using the global encrypter.
Args:
oauth_params: OAuth parameters dictionary
params: Parameters dictionary
Returns:
Base64-encoded encrypted string
"""
return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params)
return get_system_encrypter().encrypt_params(params)
def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]:
def decrypt_system_params(encrypted_data: str) -> Mapping[str, Any]:
"""
Decrypt OAuth parameters using the global encrypter.
Decrypt parameters using the global encrypter.
Args:
encrypted_data: Base64-encoded encrypted string
Returns:
Decrypted OAuth parameters dictionary
Decrypted parameters dictionary
"""
return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data)
return get_system_encrypter().decrypt_params(encrypted_data)

View File

@@ -12,20 +12,16 @@ from collections.abc import Sequence
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.workflow.human_input_policy import HumanInputSurface, get_preferred_form_token
from extensions.ext_database import db
from models.human_input import HumanInputFormRecipient, RecipientType
_FORM_TOKEN_PRIORITY = {
RecipientType.BACKSTAGE: 0,
RecipientType.CONSOLE: 1,
RecipientType.STANDALONE_WEB_APP: 2,
}
def load_form_tokens_by_form_id(
form_ids: Sequence[str],
*,
session: Session | None = None,
surface: HumanInputSurface | None = None,
) -> dict[str, str]:
"""Load the preferred access token for each human input form."""
unique_form_ids = list(dict.fromkeys(form_ids))
@@ -33,23 +29,43 @@ def load_form_tokens_by_form_id(
return {}
if session is not None:
return _load_form_tokens_by_form_id(session, unique_form_ids)
return _load_form_tokens_by_form_id(session, unique_form_ids, surface=surface)
with Session(bind=db.engine, expire_on_commit=False) as new_session:
return _load_form_tokens_by_form_id(new_session, unique_form_ids)
return _load_form_tokens_by_form_id(new_session, unique_form_ids, surface=surface)
def _load_form_tokens_by_form_id(session: Session, form_ids: Sequence[str]) -> dict[str, str]:
tokens_by_form_id: dict[str, tuple[int, str]] = {}
def _load_form_tokens_by_form_id(
session: Session,
form_ids: Sequence[str],
*,
surface: HumanInputSurface | None = None,
) -> dict[str, str]:
recipients_by_form_id: dict[str, list[tuple[RecipientType, str]]] = {}
stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
for recipient in session.scalars(stmt):
priority = _FORM_TOKEN_PRIORITY.get(recipient.recipient_type)
if priority is None or not recipient.access_token:
if not recipient.access_token:
continue
recipients_by_form_id.setdefault(recipient.form_id, []).append(
(recipient.recipient_type, recipient.access_token)
)
candidate = (priority, recipient.access_token)
current = tokens_by_form_id.get(recipient.form_id)
if current is None or candidate[0] < current[0]:
tokens_by_form_id[recipient.form_id] = candidate
tokens_by_form_id: dict[str, str] = {}
for form_id, recipients in recipients_by_form_id.items():
token = _get_surface_form_token(recipients, surface=surface)
if token is not None:
tokens_by_form_id[form_id] = token
return tokens_by_form_id
return {form_id: token for form_id, (_, token) in tokens_by_form_id.items()}
def _get_surface_form_token(
recipients: Sequence[tuple[RecipientType, str]],
*,
surface: HumanInputSurface | None,
) -> str | None:
if surface == HumanInputSurface.SERVICE_API:
for recipient_type, token in recipients:
if recipient_type == RecipientType.STANDALONE_WEB_APP and token:
return token
return get_preferred_form_token(recipients)

View File

@@ -0,0 +1,73 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import Any
from graphon.entities.pause_reason import PauseReasonType
from models.human_input import RecipientType
class HumanInputSurface(StrEnum):
SERVICE_API = "service_api"
CONSOLE = "console"
# Service API is intentionally narrower than other surfaces: app-token callers
# should only be able to act on end-user web forms, not internal console flows.
_ALLOWED_RECIPIENT_TYPES_BY_SURFACE: dict[HumanInputSurface, frozenset[RecipientType]] = {
HumanInputSurface.SERVICE_API: frozenset({RecipientType.STANDALONE_WEB_APP}),
HumanInputSurface.CONSOLE: frozenset({RecipientType.CONSOLE, RecipientType.BACKSTAGE}),
}
# A single HITL form can have multiple recipient records; this shared priority
# keeps every API surface consistent about which resume token to expose.
_RECIPIENT_TOKEN_PRIORITY: dict[RecipientType, int] = {
RecipientType.BACKSTAGE: 0,
RecipientType.CONSOLE: 1,
RecipientType.STANDALONE_WEB_APP: 2,
}
def is_recipient_type_allowed_for_surface(
recipient_type: RecipientType | None,
surface: HumanInputSurface,
) -> bool:
if recipient_type is None:
return False
return recipient_type in _ALLOWED_RECIPIENT_TYPES_BY_SURFACE[surface]
def get_preferred_form_token(
recipients: Sequence[tuple[RecipientType, str]],
) -> str | None:
chosen_token: str | None = None
chosen_priority: int | None = None
for recipient_type, token in recipients:
priority = _RECIPIENT_TOKEN_PRIORITY.get(recipient_type)
if priority is None or not token:
continue
if chosen_priority is None or priority < chosen_priority:
chosen_priority = priority
chosen_token = token
return chosen_token
def enrich_human_input_pause_reasons(
reasons: Sequence[Mapping[str, Any]],
*,
form_tokens_by_form_id: Mapping[str, str],
expiration_times_by_form_id: Mapping[str, int],
) -> list[dict[str, Any]]:
enriched: list[dict[str, Any]] = []
for reason in reasons:
updated = dict(reason)
if updated.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED:
form_id = updated.get("form_id")
if isinstance(form_id, str):
updated["form_token"] = form_tokens_by_form_id.get(form_id)
expiration_time = expiration_times_by_form_id.get(form_id)
if expiration_time is not None:
updated["expiration_time"] = expiration_time
enriched.append(updated)
return enriched

View File

@@ -1,56 +1,17 @@
import logging
from dataclasses import dataclass
from enum import StrEnum, auto
logger = logging.getLogger(__name__)
@dataclass
class QuotaCharge:
"""
Result of a quota consumption operation.
Attributes:
success: Whether the quota charge succeeded
charge_id: UUID for refund, or None if failed/disabled
"""
success: bool
charge_id: str | None
_quota_type: "QuotaType"
def refund(self) -> None:
"""
Refund this quota charge.
Safe to call even if charge failed or was disabled.
This method guarantees no exceptions will be raised.
"""
if self.charge_id:
self._quota_type.refund(self.charge_id)
logger.info("Refunded quota for %s with charge_id: %s", self._quota_type.value, self.charge_id)
class QuotaType(StrEnum):
"""
Supported quota types for tenant feature usage.
Add additional types here whenever new billable features become available.
"""
# Trigger execution quota
TRIGGER = auto()
# Workflow execution quota
WORKFLOW = auto()
UNLIMITED = auto()
@property
def billing_key(self) -> str:
"""
Get the billing key for the feature.
"""
match self:
case QuotaType.TRIGGER:
return "trigger_event"
@@ -58,152 +19,3 @@ class QuotaType(StrEnum):
return "api_rate_limit"
case _:
raise ValueError(f"Invalid quota type: {self}")
def consume(self, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Consume quota for the feature.
Args:
tenant_id: The tenant identifier
amount: Amount to consume (default: 1)
Returns:
QuotaCharge with success status and charge_id for refund
Raises:
QuotaExceededError: When quota is insufficient
"""
from configs import dify_config
from services.billing_service import BillingService
from services.errors.app import QuotaExceededError
if not dify_config.BILLING_ENABLED:
logger.debug("Billing disabled, allowing request for %s", tenant_id)
return QuotaCharge(success=True, charge_id=None, _quota_type=self)
logger.info("Consuming %d %s quota for tenant %s", amount, self.value, tenant_id)
if amount <= 0:
raise ValueError("Amount to consume must be greater than 0")
try:
response = BillingService.update_tenant_feature_plan_usage(tenant_id, self.billing_key, delta=amount)
if response.get("result") != "success":
logger.warning(
"Failed to consume quota for %s, feature %s details: %s",
tenant_id,
self.value,
response.get("detail"),
)
raise QuotaExceededError(feature=self.value, tenant_id=tenant_id, required=amount)
charge_id = response.get("history_id")
logger.debug(
"Successfully consumed %d %s quota for tenant %s, charge_id: %s",
amount,
self.value,
tenant_id,
charge_id,
)
return QuotaCharge(success=True, charge_id=charge_id, _quota_type=self)
except QuotaExceededError:
raise
except Exception:
# fail-safe: allow request on billing errors
logger.exception("Failed to consume quota for %s, feature %s", tenant_id, self.value)
return unlimited()
def check(self, tenant_id: str, amount: int = 1) -> bool:
"""
Check if tenant has sufficient quota without consuming.
Args:
tenant_id: The tenant identifier
amount: Amount to check (default: 1)
Returns:
True if quota is sufficient, False otherwise
"""
from configs import dify_config
if not dify_config.BILLING_ENABLED:
return True
if amount <= 0:
raise ValueError("Amount to check must be greater than 0")
try:
remaining = self.get_remaining(tenant_id)
return remaining >= amount if remaining != -1 else True
except Exception:
logger.exception("Failed to check quota for %s, feature %s", tenant_id, self.value)
# fail-safe: allow request on billing errors
return True
def refund(self, charge_id: str) -> None:
"""
Refund quota using charge_id from consume().
This method guarantees no exceptions will be raised.
All errors are logged but silently handled.
Args:
charge_id: The UUID returned from consume()
"""
try:
from configs import dify_config
from services.billing_service import BillingService
if not dify_config.BILLING_ENABLED:
return
if not charge_id:
logger.warning("Cannot refund: charge_id is empty")
return
logger.info("Refunding %s quota with charge_id: %s", self.value, charge_id)
response = BillingService.refund_tenant_feature_plan_usage(charge_id)
if response.get("result") == "success":
logger.debug("Successfully refunded %s quota, charge_id: %s", self.value, charge_id)
else:
logger.warning("Refund failed for charge_id: %s", charge_id)
except Exception:
# Catch ALL exceptions - refund must never fail
logger.exception("Failed to refund quota for charge_id: %s", charge_id)
# Don't raise - refund is best-effort and must be silent
def get_remaining(self, tenant_id: str) -> int:
"""
Get remaining quota for the tenant.
Args:
tenant_id: The tenant identifier
Returns:
Remaining quota amount
"""
from services.billing_service import BillingService
try:
usage_info = BillingService.get_tenant_feature_plan_usage(tenant_id, self.billing_key)
# Assuming the API returns a dict with 'remaining' or 'limit' and 'used'
if isinstance(usage_info, dict):
return usage_info.get("remaining", 0)
# If it returns a simple number, treat it as remaining
return int(usage_info) if usage_info else 0
except Exception:
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, self.value)
return -1
def unlimited() -> QuotaCharge:
"""
Return a quota charge for unlimited quota.
This is useful for features that are not subject to quota limits, such as the UNLIMITED quota type.
"""
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)

View File

@@ -1036,7 +1036,7 @@ class DocumentSegment(Base):
return attachment_list
class ChildChunk(Base):
class ChildChunk(TypeBase):
__tablename__ = "child_chunks"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="child_chunk_pkey"),
@@ -1046,29 +1046,42 @@ class ChildChunk(Base):
)
# initial fields
id = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
tenant_id = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False)
document_id = mapped_column(StringUUID, nullable=False)
segment_id = mapped_column(StringUUID, nullable=False)
id: Mapped[str] = mapped_column(StringUUID, nullable=False, default_factory=lambda: str(uuid4()), init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
content = mapped_column(LongText, nullable=False)
content: Mapped[str] = mapped_column(LongText, nullable=False)
word_count: Mapped[int] = mapped_column(sa.Integer, nullable=False)
# indexing fields
index_node_id = mapped_column(String(255), nullable=True)
index_node_hash = mapped_column(String(255), nullable=True)
type: Mapped[SegmentType] = mapped_column(
EnumText(SegmentType, length=255), nullable=False, server_default=sa.text("'automatic'")
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False
)
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, init=False)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=sa.func.current_timestamp(), onupdate=func.current_timestamp()
DateTime,
nullable=False,
server_default=sa.func.current_timestamp(),
onupdate=func.current_timestamp(),
init=False,
)
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
error = mapped_column(LongText, nullable=True)
indexing_at: Mapped[datetime | None] = mapped_column(
DateTime, nullable=True, insert_default=None, server_default=None, init=False
)
completed_at: Mapped[datetime | None] = mapped_column(
DateTime, nullable=True, insert_default=None, server_default=None, init=False
)
index_node_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
index_node_hash: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
type: Mapped[SegmentType] = mapped_column(
EnumText(SegmentType, length=255),
nullable=False,
server_default=sa.text("'automatic'"),
default=SegmentType.AUTOMATIC,
)
error: Mapped[str | None] = mapped_column(LongText, nullable=True, init=False)
@property
def dataset(self):

View File

@@ -1867,15 +1867,18 @@ class MessageAnnotation(TypeBase):
)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
StringUUID,
insert_default=lambda: str(uuid4()),
default_factory=lambda: str(uuid4()),
init=False,
)
app_id: Mapped[str] = mapped_column(StringUUID)
question: Mapped[str] = mapped_column(LongText, nullable=False)
content: Mapped[str] = mapped_column(LongText, nullable=False)
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"), init=False)
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), default=None)
message_id: Mapped[str | None] = mapped_column(StringUUID, default=None)
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"), default=0)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)

View File

@@ -225,8 +225,10 @@ class TestSpanBuilder:
span = builder.build_span(span_data)
assert isinstance(span, ReadableSpan)
assert span.name == "test-span"
assert span.context is not None
assert span.context.trace_id == 123
assert span.context.span_id == 456
assert span.parent is not None
assert span.parent.span_id == 789
assert span.resource == resource
assert span.attributes == {"attr1": "val1"}

View File

@@ -64,12 +64,13 @@ class TestSpanData:
def test_span_data_missing_required_fields(self):
with pytest.raises(ValidationError):
SpanData(
trace_id=123,
# span_id missing
name="test_span",
start_time=1000,
end_time=2000,
SpanData.model_validate(
{
"trace_id": 123,
"name": "test_span",
"start_time": 1000,
"end_time": 2000,
}
)
def test_span_data_arbitrary_types_allowed(self):

View File

@@ -2,12 +2,14 @@ from __future__ import annotations
from datetime import UTC, datetime
from types import SimpleNamespace
from typing import cast
from unittest.mock import MagicMock
import dify_trace_aliyun.aliyun_trace as aliyun_trace_module
import pytest
from dify_trace_aliyun.aliyun_trace import AliyunDataTrace
from dify_trace_aliyun.config import AliyunConfig
from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata
from dify_trace_aliyun.entities.semconv import (
GEN_AI_COMPLETION,
GEN_AI_INPUT_MESSAGE,
@@ -44,7 +46,7 @@ class RecordingTraceClient:
self.endpoint = endpoint
self.added_spans: list[object] = []
def add_span(self, span) -> None:
def add_span(self, span: object) -> None:
self.added_spans.append(span)
def api_check(self) -> bool:
@@ -63,11 +65,35 @@ def _make_link(trace_id: int = 1, span_id: int = 2) -> Link:
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags.SAMPLED,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
)
return Link(context)
def _make_trace_metadata(
trace_id: int = 1,
workflow_span_id: int = 2,
session_id: str = "s",
user_id: str = "u",
links: list[Link] | None = None,
) -> TraceMetadata:
return TraceMetadata(
trace_id=trace_id,
workflow_span_id=workflow_span_id,
session_id=session_id,
user_id=user_id,
links=[] if links is None else links,
)
def _recording_trace_client(trace_instance: AliyunDataTrace) -> RecordingTraceClient:
return cast(RecordingTraceClient, trace_instance.trace_client)
def _recorded_span_data(trace_instance: AliyunDataTrace) -> list[SpanData]:
return cast(list[SpanData], _recording_trace_client(trace_instance).added_spans)
def _make_workflow_trace_info(**overrides) -> WorkflowTraceInfo:
defaults = {
"workflow_id": "workflow-id",
@@ -263,20 +289,20 @@ def test_workflow_trace_adds_workflow_and_node_spans(trace_instance: AliyunDataT
trace_instance.workflow_trace(trace_info)
add_workflow_span.assert_called_once()
passed_trace_metadata = add_workflow_span.call_args.args[1]
passed_trace_metadata = cast(TraceMetadata, add_workflow_span.call_args.args[1])
assert passed_trace_metadata.trace_id == 111
assert passed_trace_metadata.workflow_span_id == 222
assert passed_trace_metadata.session_id == "c"
assert passed_trace_metadata.user_id == "u"
assert passed_trace_metadata.links == []
assert trace_instance.trace_client.added_spans == ["span-1", "span-2"]
assert _recording_trace_client(trace_instance).added_spans == ["span-1", "span-2"]
def test_message_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
trace_info = _make_message_trace_info(message_data=None)
trace_instance.message_trace(trace_info)
assert trace_instance.trace_client.added_spans == []
assert _recording_trace_client(trace_instance).added_spans == []
def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
@@ -302,8 +328,9 @@ def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataT
)
trace_instance.message_trace(trace_info)
assert len(trace_instance.trace_client.added_spans) == 2
message_span, llm_span = trace_instance.trace_client.added_spans
spans = _recorded_span_data(trace_instance)
assert len(spans) == 2
message_span, llm_span = spans
assert message_span.name == "message"
assert message_span.trace_id == 10
@@ -324,7 +351,7 @@ def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataT
def test_dataset_retrieval_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
trace_info = _make_dataset_retrieval_trace_info(message_data=None)
trace_instance.dataset_retrieval_trace(trace_info)
assert trace_instance.trace_client.added_spans == []
assert _recording_trace_client(trace_instance).added_spans == []
def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
@@ -338,8 +365,9 @@ def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, m
monkeypatch.setattr(aliyun_trace_module, "extract_retrieval_documents", lambda _: [{"doc": "d"}])
trace_instance.dataset_retrieval_trace(_make_dataset_retrieval_trace_info(inputs="query"))
assert len(trace_instance.trace_client.added_spans) == 1
span = trace_instance.trace_client.added_spans[0]
spans = _recorded_span_data(trace_instance)
assert len(spans) == 1
span = spans[0]
assert span.name == "dataset_retrieval"
assert span.attributes[RETRIEVAL_QUERY] == "query"
assert span.attributes[RETRIEVAL_DOCUMENT] == '[{"doc": "d"}]'
@@ -348,7 +376,7 @@ def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, m
def test_tool_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
trace_info = _make_tool_trace_info(message_data=None)
trace_instance.tool_trace(trace_info)
assert trace_instance.trace_client.added_spans == []
assert _recording_trace_client(trace_instance).added_spans == []
def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
@@ -371,8 +399,9 @@ def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: p
)
)
assert len(trace_instance.trace_client.added_spans) == 1
span = trace_instance.trace_client.added_spans[0]
spans = _recorded_span_data(trace_instance)
assert len(spans) == 1
span = spans[0]
assert span.name == "my-tool"
assert span.status == status
assert span.attributes[TOOL_NAME] == "my-tool"
@@ -409,7 +438,7 @@ def test_get_workflow_node_executions_builds_repo_and_fetches(
def test_build_workflow_node_span_routes_llm_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_llm_span", MagicMock(return_value="llm"))
@@ -422,7 +451,7 @@ def test_build_workflow_node_span_routes_knowledge_retrieval_type(
):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_retrieval_span", MagicMock(return_value="retrieval"))
@@ -433,7 +462,7 @@ def test_build_workflow_node_span_routes_knowledge_retrieval_type(
def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_tool_span", MagicMock(return_value="tool"))
@@ -444,7 +473,7 @@ def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTra
def test_build_workflow_node_span_routes_code_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(return_value="task"))
@@ -457,7 +486,7 @@ def test_build_workflow_node_span_handles_errors(
):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(side_effect=RuntimeError("boom")))
node_execution.node_type = BuiltinNodeTypes.CODE
@@ -472,7 +501,7 @@ def test_build_workflow_task_span(trace_instance: AliyunDataTrace, monkeypatch:
status = Status(StatusCode.OK)
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
trace_metadata = _make_trace_metadata()
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node-id"
node_execution.title = "title"
@@ -494,7 +523,7 @@ def test_build_workflow_tool_span(trace_instance: AliyunDataTrace, monkeypatch:
status = Status(StatusCode.OK)
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[_make_link()])
trace_metadata = _make_trace_metadata(links=[_make_link()])
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node-id"
node_execution.title = "my-tool"
@@ -527,7 +556,7 @@ def test_build_workflow_retrieval_span(trace_instance: AliyunDataTrace, monkeypa
aliyun_trace_module, "format_retrieval_documents", lambda docs: [{"formatted": True}] if docs else []
)
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
trace_metadata = _make_trace_metadata()
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node-id"
node_execution.title = "retrieval"
@@ -556,7 +585,7 @@ def test_build_workflow_llm_span(trace_instance: AliyunDataTrace, monkeypatch: p
monkeypatch.setattr(aliyun_trace_module, "format_input_messages", lambda _: "in")
monkeypatch.setattr(aliyun_trace_module, "format_output_messages", lambda _: "out")
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
trace_metadata = _make_trace_metadata()
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node-id"
node_execution.title = "llm"
@@ -594,7 +623,7 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
status = Status(StatusCode.OK)
monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status)
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
trace_metadata = _make_trace_metadata()
# CASE 1: With message_id
trace_info = _make_workflow_trace_info(
@@ -602,9 +631,11 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
)
trace_instance.add_workflow_span(trace_info, trace_metadata)
assert len(trace_instance.trace_client.added_spans) == 2
message_span = trace_instance.trace_client.added_spans[0]
workflow_span = trace_instance.trace_client.added_spans[1]
client = _recording_trace_client(trace_instance)
spans = _recorded_span_data(trace_instance)
assert len(spans) == 2
message_span = spans[0]
workflow_span = spans[1]
assert message_span.name == "message"
assert message_span.span_kind == SpanKind.SERVER
@@ -614,13 +645,14 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
assert workflow_span.span_kind == SpanKind.INTERNAL
assert workflow_span.parent_span_id == 20
trace_instance.trace_client.added_spans.clear()
client.added_spans.clear()
# CASE 2: Without message_id
trace_info_no_msg = _make_workflow_trace_info(message_id=None)
trace_instance.add_workflow_span(trace_info_no_msg, trace_metadata)
assert len(trace_instance.trace_client.added_spans) == 1
span = trace_instance.trace_client.added_spans[0]
spans = _recorded_span_data(trace_instance)
assert len(spans) == 1
span = spans[0]
assert span.name == "workflow"
assert span.span_kind == SpanKind.SERVER
assert span.parent_span_id is None
@@ -641,7 +673,8 @@ def test_suggested_question_trace(trace_instance: AliyunDataTrace, monkeypatch:
trace_info = _make_suggested_question_trace_info(suggested_question=["how?"])
trace_instance.suggested_question_trace(trace_info)
assert len(trace_instance.trace_client.added_spans) == 1
span = trace_instance.trace_client.added_spans[0]
spans = _recorded_span_data(trace_instance)
assert len(spans) == 1
span = spans[0]
assert span.name == "suggested_question"
assert span.attributes[GEN_AI_COMPLETION] == '["how?"]'

View File

@@ -1,4 +1,6 @@
import json
from collections.abc import Mapping
from typing import Any, cast
from unittest.mock import MagicMock
from dify_trace_aliyun.entities.semconv import (
@@ -170,7 +172,7 @@ def test_create_common_span_attributes():
def test_format_retrieval_documents():
# Not a list
assert format_retrieval_documents("not a list") == []
assert format_retrieval_documents(cast(list[object], "not a list")) == []
# Valid list
docs = [
@@ -211,7 +213,7 @@ def test_format_retrieval_documents():
def test_format_input_messages():
# Not a dict
assert format_input_messages(None) == serialize_json_data([])
assert format_input_messages(cast(Mapping[str, Any], None)) == serialize_json_data([])
# No prompts
assert format_input_messages({}) == serialize_json_data([])
@@ -244,7 +246,7 @@ def test_format_input_messages():
def test_format_output_messages():
# Not a dict
assert format_output_messages(None) == serialize_json_data([])
assert format_output_messages(cast(Mapping[str, Any], None)) == serialize_json_data([])
# No text
assert format_output_messages({"finish_reason": "stop"}) == serialize_json_data([])

View File

@@ -25,13 +25,13 @@ class TestAliyunConfig:
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
AliyunConfig()
AliyunConfig.model_validate({})
with pytest.raises(ValidationError):
AliyunConfig(license_key="test_license")
AliyunConfig.model_validate({"license_key": "test_license"})
with pytest.raises(ValidationError):
AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
AliyunConfig.model_validate({"endpoint": "https://tracing-analysis-dc-hz.aliyuncs.com"})
def test_app_name_validation_empty(self):
"""Test app_name validation with empty value"""

View File

@@ -1,4 +1,5 @@
from datetime import UTC, datetime, timedelta
from typing import cast
from unittest.mock import MagicMock, patch
import pytest
@@ -129,7 +130,7 @@ def test_set_span_status():
return "SilentErrorRepr"
span.reset_mock()
set_span_status(span, SilentError())
set_span_status(span, cast(Exception | str | None, SilentError()))
assert span.add_event.call_args[1]["attributes"][OTELSpanAttributes.EXCEPTION_MESSAGE] == "SilentErrorRepr"

View File

@@ -28,13 +28,13 @@ class TestLangfuseConfig:
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
LangfuseConfig()
LangfuseConfig.model_validate({})
with pytest.raises(ValidationError):
LangfuseConfig(public_key="public")
LangfuseConfig.model_validate({"public_key": "public"})
with pytest.raises(ValidationError):
LangfuseConfig(secret_key="secret")
LangfuseConfig.model_validate({"secret_key": "secret"})
def test_host_validation_empty(self):
"""Test host validation with empty value"""

View File

@@ -2,6 +2,7 @@
from datetime import datetime, timedelta
from types import SimpleNamespace
from typing import cast
from unittest.mock import MagicMock, patch
from dify_trace_langfuse.config import LangfuseConfig
@@ -134,4 +135,4 @@ class TestLangFuseDataTraceCompletionStartTime:
assert trace._get_completion_start_time(start_time, None) is None
assert trace._get_completion_start_time(start_time, -1) is None
assert trace._get_completion_start_time(start_time, "invalid") is None
assert trace._get_completion_start_time(start_time, cast(float | int | None, "invalid")) is None

View File

@@ -21,13 +21,13 @@ class TestLangSmithConfig:
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
LangSmithConfig()
LangSmithConfig.model_validate({})
with pytest.raises(ValidationError):
LangSmithConfig(api_key="key")
LangSmithConfig.model_validate({"api_key": "key"})
with pytest.raises(ValidationError):
LangSmithConfig(project="project")
LangSmithConfig.model_validate({"project": "project"})
def test_endpoint_validation_https_only(self):
"""Test endpoint validation only allows HTTPS"""

View File

@@ -599,7 +599,6 @@ class TestMessageTrace:
span = MagicMock()
mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token"
mock_db.session.query.return_value.where.return_value.first.return_value = None
trace_instance.message_trace(_make_message_trace_info())
mock_tracing["start"].assert_called_once()
@@ -609,7 +608,6 @@ class TestMessageTrace:
span = MagicMock()
mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token"
mock_db.session.query.return_value.where.return_value.first.return_value = None
trace_info = _make_message_trace_info(error="something broke")
trace_instance.message_trace(trace_info)
@@ -620,7 +618,6 @@ class TestMessageTrace:
span = MagicMock()
mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token"
mock_db.session.query.return_value.where.return_value.first.return_value = None
monkeypatch.setenv("FILES_URL", "http://files.test")
file_data = SimpleNamespace(url="path/to/file.png")
@@ -638,7 +635,6 @@ class TestMessageTrace:
span = MagicMock()
mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token"
mock_db.session.query.return_value.where.return_value.first.return_value = None
trace_info = _make_message_trace_info(file_list=None, message_file_data=None)
trace_instance.message_trace(trace_info)
@@ -651,7 +647,6 @@ class TestMessageTrace:
end_user = MagicMock()
end_user.session_id = "session-xyz"
mock_db.session.query.return_value.where.return_value.first.return_value = end_user
trace_info = _make_message_trace_info(
metadata={"from_end_user_id": "eu-1", "conversation_id": "c1"},
@@ -664,7 +659,6 @@ class TestMessageTrace:
span = MagicMock()
mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token"
mock_db.session.query.return_value.where.return_value.first.return_value = None
trace_info = _make_message_trace_info(
metadata={"from_account_id": "acc-1"},

View File

@@ -12,6 +12,7 @@ from __future__ import annotations
import uuid
from datetime import datetime
from typing import cast
from unittest.mock import MagicMock, patch
from dify_trace_opik.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid
@@ -69,6 +70,14 @@ def _make_opik_trace_instance() -> OpikDataTrace:
return instance
def _add_trace_mock(instance: OpikDataTrace) -> MagicMock:
return cast(MagicMock, instance.add_trace)
def _add_span_mock(instance: OpikDataTrace) -> MagicMock:
return cast(MagicMock, instance.add_span)
# ---------------------------------------------------------------------------
# _seed_to_uuid4
# ---------------------------------------------------------------------------
@@ -155,21 +164,21 @@ class TestWorkflowTraceWithoutMessageId:
def test_root_span_is_created(self):
trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info)
assert instance.add_span.called
assert _add_span_mock(instance).called
def test_root_span_id_matches_expected(self):
trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info)
expected = self._expected_root_span_id(trace_info)
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
assert root_span_kwargs["id"] == expected
def test_root_span_has_no_parent(self):
trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info)
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
assert root_span_kwargs["parent_span_id"] is None
def test_trace_name_is_workflow_trace(self):
@@ -177,21 +186,21 @@ class TestWorkflowTraceWithoutMessageId:
trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info)
trace_kwargs = instance.add_trace.call_args_list[0][0][0]
trace_kwargs = _add_trace_mock(instance).call_args_list[0][0][0]
assert trace_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE
def test_root_span_name_is_workflow_trace(self):
trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info)
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
assert root_span_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE
def test_root_span_has_workflow_tag(self):
trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info)
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
assert "workflow" in root_span_kwargs["tags"]
def test_node_execution_spans_are_parented_to_root(self):
@@ -214,8 +223,9 @@ class TestWorkflowTraceWithoutMessageId:
instance = self._run(trace_info, node_executions=[node_exec])
# call_args_list[0] = root span, [1] = node execution span
assert instance.add_span.call_count == 2
node_span_kwargs = instance.add_span.call_args_list[1][0][0]
add_span = _add_span_mock(instance)
assert add_span.call_count == 2
node_span_kwargs = add_span.call_args_list[1][0][0]
assert node_span_kwargs["parent_span_id"] == expected_root_span_id
def test_node_span_not_parented_to_workflow_app_log_id(self):
@@ -240,7 +250,7 @@ class TestWorkflowTraceWithoutMessageId:
instance = self._run(trace_info, node_executions=[node_exec])
old_parent_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_app_log_id)
node_span_kwargs = instance.add_span.call_args_list[1][0][0]
node_span_kwargs = _add_span_mock(instance).call_args_list[1][0][0]
assert node_span_kwargs["parent_span_id"] != old_parent_id
def test_root_span_id_differs_from_trace_id(self):
@@ -283,7 +293,7 @@ class TestWorkflowTraceWithMessageId:
trace_info = _make_workflow_trace_info(message_id=self._MESSAGE_ID)
instance = self._run(trace_info)
trace_kwargs = instance.add_trace.call_args_list[0][0][0]
trace_kwargs = _add_trace_mock(instance).call_args_list[0][0][0]
assert trace_kwargs["name"] == TraceTaskName.MESSAGE_TRACE
def test_root_span_uses_workflow_run_id_directly(self):
@@ -292,7 +302,7 @@ class TestWorkflowTraceWithMessageId:
instance = self._run(trace_info)
expected_root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id)
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
assert root_span_kwargs["id"] == expected_root_span_id
def test_root_span_id_differs_from_no_message_id_case(self):
@@ -326,5 +336,5 @@ class TestWorkflowTraceWithMessageId:
instance = self._run(trace_info, node_executions=[node_exec])
node_span_kwargs = instance.add_span.call_args_list[1][0][0]
node_span_kwargs = _add_span_mock(instance).call_args_list[1][0][0]
assert node_span_kwargs["parent_span_id"] == expected_root_span_id

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import sys
import types
from types import SimpleNamespace
from typing import Any, TypedDict, cast
from unittest.mock import MagicMock
import pytest
@@ -12,7 +13,7 @@ from dify_trace_tencent import client as client_module
from dify_trace_tencent.client import TencentTraceClient, _get_opentelemetry_sdk_version
from dify_trace_tencent.entities.tencent_trace_entity import SpanData
from opentelemetry.sdk.trace import Event
from opentelemetry.trace import Status, StatusCode
from opentelemetry.trace import SpanContext, Status, StatusCode, TraceFlags
metric_reader_instances: list[DummyMetricReader] = []
meter_provider_instances: list[DummyMeterProvider] = []
@@ -80,6 +81,16 @@ class DummyJsonMetricExporterNoTemporality:
self.kwargs = kwargs
class PatchedCoreComponents(TypedDict):
span_exporter: MagicMock
span_processor: MagicMock
tracer: MagicMock
span: MagicMock
tracer_provider: MagicMock
logger: MagicMock
trace_api: Any
def _add_stub_modules(monkeypatch: pytest.MonkeyPatch) -> None:
"""Drop fake metric modules into sys.modules so the client imports resolve."""
@@ -118,7 +129,7 @@ def stub_metric_modules(monkeypatch: pytest.MonkeyPatch) -> None:
@pytest.fixture(autouse=True)
def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]:
def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> PatchedCoreComponents:
span_exporter = MagicMock(name="span_exporter")
monkeypatch.setattr(client_module, "OTLPSpanExporter", MagicMock(return_value=span_exporter))
@@ -168,6 +179,15 @@ def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]:
}
def _make_span_context(trace_id: int = 1, span_id: int = 2) -> SpanContext:
return SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
)
def _build_client() -> TencentTraceClient:
return TencentTraceClient(
service_name="service",
@@ -208,7 +228,7 @@ def test_resolve_grpc_target_parsable_variants(endpoint: str, expected: tuple[st
def test_resolve_grpc_target_handles_errors() -> None:
assert TencentTraceClient._resolve_grpc_target(123) == ("localhost:4317", True, "localhost", 4317)
assert TencentTraceClient._resolve_grpc_target(cast(str, 123)) == ("localhost:4317", True, "localhost", 4317)
@pytest.mark.parametrize(
@@ -248,7 +268,7 @@ def test_record_methods_skip_when_histogram_missing() -> None:
client.record_trace_duration(0.5)
def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str, object]) -> None:
def test_record_llm_duration_handles_exceptions(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
client.hist_llm_duration = MagicMock(name="hist_llm_duration")
client.hist_llm_duration.record.side_effect = RuntimeError("boom")
@@ -258,10 +278,11 @@ def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str,
logger.debug.assert_called()
def test_create_and_export_span_sets_attributes(patch_core_components: dict[str, object]) -> None:
def test_create_and_export_span_sets_attributes(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
span = patch_core_components["span"]
span.get_span_context.return_value = "ctx"
ctx = _make_span_context(span_id=2)
span.get_span_context.return_value = ctx
data = SpanData(
trace_id=1,
@@ -280,14 +301,15 @@ def test_create_and_export_span_sets_attributes(patch_core_components: dict[str,
span.add_event.assert_called_once()
span.set_status.assert_called_once()
span.end.assert_called_once_with(end_time=20)
assert client.span_contexts[2] == "ctx"
assert client.span_contexts[2] == ctx
def test_create_and_export_span_uses_parent_context(patch_core_components: dict[str, object]) -> None:
def test_create_and_export_span_uses_parent_context(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
client.span_contexts[10] = "existing"
existing_context = _make_span_context(span_id=10)
client.span_contexts[10] = existing_context
span = patch_core_components["span"]
span.get_span_context.return_value = "child"
span.get_span_context.return_value = _make_span_context(span_id=11)
data = SpanData(
trace_id=1,
@@ -302,14 +324,14 @@ def test_create_and_export_span_uses_parent_context(patch_core_components: dict[
client._create_and_export_span(data)
trace_api = patch_core_components["trace_api"]
trace_api.NonRecordingSpan.assert_called_once_with("existing")
trace_api.NonRecordingSpan.assert_called_once_with(existing_context)
trace_api.set_span_in_context.assert_called_once()
def test_create_and_export_span_exception_logs_error(patch_core_components: dict[str, object]) -> None:
def test_create_and_export_span_exception_logs_error(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
span = patch_core_components["span"]
span.get_span_context.return_value = "ctx"
span.get_span_context.return_value = _make_span_context(span_id=2)
client.tracer.start_span.side_effect = RuntimeError("boom")
client._create_and_export_span(
@@ -385,7 +407,7 @@ def test_get_project_url() -> None:
assert client.get_project_url() == "https://console.cloud.tencent.com/apm"
def test_shutdown_flushes_all_components(patch_core_components: dict[str, object]) -> None:
def test_shutdown_flushes_all_components(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
span_processor = patch_core_components["span_processor"]
tracer_provider = patch_core_components["tracer_provider"]
@@ -401,10 +423,11 @@ def test_shutdown_flushes_all_components(patch_core_components: dict[str, object
metric_reader.shutdown.assert_called_once()
def test_shutdown_logs_when_meter_provider_fails(patch_core_components: dict[str, object]) -> None:
def test_shutdown_logs_when_meter_provider_fails(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
meter_provider = meter_provider_instances[-1]
meter_provider.shutdown.side_effect = RuntimeError("boom")
assert client.metric_reader is not None
client.metric_reader.shutdown.side_effect = RuntimeError("boom")
client.shutdown()
@@ -433,7 +456,7 @@ def test_metrics_initialization_failure_sets_histogram_attributes(monkeypatch: p
assert client.metric_reader is None
def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: dict[str, object]) -> None:
def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
monkeypatch.setattr(client, "_create_and_export_span", MagicMock(side_effect=RuntimeError("boom")))
@@ -454,10 +477,10 @@ def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_com
logger.exception.assert_called_once()
def test_create_and_export_span_converts_attribute_types(patch_core_components: dict[str, object]) -> None:
def test_create_and_export_span_converts_attribute_types(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
span = patch_core_components["span"]
span.get_span_context.return_value = "ctx"
span.get_span_context.return_value = _make_span_context(span_id=2)
data = SpanData.model_construct(
trace_id=1,
@@ -485,7 +508,7 @@ def test_record_llm_duration_converts_attributes() -> None:
hist_mock = MagicMock(name="hist_llm_duration")
client.hist_llm_duration = hist_mock
client.record_llm_duration(0.3, {"foo": object(), "bar": 2})
client.record_llm_duration(0.3, cast(dict[str, str], {"foo": object(), "bar": 2}))
_, attrs = hist_mock.record.call_args.args
assert isinstance(attrs["foo"], str)
assert attrs["bar"] == 2
@@ -496,7 +519,7 @@ def test_record_trace_duration_converts_attributes() -> None:
hist_mock = MagicMock(name="hist_trace_duration")
client.hist_trace_duration = hist_mock
client.record_trace_duration(1.0, {"meta": object(), "ok": True})
client.record_trace_duration(1.0, cast(dict[str, str], {"meta": object(), "ok": True}))
_, attrs = hist_mock.record.call_args.args
assert isinstance(attrs["meta"], str)
assert attrs["ok"] is True
@@ -512,7 +535,7 @@ def test_record_trace_duration_converts_attributes() -> None:
],
)
def test_record_methods_handle_exceptions(
method: str, attr_name: str, args: tuple[object, ...], patch_core_components: dict[str, object]
method: str, attr_name: str, args: tuple[object, ...], patch_core_components: PatchedCoreComponents
) -> None:
client = _build_client()
hist_mock = MagicMock(name=attr_name)
@@ -527,35 +550,38 @@ def test_record_methods_handle_exceptions(
def test_metrics_initializes_grpc_metric_exporter() -> None:
client = _build_client()
metric_reader = metric_reader_instances[-1]
exporter = cast(DummyGrpcMetricExporter, metric_reader.exporter)
assert isinstance(metric_reader.exporter, DummyGrpcMetricExporter)
assert isinstance(exporter, DummyGrpcMetricExporter)
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
assert metric_reader.exporter.kwargs["endpoint"] == "trace.example.com:4317"
assert metric_reader.exporter.kwargs["insecure"] is False
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
assert exporter.kwargs["endpoint"] == "trace.example.com:4317"
assert exporter.kwargs["insecure"] is False
assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token"
def test_metrics_initializes_http_protobuf_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/protobuf")
client = _build_client()
metric_reader = metric_reader_instances[-1]
exporter = cast(DummyHttpMetricExporter, metric_reader.exporter)
assert isinstance(metric_reader.exporter, DummyHttpMetricExporter)
assert isinstance(exporter, DummyHttpMetricExporter)
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
assert exporter.kwargs["endpoint"] == client.endpoint
assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token"
def test_metrics_initializes_http_json_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json")
client = _build_client()
metric_reader = metric_reader_instances[-1]
exporter = cast(DummyJsonMetricExporter, metric_reader.exporter)
assert isinstance(metric_reader.exporter, DummyJsonMetricExporter)
assert isinstance(exporter, DummyJsonMetricExporter)
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
assert "preferred_temporality" in metric_reader.exporter.kwargs
assert exporter.kwargs["endpoint"] == client.endpoint
assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token"
assert "preferred_temporality" in exporter.kwargs
def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkeypatch: pytest.MonkeyPatch) -> None:
@@ -564,9 +590,10 @@ def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkey
monkeypatch.setattr(exporter_module, "OTLPMetricExporter", DummyJsonMetricExporterNoTemporality)
_ = _build_client()
metric_reader = metric_reader_instances[-1]
exporter = cast(DummyJsonMetricExporterNoTemporality, metric_reader.exporter)
assert isinstance(metric_reader.exporter, DummyJsonMetricExporterNoTemporality)
assert "preferred_temporality" not in metric_reader.exporter.kwargs
assert isinstance(exporter, DummyJsonMetricExporterNoTemporality)
assert "preferred_temporality" not in exporter.kwargs
def test_metrics_http_json_uses_http_fallback_when_no_json_exporter(monkeypatch: pytest.MonkeyPatch) -> None:

View File

@@ -31,13 +31,13 @@ class TestWeaveConfig:
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
WeaveConfig()
WeaveConfig.model_validate({})
with pytest.raises(ValidationError):
WeaveConfig(api_key="key")
WeaveConfig.model_validate({"api_key": "key"})
with pytest.raises(ValidationError):
WeaveConfig(project="project")
WeaveConfig.model_validate({"project": "project"})
def test_endpoint_validation_https_only(self):
"""Test endpoint validation only allows HTTPS"""

View File

@@ -59,7 +59,7 @@ class CouchbaseVector(BaseVector):
auth = PasswordAuthenticator(config.user, config.password)
options = ClusterOptions(auth)
self._cluster = Cluster(config.connection_string, options)
self._cluster = Cluster(config.connection_string, options) # pyright: ignore[reportArgumentType]
self._bucket = self._cluster.bucket(config.bucket_name)
self._scope = self._bucket.scope(config.scope_name)
self._bucket_name = config.bucket_name
@@ -306,7 +306,7 @@ class CouchbaseVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
try:
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) # pyright: ignore[reportCallIssue]
search_iter = self._scope.search(
self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"])
)

View File

@@ -1,6 +1,6 @@
import json
import logging
from typing import Any, TypedDict
from typing import Any, TypedDict, cast
from packaging import version
from pydantic import BaseModel, model_validator
@@ -92,7 +92,7 @@ class MilvusVector(BaseVector):
def _load_collection_fields(self, fields: list[str] | None = None):
if fields is None:
# Load collection fields from remote server
collection_info = self._client.describe_collection(self._collection_name)
collection_info = cast(dict[str, Any], self._client.describe_collection(self._collection_name))
fields = [field["name"] for field in collection_info["fields"]]
# Since primary field is auto-id, no need to track it
self._fields = [f for f in fields if f != Field.PRIMARY_KEY]
@@ -106,7 +106,8 @@ class MilvusVector(BaseVector):
return False
try:
milvus_version = self._client.get_server_version()
milvus_version_raw = self._client.get_server_version()
milvus_version = milvus_version_raw if isinstance(milvus_version_raw, str) else str(milvus_version_raw)
# Check if it's Zilliz Cloud - it supports full-text search with Milvus 2.5 compatibility
if "Zilliz Cloud" in milvus_version:
return True

View File

@@ -3,7 +3,7 @@ import json
import logging
import re
import uuid
from typing import Any
from typing import Any, TypedDict
import jieba.posseg as pseg # type: ignore
import numpy
@@ -25,6 +25,18 @@ logger = logging.getLogger(__name__)
oracledb.defaults.fetch_lobs = False
class _OraclePoolParams(TypedDict, total=False):
user: str
password: str
dsn: str
min: int
max: int
increment: int
config_dir: str | None
wallet_location: str | None
wallet_password: str | None
class OracleVectorConfig(BaseModel):
user: str
password: str
@@ -127,22 +139,18 @@ class OracleVector(BaseVector):
return connection
def _create_connection_pool(self, config: OracleVectorConfig):
pool_params = {
"user": config.user,
"password": config.password,
"dsn": config.dsn,
"min": 1,
"max": 5,
"increment": 1,
}
pool_params = _OraclePoolParams(
user=config.user,
password=config.password,
dsn=config.dsn,
min=1,
max=5,
increment=1,
)
if config.is_autonomous:
pool_params.update(
{
"config_dir": config.config_dir,
"wallet_location": config.wallet_location,
"wallet_password": config.wallet_password,
}
)
pool_params["config_dir"] = config.config_dir
pool_params["wallet_location"] = config.wallet_location
pool_params["wallet_password"] = config.wallet_password
return oracledb.create_pool(**pool_params)
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):

View File

@@ -9,6 +9,7 @@ dependencies = [
"boto3>=1.42.91",
"celery>=5.6.3",
"croniter>=6.2.2",
"flask>=3.1.3,<4.0.0",
"flask-cors>=6.0.2",
"gevent>=26.4.0",
"gevent-websocket>=0.10.1",
@@ -117,7 +118,7 @@ dev = [
"faker>=40.15.0",
"lxml-stubs>=0.5.1",
"basedpyright>=1.39.3",
"ruff>=0.15.11",
"ruff>=0.15.12",
"pytest>=9.0.3",
"pytest-benchmark>=5.2.3",
"pytest-cov>=7.1.0",
@@ -144,7 +145,7 @@ dev = [
"types-pexpect>=4.9.0",
"types-protobuf>=7.34.1",
"types-psutil>=7.2.2",
"types-psycopg2>=2.9.21",
"types-psycopg2>=2.9.21.20260422",
"types-pygments>=2.20.0",
"types-pymysql>=1.1.0",
"types-python-dateutil>=2.9.0",
@@ -157,9 +158,9 @@ dev = [
"types-tensorflow>=2.18.0.20260408",
"types-tqdm>=4.67.3.20260408",
"types-ujson>=5.10.0",
"boto3-stubs>=1.42.92",
"boto3-stubs>=1.42.96",
"types-jmespath>=1.1.0.20260408",
"hypothesis>=6.152.1",
"hypothesis>=6.152.3",
"types_pyOpenSSL>=24.1.0",
"types_cffi>=2.0.0.20260408",
"types_setuptools>=82.0.0.20260408",
@@ -169,7 +170,7 @@ dev = [
"import-linter>=2.3",
"types-redis>=4.6.0.20241004",
"celery-types>=0.23.0",
"mypy>=1.20.1",
"mypy>=1.20.2",
# "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved.
"pytest-timeout>=2.4.0",
"pytest-xdist>=3.8.0",

View File

@@ -42,7 +42,7 @@ from libs.helper import convert_datetime_to_date
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.time_parser import get_time_threshold
from models.enums import WorkflowRunTriggeredFrom
from models.human_input import HumanInputForm
from models.human_input import HumanInputForm, HumanInputFormRecipient
from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
from repositories.api_workflow_run_repository import APIWorkflowRunRepository, RunsWithRelatedCountsDict
from repositories.entities.workflow_pause import WorkflowPauseEntity
@@ -63,6 +63,7 @@ class _WorkflowRunError(Exception):
def _build_human_input_required_reason(
reason_model: WorkflowPauseReason,
form_model: HumanInputForm | None,
recipients: Sequence[HumanInputFormRecipient] = (),
) -> HumanInputRequired:
form_content = ""
inputs = []
@@ -89,7 +90,7 @@ def _build_human_input_required_reason(
resolved_default_values = dict(definition.default_values)
node_title = definition.node_title or node_title
return HumanInputRequired(
reason = HumanInputRequired(
form_id=form_id,
form_content=form_content,
inputs=inputs,
@@ -98,6 +99,7 @@ def _build_human_input_required_reason(
node_title=node_title,
resolved_default_values=resolved_default_values,
)
return reason
class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
@@ -804,12 +806,23 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
form_stmt = select(HumanInputForm).where(HumanInputForm.id.in_(form_ids))
for form in session.scalars(form_stmt).all():
form_models[form.id] = form
recipients_by_form_id: dict[str, list[HumanInputFormRecipient]] = {}
if form_ids:
recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
for recipient in session.scalars(recipient_stmt).all():
recipients_by_form_id.setdefault(recipient.form_id, []).append(recipient)
pause_reasons: list[PauseReason] = []
for reason in pause_reason_models:
if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED:
form_model = form_models.get(reason.form_id)
pause_reasons.append(_build_human_input_required_reason(reason, form_model))
pause_reasons.append(
_build_human_input_required_reason(
reason,
form_model,
recipients_by_form_id.get(reason.form_id, ()),
)
)
else:
pause_reasons.append(reason.to_entity())
return pause_reasons

View File

@@ -133,7 +133,14 @@ class AppAnnotationService:
raise ValueError("'question' is required when 'message_id' is not provided")
question = maybe_question
annotation = MessageAnnotation(app_id=app.id, content=answer, question=question, account_id=current_user.id)
annotation = MessageAnnotation(
app_id=app.id,
conversation_id=None,
message_id=None,
content=answer,
question=question,
account_id=current_user.id,
)
db.session.add(annotation)
db.session.commit()

View File

@@ -18,12 +18,13 @@ from core.app.features.rate_limiting import RateLimit
from core.app.features.rate_limiting.rate_limit import rate_limit_context
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
from core.db import session_factory
from enums.quota_type import QuotaType, unlimited
from enums.quota_type import QuotaType
from extensions.otel import AppGenerateHandler, trace_span
from models.model import Account, App, AppMode, EndUser
from models.workflow import Workflow, WorkflowRun
from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
from services.quota_service import QuotaService, unlimited
from services.workflow_service import WorkflowService
from tasks.app_generate.workflow_execute_task import AppExecutionParams, workflow_based_app_execution_task
@@ -106,7 +107,7 @@ class AppGenerateService:
quota_charge = unlimited()
if dify_config.BILLING_ENABLED:
try:
quota_charge = QuotaType.WORKFLOW.consume(app_model.tenant_id)
quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, app_model.tenant_id)
except QuotaExceededError:
raise InvokeRateLimitError(f"Workflow execution quota limit reached for tenant {app_model.tenant_id}")
@@ -116,6 +117,7 @@ class AppGenerateService:
request_id = RateLimit.gen_request_key()
try:
request_id = rate_limit.enter(request_id)
quota_charge.commit()
effective_mode = (
AppMode.AGENT_CHAT if app_model.is_agent and app_model.mode != AppMode.AGENT_CHAT else app_model.mode
)
@@ -162,6 +164,7 @@ class AppGenerateService:
invoke_from=invoke_from,
streaming=True,
call_depth=0,
workflow_run_id=str(uuid.uuid4()),
)
payload_json = payload.model_dump_json()
@@ -183,6 +186,10 @@ class AppGenerateService:
else:
# Blocking mode: run synchronously and return JSON instead of SSE
# Keep behaviour consistent with WORKFLOW blocking branch.
pause_config = PauseStateLayerConfig(
session_factory=session_factory.get_session_maker(),
state_owner_user_id=workflow.created_by,
)
advanced_generator = AdvancedChatAppGenerator()
return rate_limit.generate(
advanced_generator.convert_to_event_stream(
@@ -194,6 +201,7 @@ class AppGenerateService:
invoke_from=invoke_from,
workflow_run_id=str(uuid.uuid4()),
streaming=False,
pause_state_config=pause_config,
)
),
request_id=request_id,

View File

@@ -22,6 +22,7 @@ from models.trigger import WorkflowTriggerLog, WorkflowTriggerLogDict
from models.workflow import Workflow
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
from services.quota_service import QuotaService, unlimited
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
from services.workflow_service import WorkflowService
@@ -88,7 +89,10 @@ class AsyncWorkflowService:
raise WorkflowNotFoundError(f"App not found: {trigger_data.app_id}")
# 2. Get workflow
workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id)
workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id, session=session)
# commit read only session before starting the billig rpc call
session.commit()
# 3. Get dispatcher based on tenant subscription
dispatcher = dispatcher_manager.get_dispatcher(trigger_data.tenant_id)
@@ -131,9 +135,10 @@ class AsyncWorkflowService:
trigger_log = trigger_log_repo.create(trigger_log)
session.commit()
# 7. Check and consume quota
# 7. Reserve quota (commit after successful dispatch)
quota_charge = unlimited()
try:
QuotaType.WORKFLOW.consume(trigger_data.tenant_id)
quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, trigger_data.tenant_id)
except QuotaExceededError as e:
# Update trigger log status
trigger_log.status = WorkflowTriggerStatus.RATE_LIMITED
@@ -153,13 +158,18 @@ class AsyncWorkflowService:
# 9. Dispatch to appropriate queue
task_data_dict = task_data.model_dump(mode="json")
task: AsyncResult[Any] | None = None
if queue_name == QueuePriority.PROFESSIONAL:
task = execute_workflow_professional.delay(task_data_dict)
elif queue_name == QueuePriority.TEAM:
task = execute_workflow_team.delay(task_data_dict)
else: # SANDBOX
task = execute_workflow_sandbox.delay(task_data_dict)
try:
task: AsyncResult[Any] | None = None
if queue_name == QueuePriority.PROFESSIONAL:
task = execute_workflow_professional.delay(task_data_dict)
elif queue_name == QueuePriority.TEAM:
task = execute_workflow_team.delay(task_data_dict)
else: # SANDBOX
task = execute_workflow_sandbox.delay(task_data_dict)
quota_charge.commit()
except Exception:
quota_charge.refund()
raise
# 10. Update trigger log with task info
trigger_log.status = WorkflowTriggerStatus.QUEUED
@@ -295,13 +305,21 @@ class AsyncWorkflowService:
return [log.to_dict() for log in logs]
@staticmethod
def _get_workflow(workflow_service: WorkflowService, app_model: App, workflow_id: str | None = None) -> Workflow:
def _get_workflow(
workflow_service: WorkflowService,
app_model: App,
workflow_id: str | None = None,
session: Session | None = None,
) -> Workflow:
"""
Get workflow for the app
Args:
app_model: App model instance
workflow_id: Optional specific workflow ID
session: Reuse this SQLAlchemy session for the lookup when provided,
so the caller's explicit session bears the connection cost
instead of Flask's request-scoped ``db.session``.
Returns:
Workflow instance
@@ -311,12 +329,12 @@ class AsyncWorkflowService:
"""
if workflow_id:
# Get specific published workflow
workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id)
workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id, session=session)
if not workflow:
raise WorkflowNotFoundError(f"Published workflow not found: {workflow_id}")
else:
# Get default published workflow
workflow = workflow_service.get_published_workflow(app_model)
workflow = workflow_service.get_published_workflow(app_model, session=session)
if not workflow:
raise WorkflowNotFoundError(f"No published workflow found for app: {app_model.id}")

View File

@@ -32,6 +32,50 @@ class SubscriptionPlan(TypedDict):
expiration_date: int
class QuotaReserveResult(TypedDict):
reservation_id: str
available: int
reserved: int
class QuotaCommitResult(TypedDict):
available: int
reserved: int
refunded: int
class QuotaReleaseResult(TypedDict):
available: int
reserved: int
released: int
_quota_reserve_adapter = TypeAdapter(QuotaReserveResult)
_quota_commit_adapter = TypeAdapter(QuotaCommitResult)
_quota_release_adapter = TypeAdapter(QuotaReleaseResult)
class _TenantFeatureQuota(TypedDict):
usage: int
limit: int
reset_date: NotRequired[int]
class TenantFeatureQuotaInfo(TypedDict):
"""Response of /quota/info.
NOTE (hj24):
- Same convention as BillingInfo: billing may return int fields as str,
always keep non-strict mode to auto-coerce.
"""
trigger_event: _TenantFeatureQuota
api_rate_limit: _TenantFeatureQuota
_tenant_feature_quota_info_adapter = TypeAdapter(TenantFeatureQuotaInfo)
class _BillingQuota(TypedDict):
size: int
limit: int
@@ -149,11 +193,63 @@ class BillingService:
@classmethod
def get_tenant_feature_plan_usage_info(cls, tenant_id: str):
"""Deprecated: Use get_quota_info instead."""
params = {"tenant_id": tenant_id}
usage_info = cls._send_request("GET", "/tenant-feature-usage/info", params=params)
return usage_info
@classmethod
def get_quota_info(cls, tenant_id: str) -> TenantFeatureQuotaInfo:
params = {"tenant_id": tenant_id}
return _tenant_feature_quota_info_adapter.validate_python(
cls._send_request("GET", "/quota/info", params=params)
)
@classmethod
def quota_reserve(
cls, tenant_id: str, feature_key: str, request_id: str, amount: int = 1, meta: dict | None = None
) -> QuotaReserveResult:
"""Reserve quota before task execution."""
payload: dict = {
"tenant_id": tenant_id,
"feature_key": feature_key,
"request_id": request_id,
"amount": amount,
}
if meta:
payload["meta"] = meta
return _quota_reserve_adapter.validate_python(cls._send_request("POST", "/quota/reserve", json=payload))
@classmethod
def quota_commit(
cls, tenant_id: str, feature_key: str, reservation_id: str, actual_amount: int, meta: dict | None = None
) -> QuotaCommitResult:
"""Commit a reservation with actual consumption."""
payload: dict = {
"tenant_id": tenant_id,
"feature_key": feature_key,
"reservation_id": reservation_id,
"actual_amount": actual_amount,
}
if meta:
payload["meta"] = meta
return _quota_commit_adapter.validate_python(cls._send_request("POST", "/quota/commit", json=payload))
@classmethod
def quota_release(cls, tenant_id: str, feature_key: str, reservation_id: str) -> QuotaReleaseResult:
"""Release a reservation (cancel, return frozen quota)."""
return _quota_release_adapter.validate_python(
cls._send_request(
"POST",
"/quota/release",
json={
"tenant_id": tenant_id,
"feature_key": feature_key,
"reservation_id": reservation_id,
},
)
)
@classmethod
def get_knowledge_rate_limit(cls, tenant_id: str) -> KnowledgeRateLimitDict:
params = {"tenant_id": tenant_id}

View File

@@ -3748,6 +3748,7 @@ class SegmentService:
ChildChunk.segment_id == segment.id,
)
)
assert current_user.current_tenant_id
child_chunk = ChildChunk(
tenant_id=current_user.current_tenant_id,
dataset_id=dataset.id,
@@ -3758,7 +3759,7 @@ class SegmentService:
index_node_hash=index_node_hash,
content=content,
word_count=len(content),
type="customized",
type=SegmentType.CUSTOMIZED,
created_by=current_user.id,
)
db.session.add(child_chunk)
@@ -3818,6 +3819,7 @@ class SegmentService:
if new_child_chunks_args:
child_chunk_count = len(child_chunks)
for position, args in enumerate(new_child_chunks_args, start=child_chunk_count + 1):
assert current_user.current_tenant_id
index_node_id = str(uuid.uuid4())
index_node_hash = helper.generate_text_hash(args.content)
child_chunk = ChildChunk(
@@ -3830,7 +3832,7 @@ class SegmentService:
index_node_hash=index_node_hash,
content=args.content,
word_count=len(args.content),
type="customized",
type=SegmentType.CUSTOMIZED,
created_by=current_user.id,
)

View File

@@ -5,6 +5,7 @@ import uuid
from datetime import datetime
from typing import TYPE_CHECKING
from cachetools.func import ttl_cache
from pydantic import BaseModel, ConfigDict, Field, model_validator
from configs import dify_config
@@ -99,6 +100,7 @@ def try_join_default_workspace(account_id: str) -> None:
class EnterpriseService:
@classmethod
@ttl_cache(ttl=5)
def get_info(cls):
return EnterpriseRequest.send_request("GET", "/info")

View File

@@ -177,6 +177,7 @@ class SystemFeatureModel(BaseModel):
enable_change_email: bool = True
plugin_manager: PluginManagerModel = PluginManagerModel()
trial_models: list[str] = []
enable_creators_platform: bool = False
enable_trial_app: bool = False
enable_explore_banner: bool = False
@@ -241,6 +242,9 @@ class FeatureService:
if dify_config.MARKETPLACE_ENABLED:
system_features.enable_marketplace = True
if dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
system_features.enable_creators_platform = True
return system_features
@classmethod
@@ -286,7 +290,7 @@ class FeatureService:
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
billing_info = BillingService.get_info(tenant_id)
features_usage_info = BillingService.get_tenant_feature_plan_usage_info(tenant_id)
features_usage_info = BillingService.get_quota_info(tenant_id)
features.billing.enabled = billing_info["enabled"]
features.billing.subscription.plan = billing_info["subscription"]["plan"]

View File

@@ -0,0 +1,233 @@
from __future__ import annotations
import logging
import uuid
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from configs import dify_config
if TYPE_CHECKING:
from enums.quota_type import QuotaType
logger = logging.getLogger(__name__)
@dataclass
class QuotaCharge:
"""
Result of a quota reservation (Reserve phase).
Lifecycle:
charge = QuotaService.consume(QuotaType.TRIGGER, tenant_id)
try:
do_work()
charge.commit() # Confirm consumption
except:
charge.refund() # Release frozen quota
If neither commit() nor refund() is called, the billing system's
cleanup CronJob will auto-release the reservation within ~75 seconds.
"""
success: bool
charge_id: str | None # reservation_id
_quota_type: QuotaType
_tenant_id: str | None = None
_feature_key: str | None = None
_amount: int = 0
_committed: bool = field(default=False, repr=False)
def commit(self, actual_amount: int | None = None) -> None:
"""
Confirm the consumption with actual amount.
Args:
actual_amount: Actual amount consumed. Defaults to the reserved amount.
If less than reserved, the difference is refunded automatically.
"""
if self._committed or not self.charge_id or not self._tenant_id or not self._feature_key:
return
try:
from services.billing_service import BillingService
amount = actual_amount if actual_amount is not None else self._amount
BillingService.quota_commit(
tenant_id=self._tenant_id,
feature_key=self._feature_key,
reservation_id=self.charge_id,
actual_amount=amount,
)
self._committed = True
logger.debug(
"Committed %s quota for tenant %s, reservation_id: %s, amount: %d",
self._quota_type,
self._tenant_id,
self.charge_id,
amount,
)
except Exception:
logger.exception("Failed to commit quota, reservation_id: %s", self.charge_id)
def refund(self) -> None:
"""
Release the reserved quota (cancel the charge).
Safe to call even if:
- charge failed or was disabled (charge_id is None)
- already committed (Release after Commit is a no-op)
- already refunded (idempotent)
This method guarantees no exceptions will be raised.
"""
if not self.charge_id or not self._tenant_id or not self._feature_key:
return
QuotaService.release(self._quota_type, self.charge_id, self._tenant_id, self._feature_key)
def unlimited() -> QuotaCharge:
from enums.quota_type import QuotaType
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)
class QuotaService:
"""Orchestrates quota reserve / commit / release lifecycle via BillingService."""
@staticmethod
def consume(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Reserve + immediate Commit (one-shot mode).
The returned QuotaCharge supports .refund() which calls Release.
For two-phase usage (e.g. streaming), use reserve() directly.
"""
charge = QuotaService.reserve(quota_type, tenant_id, amount)
if charge.success and charge.charge_id:
charge.commit()
return charge
@staticmethod
def reserve(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Reserve quota before task execution (Reserve phase only).
The caller MUST call charge.commit() after the task succeeds,
or charge.refund() if the task fails.
Raises:
QuotaExceededError: When quota is insufficient
"""
from services.billing_service import BillingService
from services.errors.app import QuotaExceededError
if not dify_config.BILLING_ENABLED:
logger.debug("Billing disabled, allowing request for %s", tenant_id)
return QuotaCharge(success=True, charge_id=None, _quota_type=quota_type)
logger.info("Reserving %d %s quota for tenant %s", amount, quota_type.value, tenant_id)
if amount <= 0:
raise ValueError("Amount to reserve must be greater than 0")
request_id = str(uuid.uuid4())
feature_key = quota_type.billing_key
try:
reserve_resp = BillingService.quota_reserve(
tenant_id=tenant_id,
feature_key=feature_key,
request_id=request_id,
amount=amount,
)
reservation_id = reserve_resp.get("reservation_id")
if not reservation_id:
logger.warning(
"Reserve returned no reservation_id for %s, feature %s, response: %s",
tenant_id,
quota_type.value,
reserve_resp,
)
raise QuotaExceededError(feature=quota_type.value, tenant_id=tenant_id, required=amount)
logger.debug(
"Reserved %d %s quota for tenant %s, reservation_id: %s",
amount,
quota_type.value,
tenant_id,
reservation_id,
)
return QuotaCharge(
success=True,
charge_id=reservation_id,
_quota_type=quota_type,
_tenant_id=tenant_id,
_feature_key=feature_key,
_amount=amount,
)
except QuotaExceededError:
raise
except ValueError:
raise
except Exception:
logger.exception("Failed to reserve quota for %s, feature %s", tenant_id, quota_type.value)
return unlimited()
@staticmethod
def check(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> bool:
if not dify_config.BILLING_ENABLED:
return True
if amount <= 0:
raise ValueError("Amount to check must be greater than 0")
try:
remaining = QuotaService.get_remaining(quota_type, tenant_id)
return remaining >= amount if remaining != -1 else True
except Exception:
logger.exception("Failed to check quota for %s, feature %s", tenant_id, quota_type.value)
return True
@staticmethod
def release(quota_type: QuotaType, reservation_id: str, tenant_id: str, feature_key: str) -> None:
"""Release a reservation. Guarantees no exceptions."""
try:
from services.billing_service import BillingService
if not dify_config.BILLING_ENABLED:
return
if not reservation_id:
return
logger.info("Releasing %s quota, reservation_id: %s", quota_type.value, reservation_id)
BillingService.quota_release(
tenant_id=tenant_id,
feature_key=feature_key,
reservation_id=reservation_id,
)
except Exception:
logger.exception("Failed to release quota, reservation_id: %s", reservation_id)
@staticmethod
def get_remaining(quota_type: QuotaType, tenant_id: str) -> int:
from services.billing_service import BillingService
try:
usage_info = BillingService.get_quota_info(tenant_id)
if isinstance(usage_info, dict):
feature_info = usage_info.get(quota_type.billing_key, {})
if isinstance(feature_info, dict):
limit = feature_info.get("limit", 0)
usage = feature_info.get("usage", 0)
if limit == -1:
return -1
return max(0, limit - usage)
return 0
except Exception:
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, quota_type.value)
return -1

View File

@@ -26,7 +26,7 @@ from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.encryption import create_provider_encrypter
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
from core.tools.utils.system_encryption import decrypt_system_params
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.provider_ids import ToolProviderID
@@ -521,7 +521,7 @@ class BuiltinToolManageService:
)
if system_client:
try:
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
oauth_params = decrypt_system_params(system_client.encrypted_oauth_params)
except Exception as e:
raise ValueError(f"Error decrypting system oauth params: {e}")

View File

@@ -14,7 +14,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache
from core.helper.provider_encryption import ProviderConfigEncrypter, create_provider_encrypter
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
from core.tools.utils.system_encryption import decrypt_system_params
from core.trigger.entities.api_entities import (
TriggerProviderApiEntity,
TriggerProviderSubscriptionApiEntity,
@@ -635,7 +635,7 @@ class TriggerProviderService:
if system_client:
try:
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
oauth_params = decrypt_system_params(system_client.encrypted_oauth_params)
except Exception as e:
raise ValueError(f"Error decrypting system oauth params: {e}")

View File

@@ -38,6 +38,7 @@ from models.workflow import Workflow
from services.async_workflow_service import AsyncWorkflowService
from services.end_user_service import EndUserService
from services.errors.app import QuotaExceededError
from services.quota_service import QuotaService
from services.trigger.app_trigger_service import AppTriggerService
from services.workflow.entities import WebhookTriggerData
@@ -798,45 +799,47 @@ class WebhookService:
Exception: If workflow execution fails
"""
try:
with Session(db.engine) as session:
# Prepare inputs for the webhook node
# The webhook node expects webhook_data in the inputs
workflow_inputs = cls.build_workflow_inputs(webhook_data)
workflow_inputs = cls.build_workflow_inputs(webhook_data)
# Create trigger data
trigger_data = WebhookTriggerData(
app_id=webhook_trigger.app_id,
workflow_id=workflow.id,
root_node_id=webhook_trigger.node_id, # Start from the webhook node
inputs=workflow_inputs,
tenant_id=webhook_trigger.tenant_id,
trigger_data = WebhookTriggerData(
app_id=webhook_trigger.app_id,
workflow_id=workflow.id,
root_node_id=webhook_trigger.node_id,
inputs=workflow_inputs,
tenant_id=webhook_trigger.tenant_id,
)
end_user = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.TRIGGER,
tenant_id=webhook_trigger.tenant_id,
app_id=webhook_trigger.app_id,
user_id=None,
)
try:
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, webhook_trigger.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id)
logger.info(
"Tenant %s rate limited, skipping webhook trigger %s",
webhook_trigger.tenant_id,
webhook_trigger.webhook_id,
)
raise
end_user = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.TRIGGER,
tenant_id=webhook_trigger.tenant_id,
app_id=webhook_trigger.app_id,
user_id=None,
)
# consume quota before triggering workflow execution
try:
QuotaType.TRIGGER.consume(webhook_trigger.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id)
logger.info(
"Tenant %s rate limited, skipping webhook trigger %s",
webhook_trigger.tenant_id,
webhook_trigger.webhook_id,
try:
# NOTE: don not use `with sessionmaker(bind=db.engine, expire_on_commit=False).begin()`
# trigger_workflow_async need to handle multipe session commits internally
with Session(db.engine, expire_on_commit=False) as session:
AsyncWorkflowService.trigger_workflow_async(
session,
end_user,
trigger_data,
)
raise
# Trigger workflow execution asynchronously
AsyncWorkflowService.trigger_workflow_async(
session,
end_user,
trigger_data,
)
quota_charge.commit()
except Exception:
quota_charge.refund()
raise
except Exception:
logger.exception("Failed to trigger workflow for webhook %s", webhook_trigger.webhook_id)

View File

@@ -16,6 +16,7 @@ from graphon.model_runtime.entities.model_entities import ModelType
from models import UploadFile
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument
from models.enums import SegmentType
logger = logging.getLogger(__name__)
@@ -178,7 +179,7 @@ class VectorService:
index_node_hash=child_chunk.metadata["doc_hash"],
content=child_chunk.page_content,
word_count=len(child_chunk.page_content),
type="automatic",
type=SegmentType.AUTOMATIC,
created_by=dataset_document.created_by,
)
db.session.add(child_segment)
@@ -222,6 +223,7 @@ class VectorService:
)
documents.append(new_child_document)
for update_child_chunk in update_child_chunks:
assert update_child_chunk.index_node_id
child_document = Document(
page_content=update_child_chunk.content,
metadata={
@@ -234,6 +236,7 @@ class VectorService:
documents.append(child_document)
delete_node_ids.append(update_child_chunk.index_node_id)
for delete_child_chunk in delete_child_chunks:
assert delete_child_chunk.index_node_id
delete_node_ids.append(delete_child_chunk.index_node_id)
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# update vector index
@@ -246,6 +249,7 @@ class VectorService:
@classmethod
def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset):
vector = Vector(dataset=dataset)
assert child_chunk.index_node_id
vector.delete_by_ids([child_chunk.index_node_id])
@classmethod

View File

@@ -14,6 +14,7 @@ from sqlalchemy.orm import Session, sessionmaker
from core.app.apps.message_generator import MessageGenerator
from core.app.entities.task_entities import (
HumanInputRequiredResponse,
MessageReplaceStreamResponse,
NodeFinishStreamResponse,
NodeStartStreamResponse,
@@ -22,10 +23,14 @@ from core.app.entities.task_entities import (
WorkflowStartStreamResponse,
)
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
from core.workflow.human_input_forms import load_form_tokens_by_form_id
from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons
from graphon.entities import WorkflowStartReason
from graphon.entities.pause_reason import PauseReasonType
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
from graphon.runtime import GraphRuntimeState
from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter
from models.human_input import HumanInputForm
from models.model import AppMode, Message
from models.workflow import WorkflowNodeExecutionTriggeredFrom, WorkflowRun
from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot
@@ -59,8 +64,10 @@ def build_workflow_event_stream(
tenant_id: str,
app_id: str,
session_maker: sessionmaker[Session],
human_input_surface: HumanInputSurface | None = None,
idle_timeout: float = 300,
ping_interval: float = 10.0,
close_on_pause: bool = True,
) -> Generator[Mapping[str, Any] | str, None, None]:
topic = MessageGenerator.get_response_topic(app_mode, workflow_run.id)
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@@ -115,13 +122,15 @@ def build_workflow_event_stream(
message_context=message_context,
pause_entity=pause_entity,
resumption_context=resumption_context,
session_maker=session_maker,
human_input_surface=human_input_surface,
)
for event in snapshot_events:
last_msg_time = time.time()
last_ping_time = last_msg_time
yield event
if _is_terminal_event(event, include_paused=True):
if _is_terminal_event(event, close_on_pause=close_on_pause):
return
while True:
@@ -146,7 +155,7 @@ def build_workflow_event_stream(
last_msg_time = time.time()
last_ping_time = last_msg_time
yield event
if _is_terminal_event(event, include_paused=True):
if _is_terminal_event(event, close_on_pause=close_on_pause):
return
finally:
buffer_state.stop_event.set()
@@ -207,6 +216,8 @@ def _build_snapshot_events(
message_context: MessageContext | None,
pause_entity: WorkflowPauseEntity | None,
resumption_context: WorkflowResumptionContext | None,
session_maker: sessionmaker[Session] | None = None,
human_input_surface: HumanInputSurface | None = None,
) -> list[Mapping[str, Any]]:
events: list[Mapping[str, Any]] = []
@@ -241,12 +252,24 @@ def _build_snapshot_events(
events.append(node_finished)
if workflow_run.status == WorkflowExecutionStatus.PAUSED and pause_entity is not None:
for human_input_event in _build_human_input_required_events(
workflow_run_id=workflow_run.id,
task_id=task_id,
pause_entity=pause_entity,
session_maker=session_maker,
human_input_surface=human_input_surface,
):
_apply_message_context(human_input_event, message_context)
events.append(human_input_event)
pause_event = _build_pause_event(
workflow_run=workflow_run,
workflow_run_id=workflow_run.id,
task_id=task_id,
pause_entity=pause_entity,
resumption_context=resumption_context,
session_maker=session_maker,
human_input_surface=human_input_surface,
)
if pause_event is not None:
_apply_message_context(pause_event, message_context)
@@ -314,6 +337,97 @@ def _build_node_started_event(
return response.to_ignore_detail_dict()
def _build_human_input_required_events(
*,
workflow_run_id: str,
task_id: str,
pause_entity: WorkflowPauseEntity,
session_maker: sessionmaker[Session] | None,
human_input_surface: HumanInputSurface | None,
) -> list[dict[str, Any]]:
reasons = [reason.model_dump(mode="json") for reason in pause_entity.get_pause_reasons()]
human_input_form_ids = [
form_id
for reason in reasons
if reason.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED
for form_id in [reason.get("form_id")]
if isinstance(form_id, str)
]
expiration_times_by_form_id: dict[str, int] = {}
display_in_ui_by_form_id: dict[str, bool] = {}
form_tokens_by_form_id: dict[str, str] = {}
if human_input_form_ids and session_maker is not None:
stmt = select(HumanInputForm.id, HumanInputForm.expiration_time, HumanInputForm.form_definition).where(
HumanInputForm.id.in_(human_input_form_ids)
)
with session_maker() as session:
for form_id, expiration_time, form_definition in session.execute(stmt):
expiration_times_by_form_id[str(form_id)] = int(expiration_time.timestamp())
try:
definition_payload = json.loads(form_definition) if form_definition else {}
except (TypeError, json.JSONDecodeError):
definition_payload = {}
display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui"))
form_tokens_by_form_id = load_form_tokens_by_form_id(
human_input_form_ids,
session=session,
surface=human_input_surface,
)
events: list[dict[str, Any]] = []
for reason in reasons:
if reason.get("TYPE") != PauseReasonType.HUMAN_INPUT_REQUIRED:
continue
form_id_raw = reason.get("form_id")
node_id_raw = reason.get("node_id")
node_title_raw = reason.get("node_title")
form_content_raw = reason.get("form_content")
if not isinstance(form_id_raw, str):
continue
if not isinstance(node_id_raw, str):
continue
if not isinstance(node_title_raw, str):
continue
if not isinstance(form_content_raw, str):
continue
form_id = form_id_raw
node_id = node_id_raw
node_title = node_title_raw
form_content = form_content_raw
inputs = reason.get("inputs")
actions = reason.get("actions")
resolved_default_values = reason.get("resolved_default_values")
expiration_time = expiration_times_by_form_id.get(form_id)
if expiration_time is None:
continue
response = HumanInputRequiredResponse(
task_id=task_id,
workflow_run_id=workflow_run_id,
data=HumanInputRequiredResponse.Data(
form_id=form_id,
node_id=node_id,
node_title=node_title,
form_content=form_content,
inputs=inputs if isinstance(inputs, list) else [],
actions=actions if isinstance(actions, list) else [],
display_in_ui=display_in_ui_by_form_id.get(form_id, False),
form_token=form_tokens_by_form_id.get(form_id),
resolved_default_values=(resolved_default_values if isinstance(resolved_default_values, dict) else {}),
expiration_time=expiration_time,
),
)
payload = response.model_dump(mode="json")
payload["event"] = response.event.value
events.append(payload)
return events
def _build_node_finished_event(
*,
workflow_run_id: str,
@@ -356,6 +470,8 @@ def _build_pause_event(
task_id: str,
pause_entity: WorkflowPauseEntity,
resumption_context: WorkflowResumptionContext | None,
session_maker: sessionmaker[Session] | None,
human_input_surface: HumanInputSurface | None = None,
) -> dict[str, Any] | None:
paused_nodes: list[str] = []
outputs: dict[str, Any] = {}
@@ -365,6 +481,36 @@ def _build_pause_event(
outputs = dict(WorkflowRuntimeTypeConverter().to_json_encodable(state.outputs or {}))
reasons = [reason.model_dump(mode="json") for reason in pause_entity.get_pause_reasons()]
human_input_form_ids = [
form_id
for reason in reasons
if reason.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED
for form_id in [reason.get("form_id")]
if isinstance(form_id, str)
]
form_tokens_by_form_id: dict[str, str] = {}
expiration_times_by_form_id: dict[str, int] = {}
if human_input_form_ids and session_maker is not None:
with session_maker() as session:
form_tokens_by_form_id = load_form_tokens_by_form_id(
human_input_form_ids,
session=session,
surface=human_input_surface,
)
stmt = select(HumanInputForm.id, HumanInputForm.expiration_time).where(
HumanInputForm.id.in_(human_input_form_ids)
)
for row in session.execute(stmt):
form_id, expiration_time, *_rest = row
expiration_times_by_form_id[str(form_id)] = int(expiration_time.timestamp())
# Reconnect paths must preserve the same pause-reason contract as live streams;
# otherwise clients see schema drift after resume.
reasons = enrich_human_input_pause_reasons(
reasons,
form_tokens_by_form_id=form_tokens_by_form_id,
expiration_times_by_form_id=expiration_times_by_form_id,
)
response = WorkflowPauseStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run_id,
@@ -449,12 +595,19 @@ def _parse_event_message(message: bytes) -> Mapping[str, Any] | None:
return event
def _is_terminal_event(event: Mapping[str, Any] | str, include_paused=False) -> bool:
def _is_terminal_event(
event: Mapping[str, Any] | str,
close_on_pause: bool = True,
*,
include_paused: bool | None = None,
) -> bool:
if include_paused is not None:
close_on_pause = include_paused
if not isinstance(event, Mapping):
return False
event_type = event.get("event")
if event_type == StreamEvent.WORKFLOW_FINISHED.value:
return True
if include_paused:
if close_on_pause:
return event_type == StreamEvent.WORKFLOW_PAUSED.value
return False

View File

@@ -156,11 +156,18 @@ class WorkflowService:
# return draft workflow
return workflow
def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Workflow | None:
def get_published_workflow_by_id(
self, app_model: App, workflow_id: str, session: Session | None = None
) -> Workflow | None:
"""
fetch published workflow by workflow_id
When ``session`` is provided, reuse it so callers that already hold a
Session avoid checking out an extra request-scoped ``db.session``
connection. Falls back to ``db.session`` for backward compatibility.
"""
workflow = db.session.scalar(
bind = session if session is not None else db.session
workflow = bind.scalar(
select(Workflow)
.where(
Workflow.tenant_id == app_model.tenant_id,
@@ -178,16 +185,20 @@ class WorkflowService:
)
return workflow
def get_published_workflow(self, app_model: App) -> Workflow | None:
def get_published_workflow(self, app_model: App, session: Session | None = None) -> Workflow | None:
"""
Get published workflow
When ``session`` is provided, reuse it so callers that already hold a
Session avoid checking out an extra request-scoped ``db.session``
connection. Falls back to ``db.session`` for backward compatibility.
"""
if not app_model.workflow_id:
return None
# fetch published workflow by workflow_id
workflow = db.session.scalar(
bind = session if session is not None else db.session
workflow = bind.scalar(
select(Workflow)
.where(
Workflow.tenant_id == app_model.tenant_id,

View File

@@ -399,6 +399,8 @@ def _resume_advanced_chat(
workflow_run_id: str,
workflow_run: WorkflowRun,
) -> None:
resumed_generate_entity = generate_entity.model_copy(update={"stream": True})
try:
triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from)
except ValueError:
@@ -426,7 +428,7 @@ def _resume_advanced_chat(
user=user,
conversation=conversation,
message=message,
application_generate_entity=generate_entity,
application_generate_entity=resumed_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
graph_runtime_state=graph_runtime_state,
@@ -436,9 +438,8 @@ def _resume_advanced_chat(
logger.exception("Failed to resume chatflow execution for workflow run %s", workflow_run_id)
raise
if generate_entity.stream:
assert isinstance(response, Generator)
_publish_streaming_response(response, workflow_run_id, AppMode.ADVANCED_CHAT)
assert isinstance(response, Generator)
_publish_streaming_response(response, workflow_run_id, AppMode.ADVANCED_CHAT)
def _resume_workflow(
@@ -455,6 +456,8 @@ def _resume_workflow(
workflow_run_repo,
pause_entity,
) -> None:
resumed_generate_entity = generate_entity.model_copy(update={"stream": True})
try:
triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from)
except ValueError:
@@ -480,7 +483,7 @@ def _resume_workflow(
app_model=app_model,
workflow=workflow,
user=user,
application_generate_entity=generate_entity,
application_generate_entity=resumed_generate_entity,
graph_runtime_state=graph_runtime_state,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
@@ -490,11 +493,18 @@ def _resume_workflow(
logger.exception("Failed to resume workflow execution for workflow run %s", workflow_run_id)
raise
if generate_entity.stream:
assert isinstance(response, Generator)
_publish_streaming_response(response, workflow_run_id, AppMode.WORKFLOW)
assert isinstance(response, Generator)
_publish_streaming_response(response, workflow_run_id, AppMode.WORKFLOW)
workflow_run_repo.delete_workflow_pause(pause_entity)
try:
workflow_run_repo.delete_workflow_pause(pause_entity)
except Exception as exc:
if exc.__class__.__name__ != "_WorkflowRunError" or "WorkflowPause not found" not in str(exc):
raise
logger.info(
"Skipped deleting workflow pause %s after resume because it was already replaced or removed",
pause_entity.id,
)
@shared_task(queue=WORKFLOW_BASED_APP_EXECUTION_QUEUE, name="resume_app_execution")

View File

@@ -27,7 +27,7 @@ from core.trigger.entities.entities import TriggerProviderEntity
from core.trigger.provider import PluginTriggerProviderController
from core.trigger.trigger_manager import TriggerManager
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
from enums.quota_type import QuotaType, unlimited
from enums.quota_type import QuotaType
from graphon.enums import WorkflowExecutionStatus
from models.enums import (
AppTriggerType,
@@ -42,6 +42,7 @@ from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom,
from services.async_workflow_service import AsyncWorkflowService
from services.end_user_service import EndUserService
from services.errors.app import QuotaExceededError
from services.quota_service import QuotaService, unlimited
from services.trigger.app_trigger_service import AppTriggerService
from services.trigger.trigger_provider_service import TriggerProviderService
from services.trigger.trigger_request_service import TriggerHttpRequestCachingService
@@ -258,59 +259,58 @@ def dispatch_triggered_workflow(
tenant_id=subscription.tenant_id, provider_id=TriggerProviderID(subscription.provider_id)
)
trigger_entity: TriggerProviderEntity = provider_controller.entity
# Ensure expire_on_commit is set to False to remain workflows available
with session_factory.create_session() as session:
workflows: Mapping[str, Workflow] = _get_latest_workflows_by_app_ids(session, subscribers)
end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch(
type=InvokeFrom.TRIGGER,
tenant_id=subscription.tenant_id,
app_ids=[plugin_trigger.app_id for plugin_trigger in subscribers],
user_id=user_id,
)
for plugin_trigger in subscribers:
# Get workflow from mapping
workflow: Workflow | None = workflows.get(plugin_trigger.app_id)
if not workflow:
logger.error(
"Workflow not found for app %s",
plugin_trigger.app_id,
)
continue
end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch(
type=InvokeFrom.TRIGGER,
tenant_id=subscription.tenant_id,
app_ids=[plugin_trigger.app_id for plugin_trigger in subscribers],
user_id=user_id,
)
# Find the trigger node in the workflow
event_node = None
for node_id, node_config in workflow.walk_nodes(TRIGGER_PLUGIN_NODE_TYPE):
if node_id == plugin_trigger.node_id:
event_node = node_config
break
if not event_node:
logger.error("Trigger event node not found for app %s", plugin_trigger.app_id)
continue
# invoke trigger
trigger_metadata = PluginTriggerMetadata(
plugin_unique_identifier=provider_controller.plugin_unique_identifier or "",
endpoint_id=subscription.endpoint_id,
provider_id=subscription.provider_id,
event_name=event_name,
icon_filename=trigger_entity.identity.icon or "",
icon_dark_filename=trigger_entity.identity.icon_dark or "",
for plugin_trigger in subscribers:
workflow: Workflow | None = workflows.get(plugin_trigger.app_id)
if not workflow:
logger.error(
"Workflow not found for app %s",
plugin_trigger.app_id,
)
continue
# consume quota before invoking trigger
quota_charge = unlimited()
try:
quota_charge = QuotaType.TRIGGER.consume(subscription.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id)
logger.info(
"Tenant %s rate limited, skipping plugin trigger %s", subscription.tenant_id, plugin_trigger.id
)
return 0
event_node = None
for node_id, node_config in workflow.walk_nodes(TRIGGER_PLUGIN_NODE_TYPE):
if node_id == plugin_trigger.node_id:
event_node = node_config
break
node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(event_node)
invoke_response: TriggerInvokeEventResponse | None = None
if not event_node:
logger.error("Trigger event node not found for app %s", plugin_trigger.app_id)
continue
trigger_metadata = PluginTriggerMetadata(
plugin_unique_identifier=provider_controller.plugin_unique_identifier or "",
endpoint_id=subscription.endpoint_id,
provider_id=subscription.provider_id,
event_name=event_name,
icon_filename=trigger_entity.identity.icon or "",
icon_dark_filename=trigger_entity.identity.icon_dark or "",
)
quota_charge = unlimited()
try:
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, subscription.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id)
logger.info("Tenant %s rate limited, skipping plugin trigger %s", subscription.tenant_id, plugin_trigger.id)
return dispatched_count
node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(event_node)
invoke_response: TriggerInvokeEventResponse | None = None
with session_factory.create_session() as session:
try:
invoke_response = TriggerManager.invoke_trigger_event(
tenant_id=subscription.tenant_id,
@@ -387,6 +387,7 @@ def dispatch_triggered_workflow(
raise ValueError(f"End user not found for app {plugin_trigger.app_id}")
AsyncWorkflowService.trigger_workflow_async(session=session, user=end_user, trigger_data=trigger_data)
quota_charge.commit()
dispatched_count += 1
logger.info(
"Triggered workflow for app %s with trigger event %s",
@@ -401,7 +402,7 @@ def dispatch_triggered_workflow(
plugin_trigger.app_id,
)
return dispatched_count
return dispatched_count
def dispatch_triggered_workflows(

View File

@@ -8,10 +8,11 @@ from core.workflow.nodes.trigger_schedule.exc import (
ScheduleNotFoundError,
TenantOwnerNotFoundError,
)
from enums.quota_type import QuotaType, unlimited
from enums.quota_type import QuotaType
from models.trigger import WorkflowSchedulePlan
from services.async_workflow_service import AsyncWorkflowService
from services.errors.app import QuotaExceededError
from services.quota_service import QuotaService, unlimited
from services.trigger.app_trigger_service import AppTriggerService
from services.trigger.schedule_service import ScheduleService
from services.workflow.entities import ScheduleTriggerData
@@ -32,6 +33,7 @@ def run_schedule_trigger(schedule_id: str) -> None:
TenantOwnerNotFoundError: If no owner/admin for tenant
ScheduleExecutionError: If workflow trigger fails
"""
# Ensure expire_on_commit is set to False to remain schedule/tenant_owner available
with session_factory.create_session() as session:
schedule = session.get(WorkflowSchedulePlan, schedule_id)
if not schedule:
@@ -41,16 +43,16 @@ def run_schedule_trigger(schedule_id: str) -> None:
if not tenant_owner:
raise TenantOwnerNotFoundError(f"No owner or admin found for tenant {schedule.tenant_id}")
quota_charge = unlimited()
try:
quota_charge = QuotaType.TRIGGER.consume(schedule.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id)
logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id)
return
quota_charge = unlimited()
try:
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, schedule.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id)
logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id)
return
try:
# Production dispatch: Trigger the workflow normally
try:
with session_factory.create_session() as session:
response = AsyncWorkflowService.trigger_workflow_async(
session=session,
user=tenant_owner,
@@ -61,9 +63,10 @@ def run_schedule_trigger(schedule_id: str) -> None:
tenant_id=schedule.tenant_id,
),
)
logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id)
except Exception as e:
quota_charge.refund()
raise ScheduleExecutionError(
f"Failed to trigger workflow for schedule {schedule_id}, app {schedule.app_id}"
) from e
quota_charge.commit()
logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id)
except Exception as e:
quota_charge.refund()
raise ScheduleExecutionError(
f"Failed to trigger workflow for schedule {schedule_id}, app {schedule.app_id}"
) from e

View File

@@ -171,35 +171,13 @@ class TestChatMessageApiPermissions:
parent_message_id=None,
)
class MockQuery:
def __init__(self, model):
self.model = model
def where(self, *args, **kwargs):
return self
def first(self):
if getattr(self.model, "__name__", "") == "Conversation":
return mock_conversation
return None
def order_by(self, *args, **kwargs):
return self
def limit(self, *_):
return self
def all(self):
if getattr(self.model, "__name__", "") == "Message":
return [mock_message]
return []
mock_session = mock.Mock()
mock_session.query.side_effect = MockQuery
mock_session.scalar.return_value = False
mock_session.scalar.return_value = mock_conversation
mock_session.scalars.return_value.all.return_value = [mock_message]
monkeypatch.setattr(message_api, "db", SimpleNamespace(session=mock_session))
monkeypatch.setattr(message_api, "current_user", mock_account)
monkeypatch.setattr(message_api, "attach_message_extra_contents", mock.Mock())
class DummyPagination:
def __init__(self, data, limit, has_more):

View File

@@ -24,7 +24,6 @@ def _patch_wraps():
patch("controllers.console.wraps.dify_config", dify_settings),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
):
mock_db.session.query.return_value.first.return_value = MagicMock()
yield

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import secrets
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from unittest.mock import Mock
@@ -11,6 +12,7 @@ import pytest
from sqlalchemy import Engine, delete, select
from sqlalchemy.orm import Session, sessionmaker
from core.workflow.human_input_adapter import DeliveryMethodType
from extensions.ext_storage import storage
from graphon.entities import WorkflowExecution
from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType
@@ -20,9 +22,11 @@ from graphon.nodes.human_input.enums import HumanInputFormStatus
from libs.datetime_utils import naive_utc_now
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
from models.human_input import (
BackstageRecipientPayload,
HumanInputDelivery,
HumanInputForm,
HumanInputFormRecipient,
RecipientType,
)
from models.workflow import WorkflowAppLog, WorkflowAppLogCreatedFrom, WorkflowPause, WorkflowPauseReason, WorkflowRun
from repositories.entities.workflow_pause import WorkflowPauseEntity
@@ -628,12 +632,12 @@ class TestPrivateWorkflowPauseEntity:
class TestBuildHumanInputRequiredReason:
"""Integration tests for _build_human_input_required_reason using real DB models."""
def test_builds_reason_from_form_definition(
def test_prefers_standalone_web_app_token_when_available(
self,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Build the graph pause reason from the stored form definition."""
"""Use the public standalone web-app token for service API payloads."""
expiration_time = naive_utc_now()
form_definition = FormDefinition(
@@ -660,6 +664,40 @@ class TestBuildHumanInputRequiredReason:
db_session_with_containers.add(form_model)
db_session_with_containers.flush()
delivery = HumanInputDelivery(
form_id=form_model.id,
delivery_method_type=DeliveryMethodType.WEBAPP,
channel_payload="{}",
)
db_session_with_containers.add(delivery)
db_session_with_containers.flush()
backstage_access_token = secrets.token_urlsafe(8)
backstage_recipient = HumanInputFormRecipient(
form_id=form_model.id,
delivery_id=delivery.id,
recipient_type=RecipientType.BACKSTAGE,
recipient_payload=BackstageRecipientPayload().model_dump_json(),
access_token=backstage_access_token,
)
console_access_token = secrets.token_urlsafe(8)
console_recipient = HumanInputFormRecipient(
form_id=form_model.id,
delivery_id=delivery.id,
recipient_type=RecipientType.CONSOLE,
recipient_payload="{}",
access_token=console_access_token,
)
web_app_access_token = secrets.token_urlsafe(8)
web_app_recipient = HumanInputFormRecipient(
form_id=form_model.id,
delivery_id=delivery.id,
recipient_type=RecipientType.STANDALONE_WEB_APP,
recipient_payload="{}",
access_token=web_app_access_token,
)
db_session_with_containers.add_all([backstage_recipient, console_recipient, web_app_recipient])
db_session_with_containers.flush()
# Create a pause so the reason has a valid pause_id
workflow_run = _create_workflow_run(
db_session_with_containers,
@@ -688,8 +726,15 @@ class TestBuildHumanInputRequiredReason:
# Refresh to ensure we have DB-round-tripped objects
db_session_with_containers.refresh(form_model)
db_session_with_containers.refresh(reason_model)
db_session_with_containers.refresh(backstage_recipient)
db_session_with_containers.refresh(console_recipient)
db_session_with_containers.refresh(web_app_recipient)
reason = _build_human_input_required_reason(reason_model, form_model)
reason = _build_human_input_required_reason(
reason_model,
form_model,
[backstage_recipient, console_recipient, web_app_recipient],
)
assert isinstance(reason, HumanInputRequired)
assert reason.node_title == "Ask Name"
@@ -697,3 +742,92 @@ class TestBuildHumanInputRequiredReason:
assert reason.inputs[0].output_variable_name == "name"
assert reason.actions[0].id == "approve"
assert reason.resolved_default_values == {"name": "Alice"}
assert not hasattr(reason, "form_token")
def test_falls_back_to_console_token_when_web_app_token_missing(
self,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Use the console token only when no standalone web-app token exists."""
expiration_time = naive_utc_now()
form_definition = FormDefinition(
form_content="content",
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
user_actions=[UserAction(id="approve", title="Approve")],
rendered_content="rendered",
expiration_time=expiration_time,
default_values={"name": "Alice"},
node_title="Ask Name",
display_in_ui=True,
)
form_model = HumanInputForm(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
workflow_run_id=str(uuid4()),
node_id="node-1",
form_definition=form_definition.model_dump_json(),
rendered_content="rendered",
status=HumanInputFormStatus.WAITING,
expiration_time=expiration_time,
)
db_session_with_containers.add(form_model)
db_session_with_containers.flush()
delivery = HumanInputDelivery(
form_id=form_model.id,
delivery_method_type=DeliveryMethodType.WEBAPP,
channel_payload="{}",
)
db_session_with_containers.add(delivery)
db_session_with_containers.flush()
backstage_access_token = secrets.token_urlsafe(8)
backstage_recipient = HumanInputFormRecipient(
form_id=form_model.id,
delivery_id=delivery.id,
recipient_type=RecipientType.BACKSTAGE,
recipient_payload=BackstageRecipientPayload().model_dump_json(),
access_token=backstage_access_token,
)
console_access_token = secrets.token_urlsafe(8)
console_recipient = HumanInputFormRecipient(
form_id=form_model.id,
delivery_id=delivery.id,
recipient_type=RecipientType.CONSOLE,
recipient_payload="{}",
access_token=console_access_token,
)
db_session_with_containers.add_all([backstage_recipient, console_recipient])
db_session_with_containers.flush()
workflow_run = _create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.RUNNING,
)
pause = WorkflowPause(
workflow_id=test_scope.workflow_id,
workflow_run_id=workflow_run.id,
state_object_key=f"workflow-state-{uuid4()}.json",
)
db_session_with_containers.add(pause)
db_session_with_containers.flush()
test_scope.state_keys.add(pause.state_object_key)
reason_model = WorkflowPauseReason(
pause_id=pause.id,
type_=PauseReasonType.HUMAN_INPUT_REQUIRED,
form_id=form_model.id,
node_id="node-1",
message="",
)
db_session_with_containers.add(reason_model)
db_session_with_containers.commit()
reason = _build_human_input_required_reason(reason_model, form_model, [backstage_recipient, console_recipient])
assert isinstance(reason, HumanInputRequired)
assert not hasattr(reason, "form_token")

View File

@@ -36,12 +36,19 @@ class TestAppGenerateService:
) as mock_message_based_generator,
patch("services.account_service.FeatureService", autospec=True) as mock_account_feature_service,
patch("services.app_generate_service.dify_config", autospec=True) as mock_dify_config,
patch("services.quota_service.dify_config", autospec=True) as mock_quota_dify_config,
patch("configs.dify_config", autospec=True) as mock_global_dify_config,
):
# Setup default mock returns for billing service
mock_billing_service.update_tenant_feature_plan_usage.return_value = {
"result": "success",
"history_id": "test_history_id",
mock_billing_service.quota_reserve.return_value = {
"reservation_id": "test-reservation-id",
"available": 100,
"reserved": 1,
}
mock_billing_service.quota_commit.return_value = {
"available": 99,
"reserved": 0,
"refunded": 0,
}
# Setup default mock returns for workflow service
@@ -101,6 +108,8 @@ class TestAppGenerateService:
mock_dify_config.APP_DEFAULT_ACTIVE_REQUESTS = 100
mock_dify_config.APP_DAILY_RATE_LIMIT = 1000
mock_quota_dify_config.BILLING_ENABLED = False
mock_global_dify_config.BILLING_ENABLED = False
mock_global_dify_config.APP_MAX_ACTIVE_REQUESTS = 100
mock_global_dify_config.APP_DAILY_RATE_LIMIT = 1000
@@ -118,6 +127,7 @@ class TestAppGenerateService:
"message_based_generator": mock_message_based_generator,
"account_feature_service": mock_account_feature_service,
"dify_config": mock_dify_config,
"quota_dify_config": mock_quota_dify_config,
"global_dify_config": mock_global_dify_config,
}
@@ -465,6 +475,7 @@ class TestAppGenerateService:
# Set BILLING_ENABLED to True for this test
mock_external_service_dependencies["dify_config"].BILLING_ENABLED = True
mock_external_service_dependencies["quota_dify_config"].BILLING_ENABLED = True
mock_external_service_dependencies["global_dify_config"].BILLING_ENABLED = True
# Setup test arguments
@@ -478,8 +489,10 @@ class TestAppGenerateService:
# Verify the result
assert result == ["test_response"]
# Verify billing service was called to consume quota
mock_external_service_dependencies["billing_service"].update_tenant_feature_plan_usage.assert_called_once()
# Verify billing two-phase quota (reserve + commit)
billing = mock_external_service_dependencies["billing_service"]
billing.quota_reserve.assert_called_once()
billing.quota_commit.assert_called_once()
def test_generate_with_invalid_app_mode(
self, db_session_with_containers: Session, mock_external_service_dependencies

View File

@@ -13,6 +13,12 @@ from models.model import App, Conversation, Message
from services.feedback_service import FeedbackService
def _execute_result(rows):
result = mock.Mock()
result.all.return_value = rows
return result
class TestFeedbackService:
"""Test FeedbackService methods."""
@@ -81,25 +87,17 @@ class TestFeedbackService:
def test_export_feedbacks_csv_format(self, mock_db_session, sample_data):
"""Test exporting feedback data in CSV format."""
# Setup mock query result
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
sample_data["user_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["user_feedback"].from_account,
)
]
mock_db_session.execute.return_value = mock_query
mock_db_session.execute.return_value = _execute_result(
[
(
sample_data["user_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["user_feedback"].from_account,
)
]
)
# Test CSV export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
@@ -120,25 +118,17 @@ class TestFeedbackService:
def test_export_feedbacks_json_format(self, mock_db_session, sample_data):
"""Test exporting feedback data in JSON format."""
# Setup mock query result
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
sample_data["admin_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["admin_feedback"].from_account,
)
]
mock_db_session.execute.return_value = mock_query
mock_db_session.execute.return_value = _execute_result(
[
(
sample_data["admin_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["admin_feedback"].from_account,
)
]
)
# Test JSON export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
@@ -157,25 +147,17 @@ class TestFeedbackService:
def test_export_feedbacks_with_filters(self, mock_db_session, sample_data):
"""Test exporting feedback with various filters."""
# Setup mock query result
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
sample_data["admin_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["admin_feedback"].from_account,
)
]
mock_db_session.execute.return_value = mock_query
mock_db_session.execute.return_value = _execute_result(
[
(
sample_data["admin_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["admin_feedback"].from_account,
)
]
)
# Test with filters
result = FeedbackService.export_feedbacks(
@@ -193,17 +175,7 @@ class TestFeedbackService:
def test_export_feedbacks_no_data(self, mock_db_session, sample_data):
"""Test exporting feedback when no data exists."""
# Setup mock query result with no data
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = []
mock_db_session.execute.return_value = mock_query
mock_db_session.execute.return_value = _execute_result([])
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
@@ -251,24 +223,17 @@ class TestFeedbackService:
created_at=datetime(2024, 1, 1, 10, 0, 0),
)
# Setup mock query result
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
sample_data["user_feedback"],
long_message,
sample_data["conversation"],
sample_data["app"],
sample_data["user_feedback"].from_account,
)
]
mock_db_session.execute.return_value = mock_query
mock_db_session.execute.return_value = _execute_result(
[
(
sample_data["user_feedback"],
long_message,
sample_data["conversation"],
sample_data["app"],
sample_data["user_feedback"].from_account,
)
]
)
# Test export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
@@ -309,24 +274,17 @@ class TestFeedbackService:
created_at=datetime(2024, 1, 1, 10, 0, 0),
)
# Setup mock query result
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
chinese_feedback,
chinese_message,
sample_data["conversation"],
sample_data["app"],
None, # No account for user feedback
)
]
mock_db_session.execute.return_value = mock_query
mock_db_session.execute.return_value = _execute_result(
[
(
chinese_feedback,
chinese_message,
sample_data["conversation"],
sample_data["app"],
None,
)
]
)
# Test export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
@@ -339,32 +297,24 @@ class TestFeedbackService:
def test_export_feedbacks_emoji_ratings(self, mock_db_session, sample_data):
"""Test that rating emojis are properly formatted in export."""
# Setup mock query result with both like and dislike feedback
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
sample_data["user_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["user_feedback"].from_account,
),
(
sample_data["admin_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["admin_feedback"].from_account,
),
]
mock_db_session.execute.return_value = mock_query
mock_db_session.execute.return_value = _execute_result(
[
(
sample_data["user_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["user_feedback"].from_account,
),
(
sample_data["admin_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["admin_feedback"].from_account,
),
]
)
# Test export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")

View File

@@ -10,6 +10,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE
from enums.quota_type import QuotaType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.enums import AppTriggerStatus, AppTriggerType
from models.model import App
@@ -290,17 +291,26 @@ class TestWebhookServiceTriggerExecutionWithContainers:
end_user = SimpleNamespace(id=str(uuid4()))
webhook_data = {"body": {"value": 1}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"}
quota_charge = MagicMock()
with (
patch(
"services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type",
return_value=end_user,
),
patch("services.trigger.webhook_service.QuotaType.TRIGGER.consume") as mock_consume,
patch(
"services.trigger.webhook_service.QuotaService.reserve",
return_value=quota_charge,
) as mock_reserve,
patch("services.trigger.webhook_service.AsyncWorkflowService.trigger_workflow_async") as mock_trigger,
):
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
mock_consume.assert_called_once_with(webhook_trigger.tenant_id)
mock_reserve.assert_called_once()
reserve_args = mock_reserve.call_args.args
assert reserve_args[0] == QuotaType.TRIGGER
assert reserve_args[1] == webhook_trigger.tenant_id
quota_charge.commit.assert_called_once()
mock_trigger.assert_called_once()
trigger_args = mock_trigger.call_args.args
assert trigger_args[1] is end_user
@@ -327,7 +337,7 @@ class TestWebhookServiceTriggerExecutionWithContainers:
return_value=SimpleNamespace(id=str(uuid4())),
),
patch(
"services.trigger.webhook_service.QuotaType.TRIGGER.consume",
"services.trigger.webhook_service.QuotaService.reserve",
side_effect=QuotaExceededError(feature="trigger", tenant_id=tenant.id, required=1),
),
patch(

View File

@@ -605,9 +605,9 @@ def test_schedule_trigger_creates_trigger_log(
)
# Mock quota to avoid rate limiting
from enums import quota_type
from services import quota_service
monkeypatch.setattr(quota_type.QuotaType.TRIGGER, "consume", lambda _tenant_id: quota_type.unlimited())
monkeypatch.setattr(quota_service.QuotaService, "reserve", lambda *_args, **_kwargs: quota_service.unlimited())
# Execute schedule trigger
workflow_schedule_tasks.run_schedule_trigger(plan.id)

View File

@@ -121,33 +121,32 @@ def _configure_session_factory(_unit_test_engine):
configure_session_factory(_unit_test_engine, expire_on_commit=False)
def setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account):
def setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_owner):
"""
Helper to set up the mock DB execute chain for tenant/account authentication.
Helper to stub the tenant-owner execute result for service API app authentication.
This configures the mock to return (tenant, account) for the
db.session.execute(select(...).join().join().where()).one_or_none()
query used by validate_app_token decorator.
The validate_app_token decorator currently resolves the active tenant owner
via db.session.execute(select(Tenant, Account)...).one_or_none().
Args:
mock_db: The mocked db object
mock_tenant: Mock tenant object to return
mock_account: Mock account object to return
mock_owner: Mock owner object to return from the execute result
"""
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_account)
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_owner)
def setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta):
def setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_tenant_account_join):
"""
Helper to set up the mock DB execute chain for dataset tenant authentication.
Helper to stub the tenant-owner execute result for dataset token authentication.
This configures the mock to return (tenant, tenant_account) for the
db.session.execute(select(...).where().where().where().where()).one_or_none()
query used by validate_dataset_token decorator.
The validate_dataset_token decorator currently resolves the owner mapping via
db.session.execute(select(Tenant, TenantAccountJoin)...).one_or_none(), and
then loads the Account separately via db.session.get(...).
Args:
mock_db: The mocked db object
mock_tenant: Mock tenant object to return
mock_ta: Mock tenant account object to return
mock_tenant_account_join: Mock tenant-account join object to return
"""
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta)
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_tenant_account_join)

View File

@@ -208,8 +208,6 @@ class TestAnnotationImportServiceValidation:
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
@@ -230,8 +228,6 @@ class TestAnnotationImportServiceValidation:
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
@@ -248,8 +244,6 @@ class TestAnnotationImportServiceValidation:
csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff'
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
with (
patch("services.annotation_service.current_account_with_tenant") as mock_auth,
patch("services.annotation_service.pd.read_csv", side_effect=ParserError("malformed CSV")),
@@ -269,8 +263,6 @@ class TestAnnotationImportServiceValidation:
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")

View File

@@ -43,7 +43,6 @@ class TestAuthenticationSecurity:
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.")
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
mock_features.return_value.is_allow_register = True
# Act
@@ -76,7 +75,6 @@ class TestAuthenticationSecurity:
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Wrong password")
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
# Act
with self.app.test_request_context(
@@ -109,7 +107,6 @@ class TestAuthenticationSecurity:
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.")
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
mock_features.return_value.is_allow_register = False
# Act
@@ -135,7 +132,6 @@ class TestAuthenticationSecurity:
def test_reset_password_with_existing_account(self, mock_send_email, mock_get_user, mock_features, mock_db):
"""Test that reset password returns success with token for existing accounts."""
# Mock the setup check
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
# Test with existing account
mock_get_user.return_value = MagicMock(email="existing@example.com")

View File

@@ -65,7 +65,6 @@ class TestEmailCodeLoginSendEmailApi:
- IP rate limiting is checked
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_user.return_value = mock_account
mock_send_email.return_value = "email_token_123"
@@ -98,7 +97,6 @@ class TestEmailCodeLoginSendEmailApi:
- Registration is allowed by system features
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_user.return_value = None
mock_get_features.return_value.is_allow_register = True
@@ -130,7 +128,6 @@ class TestEmailCodeLoginSendEmailApi:
- Registration is blocked by system features
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_user.return_value = None
mock_get_features.return_value.is_allow_register = False
@@ -152,7 +149,6 @@ class TestEmailCodeLoginSendEmailApi:
- Prevents spam and abuse
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = True
# Act & Assert
@@ -172,7 +168,6 @@ class TestEmailCodeLoginSendEmailApi:
- AccountInFreezeError is raised for frozen accounts
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_user.side_effect = AccountRegisterError("Account frozen")
@@ -213,7 +208,6 @@ class TestEmailCodeLoginSendEmailApi:
- Defaults to en-US when not specified
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_user.return_value = mock_account
mock_send_email.return_value = "token"
@@ -286,7 +280,6 @@ class TestEmailCodeLoginApi:
- User is logged in with token pair
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
mock_get_user.return_value = mock_account
mock_get_tenants.return_value = [MagicMock()]
@@ -335,7 +328,6 @@ class TestEmailCodeLoginApi:
- User is logged in after account creation
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "newuser@example.com", "code": "123456"}
mock_get_user.return_value = None
mock_create_account.return_value = mock_account
@@ -369,7 +361,6 @@ class TestEmailCodeLoginApi:
- InvalidTokenError is raised for invalid/expired tokens
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = None
# Act & Assert
@@ -392,7 +383,6 @@ class TestEmailCodeLoginApi:
- InvalidEmailError is raised when email doesn't match token
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "original@example.com", "code": "123456"}
# Act & Assert
@@ -415,7 +405,6 @@ class TestEmailCodeLoginApi:
- EmailCodeError is raised for wrong verification code
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
# Act & Assert
@@ -453,7 +442,6 @@ class TestEmailCodeLoginApi:
- User is added as owner of new workspace
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
mock_get_user.return_value = mock_account
mock_get_tenants.return_value = []
@@ -496,7 +484,6 @@ class TestEmailCodeLoginApi:
- WorkspacesLimitExceeded is raised when limit reached
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
mock_get_user.return_value = mock_account
mock_get_tenants.return_value = []
@@ -538,7 +525,6 @@ class TestEmailCodeLoginApi:
- NotAllowedCreateWorkspace is raised when creation disabled
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
mock_get_user.return_value = mock_account
mock_get_tenants.return_value = []

View File

@@ -110,7 +110,6 @@ class TestLoginApi:
- Rate limit is reset after successful login
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.return_value = mock_account
@@ -162,7 +161,6 @@ class TestLoginApi:
- Authentication proceeds with invitation token
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = {"data": {"email": "test@example.com"}}
mock_authenticate.return_value = mock_account
@@ -199,7 +197,6 @@ class TestLoginApi:
- EmailPasswordLoginLimitError is raised when limit exceeded
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = True
mock_get_invitation.return_value = None
@@ -228,7 +225,6 @@ class TestLoginApi:
- AccountInFreezeError is raised for frozen accounts
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_frozen.return_value = True
# Act & Assert
@@ -268,7 +264,6 @@ class TestLoginApi:
- Generic error message prevents user enumeration
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = AccountPasswordError("Invalid password")
@@ -305,7 +300,6 @@ class TestLoginApi:
- Login is prevented even with valid credentials
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = AccountLoginError("Account is banned")
@@ -351,7 +345,6 @@ class TestLoginApi:
- User cannot login without an assigned workspace
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.return_value = mock_account
@@ -383,7 +376,6 @@ class TestLoginApi:
- Security check prevents invitation token abuse
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = {"data": {"email": "invited@example.com"}}
@@ -425,7 +417,6 @@ class TestLoginApi:
mock_token_pair,
):
"""Test that login retries with lowercase email when uppercase lookup fails."""
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = [AccountPasswordError("Invalid"), mock_account]
@@ -459,7 +450,6 @@ class TestLoginApi:
mock_db,
app,
):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"}
mock_get_account.side_effect = Unauthorized("Account is banned.")
@@ -513,7 +503,6 @@ class TestLogoutApi:
- Success response is returned
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_current_account.return_value = (mock_account, MagicMock())
# Act
@@ -539,7 +528,6 @@ class TestLogoutApi:
- Success response is returned
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
# Create a mock anonymous user that will pass isinstance check
anonymous_user = MagicMock()
mock_flask_login.AnonymousUserMixin = type("AnonymousUserMixin", (), {})

View File

@@ -46,7 +46,6 @@ class TestPartnerTenants:
patch("libs.login.dify_config.LOGIN_DISABLED", False),
patch("libs.login.check_csrf_token") as mock_csrf,
):
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
mock_csrf.return_value = None
yield {"db": mock_db, "csrf": mock_csrf}

View File

@@ -8,8 +8,10 @@ from werkzeug.exceptions import Forbidden
import controllers.console.tag.tags as module
from controllers.console import console_ns
from controllers.console.tag.tags import (
TagBindingCreateApi,
TagBindingDeleteApi,
DeprecatedTagBindingCreateApi,
DeprecatedTagBindingRemoveApi,
TagBindingCollectionApi,
TagBindingItemApi,
TagListApi,
TagUpdateDeleteApi,
)
@@ -205,9 +207,9 @@ class TestTagUpdateDeleteApi:
assert status == 204
class TestTagBindingCreateApi:
class TestTagBindingCollectionApi:
def test_create_success(self, app, admin_user, payload_patch):
api = TagBindingCreateApi()
api = TagBindingCollectionApi()
method = unwrap(api.post)
payload = {
@@ -232,7 +234,7 @@ class TestTagBindingCreateApi:
assert result["result"] == "success"
def test_create_forbidden(self, app, readonly_user, payload_patch):
api = TagBindingCreateApi()
api = TagBindingCollectionApi()
method = unwrap(api.post)
with app.test_request_context("/", json={}):
@@ -247,9 +249,78 @@ class TestTagBindingCreateApi:
method(api)
class TestTagBindingDeleteApi:
class TestDeprecatedTagBindingCreateApi:
def test_create_success(self, app, admin_user, payload_patch):
api = DeprecatedTagBindingCreateApi()
method = unwrap(api.post)
payload = {
"tag_ids": ["tag-1"],
"target_id": "target-1",
"type": "knowledge",
}
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(admin_user, None),
),
payload_patch(payload),
patch("controllers.console.tag.tags.TagService.save_tag_binding") as save_mock,
):
result, status = method(api)
save_mock.assert_called_once()
assert status == 200
assert result["result"] == "success"
class TestTagBindingItemApi:
def test_delete_success(self, app, admin_user, payload_patch):
api = TagBindingItemApi()
method = unwrap(api.delete)
payload = {
"target_id": "target-1",
"type": "knowledge",
}
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(admin_user, None),
),
payload_patch(payload),
patch("controllers.console.tag.tags.TagService.delete_tag_binding") as delete_mock,
):
result, status = method(api, "tag-1")
delete_mock.assert_called_once()
delete_payload = delete_mock.call_args.args[0]
assert delete_payload.tag_id == "tag-1"
assert delete_payload.target_id == "target-1"
assert delete_payload.type == TagType.KNOWLEDGE
assert status == 200
assert result["result"] == "success"
def test_delete_forbidden(self, app, readonly_user):
api = TagBindingItemApi()
method = unwrap(api.delete)
with app.test_request_context("/"):
with patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(readonly_user, None),
):
with pytest.raises(Forbidden):
method(api, "tag-1")
class TestDeprecatedTagBindingRemoveApi:
def test_remove_success(self, app, admin_user, payload_patch):
api = TagBindingDeleteApi()
api = DeprecatedTagBindingRemoveApi()
method = unwrap(api.post)
payload = {
@@ -274,7 +345,7 @@ class TestTagBindingDeleteApi:
assert result["result"] == "success"
def test_remove_forbidden(self, app, readonly_user, payload_patch):
api = TagBindingDeleteApi()
api = DeprecatedTagBindingRemoveApi()
method = unwrap(api.post)
with app.test_request_context("/", json={}):
@@ -297,3 +368,35 @@ class TestTagResponseModel:
assert payload["type"] == "knowledge"
assert payload["binding_count"] == "1"
class TestTagBindingRouteMetadata:
def test_legacy_write_routes_are_marked_deprecated(self):
assert DeprecatedTagBindingCreateApi.post.__apidoc__["deprecated"] is True
assert DeprecatedTagBindingRemoveApi.post.__apidoc__["deprecated"] is True
assert TagBindingCollectionApi.post.__apidoc__.get("deprecated") is not True
assert TagBindingItemApi.delete.__apidoc__.get("deprecated") is not True
def test_write_routes_have_stable_operation_ids(self):
assert TagBindingCollectionApi.post.__apidoc__["id"] == "create_tag_binding"
assert TagBindingItemApi.delete.__apidoc__["id"] == "delete_tag_binding"
assert DeprecatedTagBindingCreateApi.post.__apidoc__["id"] == "create_tag_binding_deprecated"
assert DeprecatedTagBindingRemoveApi.post.__apidoc__["id"] == "delete_tag_binding_deprecated"
def test_canonical_and_legacy_write_routes_are_registered(self):
route_map = {
resource.__name__: urls
for resource, urls, _route_doc, _kwargs in console_ns.resources
if resource.__name__
in {
"TagBindingCollectionApi",
"TagBindingItemApi",
"DeprecatedTagBindingCreateApi",
"DeprecatedTagBindingRemoveApi",
}
}
assert route_map["TagBindingCollectionApi"] == ("/tag-bindings",)
assert route_map["TagBindingItemApi"] == ("/tag-bindings/<uuid:id>",)
assert route_map["DeprecatedTagBindingCreateApi"] == ("/tag-bindings/create",)
assert route_map["DeprecatedTagBindingRemoveApi"] == ("/tag-bindings/remove",)

View File

@@ -122,6 +122,35 @@ def test_post_form_invalid_recipient_type(app, monkeypatch: pytest.MonkeyPatch)
handler(api, form_token="token")
def test_post_form_rejects_webapp_recipient_type(app, monkeypatch: pytest.MonkeyPatch) -> None:
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.STANDALONE_WEB_APP)
class _ServiceStub:
def __init__(self, *_args, **_kwargs):
pass
def get_form_by_token(self, _token):
return form
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
monkeypatch.setattr(
"controllers.console.human_input_form.current_account_with_tenant",
lambda: (SimpleNamespace(id="user-1"), "tenant-1"),
)
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleHumanInputFormApi()
handler = _unwrap(api.post)
with app.test_request_context(
"/console/api/form/human_input/token",
method="POST",
json={"inputs": {"content": "ok"}, "action": "approve"},
):
with pytest.raises(NotFoundError):
handler(api, form_token="token")
def test_post_form_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
submit_mock = Mock()
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.CONSOLE)

View File

@@ -24,10 +24,6 @@ def app():
return app
def _mock_wraps_db(mock_db):
mock_db.session.query.return_value.first.return_value = MagicMock()
def _build_account(email: str, account_id: str = "acc", tenant: object | None = None) -> Account:
tenant_obj = tenant if tenant is not None else SimpleNamespace(id="tenant-id")
account = Account(name=account_id, email=email)
@@ -64,7 +60,6 @@ class TestChangeEmailSend:
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("current@example.com", "acc1")
mock_current_account.return_value = (mock_account, None)
@@ -117,7 +112,6 @@ class TestChangeEmailSend:
"""GHSA-4q3w-q5mc-45rq: a phase-1 token must not unlock the new-email send step."""
from controllers.console.auth.error import InvalidTokenError
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("current@example.com", "acc1")
mock_current_account.return_value = (mock_account, None)
@@ -163,7 +157,6 @@ class TestChangeEmailValidity:
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("user@example.com", "acc2")
mock_current_account.return_value = (mock_account, None)
@@ -223,7 +216,6 @@ class TestChangeEmailValidity:
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
mock_is_rate_limit.return_value = False
@@ -280,7 +272,6 @@ class TestChangeEmailValidity:
"""A token whose phase marker is a string but not a known transition must be rejected."""
from controllers.console.auth.error import InvalidTokenError
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
mock_is_rate_limit.return_value = False
@@ -330,7 +321,6 @@ class TestChangeEmailValidity:
"""A token minted without a phase marker (e.g. a hand-crafted token) must not validate."""
from controllers.console.auth.error import InvalidTokenError
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
mock_is_rate_limit.return_value = False
@@ -378,7 +368,6 @@ class TestChangeEmailReset:
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
current_user = _build_account("old@example.com", "acc3")
mock_current_account.return_value = (current_user, None)
@@ -434,7 +423,6 @@ class TestChangeEmailReset:
"""GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset."""
from controllers.console.auth.error import InvalidTokenError
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
current_user = _build_account("old@example.com", "acc3")
mock_current_account.return_value = (current_user, None)
@@ -488,7 +476,6 @@ class TestChangeEmailReset:
"""A verified token for address A must not be replayed to change to address B."""
from controllers.console.auth.error import InvalidTokenError
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
current_user = _build_account("old@example.com", "acc3")
mock_current_account.return_value = (current_user, None)
@@ -561,7 +548,6 @@ class TestAccountDeletionFeedback:
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.BillingService.update_account_deletion_feedback")
def test_should_normalize_feedback_email(self, mock_update, mock_db, app):
_mock_wraps_db(mock_db)
with app.test_request_context(
"/account/delete/feedback",
method="POST",
@@ -578,7 +564,6 @@ class TestCheckEmailUnique:
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app):
_mock_wraps_db(mock_db)
mock_is_freeze.return_value = False
mock_check_unique.return_value = True

View File

@@ -1,5 +1,5 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from unittest.mock import patch
import pytest
from flask import Flask, g
@@ -16,10 +16,6 @@ def app():
return flask_app
def _mock_wraps_db(mock_db):
mock_db.session.query.return_value.first.return_value = MagicMock()
def _build_feature_flags():
placeholder_quota = SimpleNamespace(limit=0, size=0)
workspace_members = SimpleNamespace(is_available=lambda count: True)
@@ -49,7 +45,6 @@ class TestMemberInviteEmailApi:
mock_get_features,
app,
):
_mock_wraps_db(mock_db)
mock_get_features.return_value = _build_feature_flags()
mock_invite_member.return_value = "token-abc"

View File

@@ -310,7 +310,6 @@ class TestSystemSetup:
def test_should_allow_when_setup_complete(self, mock_db):
"""Test that requests are allowed when setup is complete"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock() # Setup exists
@setup_required
def admin_view():

View File

@@ -22,7 +22,7 @@ _WRAPS_MODULE: ModuleType | None = None
@contextmanager
def _mock_db():
mock_session = SimpleNamespace(query=lambda *args, **kwargs: SimpleNamespace(first=lambda: True))
mock_session = SimpleNamespace(scalar=lambda *args, **kwargs: True)
with patch("extensions.ext_database.db.session", mock_session):
yield

View File

@@ -12,7 +12,7 @@ from controllers.service_api.app.app import AppInfoApi, AppMetaApi, AppParameter
from controllers.service_api.app.error import AppUnavailableError
from models.account import TenantStatus
from models.model import App, AppMode
from tests.unit_tests.conftest import setup_mock_tenant_account_query
from tests.unit_tests.conftest import setup_mock_tenant_owner_execute_result
class TestAppParameterApi:
@@ -74,7 +74,7 @@ class TestAppParameterApi:
# Mock tenant owner info for login
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
@@ -120,7 +120,7 @@ class TestAppParameterApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
@@ -161,7 +161,7 @@ class TestAppParameterApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act & Assert
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
@@ -200,7 +200,7 @@ class TestAppParameterApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act & Assert
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
@@ -263,7 +263,7 @@ class TestAppMetaApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/meta", method="GET", headers={"Authorization": "Bearer test_token"}):
@@ -331,7 +331,7 @@ class TestAppInfoApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
@@ -388,7 +388,7 @@ class TestAppInfoApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
@@ -434,7 +434,7 @@ class TestAppInfoApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
@@ -486,7 +486,7 @@ class TestAppInfoApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):

View File

@@ -0,0 +1,707 @@
"""Dedicated tests for HITL behavior exposed through the Service API."""
from __future__ import annotations
import json
import sys
from collections.abc import Sequence
from dataclasses import dataclass
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import ANY, MagicMock, Mock
import pytest
import services.app_generate_service as ags_module
from controllers.service_api.app.workflow_events import WorkflowEventsApi
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
from core.app.apps.common import workflow_response_converter
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import QueueWorkflowPausedEvent
from core.app.entities.task_entities import (
AdvancedChatPausedBlockingResponse,
HumanInputRequiredResponse,
WorkflowAppPausedBlockingResponse,
WorkflowPauseStreamResponse,
)
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper
from core.workflow.human_input_policy import HumanInputSurface
from core.workflow.system_variables import build_system_variables
from graphon.entities import WorkflowStartReason
from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
from graphon.nodes.human_input.entities import FormInput, UserAction
from graphon.nodes.human_input.enums import FormInputType
from graphon.runtime import GraphRuntimeState, VariablePool
from models.account import Account
from models.enums import CreatorUserRole
from models.model import AppMode
from models.workflow import WorkflowRun
from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot
from repositories.entities.workflow_pause import WorkflowPauseEntity
from services.app_generate_service import AppGenerateService
from services.workflow_event_snapshot_service import _build_snapshot_events
from tests.unit_tests.controllers.service_api.conftest import _unwrap
class _DummyRateLimit:
@staticmethod
def gen_request_key() -> str:
return "dummy-request-id"
def __init__(self, client_id: str, max_active_requests: int) -> None:
self.client_id = client_id
self.max_active_requests = max_active_requests
def enter(self, request_id: str | None = None) -> str:
return request_id or "dummy-request-id"
def exit(self, request_id: str) -> None:
return None
def generate(self, generator, request_id: str):
return generator
def _mock_repo_for_run(monkeypatch: pytest.MonkeyPatch, workflow_run):
workflow_events_module = sys.modules["controllers.service_api.app.workflow_events"]
repo = SimpleNamespace(get_workflow_run_by_id_and_tenant_id=lambda **_kwargs: workflow_run)
monkeypatch.setattr(
workflow_events_module.DifyAPIRepositoryFactory,
"create_api_workflow_run_repository",
lambda *_args, **_kwargs: repo,
)
monkeypatch.setattr(workflow_events_module, "db", SimpleNamespace(engine=object()))
return workflow_events_module
def _build_service_api_pause_converter() -> WorkflowResponseConverter:
application_generate_entity = SimpleNamespace(
inputs={},
files=[],
invoke_from=InvokeFrom.SERVICE_API,
app_config=SimpleNamespace(app_id="app-id", tenant_id="tenant-id"),
)
system_variables = build_system_variables(
user_id="user",
app_id="app-id",
workflow_id="workflow-id",
workflow_execution_id="run-id",
)
user = MagicMock(spec=Account)
user.id = "account-id"
user.name = "Tester"
user.email = "tester@example.com"
return WorkflowResponseConverter(
application_generate_entity=application_generate_entity,
user=user,
system_variables=system_variables,
)
def _build_advanced_chat_paused_blocking_response() -> AdvancedChatPausedBlockingResponse:
data = AdvancedChatPausedBlockingResponse.Data(
id="msg-1",
mode="chat",
conversation_id="c1",
message_id="m1",
workflow_run_id="run-1",
answer="partial",
metadata={"usage": {"total_tokens": 1}},
created_at=1,
paused_nodes=["node-1"],
reasons=[
{
"type": PauseReasonType.HUMAN_INPUT_REQUIRED,
"form_id": "form-1",
"expiration_time": 100,
}
],
status=WorkflowExecutionStatus.PAUSED,
elapsed_time=0.1,
total_tokens=0,
total_steps=0,
)
return AdvancedChatPausedBlockingResponse(task_id="t1", data=data)
def _build_workflow_paused_blocking_response() -> WorkflowAppPausedBlockingResponse:
return WorkflowAppPausedBlockingResponse(
task_id="t1",
workflow_run_id="r1",
data=WorkflowAppPausedBlockingResponse.Data(
id="r1",
workflow_id="wf-1",
status=WorkflowExecutionStatus.PAUSED,
outputs={},
error=None,
elapsed_time=0.5,
total_tokens=0,
total_steps=2,
created_at=1,
finished_at=None,
paused_nodes=["node-1"],
reasons=[{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 100}],
),
)
@dataclass(frozen=True)
class _FakePauseEntity(WorkflowPauseEntity):
pause_id: str
workflow_run_id: str
paused_at_value: datetime
pause_reasons: Sequence[HumanInputRequired]
@property
def id(self) -> str:
return self.pause_id
@property
def workflow_execution_id(self) -> str:
return self.workflow_run_id
def get_state(self) -> bytes:
raise AssertionError("state is not required for snapshot tests")
@property
def resumed_at(self) -> datetime | None:
return None
@property
def paused_at(self) -> datetime:
return self.paused_at_value
def get_pause_reasons(self) -> Sequence[HumanInputRequired]:
return self.pause_reasons
def _build_workflow_run(status: WorkflowExecutionStatus) -> WorkflowRun:
return WorkflowRun(
id="run-1",
tenant_id="tenant-1",
app_id="app-1",
workflow_id="workflow-1",
type="workflow",
triggered_from="app-run",
version="v1",
graph=None,
inputs=json.dumps({"input": "value"}),
status=status,
outputs=json.dumps({}),
error=None,
elapsed_time=0.0,
total_tokens=0,
total_steps=0,
created_by_role=CreatorUserRole.END_USER,
created_by="user-1",
created_at=datetime(2024, 1, 1, tzinfo=UTC),
)
def _build_snapshot(status: WorkflowNodeExecutionStatus) -> WorkflowNodeExecutionSnapshot:
created_at = datetime(2024, 1, 1, tzinfo=UTC)
finished_at = datetime(2024, 1, 1, 0, 0, 5, tzinfo=UTC)
return WorkflowNodeExecutionSnapshot(
execution_id="exec-1",
node_id="node-1",
node_type="human-input",
title="Human Input",
index=1,
status=status.value,
elapsed_time=0.5,
created_at=created_at,
finished_at=finished_at,
iteration_id=None,
loop_id=None,
)
def _build_resumption_context(task_id: str) -> WorkflowResumptionContext:
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant-1",
app_id="app-1",
app_mode=AppMode.WORKFLOW,
workflow_id="workflow-1",
)
generate_entity = WorkflowAppGenerateEntity(
task_id=task_id,
app_config=app_config,
inputs={},
files=[],
user_id="user-1",
stream=True,
invoke_from=InvokeFrom.EXPLORE,
call_depth=0,
workflow_execution_id="run-1",
)
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
runtime_state.register_paused_node("node-1")
runtime_state.outputs = {"result": "value"}
wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity)
return WorkflowResumptionContext(
generate_entity=wrapper,
serialized_graph_runtime_state=runtime_state.dumps(),
)
class TestHitlServiceApi:
# Service API event-stream continuation
def test_workflow_events_continue_on_pause_keeps_stream_open(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
app_id="app-1",
created_by_role=CreatorUserRole.END_USER,
created_by="end-user-1",
finished_at=None,
)
workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
msg_generator = Mock()
msg_generator.retrieve_events.return_value = ["raw-event"]
workflow_generator = Mock()
workflow_generator.convert_to_event_stream.return_value = iter(["data: streamed\n\n"])
monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator)
monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context("/workflow/run-1/events?user=u1&continue_on_pause=true", method="GET"):
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
assert response.get_data(as_text=True) == "data: streamed\n\n"
msg_generator.retrieve_events.assert_called_once_with(
AppMode.WORKFLOW,
"run-1",
terminal_events=[],
)
workflow_generator.convert_to_event_stream.assert_called_once_with(["raw-event"])
def test_workflow_events_snapshot_continue_on_pause_keeps_pause_open(
self, app, monkeypatch: pytest.MonkeyPatch
) -> None:
workflow_run = SimpleNamespace(
id="run-1",
app_id="app-1",
created_by_role=CreatorUserRole.END_USER,
created_by="end-user-1",
finished_at=None,
)
workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
msg_generator = Mock()
workflow_generator = Mock()
workflow_generator.convert_to_event_stream.return_value = iter(["data: snapshot\n\n"])
snapshot_builder = Mock(return_value=["snapshot-events"])
monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator)
monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator)
monkeypatch.setattr(workflow_events_module, "build_workflow_event_stream", snapshot_builder)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context(
"/workflow/run-1/events?user=u1&include_state_snapshot=true&continue_on_pause=true",
method="GET",
):
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
assert response.get_data(as_text=True) == "data: snapshot\n\n"
msg_generator.retrieve_events.assert_not_called()
snapshot_builder.assert_called_once_with(
app_mode=AppMode.WORKFLOW,
workflow_run=workflow_run,
tenant_id="tenant-1",
app_id="app-1",
session_maker=ANY,
human_input_surface=HumanInputSurface.SERVICE_API,
close_on_pause=False,
)
workflow_generator.convert_to_event_stream.assert_called_once_with(["snapshot-events"])
def test_advanced_chat_blocking_injects_pause_state_config(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", False)
monkeypatch.setattr(ags_module, "RateLimit", _DummyRateLimit)
workflow = MagicMock()
workflow.created_by = "owner-id"
monkeypatch.setattr(AppGenerateService, "_get_workflow", lambda *args, **kwargs: workflow)
monkeypatch.setattr(ags_module.session_factory, "get_session_maker", lambda: "session-maker")
generator_instance = MagicMock()
generator_instance.generate.return_value = {"result": "advanced-blocking"}
generator_instance.convert_to_event_stream.side_effect = lambda payload: payload
monkeypatch.setattr(ags_module, "AdvancedChatAppGenerator", lambda: generator_instance)
app_model = MagicMock()
app_model.mode = AppMode.ADVANCED_CHAT
app_model.id = "app-id"
app_model.tenant_id = "tenant-id"
app_model.max_active_requests = 0
app_model.is_agent = False
user = MagicMock()
user.id = "user-id"
result = AppGenerateService.generate(
app_model=app_model,
user=user,
args={"workflow_id": None, "query": "hi", "inputs": {}},
invoke_from=InvokeFrom.SERVICE_API,
streaming=False,
)
assert result == {"result": "advanced-blocking"}
call_kwargs = generator_instance.generate.call_args.kwargs
assert call_kwargs["streaming"] is False
assert call_kwargs["pause_state_config"] is not None
assert call_kwargs["pause_state_config"].session_factory == "session-maker"
assert call_kwargs["pause_state_config"].state_owner_user_id == "owner-id"
# Blocking payload contract
def test_advanced_chat_blocking_pause_payload_contract(self) -> None:
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
response = AdvancedChatAppGenerateResponseConverter.convert_blocking_full_response(
_build_advanced_chat_paused_blocking_response()
)
assert response["event"] == "workflow_paused"
assert response["workflow_run_id"] == "run-1"
assert response["answer"] == "partial"
assert response["data"]["reasons"][0]["type"] == PauseReasonType.HUMAN_INPUT_REQUIRED
assert response["data"]["reasons"][0]["expiration_time"] == 100
assert "human_input_forms" not in response["data"]
def test_workflow_blocking_pause_payload_contract(self) -> None:
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
response = WorkflowAppGenerateResponseConverter.convert_blocking_full_response(
_build_workflow_paused_blocking_response()
)
assert response["workflow_run_id"] == "r1"
assert response["data"]["status"] == WorkflowExecutionStatus.PAUSED
assert response["data"]["paused_nodes"] == ["node-1"]
assert response["data"]["reasons"] == [
{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 100}
]
assert "human_input_forms" not in response["data"]
def test_advanced_chat_blocking_pipeline_pause_payload_contract(self) -> None:
from core.app.app_config.entities import AppAdditionalFeatures
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
from models.enums import MessageStatus
from models.model import EndUser
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant",
app_id="app",
app_mode=AppMode.ADVANCED_CHAT,
additional_features=AppAdditionalFeatures(),
variables=[],
workflow_id="workflow-id",
)
application_generate_entity = AdvancedChatAppGenerateEntity.model_construct(
task_id="task",
app_config=app_config,
inputs={},
query="hello",
files=[],
user_id="user",
stream=False,
invoke_from=InvokeFrom.WEB_APP,
extras={},
trace_manager=None,
workflow_run_id="run-id",
)
pipeline = AdvancedChatAppGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
workflow=SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}),
queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None),
conversation=SimpleNamespace(id="conv-id", mode=AppMode.ADVANCED_CHAT),
message=SimpleNamespace(
id="message-id",
query="hello",
created_at=datetime.utcnow(),
status=MessageStatus.NORMAL,
answer="",
),
user=EndUser(tenant_id="tenant", type="session", name="tester", session_id="session"),
stream=False,
dialogue_count=1,
draft_var_saver_factory=lambda **kwargs: None,
)
pipeline._task_state.answer = "partial answer"
pipeline._workflow_run_id = "run-id"
def _gen():
yield HumanInputRequiredResponse(
task_id="task",
workflow_run_id="run-id",
data=HumanInputRequiredResponse.Data(
form_id="form-1",
node_id="node-1",
node_title="Approval",
form_content="Need approval",
inputs=[],
actions=[UserAction(id="approve", title="Approve")],
display_in_ui=True,
form_token="token-1",
resolved_default_values={},
expiration_time=123,
),
)
yield WorkflowPauseStreamResponse(
task_id="task",
workflow_run_id="run-id",
data=WorkflowPauseStreamResponse.Data(
workflow_run_id="run-id",
paused_nodes=["node-1"],
outputs={},
reasons=[
{
"type": PauseReasonType.HUMAN_INPUT_REQUIRED,
"form_id": "form-1",
"node_id": "node-1",
"expiration_time": 123,
},
],
status="paused",
created_at=1,
elapsed_time=0.1,
total_tokens=0,
total_steps=0,
),
)
response = pipeline._to_blocking_response(_gen())
assert isinstance(response, AdvancedChatPausedBlockingResponse)
assert response.data.answer == "partial answer"
assert response.data.workflow_run_id == "run-id"
assert response.data.reasons[0]["form_id"] == "form-1"
assert response.data.reasons[0]["expiration_time"] == 123
def test_workflow_blocking_pipeline_pause_payload_contract(self, monkeypatch: pytest.MonkeyPatch) -> None:
from core.app.apps.workflow import generate_task_pipeline as workflow_pipeline_module
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant",
app_id="app",
app_mode=AppMode.WORKFLOW,
additional_features=AppAdditionalFeatures(),
variables=[],
workflow_id="workflow-id",
)
application_generate_entity = WorkflowAppGenerateEntity.model_construct(
task_id="task",
app_config=app_config,
inputs={},
files=[],
user_id="user",
stream=False,
invoke_from=InvokeFrom.WEB_APP,
trace_manager=None,
workflow_execution_id="run-id",
extras={},
call_depth=0,
)
pipeline = WorkflowAppGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
workflow=SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}),
queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None),
user=SimpleNamespace(id="user", session_id="session"),
stream=False,
draft_var_saver_factory=lambda **kwargs: None,
)
monkeypatch.setattr(workflow_pipeline_module.time, "time", lambda: 1700000000)
def _gen():
yield HumanInputRequiredResponse(
task_id="task",
workflow_run_id="run",
data=HumanInputRequiredResponse.Data(
form_id="form-1",
node_id="node-1",
node_title="Human Input",
form_content="content",
expiration_time=1,
),
)
yield WorkflowPauseStreamResponse(
task_id="task",
workflow_run_id="run",
data=WorkflowPauseStreamResponse.Data(
workflow_run_id="run",
status=WorkflowExecutionStatus.PAUSED,
outputs={},
paused_nodes=["node-1"],
reasons=[{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 1}],
created_at=1,
elapsed_time=0.1,
total_tokens=0,
total_steps=0,
),
)
response = pipeline._to_blocking_response(_gen())
assert isinstance(response, WorkflowAppPausedBlockingResponse)
assert response.data.status == WorkflowExecutionStatus.PAUSED
assert response.data.paused_nodes == ["node-1"]
assert response.data.reasons == [{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 1}]
def test_service_api_pause_event_serializes_hitl_reason(self, monkeypatch: pytest.MonkeyPatch) -> None:
converter = _build_service_api_pause_converter()
converter.workflow_start_to_stream_response(
task_id="task",
workflow_run_id="run-id",
workflow_id="workflow-id",
reason=WorkflowStartReason.INITIAL,
)
expiration_time = datetime(2024, 1, 1, tzinfo=UTC)
class _FakeSession:
def execute(self, _stmt):
return [("form-1", expiration_time, '{"display_in_ui": true}')]
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(workflow_response_converter, "Session", lambda **_: _FakeSession())
monkeypatch.setattr(workflow_response_converter, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(
workflow_response_converter,
"load_form_tokens_by_form_id",
lambda form_ids, session=None, surface=None: {"form-1": "token"},
)
reason = HumanInputRequired(
form_id="form-1",
form_content="Rendered",
inputs=[
FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="field", default=None),
],
actions=[UserAction(id="approve", title="Approve")],
display_in_ui=True,
node_id="node-id",
node_title="Human Step",
form_token="token",
)
queue_event = QueueWorkflowPausedEvent(
reasons=[reason],
outputs={"answer": "value"},
paused_nodes=["node-id"],
)
runtime_state = SimpleNamespace(total_tokens=0, node_run_steps=0)
responses = converter.workflow_pause_to_stream_response(
event=queue_event,
task_id="task",
graph_runtime_state=runtime_state,
)
assert isinstance(responses[-1], WorkflowPauseStreamResponse)
pause_resp = responses[-1]
assert pause_resp.workflow_run_id == "run-id"
assert pause_resp.data.paused_nodes == ["node-id"]
assert pause_resp.data.outputs == {}
assert pause_resp.data.reasons[0]["TYPE"] == "human_input_required"
assert pause_resp.data.reasons[0]["form_id"] == "form-1"
assert pause_resp.data.reasons[0]["form_token"] == "token"
assert pause_resp.data.reasons[0]["expiration_time"] == int(expiration_time.timestamp())
assert isinstance(responses[0], HumanInputRequiredResponse)
hi_resp = responses[0]
assert hi_resp.data.form_id == "form-1"
assert hi_resp.data.node_id == "node-id"
assert hi_resp.data.node_title == "Human Step"
assert hi_resp.data.inputs[0].output_variable_name == "field"
assert hi_resp.data.actions[0].id == "approve"
assert hi_resp.data.display_in_ui is True
assert hi_resp.data.form_token == "token"
assert hi_resp.data.expiration_time == int(expiration_time.timestamp())
# Snapshot payload contract
def test_snapshot_events_include_pause_payload_contract(self, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = _build_workflow_run(WorkflowExecutionStatus.PAUSED)
snapshot = _build_snapshot(WorkflowNodeExecutionStatus.PAUSED)
resumption_context = _build_resumption_context("task-ctx")
monkeypatch.setattr(
"services.workflow_event_snapshot_service.load_form_tokens_by_form_id",
lambda form_ids, session=None, surface=None: {"form-1": "wtok"},
)
class _SessionContext:
def __init__(self, session):
self._session = session
def __enter__(self):
return self._session
def __exit__(self, exc_type, exc, tb):
return False
def session_maker() -> _SessionContext:
return _SessionContext(
SimpleNamespace(
execute=lambda _stmt: [("form-1", datetime(2024, 1, 1, tzinfo=UTC), '{"display_in_ui": true}')],
)
)
pause_entity = _FakePauseEntity(
pause_id="pause-1",
workflow_run_id="run-1",
paused_at_value=datetime(2024, 1, 1, tzinfo=UTC),
pause_reasons=[
HumanInputRequired(
form_id="form-1",
form_content="content",
node_id="node-1",
node_title="Human Input",
form_token="wtok",
)
],
)
events = _build_snapshot_events(
workflow_run=workflow_run,
node_snapshots=[snapshot],
task_id="task-ctx",
message_context=None,
pause_entity=pause_entity,
resumption_context=resumption_context,
session_maker=session_maker,
)
assert [event["event"] for event in events] == [
"workflow_started",
"node_started",
"node_finished",
"human_input_required",
"workflow_paused",
]
assert events[2]["data"]["status"] == WorkflowNodeExecutionStatus.PAUSED.value
assert events[3]["data"]["form_token"] == "wtok"
assert events[3]["data"]["expiration_time"] == int(datetime(2024, 1, 1, tzinfo=UTC).timestamp())
pause_data = events[-1]["data"]
assert pause_data["paused_nodes"] == ["node-1"]
assert pause_data["outputs"] == {"result": "value"}
assert pause_data["reasons"][0]["TYPE"] == "human_input_required"
assert pause_data["reasons"][0]["form_token"] == "wtok"
assert pause_data["reasons"][0]["expiration_time"] == int(datetime(2024, 1, 1, tzinfo=UTC).timestamp())
assert pause_data["status"] == WorkflowExecutionStatus.PAUSED.value
assert pause_data["created_at"] == int(workflow_run.created_at.timestamp())
assert pause_data["elapsed_time"] == workflow_run.elapsed_time
assert pause_data["total_tokens"] == workflow_run.total_tokens
assert pause_data["total_steps"] == workflow_run.total_steps

Some files were not shown because too many files have changed in this diff Show More