mirror of
https://github.com/langgenius/dify.git
synced 2026-04-13 12:00:19 -04:00
Merge branch 'main' into jzh
This commit is contained in:
@@ -71,6 +71,13 @@ REDIS_USE_CLUSTERS=false
|
||||
REDIS_CLUSTERS=
|
||||
REDIS_CLUSTERS_PASSWORD=
|
||||
|
||||
REDIS_RETRY_RETRIES=3
|
||||
REDIS_RETRY_BACKOFF_BASE=1.0
|
||||
REDIS_RETRY_BACKOFF_CAP=10.0
|
||||
REDIS_SOCKET_TIMEOUT=5.0
|
||||
REDIS_SOCKET_CONNECT_TIMEOUT=5.0
|
||||
REDIS_HEALTH_CHECK_INTERVAL=30
|
||||
|
||||
# celery configuration
|
||||
CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1
|
||||
CELERY_BACKEND=redis
|
||||
|
||||
31
api/configs/middleware/cache/redis_config.py
vendored
31
api/configs/middleware/cache/redis_config.py
vendored
@@ -117,6 +117,37 @@ class RedisConfig(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_RETRY_RETRIES: NonNegativeInt = Field(
|
||||
description="Maximum number of retries per Redis command on "
|
||||
"transient failures (ConnectionError, TimeoutError, socket.timeout)",
|
||||
default=3,
|
||||
)
|
||||
|
||||
REDIS_RETRY_BACKOFF_BASE: PositiveFloat = Field(
|
||||
description="Base delay in seconds for exponential backoff between retries",
|
||||
default=1.0,
|
||||
)
|
||||
|
||||
REDIS_RETRY_BACKOFF_CAP: PositiveFloat = Field(
|
||||
description="Maximum backoff delay in seconds between retries",
|
||||
default=10.0,
|
||||
)
|
||||
|
||||
REDIS_SOCKET_TIMEOUT: PositiveFloat | None = Field(
|
||||
description="Socket timeout in seconds for Redis read/write operations",
|
||||
default=5.0,
|
||||
)
|
||||
|
||||
REDIS_SOCKET_CONNECT_TIMEOUT: PositiveFloat | None = Field(
|
||||
description="Socket timeout in seconds for Redis connection establishment",
|
||||
default=5.0,
|
||||
)
|
||||
|
||||
REDIS_HEALTH_CHECK_INTERVAL: NonNegativeInt = Field(
|
||||
description="Interval in seconds between Redis connection health checks (0 to disable)",
|
||||
default=30,
|
||||
)
|
||||
|
||||
@field_validator("REDIS_MAX_CONNECTIONS", mode="before")
|
||||
@classmethod
|
||||
def _empty_string_to_none_for_max_conns(cls, v):
|
||||
|
||||
@@ -384,24 +384,27 @@ class VariableApi(Resource):
|
||||
|
||||
new_value = None
|
||||
if raw_value is not None:
|
||||
if variable.value_type == SegmentType.FILE:
|
||||
if not isinstance(raw_value, dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||
raw_value = build_from_mapping(
|
||||
mapping=raw_value,
|
||||
tenant_id=app_model.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
elif variable.value_type == SegmentType.ARRAY_FILE:
|
||||
if not isinstance(raw_value, list):
|
||||
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(
|
||||
mappings=raw_value,
|
||||
tenant_id=app_model.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
match variable.value_type:
|
||||
case SegmentType.FILE:
|
||||
if not isinstance(raw_value, dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||
raw_value = build_from_mapping(
|
||||
mapping=raw_value,
|
||||
tenant_id=app_model.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
case SegmentType.ARRAY_FILE:
|
||||
if not isinstance(raw_value, list):
|
||||
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(
|
||||
mappings=raw_value,
|
||||
tenant_id=app_model.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
case _:
|
||||
pass
|
||||
new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||
db.session.commit()
|
||||
|
||||
@@ -223,24 +223,27 @@ class RagPipelineVariableApi(Resource):
|
||||
|
||||
new_value = None
|
||||
if raw_value is not None:
|
||||
if variable.value_type == SegmentType.FILE:
|
||||
if not isinstance(raw_value, dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||
raw_value = build_from_mapping(
|
||||
mapping=raw_value,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
elif variable.value_type == SegmentType.ARRAY_FILE:
|
||||
if not isinstance(raw_value, list):
|
||||
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(
|
||||
mappings=raw_value,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
match variable.value_type:
|
||||
case SegmentType.FILE:
|
||||
if not isinstance(raw_value, dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||
raw_value = build_from_mapping(
|
||||
mapping=raw_value,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
case SegmentType.ARRAY_FILE:
|
||||
if not isinstance(raw_value, list):
|
||||
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(
|
||||
mappings=raw_value,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
case _:
|
||||
pass
|
||||
new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||
db.session.commit()
|
||||
|
||||
@@ -168,12 +168,13 @@ class ConsoleWorkflowEventsApi(Resource):
|
||||
else:
|
||||
msg_generator = MessageGenerator()
|
||||
generator: BaseAppGenerator
|
||||
if app.mode == AppMode.ADVANCED_CHAT:
|
||||
generator = AdvancedChatAppGenerator()
|
||||
elif app.mode == AppMode.WORKFLOW:
|
||||
generator = WorkflowAppGenerator()
|
||||
else:
|
||||
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
|
||||
match app.mode:
|
||||
case AppMode.ADVANCED_CHAT:
|
||||
generator = AdvancedChatAppGenerator()
|
||||
case AppMode.WORKFLOW:
|
||||
generator = WorkflowAppGenerator()
|
||||
case _:
|
||||
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
|
||||
|
||||
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
|
||||
|
||||
|
||||
@@ -72,12 +72,13 @@ class WorkflowEventsApi(WebApiResource):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
msg_generator = MessageGenerator()
|
||||
generator: BaseAppGenerator
|
||||
if app_mode == AppMode.ADVANCED_CHAT:
|
||||
generator = AdvancedChatAppGenerator()
|
||||
elif app_mode == AppMode.WORKFLOW:
|
||||
generator = WorkflowAppGenerator()
|
||||
else:
|
||||
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
|
||||
match app_mode:
|
||||
case AppMode.ADVANCED_CHAT:
|
||||
generator = AdvancedChatAppGenerator()
|
||||
case AppMode.WORKFLOW:
|
||||
generator = WorkflowAppGenerator()
|
||||
case _:
|
||||
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
|
||||
|
||||
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
|
||||
|
||||
|
||||
@@ -57,36 +57,37 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL
|
||||
used_quota = 1
|
||||
|
||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
||||
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
match system_configuration.current_quota_type:
|
||||
case ProviderQuotaType.TRIAL:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
)
|
||||
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
pool_type="paid",
|
||||
)
|
||||
else:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
stmt = (
|
||||
update(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
quota_used=Provider.quota_used + used_quota,
|
||||
last_used=naive_utc_now(),
|
||||
)
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
)
|
||||
session.execute(stmt)
|
||||
case ProviderQuotaType.PAID:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
pool_type="paid",
|
||||
)
|
||||
case ProviderQuotaType.FREE:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
stmt = (
|
||||
update(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
quota_used=Provider.quota_used + used_quota,
|
||||
last_used=naive_utc_now(),
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
|
||||
@@ -40,41 +40,44 @@ def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, Upl
|
||||
size = 0
|
||||
extension = ""
|
||||
|
||||
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
url = message_file.url
|
||||
if message_file.url:
|
||||
filename = message_file.url.split("/")[-1].split("?")[0]
|
||||
if "." in filename:
|
||||
extension = "." + filename.rsplit(".", 1)[1]
|
||||
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if upload_file:
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
|
||||
filename = upload_file.name
|
||||
mime_type = upload_file.mime_type or "application/octet-stream"
|
||||
size = upload_file.size or 0
|
||||
extension = f".{upload_file.extension}" if upload_file.extension else ""
|
||||
elif message_file.upload_file_id:
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
|
||||
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
|
||||
if message_file.url.startswith(("http://", "https://")):
|
||||
match message_file.transfer_method:
|
||||
case FileTransferMethod.REMOTE_URL:
|
||||
url = message_file.url
|
||||
filename = message_file.url.split("/")[-1].split("?")[0]
|
||||
if "." in filename:
|
||||
extension = "." + filename.rsplit(".", 1)[1]
|
||||
else:
|
||||
url_parts = message_file.url.split("/")
|
||||
if url_parts:
|
||||
file_part = url_parts[-1].split("?")[0]
|
||||
if "." in file_part:
|
||||
tool_file_id, ext = file_part.rsplit(".", 1)
|
||||
extension = f".{ext}"
|
||||
if len(extension) > MAX_TOOL_FILE_EXTENSION_LENGTH:
|
||||
if message_file.url:
|
||||
filename = message_file.url.split("/")[-1].split("?")[0]
|
||||
if "." in filename:
|
||||
extension = "." + filename.rsplit(".", 1)[1]
|
||||
case FileTransferMethod.LOCAL_FILE:
|
||||
if upload_file:
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
|
||||
filename = upload_file.name
|
||||
mime_type = upload_file.mime_type or "application/octet-stream"
|
||||
size = upload_file.size or 0
|
||||
extension = f".{upload_file.extension}" if upload_file.extension else ""
|
||||
elif message_file.upload_file_id:
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
|
||||
case FileTransferMethod.TOOL_FILE if message_file.url:
|
||||
if message_file.url.startswith(("http://", "https://")):
|
||||
url = message_file.url
|
||||
filename = message_file.url.split("/")[-1].split("?")[0]
|
||||
if "." in filename:
|
||||
extension = "." + filename.rsplit(".", 1)[1]
|
||||
else:
|
||||
url_parts = message_file.url.split("/")
|
||||
if url_parts:
|
||||
file_part = url_parts[-1].split("?")[0]
|
||||
if "." in file_part:
|
||||
tool_file_id, ext = file_part.rsplit(".", 1)
|
||||
extension = f".{ext}"
|
||||
if len(extension) > MAX_TOOL_FILE_EXTENSION_LENGTH:
|
||||
extension = ".bin"
|
||||
else:
|
||||
tool_file_id = file_part
|
||||
extension = ".bin"
|
||||
else:
|
||||
tool_file_id = file_part
|
||||
extension = ".bin"
|
||||
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
|
||||
filename = file_part
|
||||
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
|
||||
filename = file_part
|
||||
case FileTransferMethod.TOOL_FILE | FileTransferMethod.DATASOURCE_FILE:
|
||||
pass
|
||||
|
||||
transfer_method_value = message_file.transfer_method.value
|
||||
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""
|
||||
|
||||
@@ -187,15 +187,16 @@ def build_parameter_schema(
|
||||
|
||||
def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> ToolArgumentsDict:
|
||||
"""Prepare arguments based on app mode"""
|
||||
if app.mode == AppMode.WORKFLOW:
|
||||
return {"inputs": arguments}
|
||||
elif app.mode == AppMode.COMPLETION:
|
||||
return {"query": "", "inputs": arguments}
|
||||
else:
|
||||
# Chat modes - create a copy to avoid modifying original dict
|
||||
args_copy = arguments.copy()
|
||||
query = args_copy.pop("query", "")
|
||||
return {"query": query, "inputs": args_copy}
|
||||
match app.mode:
|
||||
case AppMode.WORKFLOW:
|
||||
return {"inputs": arguments}
|
||||
case AppMode.COMPLETION:
|
||||
return {"query": "", "inputs": arguments}
|
||||
case _:
|
||||
# Chat modes - create a copy to avoid modifying original dict
|
||||
args_copy = arguments.copy()
|
||||
query = args_copy.pop("query", "")
|
||||
return {"query": query, "inputs": args_copy}
|
||||
|
||||
|
||||
def extract_answer_from_response(app: App, response: Any) -> str:
|
||||
@@ -229,17 +230,13 @@ def process_streaming_response(response: RateLimitGenerator) -> str:
|
||||
|
||||
def process_mapping_response(app: App, response: Mapping) -> str:
|
||||
"""Process mapping response based on app mode"""
|
||||
if app.mode in {
|
||||
AppMode.ADVANCED_CHAT,
|
||||
AppMode.COMPLETION,
|
||||
AppMode.CHAT,
|
||||
AppMode.AGENT_CHAT,
|
||||
}:
|
||||
return response.get("answer", "")
|
||||
elif app.mode == AppMode.WORKFLOW:
|
||||
return json.dumps(response["data"]["outputs"], ensure_ascii=False)
|
||||
else:
|
||||
raise ValueError("Invalid app mode: " + str(app.mode))
|
||||
match app.mode:
|
||||
case AppMode.ADVANCED_CHAT | AppMode.COMPLETION | AppMode.CHAT | AppMode.AGENT_CHAT:
|
||||
return response.get("answer", "")
|
||||
case AppMode.WORKFLOW:
|
||||
return json.dumps(response["data"]["outputs"], ensure_ascii=False)
|
||||
case _:
|
||||
raise ValueError("Invalid app mode: " + str(app.mode))
|
||||
|
||||
|
||||
def convert_input_form_to_parameters(
|
||||
|
||||
@@ -72,17 +72,18 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
|
||||
conversation_id = conversation_id or ""
|
||||
|
||||
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.AGENT_CHAT, AppMode.CHAT}:
|
||||
if not query:
|
||||
raise ValueError("missing query")
|
||||
match app.mode:
|
||||
case AppMode.ADVANCED_CHAT | AppMode.AGENT_CHAT | AppMode.CHAT:
|
||||
if not query:
|
||||
raise ValueError("missing query")
|
||||
|
||||
return cls.invoke_chat_app(app, user, conversation_id, query, stream, inputs, files)
|
||||
elif app.mode == AppMode.WORKFLOW:
|
||||
return cls.invoke_workflow_app(app, user, stream, inputs, files)
|
||||
elif app.mode == AppMode.COMPLETION:
|
||||
return cls.invoke_completion_app(app, user, stream, inputs, files)
|
||||
|
||||
raise ValueError("unexpected app type")
|
||||
return cls.invoke_chat_app(app, user, conversation_id, query, stream, inputs, files)
|
||||
case AppMode.WORKFLOW:
|
||||
return cls.invoke_workflow_app(app, user, stream, inputs, files)
|
||||
case AppMode.COMPLETION:
|
||||
return cls.invoke_completion_app(app, user, stream, inputs, files)
|
||||
case _:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
@classmethod
|
||||
def invoke_chat_app(
|
||||
@@ -98,60 +99,61 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke chat app
|
||||
"""
|
||||
if app.mode == AppMode.ADVANCED_CHAT:
|
||||
workflow = app.workflow
|
||||
if not workflow:
|
||||
match app.mode:
|
||||
case AppMode.ADVANCED_CHAT:
|
||||
workflow = app.workflow
|
||||
if not workflow:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
pause_config = PauseStateLayerConfig(
|
||||
session_factory=db.engine,
|
||||
state_owner_user_id=workflow.created_by,
|
||||
)
|
||||
|
||||
return AdvancedChatAppGenerator().generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args={
|
||||
"inputs": inputs,
|
||||
"query": query,
|
||||
"files": files,
|
||||
"conversation_id": conversation_id,
|
||||
},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
streaming=stream,
|
||||
pause_state_config=pause_config,
|
||||
)
|
||||
case AppMode.AGENT_CHAT:
|
||||
return AgentChatAppGenerator().generate(
|
||||
app_model=app,
|
||||
user=user,
|
||||
args={
|
||||
"inputs": inputs,
|
||||
"query": query,
|
||||
"files": files,
|
||||
"conversation_id": conversation_id,
|
||||
},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=stream,
|
||||
)
|
||||
case AppMode.CHAT:
|
||||
return ChatAppGenerator().generate(
|
||||
app_model=app,
|
||||
user=user,
|
||||
args={
|
||||
"inputs": inputs,
|
||||
"query": query,
|
||||
"files": files,
|
||||
"conversation_id": conversation_id,
|
||||
},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=stream,
|
||||
)
|
||||
case _:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
pause_config = PauseStateLayerConfig(
|
||||
session_factory=db.engine,
|
||||
state_owner_user_id=workflow.created_by,
|
||||
)
|
||||
|
||||
return AdvancedChatAppGenerator().generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args={
|
||||
"inputs": inputs,
|
||||
"query": query,
|
||||
"files": files,
|
||||
"conversation_id": conversation_id,
|
||||
},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
streaming=stream,
|
||||
pause_state_config=pause_config,
|
||||
)
|
||||
elif app.mode == AppMode.AGENT_CHAT:
|
||||
return AgentChatAppGenerator().generate(
|
||||
app_model=app,
|
||||
user=user,
|
||||
args={
|
||||
"inputs": inputs,
|
||||
"query": query,
|
||||
"files": files,
|
||||
"conversation_id": conversation_id,
|
||||
},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=stream,
|
||||
)
|
||||
elif app.mode == AppMode.CHAT:
|
||||
return ChatAppGenerator().generate(
|
||||
app_model=app,
|
||||
user=user,
|
||||
args={
|
||||
"inputs": inputs,
|
||||
"query": query,
|
||||
"files": files,
|
||||
"conversation_id": conversation_id,
|
||||
},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=stream,
|
||||
)
|
||||
else:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
@classmethod
|
||||
def invoke_workflow_app(
|
||||
cls,
|
||||
|
||||
@@ -961,36 +961,37 @@ class ProviderManager:
|
||||
raise ValueError("quota_used is None")
|
||||
if provider_record.quota_limit is None:
|
||||
raise ValueError("quota_limit is None")
|
||||
if provider_quota.quota_type == ProviderQuotaType.TRIAL and trail_pool is not None:
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||
quota_used=trail_pool.quota_used,
|
||||
quota_limit=trail_pool.quota_limit,
|
||||
is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1,
|
||||
restrict_models=provider_quota.restrict_models,
|
||||
)
|
||||
match provider_quota.quota_type:
|
||||
case ProviderQuotaType.TRIAL if trail_pool is not None:
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||
quota_used=trail_pool.quota_used,
|
||||
quota_limit=trail_pool.quota_limit,
|
||||
is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1,
|
||||
restrict_models=provider_quota.restrict_models,
|
||||
)
|
||||
|
||||
elif provider_quota.quota_type == ProviderQuotaType.PAID and paid_pool is not None:
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||
quota_used=paid_pool.quota_used,
|
||||
quota_limit=paid_pool.quota_limit,
|
||||
is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1,
|
||||
restrict_models=provider_quota.restrict_models,
|
||||
)
|
||||
case ProviderQuotaType.PAID if paid_pool is not None:
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||
quota_used=paid_pool.quota_used,
|
||||
quota_limit=paid_pool.quota_limit,
|
||||
is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1,
|
||||
restrict_models=provider_quota.restrict_models,
|
||||
)
|
||||
|
||||
else:
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||
quota_used=provider_record.quota_used,
|
||||
quota_limit=provider_record.quota_limit,
|
||||
is_valid=provider_record.quota_limit > provider_record.quota_used
|
||||
or provider_record.quota_limit == -1,
|
||||
restrict_models=provider_quota.restrict_models,
|
||||
)
|
||||
case _:
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||
quota_used=provider_record.quota_used,
|
||||
quota_limit=provider_record.quota_limit,
|
||||
is_valid=provider_record.quota_limit > provider_record.quota_used
|
||||
or provider_record.quota_limit == -1,
|
||||
restrict_models=provider_quota.restrict_models,
|
||||
)
|
||||
|
||||
quota_configurations.append(quota_configuration)
|
||||
|
||||
|
||||
@@ -7,14 +7,13 @@ from sqlalchemy import select
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
|
||||
from core.rag.entities import RetrievalSourceMetadata
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.models.document import Document as RagDocument
|
||||
from core.rag.rerank.rerank_model import RerankModelRunner
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_tool import DefaultRetrievalModelDict
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
from typing import NotRequired, TypedDict, cast
|
||||
from typing import cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
|
||||
from core.rag.entities import DocumentContext, RetrievalSourceMetadata
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.models.document import Document as RetrievalDocument
|
||||
@@ -17,18 +16,6 @@ from models.dataset import Dataset
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
|
||||
class DefaultRetrievalModelDict(TypedDict):
|
||||
search_method: RetrievalMethod
|
||||
reranking_enable: bool
|
||||
reranking_model: RerankingModelDict
|
||||
reranking_mode: NotRequired[str]
|
||||
weights: NotRequired[WeightsDict | None]
|
||||
score_threshold: NotRequired[float]
|
||||
top_k: int
|
||||
score_threshold_enabled: bool
|
||||
|
||||
|
||||
default_retrieval_model: DefaultRetrievalModelDict = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
|
||||
@@ -155,24 +155,25 @@ class TriggerWebhookNode(Node[WebhookData]):
|
||||
outputs[param_name] = raw_data
|
||||
continue
|
||||
|
||||
if param_type == SegmentType.FILE:
|
||||
# Get File object (already processed by webhook controller)
|
||||
files = webhook_data.get("files", {})
|
||||
if files and isinstance(files, dict):
|
||||
file = files.get(param_name)
|
||||
if file and isinstance(file, dict):
|
||||
file_var = self.generate_file_var(param_name, file)
|
||||
if file_var:
|
||||
outputs[param_name] = file_var
|
||||
match param_type:
|
||||
case SegmentType.FILE:
|
||||
# Get File object (already processed by webhook controller)
|
||||
files = webhook_data.get("files", {})
|
||||
if files and isinstance(files, dict):
|
||||
file = files.get(param_name)
|
||||
if file and isinstance(file, dict):
|
||||
file_var = self.generate_file_var(param_name, file)
|
||||
if file_var:
|
||||
outputs[param_name] = file_var
|
||||
else:
|
||||
outputs[param_name] = files
|
||||
else:
|
||||
outputs[param_name] = files
|
||||
else:
|
||||
outputs[param_name] = files
|
||||
else:
|
||||
outputs[param_name] = files
|
||||
else:
|
||||
# Get regular body parameter
|
||||
outputs[param_name] = webhook_data.get("body", {}).get(param_name)
|
||||
case _:
|
||||
# Get regular body parameter
|
||||
outputs[param_name] = webhook_data.get("body", {}).get(param_name)
|
||||
|
||||
# Include raw webhook data for debugging/advanced use
|
||||
outputs["_webhook_raw"] = webhook_data
|
||||
|
||||
@@ -68,46 +68,49 @@ class EnterpriseMetricHandler:
|
||||
|
||||
# Route to appropriate handler based on case
|
||||
case = envelope.case
|
||||
if case == TelemetryCase.APP_CREATED:
|
||||
self._on_app_created(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "app_created"})
|
||||
elif case == TelemetryCase.APP_UPDATED:
|
||||
self._on_app_updated(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "app_updated"})
|
||||
elif case == TelemetryCase.APP_DELETED:
|
||||
self._on_app_deleted(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "app_deleted"})
|
||||
elif case == TelemetryCase.FEEDBACK_CREATED:
|
||||
self._on_feedback_created(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "feedback_created"})
|
||||
elif case == TelemetryCase.MESSAGE_RUN:
|
||||
self._on_message_run(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "message_run"})
|
||||
elif case == TelemetryCase.TOOL_EXECUTION:
|
||||
self._on_tool_execution(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "tool_execution"})
|
||||
elif case == TelemetryCase.MODERATION_CHECK:
|
||||
self._on_moderation_check(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "moderation_check"})
|
||||
elif case == TelemetryCase.SUGGESTED_QUESTION:
|
||||
self._on_suggested_question(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "suggested_question"})
|
||||
elif case == TelemetryCase.DATASET_RETRIEVAL:
|
||||
self._on_dataset_retrieval(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "dataset_retrieval"})
|
||||
elif case == TelemetryCase.GENERATE_NAME:
|
||||
self._on_generate_name(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "generate_name"})
|
||||
elif case == TelemetryCase.PROMPT_GENERATION:
|
||||
self._on_prompt_generation(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "prompt_generation"})
|
||||
else:
|
||||
logger.warning(
|
||||
"Unknown telemetry case: %s (tenant_id=%s, event_id=%s)",
|
||||
case,
|
||||
envelope.tenant_id,
|
||||
envelope.event_id,
|
||||
)
|
||||
match case:
|
||||
case TelemetryCase.APP_CREATED:
|
||||
self._on_app_created(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "app_created"})
|
||||
case TelemetryCase.APP_UPDATED:
|
||||
self._on_app_updated(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "app_updated"})
|
||||
case TelemetryCase.APP_DELETED:
|
||||
self._on_app_deleted(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "app_deleted"})
|
||||
case TelemetryCase.FEEDBACK_CREATED:
|
||||
self._on_feedback_created(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "feedback_created"})
|
||||
case TelemetryCase.MESSAGE_RUN:
|
||||
self._on_message_run(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "message_run"})
|
||||
case TelemetryCase.TOOL_EXECUTION:
|
||||
self._on_tool_execution(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "tool_execution"})
|
||||
case TelemetryCase.MODERATION_CHECK:
|
||||
self._on_moderation_check(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "moderation_check"})
|
||||
case TelemetryCase.SUGGESTED_QUESTION:
|
||||
self._on_suggested_question(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "suggested_question"})
|
||||
case TelemetryCase.DATASET_RETRIEVAL:
|
||||
self._on_dataset_retrieval(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "dataset_retrieval"})
|
||||
case TelemetryCase.GENERATE_NAME:
|
||||
self._on_generate_name(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "generate_name"})
|
||||
case TelemetryCase.PROMPT_GENERATION:
|
||||
self._on_prompt_generation(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "prompt_generation"})
|
||||
case TelemetryCase.WORKFLOW_RUN | TelemetryCase.NODE_EXECUTION | TelemetryCase.DRAFT_NODE_EXECUTION:
|
||||
pass
|
||||
case _:
|
||||
logger.warning(
|
||||
"Unknown telemetry case: %s (tenant_id=%s, event_id=%s)",
|
||||
case,
|
||||
envelope.tenant_id,
|
||||
envelope.event_id,
|
||||
)
|
||||
|
||||
def _is_duplicate(self, envelope: TelemetryEnvelope) -> bool:
|
||||
"""Check if this event has already been processed.
|
||||
|
||||
@@ -7,10 +7,12 @@ from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
import redis
|
||||
from redis import RedisError
|
||||
from redis.backoff import ExponentialWithJitterBackoff # type: ignore
|
||||
from redis.cache import CacheConfig
|
||||
from redis.client import PubSub
|
||||
from redis.cluster import ClusterNode, RedisCluster
|
||||
from redis.connection import Connection, SSLConnection
|
||||
from redis.retry import Retry
|
||||
from redis.sentinel import Sentinel
|
||||
|
||||
from configs import dify_config
|
||||
@@ -158,8 +160,41 @@ def _get_cache_configuration() -> CacheConfig | None:
|
||||
return CacheConfig()
|
||||
|
||||
|
||||
def _get_retry_policy() -> Retry:
|
||||
"""Build the shared retry policy for Redis connections."""
|
||||
return Retry(
|
||||
backoff=ExponentialWithJitterBackoff(
|
||||
base=dify_config.REDIS_RETRY_BACKOFF_BASE,
|
||||
cap=dify_config.REDIS_RETRY_BACKOFF_CAP,
|
||||
),
|
||||
retries=dify_config.REDIS_RETRY_RETRIES,
|
||||
)
|
||||
|
||||
|
||||
def _get_connection_health_params() -> dict[str, Any]:
|
||||
"""Get connection health and retry parameters for standalone and Sentinel Redis clients."""
|
||||
return {
|
||||
"retry": _get_retry_policy(),
|
||||
"socket_timeout": dify_config.REDIS_SOCKET_TIMEOUT,
|
||||
"socket_connect_timeout": dify_config.REDIS_SOCKET_CONNECT_TIMEOUT,
|
||||
"health_check_interval": dify_config.REDIS_HEALTH_CHECK_INTERVAL,
|
||||
}
|
||||
|
||||
|
||||
def _get_cluster_connection_health_params() -> dict[str, Any]:
|
||||
"""Get retry and timeout parameters for Redis Cluster clients.
|
||||
|
||||
RedisCluster does not support ``health_check_interval`` as a constructor
|
||||
keyword (it is silently stripped by ``cleanup_kwargs``), so it is excluded
|
||||
here. Only ``retry``, ``socket_timeout``, and ``socket_connect_timeout``
|
||||
are passed through.
|
||||
"""
|
||||
params = _get_connection_health_params()
|
||||
return {k: v for k, v in params.items() if k != "health_check_interval"}
|
||||
|
||||
|
||||
def _get_base_redis_params() -> dict[str, Any]:
|
||||
"""Get base Redis connection parameters."""
|
||||
"""Get base Redis connection parameters including retry and health policy."""
|
||||
return {
|
||||
"username": dify_config.REDIS_USERNAME,
|
||||
"password": dify_config.REDIS_PASSWORD or None,
|
||||
@@ -169,6 +204,7 @@ def _get_base_redis_params() -> dict[str, Any]:
|
||||
"decode_responses": False,
|
||||
"protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL,
|
||||
"cache_config": _get_cache_configuration(),
|
||||
**_get_connection_health_params(),
|
||||
}
|
||||
|
||||
|
||||
@@ -215,6 +251,7 @@ def _create_cluster_client() -> Union[redis.Redis, RedisCluster]:
|
||||
"password": dify_config.REDIS_CLUSTERS_PASSWORD,
|
||||
"protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL,
|
||||
"cache_config": _get_cache_configuration(),
|
||||
**_get_cluster_connection_health_params(),
|
||||
}
|
||||
if dify_config.REDIS_MAX_CONNECTIONS:
|
||||
cluster_kwargs["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS
|
||||
@@ -226,7 +263,8 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis
|
||||
"""Create standalone Redis client."""
|
||||
connection_class, ssl_kwargs = _get_ssl_configuration()
|
||||
|
||||
redis_params.update(
|
||||
params = {**redis_params}
|
||||
params.update(
|
||||
{
|
||||
"host": dify_config.REDIS_HOST,
|
||||
"port": dify_config.REDIS_PORT,
|
||||
@@ -235,28 +273,31 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis
|
||||
)
|
||||
|
||||
if dify_config.REDIS_MAX_CONNECTIONS:
|
||||
redis_params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS
|
||||
params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS
|
||||
|
||||
if ssl_kwargs:
|
||||
redis_params.update(ssl_kwargs)
|
||||
params.update(ssl_kwargs)
|
||||
|
||||
pool = redis.ConnectionPool(**redis_params)
|
||||
pool = redis.ConnectionPool(**params)
|
||||
client: redis.Redis = redis.Redis(connection_pool=pool)
|
||||
return client
|
||||
|
||||
|
||||
def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> redis.Redis | RedisCluster:
|
||||
max_conns = dify_config.REDIS_MAX_CONNECTIONS
|
||||
if use_clusters:
|
||||
if max_conns:
|
||||
return RedisCluster.from_url(pubsub_url, max_connections=max_conns)
|
||||
else:
|
||||
return RedisCluster.from_url(pubsub_url)
|
||||
|
||||
if use_clusters:
|
||||
health_params = _get_cluster_connection_health_params()
|
||||
kwargs: dict[str, Any] = {**health_params}
|
||||
if max_conns:
|
||||
kwargs["max_connections"] = max_conns
|
||||
return RedisCluster.from_url(pubsub_url, **kwargs)
|
||||
|
||||
health_params = _get_connection_health_params()
|
||||
kwargs = {**health_params}
|
||||
if max_conns:
|
||||
return redis.Redis.from_url(pubsub_url, max_connections=max_conns)
|
||||
else:
|
||||
return redis.Redis.from_url(pubsub_url)
|
||||
kwargs["max_connections"] = max_conns
|
||||
return redis.Redis.from_url(pubsub_url, **kwargs)
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
|
||||
@@ -40,7 +40,7 @@ dependencies = [
|
||||
"numpy~=1.26.4",
|
||||
"openpyxl~=3.1.5",
|
||||
"opik~=1.10.37",
|
||||
"litellm==1.82.6", # Pinned to avoid madoka dependency issue
|
||||
"litellm==1.83.0", # Pinned to avoid madoka dependency issue
|
||||
"opentelemetry-api==1.40.0",
|
||||
"opentelemetry-distro==0.61b0",
|
||||
"opentelemetry-exporter-otlp==1.40.0",
|
||||
|
||||
@@ -1,53 +1,125 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from redis import RedisError
|
||||
from redis.retry import Retry
|
||||
|
||||
from extensions.ext_redis import redis_fallback
|
||||
from extensions.ext_redis import (
|
||||
_get_base_redis_params,
|
||||
_get_cluster_connection_health_params,
|
||||
_get_connection_health_params,
|
||||
redis_fallback,
|
||||
)
|
||||
|
||||
|
||||
def test_redis_fallback_success():
|
||||
@redis_fallback(default_return=None)
|
||||
def test_func():
|
||||
return "success"
|
||||
class TestGetConnectionHealthParams:
|
||||
@patch("extensions.ext_redis.dify_config")
|
||||
def test_includes_all_health_params(self, mock_config):
|
||||
mock_config.REDIS_RETRY_RETRIES = 3
|
||||
mock_config.REDIS_RETRY_BACKOFF_BASE = 1.0
|
||||
mock_config.REDIS_RETRY_BACKOFF_CAP = 10.0
|
||||
mock_config.REDIS_SOCKET_TIMEOUT = 5.0
|
||||
mock_config.REDIS_SOCKET_CONNECT_TIMEOUT = 5.0
|
||||
mock_config.REDIS_HEALTH_CHECK_INTERVAL = 30
|
||||
|
||||
assert test_func() == "success"
|
||||
params = _get_connection_health_params()
|
||||
|
||||
assert "retry" in params
|
||||
assert "socket_timeout" in params
|
||||
assert "socket_connect_timeout" in params
|
||||
assert "health_check_interval" in params
|
||||
assert isinstance(params["retry"], Retry)
|
||||
assert params["retry"]._retries == 3
|
||||
assert params["socket_timeout"] == 5.0
|
||||
assert params["socket_connect_timeout"] == 5.0
|
||||
assert params["health_check_interval"] == 30
|
||||
|
||||
|
||||
def test_redis_fallback_error():
|
||||
@redis_fallback(default_return="fallback")
|
||||
def test_func():
|
||||
raise RedisError("Redis error")
|
||||
class TestGetClusterConnectionHealthParams:
|
||||
@patch("extensions.ext_redis.dify_config")
|
||||
def test_excludes_health_check_interval(self, mock_config):
|
||||
mock_config.REDIS_RETRY_RETRIES = 3
|
||||
mock_config.REDIS_RETRY_BACKOFF_BASE = 1.0
|
||||
mock_config.REDIS_RETRY_BACKOFF_CAP = 10.0
|
||||
mock_config.REDIS_SOCKET_TIMEOUT = 5.0
|
||||
mock_config.REDIS_SOCKET_CONNECT_TIMEOUT = 5.0
|
||||
mock_config.REDIS_HEALTH_CHECK_INTERVAL = 30
|
||||
|
||||
assert test_func() == "fallback"
|
||||
params = _get_cluster_connection_health_params()
|
||||
|
||||
assert "retry" in params
|
||||
assert "socket_timeout" in params
|
||||
assert "socket_connect_timeout" in params
|
||||
assert "health_check_interval" not in params
|
||||
|
||||
|
||||
def test_redis_fallback_none_default():
|
||||
@redis_fallback()
|
||||
def test_func():
|
||||
raise RedisError("Redis error")
|
||||
class TestGetBaseRedisParams:
|
||||
@patch("extensions.ext_redis.dify_config")
|
||||
def test_includes_retry_and_health_params(self, mock_config):
|
||||
mock_config.REDIS_USERNAME = None
|
||||
mock_config.REDIS_PASSWORD = None
|
||||
mock_config.REDIS_DB = 0
|
||||
mock_config.REDIS_SERIALIZATION_PROTOCOL = 3
|
||||
mock_config.REDIS_ENABLE_CLIENT_SIDE_CACHE = False
|
||||
mock_config.REDIS_RETRY_RETRIES = 3
|
||||
mock_config.REDIS_RETRY_BACKOFF_BASE = 1.0
|
||||
mock_config.REDIS_RETRY_BACKOFF_CAP = 10.0
|
||||
mock_config.REDIS_SOCKET_TIMEOUT = 5.0
|
||||
mock_config.REDIS_SOCKET_CONNECT_TIMEOUT = 5.0
|
||||
mock_config.REDIS_HEALTH_CHECK_INTERVAL = 30
|
||||
|
||||
assert test_func() is None
|
||||
params = _get_base_redis_params()
|
||||
|
||||
assert "retry" in params
|
||||
assert isinstance(params["retry"], Retry)
|
||||
assert params["socket_timeout"] == 5.0
|
||||
assert params["socket_connect_timeout"] == 5.0
|
||||
assert params["health_check_interval"] == 30
|
||||
# Existing params still present
|
||||
assert params["db"] == 0
|
||||
assert params["encoding"] == "utf-8"
|
||||
|
||||
|
||||
def test_redis_fallback_with_args():
|
||||
@redis_fallback(default_return=0)
|
||||
def test_func(x, y):
|
||||
raise RedisError("Redis error")
|
||||
class TestRedisFallback:
|
||||
def test_redis_fallback_success(self):
|
||||
@redis_fallback(default_return=None)
|
||||
def test_func():
|
||||
return "success"
|
||||
|
||||
assert test_func(1, 2) == 0
|
||||
assert test_func() == "success"
|
||||
|
||||
def test_redis_fallback_error(self):
|
||||
@redis_fallback(default_return="fallback")
|
||||
def test_func():
|
||||
raise RedisError("Redis error")
|
||||
|
||||
def test_redis_fallback_with_kwargs():
|
||||
@redis_fallback(default_return={})
|
||||
def test_func(x=None, y=None):
|
||||
raise RedisError("Redis error")
|
||||
assert test_func() == "fallback"
|
||||
|
||||
assert test_func(x=1, y=2) == {}
|
||||
def test_redis_fallback_none_default(self):
|
||||
@redis_fallback()
|
||||
def test_func():
|
||||
raise RedisError("Redis error")
|
||||
|
||||
assert test_func() is None
|
||||
|
||||
def test_redis_fallback_preserves_function_metadata():
|
||||
@redis_fallback(default_return=None)
|
||||
def test_func():
|
||||
"""Test function docstring"""
|
||||
pass
|
||||
def test_redis_fallback_with_args(self):
|
||||
@redis_fallback(default_return=0)
|
||||
def test_func(x, y):
|
||||
raise RedisError("Redis error")
|
||||
|
||||
assert test_func.__name__ == "test_func"
|
||||
assert test_func.__doc__ == "Test function docstring"
|
||||
assert test_func(1, 2) == 0
|
||||
|
||||
def test_redis_fallback_with_kwargs(self):
|
||||
@redis_fallback(default_return={})
|
||||
def test_func(x=None, y=None):
|
||||
raise RedisError("Redis error")
|
||||
|
||||
assert test_func(x=1, y=2) == {}
|
||||
|
||||
def test_redis_fallback_preserves_function_metadata(self):
|
||||
@redis_fallback(default_return=None)
|
||||
def test_func():
|
||||
"""Test function docstring"""
|
||||
pass
|
||||
|
||||
assert test_func.__name__ == "test_func"
|
||||
assert test_func.__doc__ == "Test function docstring"
|
||||
|
||||
8
api/uv.lock
generated
8
api/uv.lock
generated
@@ -1507,7 +1507,7 @@ requires-dist = [
|
||||
{ name = "json-repair", specifier = ">=0.55.1" },
|
||||
{ name = "langfuse", specifier = ">=3.0.0,<5.0.0" },
|
||||
{ name = "langsmith", specifier = "~=0.7.16" },
|
||||
{ name = "litellm", specifier = "==1.82.6" },
|
||||
{ name = "litellm", specifier = "==1.83.0" },
|
||||
{ name = "markdown", specifier = "~=3.10.2" },
|
||||
{ name = "mlflow-skinny", specifier = ">=3.0.0" },
|
||||
{ name = "numpy", specifier = "~=1.26.4" },
|
||||
@@ -3121,7 +3121,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "litellm"
|
||||
version = "1.82.6"
|
||||
version = "1.83.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "aiohttp" },
|
||||
@@ -3137,9 +3137,9 @@ dependencies = [
|
||||
{ name = "tiktoken" },
|
||||
{ name = "tokenizers" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/29/75/1c537aa458426a9127a92bc2273787b2f987f4e5044e21f01f2eed5244fd/litellm-1.82.6.tar.gz", hash = "sha256:2aa1c2da21fe940c33613aa447119674a3ad4d2ad5eb064e4d5ce5ee42420136", size = 17414147, upload-time = "2026-03-22T06:36:00.452Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/22/92/6ce9737554994ca8e536e5f4f6a87cc7c4774b656c9eb9add071caf7d54b/litellm-1.83.0.tar.gz", hash = "sha256:860bebc76c4bb27b4cf90b4a77acd66dba25aced37e3db98750de8a1766bfb7a", size = 17333062, upload-time = "2026-03-31T05:08:25.331Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/02/6c/5327667e6dbe9e98cbfbd4261c8e91386a52e38f41419575854248bbab6a/litellm-1.82.6-py3-none-any.whl", hash = "sha256:164a3ef3e19f309e3cabc199bef3d2045212712fefdfa25fc7f75884a5b5b205", size = 15591595, upload-time = "2026-03-22T06:35:56.795Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/19/2c/a670cc050fcd6f45c6199eb99e259c73aea92edba8d5c2fc1b3686d36217/litellm-1.83.0-py3-none-any.whl", hash = "sha256:88c536d339248f3987571493015784671ba3f193a328e1ea6780dbebaa2094a8", size = 15610306, upload-time = "2026-03-31T05:08:21.987Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -373,6 +373,20 @@ REDIS_USE_CLUSTERS=false
|
||||
REDIS_CLUSTERS=
|
||||
REDIS_CLUSTERS_PASSWORD=
|
||||
|
||||
# Redis connection and retry configuration
|
||||
# max redis retry
|
||||
REDIS_RETRY_RETRIES=3
|
||||
# Base delay (in seconds) for exponential backoff on retries
|
||||
REDIS_RETRY_BACKOFF_BASE=1.0
|
||||
# Cap (in seconds) for exponential backoff on retries
|
||||
REDIS_RETRY_BACKOFF_CAP=10.0
|
||||
# Timeout (in seconds) for Redis socket operations
|
||||
REDIS_SOCKET_TIMEOUT=5.0
|
||||
# Timeout (in seconds) for establishing a Redis connection
|
||||
REDIS_SOCKET_CONNECT_TIMEOUT=5.0
|
||||
# Interval (in seconds) for Redis health checks
|
||||
REDIS_HEALTH_CHECK_INTERVAL=30
|
||||
|
||||
# ------------------------------
|
||||
# Celery Configuration
|
||||
# ------------------------------
|
||||
|
||||
@@ -100,6 +100,12 @@ x-shared-env: &shared-api-worker-env
|
||||
REDIS_USE_CLUSTERS: ${REDIS_USE_CLUSTERS:-false}
|
||||
REDIS_CLUSTERS: ${REDIS_CLUSTERS:-}
|
||||
REDIS_CLUSTERS_PASSWORD: ${REDIS_CLUSTERS_PASSWORD:-}
|
||||
REDIS_RETRY_RETRIES: ${REDIS_RETRY_RETRIES:-3}
|
||||
REDIS_RETRY_BACKOFF_BASE: ${REDIS_RETRY_BACKOFF_BASE:-1.0}
|
||||
REDIS_RETRY_BACKOFF_CAP: ${REDIS_RETRY_BACKOFF_CAP:-10.0}
|
||||
REDIS_SOCKET_TIMEOUT: ${REDIS_SOCKET_TIMEOUT:-5.0}
|
||||
REDIS_SOCKET_CONNECT_TIMEOUT: ${REDIS_SOCKET_CONNECT_TIMEOUT:-5.0}
|
||||
REDIS_HEALTH_CHECK_INTERVAL: ${REDIS_HEALTH_CHECK_INTERVAL:-30}
|
||||
CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1}
|
||||
CELERY_BACKEND: ${CELERY_BACKEND:-redis}
|
||||
BROKER_USE_SSL: ${BROKER_USE_SSL:-false}
|
||||
|
||||
11
e2e/features/apps/create-chatbot-app.feature
Normal file
11
e2e/features/apps/create-chatbot-app.feature
Normal file
@@ -0,0 +1,11 @@
|
||||
@apps @authenticated
|
||||
Feature: Create Chatbot app
|
||||
Scenario: Create a new Chatbot app and redirect to the configuration page
|
||||
Given I am signed in as the default E2E admin
|
||||
When I open the apps console
|
||||
And I start creating a blank app
|
||||
And I expand the beginner app types
|
||||
And I select the "Chatbot" app type
|
||||
And I enter a unique E2E app name
|
||||
And I confirm app creation
|
||||
Then I should land on the app configuration page
|
||||
10
e2e/features/apps/create-workflow-app.feature
Normal file
10
e2e/features/apps/create-workflow-app.feature
Normal file
@@ -0,0 +1,10 @@
|
||||
@apps @authenticated
|
||||
Feature: Create Workflow app
|
||||
Scenario: Create a new Workflow app and redirect to the workflow editor
|
||||
Given I am signed in as the default E2E admin
|
||||
When I open the apps console
|
||||
And I start creating a blank app
|
||||
And I select the "Workflow" app type
|
||||
And I enter a unique E2E app name
|
||||
And I confirm app creation
|
||||
Then I should land on the workflow editor
|
||||
8
e2e/features/auth/sign-out.feature
Normal file
8
e2e/features/auth/sign-out.feature
Normal file
@@ -0,0 +1,8 @@
|
||||
@auth @authenticated
|
||||
Feature: Sign out
|
||||
Scenario: Sign out from the apps console
|
||||
Given I am signed in as the default E2E admin
|
||||
When I open the apps console
|
||||
And I open the account menu
|
||||
And I sign out
|
||||
Then I should be on the sign-in page
|
||||
@@ -24,6 +24,30 @@ When('I confirm app creation', async function (this: DifyWorld) {
|
||||
await createButton.click()
|
||||
})
|
||||
|
||||
When('I select the {string} app type', async function (this: DifyWorld, appType: string) {
|
||||
const dialog = this.getPage().getByRole('dialog')
|
||||
const appTypeTitle = dialog.getByText(appType, { exact: true })
|
||||
|
||||
await expect(appTypeTitle).toBeVisible()
|
||||
await appTypeTitle.click()
|
||||
})
|
||||
|
||||
When('I expand the beginner app types', async function (this: DifyWorld) {
|
||||
const page = this.getPage()
|
||||
const toggle = page.getByRole('button', { name: 'More basic app types' })
|
||||
|
||||
await expect(toggle).toBeVisible()
|
||||
await toggle.click()
|
||||
})
|
||||
|
||||
Then('I should land on the app editor', async function (this: DifyWorld) {
|
||||
await expect(this.getPage()).toHaveURL(/\/app\/[^/]+\/(workflow|configuration)(?:\?.*)?$/)
|
||||
})
|
||||
|
||||
Then('I should land on the workflow editor', async function (this: DifyWorld) {
|
||||
await expect(this.getPage()).toHaveURL(/\/app\/[^/]+\/workflow(?:\?.*)?$/)
|
||||
})
|
||||
|
||||
Then('I should land on the app configuration page', async function (this: DifyWorld) {
|
||||
await expect(this.getPage()).toHaveURL(/\/app\/[^/]+\/configuration(?:\?.*)?$/)
|
||||
})
|
||||
|
||||
25
e2e/features/step-definitions/auth/sign-out.steps.ts
Normal file
25
e2e/features/step-definitions/auth/sign-out.steps.ts
Normal file
@@ -0,0 +1,25 @@
|
||||
import { Then, When } from '@cucumber/cucumber'
|
||||
import { expect } from '@playwright/test'
|
||||
import type { DifyWorld } from '../../support/world'
|
||||
|
||||
When('I open the account menu', async function (this: DifyWorld) {
|
||||
const page = this.getPage()
|
||||
const trigger = page.getByRole('button', { name: 'Account' })
|
||||
|
||||
await expect(trigger).toBeVisible()
|
||||
await trigger.click()
|
||||
})
|
||||
|
||||
When('I sign out', async function (this: DifyWorld) {
|
||||
const page = this.getPage()
|
||||
|
||||
await expect(page.getByText('Log out')).toBeVisible()
|
||||
await page.getByText('Log out').click()
|
||||
})
|
||||
|
||||
Then('I should be on the sign-in page', async function (this: DifyWorld) {
|
||||
await expect(this.getPage()).toHaveURL(/\/signin/)
|
||||
await expect(this.getPage().getByRole('button', { name: /^Sign in$/i })).toBeVisible({
|
||||
timeout: 30_000,
|
||||
})
|
||||
})
|
||||
42
pnpm-lock.yaml
generated
42
pnpm-lock.yaml
generated
@@ -249,6 +249,9 @@ catalogs:
|
||||
class-variance-authority:
|
||||
specifier: 0.7.1
|
||||
version: 0.7.1
|
||||
client-only:
|
||||
specifier: 0.0.1
|
||||
version: 0.0.1
|
||||
clsx:
|
||||
specifier: 2.1.1
|
||||
version: 2.1.1
|
||||
@@ -324,9 +327,6 @@ catalogs:
|
||||
fast-deep-equal:
|
||||
specifier: 3.1.3
|
||||
version: 3.1.3
|
||||
foxact:
|
||||
specifier: 0.3.0
|
||||
version: 0.3.0
|
||||
happy-dom:
|
||||
specifier: 20.8.9
|
||||
version: 20.8.9
|
||||
@@ -736,6 +736,9 @@ importers:
|
||||
class-variance-authority:
|
||||
specifier: 'catalog:'
|
||||
version: 0.7.1
|
||||
client-only:
|
||||
specifier: 'catalog:'
|
||||
version: 0.0.1
|
||||
clsx:
|
||||
specifier: 'catalog:'
|
||||
version: 2.1.1
|
||||
@@ -781,9 +784,6 @@ importers:
|
||||
fast-deep-equal:
|
||||
specifier: 'catalog:'
|
||||
version: 3.1.3
|
||||
foxact:
|
||||
specifier: 'catalog:'
|
||||
version: 0.3.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4)
|
||||
hast-util-to-jsx-runtime:
|
||||
specifier: 'catalog:'
|
||||
version: 2.3.6
|
||||
@@ -5871,9 +5871,6 @@ packages:
|
||||
resolution: {integrity: sha512-kVscqXk4OCp68SZ0dkgEKVi6/8ij300KBWTJq32P/dYeWTSwK41WyTxalN1eRmA5Z9UU/LX9D7FWSmV9SAYx6g==}
|
||||
engines: {node: '>=0.10.0'}
|
||||
|
||||
event-target-bus@1.0.0:
|
||||
resolution: {integrity: sha512-uPcWKbj/BJU3Tbw9XqhHqET4/LBOhvv3/SJWr7NksxA6TC5YqBpaZgawE9R+WpYFCBFSAE4Vun+xQS6w4ABdlA==}
|
||||
|
||||
events@3.3.0:
|
||||
resolution: {integrity: sha512-mQw+2fkQbALzQ7V0MY0IqdnXNOeTtP4r0lN9z7AAawCXgqea7bDii20AYrIBrFd/Hx0M2Ocz6S111CaFkUcb0Q==}
|
||||
engines: {node: '>=0.8.x'}
|
||||
@@ -5986,17 +5983,6 @@ packages:
|
||||
engines: {node: '>=18.3.0'}
|
||||
hasBin: true
|
||||
|
||||
foxact@0.3.0:
|
||||
resolution: {integrity: sha512-CSlMlC0KlKQQEO83iLeQCLuT1V0OqnMWj7mjLstIDV8baMe1w4F7z3cz3/T+6Z8W12jqkQj07rwlw4Gi39knGg==}
|
||||
peerDependencies:
|
||||
react: '*'
|
||||
react-dom: '*'
|
||||
peerDependenciesMeta:
|
||||
react:
|
||||
optional: true
|
||||
react-dom:
|
||||
optional: true
|
||||
|
||||
fs-constants@1.0.0:
|
||||
resolution: {integrity: sha512-y6OAwoSIf7FyjMIv94u+b5rdheZEjzR63GTyZJm5qh4Bi+2YgwLCcI/fPFZkL5PSixOt6ZNKm+w+Hfp/Bciwow==}
|
||||
|
||||
@@ -7710,9 +7696,6 @@ packages:
|
||||
resolution: {integrity: sha512-OwrZRZAfhHww0WEnKHDY8OM0U/Qs8OTfIDWhUD4BLpNJUfXK4cGmjiagGze086m+mhI+V2nD0gfbHEnJjb9STA==}
|
||||
engines: {node: '>=10'}
|
||||
|
||||
server-only@0.0.1:
|
||||
resolution: {integrity: sha512-qepMx2JxAa5jjfzxG79yPPq+8BuFToHd1hm7kI+Z4zAq1ftQiP7HcxMhDDItrbtwVeLg/cY2JnKnrcFkmiswNA==}
|
||||
|
||||
sharp@0.34.5:
|
||||
resolution: {integrity: sha512-Ou9I5Ft9WNcCbXrU9cMgPBcCK8LiwLqcbywW3t4oDV37n1pzpuNLsYiAV8eODnjbtQlSDwZ2cUEeQz4E54Hltg==}
|
||||
engines: {node: ^18.17.0 || ^20.3.0 || >=21.0.0}
|
||||
@@ -13552,8 +13535,6 @@ snapshots:
|
||||
|
||||
esutils@2.0.3: {}
|
||||
|
||||
event-target-bus@1.0.0: {}
|
||||
|
||||
events@3.3.0: {}
|
||||
|
||||
expand-template@2.0.3:
|
||||
@@ -13661,15 +13642,6 @@ snapshots:
|
||||
dependencies:
|
||||
fd-package-json: 2.0.0
|
||||
|
||||
foxact@0.3.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4):
|
||||
dependencies:
|
||||
client-only: 0.0.1
|
||||
event-target-bus: 1.0.0
|
||||
server-only: 0.0.1
|
||||
optionalDependencies:
|
||||
react: 19.2.4
|
||||
react-dom: 19.2.4(react@19.2.4)
|
||||
|
||||
fs-constants@1.0.0:
|
||||
optional: true
|
||||
|
||||
@@ -15905,8 +15877,6 @@ snapshots:
|
||||
|
||||
seroval@1.5.1: {}
|
||||
|
||||
server-only@0.0.1: {}
|
||||
|
||||
sharp@0.34.5:
|
||||
dependencies:
|
||||
'@img/colour': 1.1.0
|
||||
|
||||
@@ -129,6 +129,7 @@ catalog:
|
||||
ahooks: 3.9.7
|
||||
autoprefixer: 10.4.27
|
||||
class-variance-authority: 0.7.1
|
||||
client-only: 0.0.1
|
||||
clsx: 2.1.1
|
||||
cmdk: 1.1.1
|
||||
code-inspector-plugin: 1.5.1
|
||||
@@ -154,7 +155,6 @@ catalog:
|
||||
eslint-plugin-sonarjs: 4.0.2
|
||||
eslint-plugin-storybook: 10.3.5
|
||||
fast-deep-equal: 3.1.3
|
||||
foxact: 0.3.0
|
||||
happy-dom: 20.8.9
|
||||
hast-util-to-jsx-runtime: 2.3.6
|
||||
hono: 4.12.12
|
||||
|
||||
@@ -5,7 +5,7 @@ const mockCopy = vi.fn()
|
||||
const mockReset = vi.fn()
|
||||
let mockCopied = false
|
||||
|
||||
vi.mock('foxact/use-clipboard', () => ({
|
||||
vi.mock('@/hooks/use-clipboard', () => ({
|
||||
useClipboard: () => ({
|
||||
copy: mockCopy,
|
||||
reset: mockReset,
|
||||
|
||||
@@ -3,11 +3,11 @@ import {
|
||||
RiClipboardFill,
|
||||
RiClipboardLine,
|
||||
} from '@remixicon/react'
|
||||
import { useClipboard } from 'foxact/use-clipboard'
|
||||
import { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import ActionButton from '@/app/components/base/action-button'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import { useClipboard } from '@/hooks/use-clipboard'
|
||||
import copyStyle from './style.module.css'
|
||||
|
||||
type Props = {
|
||||
|
||||
@@ -5,7 +5,7 @@ const copy = vi.fn()
|
||||
const reset = vi.fn()
|
||||
let copied = false
|
||||
|
||||
vi.mock('foxact/use-clipboard', () => ({
|
||||
vi.mock('@/hooks/use-clipboard', () => ({
|
||||
useClipboard: () => ({
|
||||
copy,
|
||||
reset,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
'use client'
|
||||
import { useClipboard } from 'foxact/use-clipboard'
|
||||
import { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useClipboard } from '@/hooks/use-clipboard'
|
||||
import Tooltip from '../tooltip'
|
||||
|
||||
type Props = {
|
||||
|
||||
@@ -6,7 +6,7 @@ const mockCopy = vi.fn()
|
||||
let mockCopied = false
|
||||
const mockReset = vi.fn()
|
||||
|
||||
vi.mock('foxact/use-clipboard', () => ({
|
||||
vi.mock('@/hooks/use-clipboard', () => ({
|
||||
useClipboard: () => ({
|
||||
copy: mockCopy,
|
||||
copied: mockCopied,
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
'use client'
|
||||
import type { InputProps } from '../input'
|
||||
import { useClipboard } from 'foxact/use-clipboard'
|
||||
import * as React from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useClipboard } from '@/hooks/use-clipboard'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import ActionButton from '../action-button'
|
||||
import Tooltip from '../tooltip'
|
||||
|
||||
7
web/hooks/noop.ts
Normal file
7
web/hooks/noop.ts
Normal file
@@ -0,0 +1,7 @@
|
||||
type Noop = {
|
||||
// eslint-disable-next-line ts/no-explicit-any
|
||||
(...args: any[]): any
|
||||
}
|
||||
|
||||
/** @see https://foxact.skk.moe/noop */
|
||||
export const noop: Noop = () => { /* noop */ }
|
||||
72
web/hooks/use-clipboard.ts
Normal file
72
web/hooks/use-clipboard.ts
Normal file
@@ -0,0 +1,72 @@
|
||||
import { useRef, useState } from 'react'
|
||||
import { writeTextToClipboard } from '@/utils/clipboard'
|
||||
import { noop } from './noop'
|
||||
import { useStableHandler } from './use-stable-handler-only-when-you-know-what-you-are-doing-or-you-will-be-fired'
|
||||
import { useCallback } from './use-typescript-happy-callback'
|
||||
import 'client-only'
|
||||
|
||||
type UseClipboardOption = {
|
||||
timeout?: number
|
||||
usePromptAsFallback?: boolean
|
||||
promptFallbackText?: string
|
||||
onCopyError?: (error: Error) => void
|
||||
}
|
||||
|
||||
/** @see https://foxact.skk.moe/use-clipboard */
|
||||
export function useClipboard({
|
||||
timeout = 1000,
|
||||
usePromptAsFallback = false,
|
||||
promptFallbackText = 'Failed to copy to clipboard automatically, please manually copy the text below.',
|
||||
onCopyError,
|
||||
}: UseClipboardOption = {}) {
|
||||
const [error, setError] = useState<Error | null>(null)
|
||||
const [copied, setCopied] = useState(false)
|
||||
const copyTimeoutRef = useRef<number | null>(null)
|
||||
|
||||
const stablizedOnCopyError = useStableHandler<[e: Error], void>(onCopyError || noop)
|
||||
|
||||
const handleCopyResult = useCallback((isCopied: boolean) => {
|
||||
if (copyTimeoutRef.current) {
|
||||
clearTimeout(copyTimeoutRef.current)
|
||||
}
|
||||
if (isCopied) {
|
||||
copyTimeoutRef.current = window.setTimeout(() => setCopied(false), timeout)
|
||||
}
|
||||
setCopied(isCopied)
|
||||
}, [timeout])
|
||||
|
||||
const handleCopyError = useCallback((e: Error) => {
|
||||
setError(e)
|
||||
stablizedOnCopyError(e)
|
||||
}, [stablizedOnCopyError])
|
||||
|
||||
const copy = useCallback(async (valueToCopy: string) => {
|
||||
try {
|
||||
await writeTextToClipboard(valueToCopy)
|
||||
}
|
||||
catch (e) {
|
||||
if (usePromptAsFallback) {
|
||||
try {
|
||||
// eslint-disable-next-line no-alert -- prompt as fallback in case of copy error
|
||||
window.prompt(promptFallbackText, valueToCopy)
|
||||
}
|
||||
catch (e2) {
|
||||
handleCopyError(e2 as Error)
|
||||
}
|
||||
}
|
||||
else {
|
||||
handleCopyError(e as Error)
|
||||
}
|
||||
}
|
||||
}, [handleCopyResult, promptFallbackText, handleCopyError, usePromptAsFallback])
|
||||
|
||||
const reset = useCallback(() => {
|
||||
setCopied(false)
|
||||
setError(null)
|
||||
if (copyTimeoutRef.current) {
|
||||
clearTimeout(copyTimeoutRef.current)
|
||||
}
|
||||
}, [])
|
||||
|
||||
return { copy, reset, error, copied }
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
import * as reactExports from 'react'
|
||||
import { useCallback, useEffect, useLayoutEffect, useRef } from 'react'
|
||||
|
||||
// useIsomorphicInsertionEffect
|
||||
const useInsertionEffect
|
||||
= typeof window === 'undefined'
|
||||
// useInsertionEffect is only available in React 18+
|
||||
|
||||
? useEffect
|
||||
: reactExports.useInsertionEffect || useLayoutEffect
|
||||
|
||||
/**
|
||||
* @see https://foxact.skk.moe/use-stable-handler-only-when-you-know-what-you-are-doing-or-you-will-be-fired
|
||||
* Similar to useCallback, with a few subtle differences:
|
||||
* - The returned function is a stable reference, and will always be the same between renders
|
||||
* - No dependency lists required
|
||||
* - Properties or state accessed within the callback will always be "current"
|
||||
*/
|
||||
// eslint-disable-next-line ts/no-explicit-any
|
||||
export function useStableHandler<Args extends any[], Result>(
|
||||
callback: (...args: Args) => Result,
|
||||
): typeof callback {
|
||||
// Keep track of the latest callback:
|
||||
// eslint-disable-next-line ts/no-explicit-any
|
||||
const latestRef = useRef<typeof callback>(shouldNotBeInvokedBeforeMount as any)
|
||||
useInsertionEffect(() => {
|
||||
latestRef.current = callback
|
||||
}, [callback])
|
||||
|
||||
return useCallback<typeof callback>((...args) => {
|
||||
const fn = latestRef.current
|
||||
return fn(...args)
|
||||
}, [])
|
||||
}
|
||||
|
||||
/**
|
||||
* Render methods should be pure, especially when concurrency is used,
|
||||
* so we will throw this error if the callback is called while rendering.
|
||||
*/
|
||||
function shouldNotBeInvokedBeforeMount() {
|
||||
throw new Error(
|
||||
'foxact: the stablized handler cannot be invoked before the component has mounted.',
|
||||
)
|
||||
}
|
||||
10
web/hooks/use-typescript-happy-callback.ts
Normal file
10
web/hooks/use-typescript-happy-callback.ts
Normal file
@@ -0,0 +1,10 @@
|
||||
import { useCallback as useCallbackFromReact } from 'react'
|
||||
|
||||
/** @see https://foxact.skk.moe/use-typescript-happy-callback */
|
||||
const useTypeScriptHappyCallback: <Args extends unknown[], R>(
|
||||
fn: (...args: Args) => R,
|
||||
deps: React.DependencyList,
|
||||
) => (...args: Args) => R = useCallbackFromReact
|
||||
|
||||
/** @see https://foxact.skk.moe/use-typescript-happy-callback */
|
||||
export const useCallback = useTypeScriptHappyCallback
|
||||
@@ -85,6 +85,7 @@
|
||||
"abcjs": "catalog:",
|
||||
"ahooks": "catalog:",
|
||||
"class-variance-authority": "catalog:",
|
||||
"client-only": "catalog:",
|
||||
"clsx": "catalog:",
|
||||
"cmdk": "catalog:",
|
||||
"copy-to-clipboard": "catalog:",
|
||||
@@ -100,7 +101,6 @@
|
||||
"emoji-mart": "catalog:",
|
||||
"es-toolkit": "catalog:",
|
||||
"fast-deep-equal": "catalog:",
|
||||
"foxact": "catalog:",
|
||||
"hast-util-to-jsx-runtime": "catalog:",
|
||||
"html-entities": "catalog:",
|
||||
"html-to-image": "catalog:",
|
||||
|
||||
@@ -83,11 +83,12 @@ afterEach(async () => {
|
||||
})
|
||||
})
|
||||
|
||||
// mock foxact/use-clipboard - not available in test environment
|
||||
vi.mock('foxact/use-clipboard', () => ({
|
||||
// mock custom clipboard hook - wraps writeTextToClipboard with fallback
|
||||
vi.mock('@/hooks/use-clipboard', () => ({
|
||||
useClipboard: () => ({
|
||||
copy: vi.fn(),
|
||||
copied: false,
|
||||
reset: vi.fn(),
|
||||
}),
|
||||
}))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user