Merge branch 'main' into jzh

This commit is contained in:
JzoNg
2026-04-09 13:15:20 +08:00
40 changed files with 745 additions and 390 deletions

View File

@@ -71,6 +71,13 @@ REDIS_USE_CLUSTERS=false
REDIS_CLUSTERS=
REDIS_CLUSTERS_PASSWORD=
REDIS_RETRY_RETRIES=3
REDIS_RETRY_BACKOFF_BASE=1.0
REDIS_RETRY_BACKOFF_CAP=10.0
REDIS_SOCKET_TIMEOUT=5.0
REDIS_SOCKET_CONNECT_TIMEOUT=5.0
REDIS_HEALTH_CHECK_INTERVAL=30
# celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1
CELERY_BACKEND=redis

View File

@@ -117,6 +117,37 @@ class RedisConfig(BaseSettings):
default=None,
)
REDIS_RETRY_RETRIES: NonNegativeInt = Field(
description="Maximum number of retries per Redis command on "
"transient failures (ConnectionError, TimeoutError, socket.timeout)",
default=3,
)
REDIS_RETRY_BACKOFF_BASE: PositiveFloat = Field(
description="Base delay in seconds for exponential backoff between retries",
default=1.0,
)
REDIS_RETRY_BACKOFF_CAP: PositiveFloat = Field(
description="Maximum backoff delay in seconds between retries",
default=10.0,
)
REDIS_SOCKET_TIMEOUT: PositiveFloat | None = Field(
description="Socket timeout in seconds for Redis read/write operations",
default=5.0,
)
REDIS_SOCKET_CONNECT_TIMEOUT: PositiveFloat | None = Field(
description="Socket timeout in seconds for Redis connection establishment",
default=5.0,
)
REDIS_HEALTH_CHECK_INTERVAL: NonNegativeInt = Field(
description="Interval in seconds between Redis connection health checks (0 to disable)",
default=30,
)
@field_validator("REDIS_MAX_CONNECTIONS", mode="before")
@classmethod
def _empty_string_to_none_for_max_conns(cls, v):

View File

@@ -384,24 +384,27 @@ class VariableApi(Resource):
new_value = None
if raw_value is not None:
if variable.value_type == SegmentType.FILE:
if not isinstance(raw_value, dict):
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
raw_value = build_from_mapping(
mapping=raw_value,
tenant_id=app_model.tenant_id,
access_controller=_file_access_controller,
)
elif variable.value_type == SegmentType.ARRAY_FILE:
if not isinstance(raw_value, list):
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(
mappings=raw_value,
tenant_id=app_model.tenant_id,
access_controller=_file_access_controller,
)
match variable.value_type:
case SegmentType.FILE:
if not isinstance(raw_value, dict):
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
raw_value = build_from_mapping(
mapping=raw_value,
tenant_id=app_model.tenant_id,
access_controller=_file_access_controller,
)
case SegmentType.ARRAY_FILE:
if not isinstance(raw_value, list):
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(
mappings=raw_value,
tenant_id=app_model.tenant_id,
access_controller=_file_access_controller,
)
case _:
pass
new_value = build_segment_with_type(variable.value_type, raw_value)
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
db.session.commit()

View File

@@ -223,24 +223,27 @@ class RagPipelineVariableApi(Resource):
new_value = None
if raw_value is not None:
if variable.value_type == SegmentType.FILE:
if not isinstance(raw_value, dict):
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
raw_value = build_from_mapping(
mapping=raw_value,
tenant_id=pipeline.tenant_id,
access_controller=_file_access_controller,
)
elif variable.value_type == SegmentType.ARRAY_FILE:
if not isinstance(raw_value, list):
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(
mappings=raw_value,
tenant_id=pipeline.tenant_id,
access_controller=_file_access_controller,
)
match variable.value_type:
case SegmentType.FILE:
if not isinstance(raw_value, dict):
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
raw_value = build_from_mapping(
mapping=raw_value,
tenant_id=pipeline.tenant_id,
access_controller=_file_access_controller,
)
case SegmentType.ARRAY_FILE:
if not isinstance(raw_value, list):
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(
mappings=raw_value,
tenant_id=pipeline.tenant_id,
access_controller=_file_access_controller,
)
case _:
pass
new_value = build_segment_with_type(variable.value_type, raw_value)
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
db.session.commit()

View File

@@ -168,12 +168,13 @@ class ConsoleWorkflowEventsApi(Resource):
else:
msg_generator = MessageGenerator()
generator: BaseAppGenerator
if app.mode == AppMode.ADVANCED_CHAT:
generator = AdvancedChatAppGenerator()
elif app.mode == AppMode.WORKFLOW:
generator = WorkflowAppGenerator()
else:
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
match app.mode:
case AppMode.ADVANCED_CHAT:
generator = AdvancedChatAppGenerator()
case AppMode.WORKFLOW:
generator = WorkflowAppGenerator()
case _:
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"

View File

@@ -72,12 +72,13 @@ class WorkflowEventsApi(WebApiResource):
app_mode = AppMode.value_of(app_model.mode)
msg_generator = MessageGenerator()
generator: BaseAppGenerator
if app_mode == AppMode.ADVANCED_CHAT:
generator = AdvancedChatAppGenerator()
elif app_mode == AppMode.WORKFLOW:
generator = WorkflowAppGenerator()
else:
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
match app_mode:
case AppMode.ADVANCED_CHAT:
generator = AdvancedChatAppGenerator()
case AppMode.WORKFLOW:
generator = WorkflowAppGenerator()
case _:
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"

View File

@@ -57,36 +57,37 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL
used_quota = 1
if used_quota is not None and system_configuration.current_quota_type is not None:
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
match system_configuration.current_quota_type:
case ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
)
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
else:
with sessionmaker(bind=db.engine).begin() as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
)
session.execute(stmt)
case ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
case ProviderQuotaType.FREE:
with sessionmaker(bind=db.engine).begin() as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
)
session.execute(stmt)

View File

@@ -40,41 +40,44 @@ def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, Upl
size = 0
extension = ""
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
url = message_file.url
if message_file.url:
filename = message_file.url.split("/")[-1].split("?")[0]
if "." in filename:
extension = "." + filename.rsplit(".", 1)[1]
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
if upload_file:
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
filename = upload_file.name
mime_type = upload_file.mime_type or "application/octet-stream"
size = upload_file.size or 0
extension = f".{upload_file.extension}" if upload_file.extension else ""
elif message_file.upload_file_id:
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
if message_file.url.startswith(("http://", "https://")):
match message_file.transfer_method:
case FileTransferMethod.REMOTE_URL:
url = message_file.url
filename = message_file.url.split("/")[-1].split("?")[0]
if "." in filename:
extension = "." + filename.rsplit(".", 1)[1]
else:
url_parts = message_file.url.split("/")
if url_parts:
file_part = url_parts[-1].split("?")[0]
if "." in file_part:
tool_file_id, ext = file_part.rsplit(".", 1)
extension = f".{ext}"
if len(extension) > MAX_TOOL_FILE_EXTENSION_LENGTH:
if message_file.url:
filename = message_file.url.split("/")[-1].split("?")[0]
if "." in filename:
extension = "." + filename.rsplit(".", 1)[1]
case FileTransferMethod.LOCAL_FILE:
if upload_file:
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
filename = upload_file.name
mime_type = upload_file.mime_type or "application/octet-stream"
size = upload_file.size or 0
extension = f".{upload_file.extension}" if upload_file.extension else ""
elif message_file.upload_file_id:
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
case FileTransferMethod.TOOL_FILE if message_file.url:
if message_file.url.startswith(("http://", "https://")):
url = message_file.url
filename = message_file.url.split("/")[-1].split("?")[0]
if "." in filename:
extension = "." + filename.rsplit(".", 1)[1]
else:
url_parts = message_file.url.split("/")
if url_parts:
file_part = url_parts[-1].split("?")[0]
if "." in file_part:
tool_file_id, ext = file_part.rsplit(".", 1)
extension = f".{ext}"
if len(extension) > MAX_TOOL_FILE_EXTENSION_LENGTH:
extension = ".bin"
else:
tool_file_id = file_part
extension = ".bin"
else:
tool_file_id = file_part
extension = ".bin"
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
filename = file_part
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
filename = file_part
case FileTransferMethod.TOOL_FILE | FileTransferMethod.DATASOURCE_FILE:
pass
transfer_method_value = message_file.transfer_method.value
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""

View File

@@ -187,15 +187,16 @@ def build_parameter_schema(
def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> ToolArgumentsDict:
"""Prepare arguments based on app mode"""
if app.mode == AppMode.WORKFLOW:
return {"inputs": arguments}
elif app.mode == AppMode.COMPLETION:
return {"query": "", "inputs": arguments}
else:
# Chat modes - create a copy to avoid modifying original dict
args_copy = arguments.copy()
query = args_copy.pop("query", "")
return {"query": query, "inputs": args_copy}
match app.mode:
case AppMode.WORKFLOW:
return {"inputs": arguments}
case AppMode.COMPLETION:
return {"query": "", "inputs": arguments}
case _:
# Chat modes - create a copy to avoid modifying original dict
args_copy = arguments.copy()
query = args_copy.pop("query", "")
return {"query": query, "inputs": args_copy}
def extract_answer_from_response(app: App, response: Any) -> str:
@@ -229,17 +230,13 @@ def process_streaming_response(response: RateLimitGenerator) -> str:
def process_mapping_response(app: App, response: Mapping) -> str:
"""Process mapping response based on app mode"""
if app.mode in {
AppMode.ADVANCED_CHAT,
AppMode.COMPLETION,
AppMode.CHAT,
AppMode.AGENT_CHAT,
}:
return response.get("answer", "")
elif app.mode == AppMode.WORKFLOW:
return json.dumps(response["data"]["outputs"], ensure_ascii=False)
else:
raise ValueError("Invalid app mode: " + str(app.mode))
match app.mode:
case AppMode.ADVANCED_CHAT | AppMode.COMPLETION | AppMode.CHAT | AppMode.AGENT_CHAT:
return response.get("answer", "")
case AppMode.WORKFLOW:
return json.dumps(response["data"]["outputs"], ensure_ascii=False)
case _:
raise ValueError("Invalid app mode: " + str(app.mode))
def convert_input_form_to_parameters(

View File

@@ -72,17 +72,18 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
conversation_id = conversation_id or ""
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.AGENT_CHAT, AppMode.CHAT}:
if not query:
raise ValueError("missing query")
match app.mode:
case AppMode.ADVANCED_CHAT | AppMode.AGENT_CHAT | AppMode.CHAT:
if not query:
raise ValueError("missing query")
return cls.invoke_chat_app(app, user, conversation_id, query, stream, inputs, files)
elif app.mode == AppMode.WORKFLOW:
return cls.invoke_workflow_app(app, user, stream, inputs, files)
elif app.mode == AppMode.COMPLETION:
return cls.invoke_completion_app(app, user, stream, inputs, files)
raise ValueError("unexpected app type")
return cls.invoke_chat_app(app, user, conversation_id, query, stream, inputs, files)
case AppMode.WORKFLOW:
return cls.invoke_workflow_app(app, user, stream, inputs, files)
case AppMode.COMPLETION:
return cls.invoke_completion_app(app, user, stream, inputs, files)
case _:
raise ValueError("unexpected app type")
@classmethod
def invoke_chat_app(
@@ -98,60 +99,61 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke chat app
"""
if app.mode == AppMode.ADVANCED_CHAT:
workflow = app.workflow
if not workflow:
match app.mode:
case AppMode.ADVANCED_CHAT:
workflow = app.workflow
if not workflow:
raise ValueError("unexpected app type")
pause_config = PauseStateLayerConfig(
session_factory=db.engine,
state_owner_user_id=workflow.created_by,
)
return AdvancedChatAppGenerator().generate(
app_model=app,
workflow=workflow,
user=user,
args={
"inputs": inputs,
"query": query,
"files": files,
"conversation_id": conversation_id,
},
invoke_from=InvokeFrom.SERVICE_API,
workflow_run_id=str(uuid.uuid4()),
streaming=stream,
pause_state_config=pause_config,
)
case AppMode.AGENT_CHAT:
return AgentChatAppGenerator().generate(
app_model=app,
user=user,
args={
"inputs": inputs,
"query": query,
"files": files,
"conversation_id": conversation_id,
},
invoke_from=InvokeFrom.SERVICE_API,
streaming=stream,
)
case AppMode.CHAT:
return ChatAppGenerator().generate(
app_model=app,
user=user,
args={
"inputs": inputs,
"query": query,
"files": files,
"conversation_id": conversation_id,
},
invoke_from=InvokeFrom.SERVICE_API,
streaming=stream,
)
case _:
raise ValueError("unexpected app type")
pause_config = PauseStateLayerConfig(
session_factory=db.engine,
state_owner_user_id=workflow.created_by,
)
return AdvancedChatAppGenerator().generate(
app_model=app,
workflow=workflow,
user=user,
args={
"inputs": inputs,
"query": query,
"files": files,
"conversation_id": conversation_id,
},
invoke_from=InvokeFrom.SERVICE_API,
workflow_run_id=str(uuid.uuid4()),
streaming=stream,
pause_state_config=pause_config,
)
elif app.mode == AppMode.AGENT_CHAT:
return AgentChatAppGenerator().generate(
app_model=app,
user=user,
args={
"inputs": inputs,
"query": query,
"files": files,
"conversation_id": conversation_id,
},
invoke_from=InvokeFrom.SERVICE_API,
streaming=stream,
)
elif app.mode == AppMode.CHAT:
return ChatAppGenerator().generate(
app_model=app,
user=user,
args={
"inputs": inputs,
"query": query,
"files": files,
"conversation_id": conversation_id,
},
invoke_from=InvokeFrom.SERVICE_API,
streaming=stream,
)
else:
raise ValueError("unexpected app type")
@classmethod
def invoke_workflow_app(
cls,

View File

@@ -961,36 +961,37 @@ class ProviderManager:
raise ValueError("quota_used is None")
if provider_record.quota_limit is None:
raise ValueError("quota_limit is None")
if provider_quota.quota_type == ProviderQuotaType.TRIAL and trail_pool is not None:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=trail_pool.quota_used,
quota_limit=trail_pool.quota_limit,
is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
match provider_quota.quota_type:
case ProviderQuotaType.TRIAL if trail_pool is not None:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=trail_pool.quota_used,
quota_limit=trail_pool.quota_limit,
is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
elif provider_quota.quota_type == ProviderQuotaType.PAID and paid_pool is not None:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=paid_pool.quota_used,
quota_limit=paid_pool.quota_limit,
is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
case ProviderQuotaType.PAID if paid_pool is not None:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=paid_pool.quota_used,
quota_limit=paid_pool.quota_limit,
is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
else:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=provider_record.quota_used,
quota_limit=provider_record.quota_limit,
is_valid=provider_record.quota_limit > provider_record.quota_used
or provider_record.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
case _:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=provider_record.quota_used,
quota_limit=provider_record.quota_limit,
is_valid=provider_record.quota_limit > provider_record.quota_used
or provider_record.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
quota_configurations.append(quota_configuration)

View File

@@ -7,14 +7,13 @@ from sqlalchemy import select
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.model_manager import ModelManager
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
from core.rag.entities import RetrievalSourceMetadata
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.models.document import Document as RagDocument
from core.rag.rerank.rerank_model import RerankModelRunner
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from core.tools.utils.dataset_retriever.dataset_retriever_tool import DefaultRetrievalModelDict
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment

View File

@@ -1,11 +1,10 @@
from typing import NotRequired, TypedDict, cast
from typing import cast
from pydantic import BaseModel, Field
from sqlalchemy import select
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
from core.rag.entities import DocumentContext, RetrievalSourceMetadata
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.models.document import Document as RetrievalDocument
@@ -17,18 +16,6 @@ from models.dataset import Dataset
from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService
class DefaultRetrievalModelDict(TypedDict):
search_method: RetrievalMethod
reranking_enable: bool
reranking_model: RerankingModelDict
reranking_mode: NotRequired[str]
weights: NotRequired[WeightsDict | None]
score_threshold: NotRequired[float]
top_k: int
score_threshold_enabled: bool
default_retrieval_model: DefaultRetrievalModelDict = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,

View File

@@ -155,24 +155,25 @@ class TriggerWebhookNode(Node[WebhookData]):
outputs[param_name] = raw_data
continue
if param_type == SegmentType.FILE:
# Get File object (already processed by webhook controller)
files = webhook_data.get("files", {})
if files and isinstance(files, dict):
file = files.get(param_name)
if file and isinstance(file, dict):
file_var = self.generate_file_var(param_name, file)
if file_var:
outputs[param_name] = file_var
match param_type:
case SegmentType.FILE:
# Get File object (already processed by webhook controller)
files = webhook_data.get("files", {})
if files and isinstance(files, dict):
file = files.get(param_name)
if file and isinstance(file, dict):
file_var = self.generate_file_var(param_name, file)
if file_var:
outputs[param_name] = file_var
else:
outputs[param_name] = files
else:
outputs[param_name] = files
else:
outputs[param_name] = files
else:
outputs[param_name] = files
else:
# Get regular body parameter
outputs[param_name] = webhook_data.get("body", {}).get(param_name)
case _:
# Get regular body parameter
outputs[param_name] = webhook_data.get("body", {}).get(param_name)
# Include raw webhook data for debugging/advanced use
outputs["_webhook_raw"] = webhook_data

View File

@@ -68,46 +68,49 @@ class EnterpriseMetricHandler:
# Route to appropriate handler based on case
case = envelope.case
if case == TelemetryCase.APP_CREATED:
self._on_app_created(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "app_created"})
elif case == TelemetryCase.APP_UPDATED:
self._on_app_updated(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "app_updated"})
elif case == TelemetryCase.APP_DELETED:
self._on_app_deleted(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "app_deleted"})
elif case == TelemetryCase.FEEDBACK_CREATED:
self._on_feedback_created(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "feedback_created"})
elif case == TelemetryCase.MESSAGE_RUN:
self._on_message_run(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "message_run"})
elif case == TelemetryCase.TOOL_EXECUTION:
self._on_tool_execution(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "tool_execution"})
elif case == TelemetryCase.MODERATION_CHECK:
self._on_moderation_check(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "moderation_check"})
elif case == TelemetryCase.SUGGESTED_QUESTION:
self._on_suggested_question(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "suggested_question"})
elif case == TelemetryCase.DATASET_RETRIEVAL:
self._on_dataset_retrieval(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "dataset_retrieval"})
elif case == TelemetryCase.GENERATE_NAME:
self._on_generate_name(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "generate_name"})
elif case == TelemetryCase.PROMPT_GENERATION:
self._on_prompt_generation(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "prompt_generation"})
else:
logger.warning(
"Unknown telemetry case: %s (tenant_id=%s, event_id=%s)",
case,
envelope.tenant_id,
envelope.event_id,
)
match case:
case TelemetryCase.APP_CREATED:
self._on_app_created(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "app_created"})
case TelemetryCase.APP_UPDATED:
self._on_app_updated(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "app_updated"})
case TelemetryCase.APP_DELETED:
self._on_app_deleted(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "app_deleted"})
case TelemetryCase.FEEDBACK_CREATED:
self._on_feedback_created(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "feedback_created"})
case TelemetryCase.MESSAGE_RUN:
self._on_message_run(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "message_run"})
case TelemetryCase.TOOL_EXECUTION:
self._on_tool_execution(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "tool_execution"})
case TelemetryCase.MODERATION_CHECK:
self._on_moderation_check(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "moderation_check"})
case TelemetryCase.SUGGESTED_QUESTION:
self._on_suggested_question(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "suggested_question"})
case TelemetryCase.DATASET_RETRIEVAL:
self._on_dataset_retrieval(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "dataset_retrieval"})
case TelemetryCase.GENERATE_NAME:
self._on_generate_name(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "generate_name"})
case TelemetryCase.PROMPT_GENERATION:
self._on_prompt_generation(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "prompt_generation"})
case TelemetryCase.WORKFLOW_RUN | TelemetryCase.NODE_EXECUTION | TelemetryCase.DRAFT_NODE_EXECUTION:
pass
case _:
logger.warning(
"Unknown telemetry case: %s (tenant_id=%s, event_id=%s)",
case,
envelope.tenant_id,
envelope.event_id,
)
def _is_duplicate(self, envelope: TelemetryEnvelope) -> bool:
"""Check if this event has already been processed.

View File

@@ -7,10 +7,12 @@ from typing import TYPE_CHECKING, Any, Union
import redis
from redis import RedisError
from redis.backoff import ExponentialWithJitterBackoff # type: ignore
from redis.cache import CacheConfig
from redis.client import PubSub
from redis.cluster import ClusterNode, RedisCluster
from redis.connection import Connection, SSLConnection
from redis.retry import Retry
from redis.sentinel import Sentinel
from configs import dify_config
@@ -158,8 +160,41 @@ def _get_cache_configuration() -> CacheConfig | None:
return CacheConfig()
def _get_retry_policy() -> Retry:
"""Build the shared retry policy for Redis connections."""
return Retry(
backoff=ExponentialWithJitterBackoff(
base=dify_config.REDIS_RETRY_BACKOFF_BASE,
cap=dify_config.REDIS_RETRY_BACKOFF_CAP,
),
retries=dify_config.REDIS_RETRY_RETRIES,
)
def _get_connection_health_params() -> dict[str, Any]:
"""Get connection health and retry parameters for standalone and Sentinel Redis clients."""
return {
"retry": _get_retry_policy(),
"socket_timeout": dify_config.REDIS_SOCKET_TIMEOUT,
"socket_connect_timeout": dify_config.REDIS_SOCKET_CONNECT_TIMEOUT,
"health_check_interval": dify_config.REDIS_HEALTH_CHECK_INTERVAL,
}
def _get_cluster_connection_health_params() -> dict[str, Any]:
"""Get retry and timeout parameters for Redis Cluster clients.
RedisCluster does not support ``health_check_interval`` as a constructor
keyword (it is silently stripped by ``cleanup_kwargs``), so it is excluded
here. Only ``retry``, ``socket_timeout``, and ``socket_connect_timeout``
are passed through.
"""
params = _get_connection_health_params()
return {k: v for k, v in params.items() if k != "health_check_interval"}
def _get_base_redis_params() -> dict[str, Any]:
"""Get base Redis connection parameters."""
"""Get base Redis connection parameters including retry and health policy."""
return {
"username": dify_config.REDIS_USERNAME,
"password": dify_config.REDIS_PASSWORD or None,
@@ -169,6 +204,7 @@ def _get_base_redis_params() -> dict[str, Any]:
"decode_responses": False,
"protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL,
"cache_config": _get_cache_configuration(),
**_get_connection_health_params(),
}
@@ -215,6 +251,7 @@ def _create_cluster_client() -> Union[redis.Redis, RedisCluster]:
"password": dify_config.REDIS_CLUSTERS_PASSWORD,
"protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL,
"cache_config": _get_cache_configuration(),
**_get_cluster_connection_health_params(),
}
if dify_config.REDIS_MAX_CONNECTIONS:
cluster_kwargs["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS
@@ -226,7 +263,8 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis
"""Create standalone Redis client."""
connection_class, ssl_kwargs = _get_ssl_configuration()
redis_params.update(
params = {**redis_params}
params.update(
{
"host": dify_config.REDIS_HOST,
"port": dify_config.REDIS_PORT,
@@ -235,28 +273,31 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis
)
if dify_config.REDIS_MAX_CONNECTIONS:
redis_params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS
params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS
if ssl_kwargs:
redis_params.update(ssl_kwargs)
params.update(ssl_kwargs)
pool = redis.ConnectionPool(**redis_params)
pool = redis.ConnectionPool(**params)
client: redis.Redis = redis.Redis(connection_pool=pool)
return client
def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> redis.Redis | RedisCluster:
max_conns = dify_config.REDIS_MAX_CONNECTIONS
if use_clusters:
if max_conns:
return RedisCluster.from_url(pubsub_url, max_connections=max_conns)
else:
return RedisCluster.from_url(pubsub_url)
if use_clusters:
health_params = _get_cluster_connection_health_params()
kwargs: dict[str, Any] = {**health_params}
if max_conns:
kwargs["max_connections"] = max_conns
return RedisCluster.from_url(pubsub_url, **kwargs)
health_params = _get_connection_health_params()
kwargs = {**health_params}
if max_conns:
return redis.Redis.from_url(pubsub_url, max_connections=max_conns)
else:
return redis.Redis.from_url(pubsub_url)
kwargs["max_connections"] = max_conns
return redis.Redis.from_url(pubsub_url, **kwargs)
def init_app(app: DifyApp):

View File

@@ -40,7 +40,7 @@ dependencies = [
"numpy~=1.26.4",
"openpyxl~=3.1.5",
"opik~=1.10.37",
"litellm==1.82.6", # Pinned to avoid madoka dependency issue
"litellm==1.83.0", # Pinned to avoid madoka dependency issue
"opentelemetry-api==1.40.0",
"opentelemetry-distro==0.61b0",
"opentelemetry-exporter-otlp==1.40.0",

View File

@@ -1,53 +1,125 @@
from unittest.mock import patch
from redis import RedisError
from redis.retry import Retry
from extensions.ext_redis import redis_fallback
from extensions.ext_redis import (
_get_base_redis_params,
_get_cluster_connection_health_params,
_get_connection_health_params,
redis_fallback,
)
def test_redis_fallback_success():
@redis_fallback(default_return=None)
def test_func():
return "success"
class TestGetConnectionHealthParams:
@patch("extensions.ext_redis.dify_config")
def test_includes_all_health_params(self, mock_config):
mock_config.REDIS_RETRY_RETRIES = 3
mock_config.REDIS_RETRY_BACKOFF_BASE = 1.0
mock_config.REDIS_RETRY_BACKOFF_CAP = 10.0
mock_config.REDIS_SOCKET_TIMEOUT = 5.0
mock_config.REDIS_SOCKET_CONNECT_TIMEOUT = 5.0
mock_config.REDIS_HEALTH_CHECK_INTERVAL = 30
assert test_func() == "success"
params = _get_connection_health_params()
assert "retry" in params
assert "socket_timeout" in params
assert "socket_connect_timeout" in params
assert "health_check_interval" in params
assert isinstance(params["retry"], Retry)
assert params["retry"]._retries == 3
assert params["socket_timeout"] == 5.0
assert params["socket_connect_timeout"] == 5.0
assert params["health_check_interval"] == 30
def test_redis_fallback_error():
@redis_fallback(default_return="fallback")
def test_func():
raise RedisError("Redis error")
class TestGetClusterConnectionHealthParams:
@patch("extensions.ext_redis.dify_config")
def test_excludes_health_check_interval(self, mock_config):
mock_config.REDIS_RETRY_RETRIES = 3
mock_config.REDIS_RETRY_BACKOFF_BASE = 1.0
mock_config.REDIS_RETRY_BACKOFF_CAP = 10.0
mock_config.REDIS_SOCKET_TIMEOUT = 5.0
mock_config.REDIS_SOCKET_CONNECT_TIMEOUT = 5.0
mock_config.REDIS_HEALTH_CHECK_INTERVAL = 30
assert test_func() == "fallback"
params = _get_cluster_connection_health_params()
assert "retry" in params
assert "socket_timeout" in params
assert "socket_connect_timeout" in params
assert "health_check_interval" not in params
def test_redis_fallback_none_default():
@redis_fallback()
def test_func():
raise RedisError("Redis error")
class TestGetBaseRedisParams:
@patch("extensions.ext_redis.dify_config")
def test_includes_retry_and_health_params(self, mock_config):
mock_config.REDIS_USERNAME = None
mock_config.REDIS_PASSWORD = None
mock_config.REDIS_DB = 0
mock_config.REDIS_SERIALIZATION_PROTOCOL = 3
mock_config.REDIS_ENABLE_CLIENT_SIDE_CACHE = False
mock_config.REDIS_RETRY_RETRIES = 3
mock_config.REDIS_RETRY_BACKOFF_BASE = 1.0
mock_config.REDIS_RETRY_BACKOFF_CAP = 10.0
mock_config.REDIS_SOCKET_TIMEOUT = 5.0
mock_config.REDIS_SOCKET_CONNECT_TIMEOUT = 5.0
mock_config.REDIS_HEALTH_CHECK_INTERVAL = 30
assert test_func() is None
params = _get_base_redis_params()
assert "retry" in params
assert isinstance(params["retry"], Retry)
assert params["socket_timeout"] == 5.0
assert params["socket_connect_timeout"] == 5.0
assert params["health_check_interval"] == 30
# Existing params still present
assert params["db"] == 0
assert params["encoding"] == "utf-8"
def test_redis_fallback_with_args():
@redis_fallback(default_return=0)
def test_func(x, y):
raise RedisError("Redis error")
class TestRedisFallback:
def test_redis_fallback_success(self):
@redis_fallback(default_return=None)
def test_func():
return "success"
assert test_func(1, 2) == 0
assert test_func() == "success"
def test_redis_fallback_error(self):
@redis_fallback(default_return="fallback")
def test_func():
raise RedisError("Redis error")
def test_redis_fallback_with_kwargs():
@redis_fallback(default_return={})
def test_func(x=None, y=None):
raise RedisError("Redis error")
assert test_func() == "fallback"
assert test_func(x=1, y=2) == {}
def test_redis_fallback_none_default(self):
@redis_fallback()
def test_func():
raise RedisError("Redis error")
assert test_func() is None
def test_redis_fallback_preserves_function_metadata():
@redis_fallback(default_return=None)
def test_func():
"""Test function docstring"""
pass
def test_redis_fallback_with_args(self):
@redis_fallback(default_return=0)
def test_func(x, y):
raise RedisError("Redis error")
assert test_func.__name__ == "test_func"
assert test_func.__doc__ == "Test function docstring"
assert test_func(1, 2) == 0
def test_redis_fallback_with_kwargs(self):
@redis_fallback(default_return={})
def test_func(x=None, y=None):
raise RedisError("Redis error")
assert test_func(x=1, y=2) == {}
def test_redis_fallback_preserves_function_metadata(self):
@redis_fallback(default_return=None)
def test_func():
"""Test function docstring"""
pass
assert test_func.__name__ == "test_func"
assert test_func.__doc__ == "Test function docstring"

8
api/uv.lock generated
View File

@@ -1507,7 +1507,7 @@ requires-dist = [
{ name = "json-repair", specifier = ">=0.55.1" },
{ name = "langfuse", specifier = ">=3.0.0,<5.0.0" },
{ name = "langsmith", specifier = "~=0.7.16" },
{ name = "litellm", specifier = "==1.82.6" },
{ name = "litellm", specifier = "==1.83.0" },
{ name = "markdown", specifier = "~=3.10.2" },
{ name = "mlflow-skinny", specifier = ">=3.0.0" },
{ name = "numpy", specifier = "~=1.26.4" },
@@ -3121,7 +3121,7 @@ wheels = [
[[package]]
name = "litellm"
version = "1.82.6"
version = "1.83.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "aiohttp" },
@@ -3137,9 +3137,9 @@ dependencies = [
{ name = "tiktoken" },
{ name = "tokenizers" },
]
sdist = { url = "https://files.pythonhosted.org/packages/29/75/1c537aa458426a9127a92bc2273787b2f987f4e5044e21f01f2eed5244fd/litellm-1.82.6.tar.gz", hash = "sha256:2aa1c2da21fe940c33613aa447119674a3ad4d2ad5eb064e4d5ce5ee42420136", size = 17414147, upload-time = "2026-03-22T06:36:00.452Z" }
sdist = { url = "https://files.pythonhosted.org/packages/22/92/6ce9737554994ca8e536e5f4f6a87cc7c4774b656c9eb9add071caf7d54b/litellm-1.83.0.tar.gz", hash = "sha256:860bebc76c4bb27b4cf90b4a77acd66dba25aced37e3db98750de8a1766bfb7a", size = 17333062, upload-time = "2026-03-31T05:08:25.331Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/02/6c/5327667e6dbe9e98cbfbd4261c8e91386a52e38f41419575854248bbab6a/litellm-1.82.6-py3-none-any.whl", hash = "sha256:164a3ef3e19f309e3cabc199bef3d2045212712fefdfa25fc7f75884a5b5b205", size = 15591595, upload-time = "2026-03-22T06:35:56.795Z" },
{ url = "https://files.pythonhosted.org/packages/19/2c/a670cc050fcd6f45c6199eb99e259c73aea92edba8d5c2fc1b3686d36217/litellm-1.83.0-py3-none-any.whl", hash = "sha256:88c536d339248f3987571493015784671ba3f193a328e1ea6780dbebaa2094a8", size = 15610306, upload-time = "2026-03-31T05:08:21.987Z" },
]
[[package]]

View File

@@ -373,6 +373,20 @@ REDIS_USE_CLUSTERS=false
REDIS_CLUSTERS=
REDIS_CLUSTERS_PASSWORD=
# Redis connection and retry configuration
# max redis retry
REDIS_RETRY_RETRIES=3
# Base delay (in seconds) for exponential backoff on retries
REDIS_RETRY_BACKOFF_BASE=1.0
# Cap (in seconds) for exponential backoff on retries
REDIS_RETRY_BACKOFF_CAP=10.0
# Timeout (in seconds) for Redis socket operations
REDIS_SOCKET_TIMEOUT=5.0
# Timeout (in seconds) for establishing a Redis connection
REDIS_SOCKET_CONNECT_TIMEOUT=5.0
# Interval (in seconds) for Redis health checks
REDIS_HEALTH_CHECK_INTERVAL=30
# ------------------------------
# Celery Configuration
# ------------------------------

View File

@@ -100,6 +100,12 @@ x-shared-env: &shared-api-worker-env
REDIS_USE_CLUSTERS: ${REDIS_USE_CLUSTERS:-false}
REDIS_CLUSTERS: ${REDIS_CLUSTERS:-}
REDIS_CLUSTERS_PASSWORD: ${REDIS_CLUSTERS_PASSWORD:-}
REDIS_RETRY_RETRIES: ${REDIS_RETRY_RETRIES:-3}
REDIS_RETRY_BACKOFF_BASE: ${REDIS_RETRY_BACKOFF_BASE:-1.0}
REDIS_RETRY_BACKOFF_CAP: ${REDIS_RETRY_BACKOFF_CAP:-10.0}
REDIS_SOCKET_TIMEOUT: ${REDIS_SOCKET_TIMEOUT:-5.0}
REDIS_SOCKET_CONNECT_TIMEOUT: ${REDIS_SOCKET_CONNECT_TIMEOUT:-5.0}
REDIS_HEALTH_CHECK_INTERVAL: ${REDIS_HEALTH_CHECK_INTERVAL:-30}
CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1}
CELERY_BACKEND: ${CELERY_BACKEND:-redis}
BROKER_USE_SSL: ${BROKER_USE_SSL:-false}

View File

@@ -0,0 +1,11 @@
@apps @authenticated
Feature: Create Chatbot app
Scenario: Create a new Chatbot app and redirect to the configuration page
Given I am signed in as the default E2E admin
When I open the apps console
And I start creating a blank app
And I expand the beginner app types
And I select the "Chatbot" app type
And I enter a unique E2E app name
And I confirm app creation
Then I should land on the app configuration page

View File

@@ -0,0 +1,10 @@
@apps @authenticated
Feature: Create Workflow app
Scenario: Create a new Workflow app and redirect to the workflow editor
Given I am signed in as the default E2E admin
When I open the apps console
And I start creating a blank app
And I select the "Workflow" app type
And I enter a unique E2E app name
And I confirm app creation
Then I should land on the workflow editor

View File

@@ -0,0 +1,8 @@
@auth @authenticated
Feature: Sign out
Scenario: Sign out from the apps console
Given I am signed in as the default E2E admin
When I open the apps console
And I open the account menu
And I sign out
Then I should be on the sign-in page

View File

@@ -24,6 +24,30 @@ When('I confirm app creation', async function (this: DifyWorld) {
await createButton.click()
})
When('I select the {string} app type', async function (this: DifyWorld, appType: string) {
const dialog = this.getPage().getByRole('dialog')
const appTypeTitle = dialog.getByText(appType, { exact: true })
await expect(appTypeTitle).toBeVisible()
await appTypeTitle.click()
})
When('I expand the beginner app types', async function (this: DifyWorld) {
const page = this.getPage()
const toggle = page.getByRole('button', { name: 'More basic app types' })
await expect(toggle).toBeVisible()
await toggle.click()
})
Then('I should land on the app editor', async function (this: DifyWorld) {
await expect(this.getPage()).toHaveURL(/\/app\/[^/]+\/(workflow|configuration)(?:\?.*)?$/)
})
Then('I should land on the workflow editor', async function (this: DifyWorld) {
await expect(this.getPage()).toHaveURL(/\/app\/[^/]+\/workflow(?:\?.*)?$/)
})
Then('I should land on the app configuration page', async function (this: DifyWorld) {
await expect(this.getPage()).toHaveURL(/\/app\/[^/]+\/configuration(?:\?.*)?$/)
})

View File

@@ -0,0 +1,25 @@
import { Then, When } from '@cucumber/cucumber'
import { expect } from '@playwright/test'
import type { DifyWorld } from '../../support/world'
When('I open the account menu', async function (this: DifyWorld) {
const page = this.getPage()
const trigger = page.getByRole('button', { name: 'Account' })
await expect(trigger).toBeVisible()
await trigger.click()
})
When('I sign out', async function (this: DifyWorld) {
const page = this.getPage()
await expect(page.getByText('Log out')).toBeVisible()
await page.getByText('Log out').click()
})
Then('I should be on the sign-in page', async function (this: DifyWorld) {
await expect(this.getPage()).toHaveURL(/\/signin/)
await expect(this.getPage().getByRole('button', { name: /^Sign in$/i })).toBeVisible({
timeout: 30_000,
})
})

42
pnpm-lock.yaml generated
View File

@@ -249,6 +249,9 @@ catalogs:
class-variance-authority:
specifier: 0.7.1
version: 0.7.1
client-only:
specifier: 0.0.1
version: 0.0.1
clsx:
specifier: 2.1.1
version: 2.1.1
@@ -324,9 +327,6 @@ catalogs:
fast-deep-equal:
specifier: 3.1.3
version: 3.1.3
foxact:
specifier: 0.3.0
version: 0.3.0
happy-dom:
specifier: 20.8.9
version: 20.8.9
@@ -736,6 +736,9 @@ importers:
class-variance-authority:
specifier: 'catalog:'
version: 0.7.1
client-only:
specifier: 'catalog:'
version: 0.0.1
clsx:
specifier: 'catalog:'
version: 2.1.1
@@ -781,9 +784,6 @@ importers:
fast-deep-equal:
specifier: 'catalog:'
version: 3.1.3
foxact:
specifier: 'catalog:'
version: 0.3.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4)
hast-util-to-jsx-runtime:
specifier: 'catalog:'
version: 2.3.6
@@ -5871,9 +5871,6 @@ packages:
resolution: {integrity: sha512-kVscqXk4OCp68SZ0dkgEKVi6/8ij300KBWTJq32P/dYeWTSwK41WyTxalN1eRmA5Z9UU/LX9D7FWSmV9SAYx6g==}
engines: {node: '>=0.10.0'}
event-target-bus@1.0.0:
resolution: {integrity: sha512-uPcWKbj/BJU3Tbw9XqhHqET4/LBOhvv3/SJWr7NksxA6TC5YqBpaZgawE9R+WpYFCBFSAE4Vun+xQS6w4ABdlA==}
events@3.3.0:
resolution: {integrity: sha512-mQw+2fkQbALzQ7V0MY0IqdnXNOeTtP4r0lN9z7AAawCXgqea7bDii20AYrIBrFd/Hx0M2Ocz6S111CaFkUcb0Q==}
engines: {node: '>=0.8.x'}
@@ -5986,17 +5983,6 @@ packages:
engines: {node: '>=18.3.0'}
hasBin: true
foxact@0.3.0:
resolution: {integrity: sha512-CSlMlC0KlKQQEO83iLeQCLuT1V0OqnMWj7mjLstIDV8baMe1w4F7z3cz3/T+6Z8W12jqkQj07rwlw4Gi39knGg==}
peerDependencies:
react: '*'
react-dom: '*'
peerDependenciesMeta:
react:
optional: true
react-dom:
optional: true
fs-constants@1.0.0:
resolution: {integrity: sha512-y6OAwoSIf7FyjMIv94u+b5rdheZEjzR63GTyZJm5qh4Bi+2YgwLCcI/fPFZkL5PSixOt6ZNKm+w+Hfp/Bciwow==}
@@ -7710,9 +7696,6 @@ packages:
resolution: {integrity: sha512-OwrZRZAfhHww0WEnKHDY8OM0U/Qs8OTfIDWhUD4BLpNJUfXK4cGmjiagGze086m+mhI+V2nD0gfbHEnJjb9STA==}
engines: {node: '>=10'}
server-only@0.0.1:
resolution: {integrity: sha512-qepMx2JxAa5jjfzxG79yPPq+8BuFToHd1hm7kI+Z4zAq1ftQiP7HcxMhDDItrbtwVeLg/cY2JnKnrcFkmiswNA==}
sharp@0.34.5:
resolution: {integrity: sha512-Ou9I5Ft9WNcCbXrU9cMgPBcCK8LiwLqcbywW3t4oDV37n1pzpuNLsYiAV8eODnjbtQlSDwZ2cUEeQz4E54Hltg==}
engines: {node: ^18.17.0 || ^20.3.0 || >=21.0.0}
@@ -13552,8 +13535,6 @@ snapshots:
esutils@2.0.3: {}
event-target-bus@1.0.0: {}
events@3.3.0: {}
expand-template@2.0.3:
@@ -13661,15 +13642,6 @@ snapshots:
dependencies:
fd-package-json: 2.0.0
foxact@0.3.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4):
dependencies:
client-only: 0.0.1
event-target-bus: 1.0.0
server-only: 0.0.1
optionalDependencies:
react: 19.2.4
react-dom: 19.2.4(react@19.2.4)
fs-constants@1.0.0:
optional: true
@@ -15905,8 +15877,6 @@ snapshots:
seroval@1.5.1: {}
server-only@0.0.1: {}
sharp@0.34.5:
dependencies:
'@img/colour': 1.1.0

View File

@@ -129,6 +129,7 @@ catalog:
ahooks: 3.9.7
autoprefixer: 10.4.27
class-variance-authority: 0.7.1
client-only: 0.0.1
clsx: 2.1.1
cmdk: 1.1.1
code-inspector-plugin: 1.5.1
@@ -154,7 +155,6 @@ catalog:
eslint-plugin-sonarjs: 4.0.2
eslint-plugin-storybook: 10.3.5
fast-deep-equal: 3.1.3
foxact: 0.3.0
happy-dom: 20.8.9
hast-util-to-jsx-runtime: 2.3.6
hono: 4.12.12

View File

@@ -5,7 +5,7 @@ const mockCopy = vi.fn()
const mockReset = vi.fn()
let mockCopied = false
vi.mock('foxact/use-clipboard', () => ({
vi.mock('@/hooks/use-clipboard', () => ({
useClipboard: () => ({
copy: mockCopy,
reset: mockReset,

View File

@@ -3,11 +3,11 @@ import {
RiClipboardFill,
RiClipboardLine,
} from '@remixicon/react'
import { useClipboard } from 'foxact/use-clipboard'
import { useCallback } from 'react'
import { useTranslation } from 'react-i18next'
import ActionButton from '@/app/components/base/action-button'
import Tooltip from '@/app/components/base/tooltip'
import { useClipboard } from '@/hooks/use-clipboard'
import copyStyle from './style.module.css'
type Props = {

View File

@@ -5,7 +5,7 @@ const copy = vi.fn()
const reset = vi.fn()
let copied = false
vi.mock('foxact/use-clipboard', () => ({
vi.mock('@/hooks/use-clipboard', () => ({
useClipboard: () => ({
copy,
reset,

View File

@@ -1,7 +1,7 @@
'use client'
import { useClipboard } from 'foxact/use-clipboard'
import { useCallback } from 'react'
import { useTranslation } from 'react-i18next'
import { useClipboard } from '@/hooks/use-clipboard'
import Tooltip from '../tooltip'
type Props = {

View File

@@ -6,7 +6,7 @@ const mockCopy = vi.fn()
let mockCopied = false
const mockReset = vi.fn()
vi.mock('foxact/use-clipboard', () => ({
vi.mock('@/hooks/use-clipboard', () => ({
useClipboard: () => ({
copy: mockCopy,
copied: mockCopied,

View File

@@ -1,8 +1,8 @@
'use client'
import type { InputProps } from '../input'
import { useClipboard } from 'foxact/use-clipboard'
import * as React from 'react'
import { useTranslation } from 'react-i18next'
import { useClipboard } from '@/hooks/use-clipboard'
import { cn } from '@/utils/classnames'
import ActionButton from '../action-button'
import Tooltip from '../tooltip'

7
web/hooks/noop.ts Normal file
View File

@@ -0,0 +1,7 @@
type Noop = {
// eslint-disable-next-line ts/no-explicit-any
(...args: any[]): any
}
/** @see https://foxact.skk.moe/noop */
export const noop: Noop = () => { /* noop */ }

View File

@@ -0,0 +1,72 @@
import { useRef, useState } from 'react'
import { writeTextToClipboard } from '@/utils/clipboard'
import { noop } from './noop'
import { useStableHandler } from './use-stable-handler-only-when-you-know-what-you-are-doing-or-you-will-be-fired'
import { useCallback } from './use-typescript-happy-callback'
import 'client-only'
type UseClipboardOption = {
timeout?: number
usePromptAsFallback?: boolean
promptFallbackText?: string
onCopyError?: (error: Error) => void
}
/** @see https://foxact.skk.moe/use-clipboard */
export function useClipboard({
timeout = 1000,
usePromptAsFallback = false,
promptFallbackText = 'Failed to copy to clipboard automatically, please manually copy the text below.',
onCopyError,
}: UseClipboardOption = {}) {
const [error, setError] = useState<Error | null>(null)
const [copied, setCopied] = useState(false)
const copyTimeoutRef = useRef<number | null>(null)
const stablizedOnCopyError = useStableHandler<[e: Error], void>(onCopyError || noop)
const handleCopyResult = useCallback((isCopied: boolean) => {
if (copyTimeoutRef.current) {
clearTimeout(copyTimeoutRef.current)
}
if (isCopied) {
copyTimeoutRef.current = window.setTimeout(() => setCopied(false), timeout)
}
setCopied(isCopied)
}, [timeout])
const handleCopyError = useCallback((e: Error) => {
setError(e)
stablizedOnCopyError(e)
}, [stablizedOnCopyError])
const copy = useCallback(async (valueToCopy: string) => {
try {
await writeTextToClipboard(valueToCopy)
}
catch (e) {
if (usePromptAsFallback) {
try {
// eslint-disable-next-line no-alert -- prompt as fallback in case of copy error
window.prompt(promptFallbackText, valueToCopy)
}
catch (e2) {
handleCopyError(e2 as Error)
}
}
else {
handleCopyError(e as Error)
}
}
}, [handleCopyResult, promptFallbackText, handleCopyError, usePromptAsFallback])
const reset = useCallback(() => {
setCopied(false)
setError(null)
if (copyTimeoutRef.current) {
clearTimeout(copyTimeoutRef.current)
}
}, [])
return { copy, reset, error, copied }
}

View File

@@ -0,0 +1,44 @@
import * as reactExports from 'react'
import { useCallback, useEffect, useLayoutEffect, useRef } from 'react'
// useIsomorphicInsertionEffect
const useInsertionEffect
= typeof window === 'undefined'
// useInsertionEffect is only available in React 18+
? useEffect
: reactExports.useInsertionEffect || useLayoutEffect
/**
* @see https://foxact.skk.moe/use-stable-handler-only-when-you-know-what-you-are-doing-or-you-will-be-fired
* Similar to useCallback, with a few subtle differences:
* - The returned function is a stable reference, and will always be the same between renders
* - No dependency lists required
* - Properties or state accessed within the callback will always be "current"
*/
// eslint-disable-next-line ts/no-explicit-any
export function useStableHandler<Args extends any[], Result>(
callback: (...args: Args) => Result,
): typeof callback {
// Keep track of the latest callback:
// eslint-disable-next-line ts/no-explicit-any
const latestRef = useRef<typeof callback>(shouldNotBeInvokedBeforeMount as any)
useInsertionEffect(() => {
latestRef.current = callback
}, [callback])
return useCallback<typeof callback>((...args) => {
const fn = latestRef.current
return fn(...args)
}, [])
}
/**
* Render methods should be pure, especially when concurrency is used,
* so we will throw this error if the callback is called while rendering.
*/
function shouldNotBeInvokedBeforeMount() {
throw new Error(
'foxact: the stablized handler cannot be invoked before the component has mounted.',
)
}

View File

@@ -0,0 +1,10 @@
import { useCallback as useCallbackFromReact } from 'react'
/** @see https://foxact.skk.moe/use-typescript-happy-callback */
const useTypeScriptHappyCallback: <Args extends unknown[], R>(
fn: (...args: Args) => R,
deps: React.DependencyList,
) => (...args: Args) => R = useCallbackFromReact
/** @see https://foxact.skk.moe/use-typescript-happy-callback */
export const useCallback = useTypeScriptHappyCallback

View File

@@ -85,6 +85,7 @@
"abcjs": "catalog:",
"ahooks": "catalog:",
"class-variance-authority": "catalog:",
"client-only": "catalog:",
"clsx": "catalog:",
"cmdk": "catalog:",
"copy-to-clipboard": "catalog:",
@@ -100,7 +101,6 @@
"emoji-mart": "catalog:",
"es-toolkit": "catalog:",
"fast-deep-equal": "catalog:",
"foxact": "catalog:",
"hast-util-to-jsx-runtime": "catalog:",
"html-entities": "catalog:",
"html-to-image": "catalog:",

View File

@@ -83,11 +83,12 @@ afterEach(async () => {
})
})
// mock foxact/use-clipboard - not available in test environment
vi.mock('foxact/use-clipboard', () => ({
// mock custom clipboard hook - wraps writeTextToClipboard with fallback
vi.mock('@/hooks/use-clipboard', () => ({
useClipboard: () => ({
copy: vi.fn(),
copied: false,
reset: vi.fn(),
}),
}))