diff --git a/api/.env.example b/api/.env.example index c6541731e6..2c1a755059 100644 --- a/api/.env.example +++ b/api/.env.example @@ -71,6 +71,13 @@ REDIS_USE_CLUSTERS=false REDIS_CLUSTERS= REDIS_CLUSTERS_PASSWORD= +REDIS_RETRY_RETRIES=3 +REDIS_RETRY_BACKOFF_BASE=1.0 +REDIS_RETRY_BACKOFF_CAP=10.0 +REDIS_SOCKET_TIMEOUT=5.0 +REDIS_SOCKET_CONNECT_TIMEOUT=5.0 +REDIS_HEALTH_CHECK_INTERVAL=30 + # celery configuration CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1 CELERY_BACKEND=redis diff --git a/api/configs/middleware/cache/redis_config.py b/api/configs/middleware/cache/redis_config.py index 3b91207545..b49275758a 100644 --- a/api/configs/middleware/cache/redis_config.py +++ b/api/configs/middleware/cache/redis_config.py @@ -117,6 +117,37 @@ class RedisConfig(BaseSettings): default=None, ) + REDIS_RETRY_RETRIES: NonNegativeInt = Field( + description="Maximum number of retries per Redis command on " + "transient failures (ConnectionError, TimeoutError, socket.timeout)", + default=3, + ) + + REDIS_RETRY_BACKOFF_BASE: PositiveFloat = Field( + description="Base delay in seconds for exponential backoff between retries", + default=1.0, + ) + + REDIS_RETRY_BACKOFF_CAP: PositiveFloat = Field( + description="Maximum backoff delay in seconds between retries", + default=10.0, + ) + + REDIS_SOCKET_TIMEOUT: PositiveFloat | None = Field( + description="Socket timeout in seconds for Redis read/write operations", + default=5.0, + ) + + REDIS_SOCKET_CONNECT_TIMEOUT: PositiveFloat | None = Field( + description="Socket timeout in seconds for Redis connection establishment", + default=5.0, + ) + + REDIS_HEALTH_CHECK_INTERVAL: NonNegativeInt = Field( + description="Interval in seconds between Redis connection health checks (0 to disable)", + default=30, + ) + @field_validator("REDIS_MAX_CONNECTIONS", mode="before") @classmethod def _empty_string_to_none_for_max_conns(cls, v): diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index f6d076320c..9771d6f1e5 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -384,24 +384,27 @@ class VariableApi(Resource): new_value = None if raw_value is not None: - if variable.value_type == SegmentType.FILE: - if not isinstance(raw_value, dict): - raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") - raw_value = build_from_mapping( - mapping=raw_value, - tenant_id=app_model.tenant_id, - access_controller=_file_access_controller, - ) - elif variable.value_type == SegmentType.ARRAY_FILE: - if not isinstance(raw_value, list): - raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") - if len(raw_value) > 0 and not isinstance(raw_value[0], dict): - raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") - raw_value = build_from_mappings( - mappings=raw_value, - tenant_id=app_model.tenant_id, - access_controller=_file_access_controller, - ) + match variable.value_type: + case SegmentType.FILE: + if not isinstance(raw_value, dict): + raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") + raw_value = build_from_mapping( + mapping=raw_value, + tenant_id=app_model.tenant_id, + access_controller=_file_access_controller, + ) + case SegmentType.ARRAY_FILE: + if not isinstance(raw_value, list): + raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") + if len(raw_value) > 0 and not isinstance(raw_value[0], dict): + raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") + raw_value = build_from_mappings( + mappings=raw_value, + tenant_id=app_model.tenant_id, + access_controller=_file_access_controller, + ) + case _: + pass new_value = build_segment_with_type(variable.value_type, raw_value) draft_var_srv.update_variable(variable, name=new_name, value=new_value) db.session.commit() diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index 93feec0019..3549f9542d 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -223,24 +223,27 @@ class RagPipelineVariableApi(Resource): new_value = None if raw_value is not None: - if variable.value_type == SegmentType.FILE: - if not isinstance(raw_value, dict): - raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") - raw_value = build_from_mapping( - mapping=raw_value, - tenant_id=pipeline.tenant_id, - access_controller=_file_access_controller, - ) - elif variable.value_type == SegmentType.ARRAY_FILE: - if not isinstance(raw_value, list): - raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") - if len(raw_value) > 0 and not isinstance(raw_value[0], dict): - raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") - raw_value = build_from_mappings( - mappings=raw_value, - tenant_id=pipeline.tenant_id, - access_controller=_file_access_controller, - ) + match variable.value_type: + case SegmentType.FILE: + if not isinstance(raw_value, dict): + raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") + raw_value = build_from_mapping( + mapping=raw_value, + tenant_id=pipeline.tenant_id, + access_controller=_file_access_controller, + ) + case SegmentType.ARRAY_FILE: + if not isinstance(raw_value, list): + raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") + if len(raw_value) > 0 and not isinstance(raw_value[0], dict): + raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") + raw_value = build_from_mappings( + mappings=raw_value, + tenant_id=pipeline.tenant_id, + access_controller=_file_access_controller, + ) + case _: + pass new_value = build_segment_with_type(variable.value_type, raw_value) draft_var_srv.update_variable(variable, name=new_name, value=new_value) db.session.commit() diff --git a/api/controllers/console/human_input_form.py b/api/controllers/console/human_input_form.py index e37e78c966..5d79e1b5e9 100644 --- a/api/controllers/console/human_input_form.py +++ b/api/controllers/console/human_input_form.py @@ -168,12 +168,13 @@ class ConsoleWorkflowEventsApi(Resource): else: msg_generator = MessageGenerator() generator: BaseAppGenerator - if app.mode == AppMode.ADVANCED_CHAT: - generator = AdvancedChatAppGenerator() - elif app.mode == AppMode.WORKFLOW: - generator = WorkflowAppGenerator() - else: - raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}") + match app.mode: + case AppMode.ADVANCED_CHAT: + generator = AdvancedChatAppGenerator() + case AppMode.WORKFLOW: + generator = WorkflowAppGenerator() + case _: + raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}") include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true" diff --git a/api/controllers/web/workflow_events.py b/api/controllers/web/workflow_events.py index 61568e70e6..474f9c0957 100644 --- a/api/controllers/web/workflow_events.py +++ b/api/controllers/web/workflow_events.py @@ -72,12 +72,13 @@ class WorkflowEventsApi(WebApiResource): app_mode = AppMode.value_of(app_model.mode) msg_generator = MessageGenerator() generator: BaseAppGenerator - if app_mode == AppMode.ADVANCED_CHAT: - generator = AdvancedChatAppGenerator() - elif app_mode == AppMode.WORKFLOW: - generator = WorkflowAppGenerator() - else: - raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}") + match app_mode: + case AppMode.ADVANCED_CHAT: + generator = AdvancedChatAppGenerator() + case AppMode.WORKFLOW: + generator = WorkflowAppGenerator() + case _: + raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}") include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true" diff --git a/api/core/app/llm/quota.py b/api/core/app/llm/quota.py index a454217768..0bb10190c4 100644 --- a/api/core/app/llm/quota.py +++ b/api/core/app/llm/quota.py @@ -57,36 +57,37 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL used_quota = 1 if used_quota is not None and system_configuration.current_quota_type is not None: - if system_configuration.current_quota_type == ProviderQuotaType.TRIAL: - from services.credit_pool_service import CreditPoolService + match system_configuration.current_quota_type: + case ProviderQuotaType.TRIAL: + from services.credit_pool_service import CreditPoolService - CreditPoolService.check_and_deduct_credits( - tenant_id=tenant_id, - credits_required=used_quota, - ) - elif system_configuration.current_quota_type == ProviderQuotaType.PAID: - from services.credit_pool_service import CreditPoolService - - CreditPoolService.check_and_deduct_credits( - tenant_id=tenant_id, - credits_required=used_quota, - pool_type="paid", - ) - else: - with sessionmaker(bind=db.engine).begin() as session: - stmt = ( - update(Provider) - .where( - Provider.tenant_id == tenant_id, - # TODO: Use provider name with prefix after the data migration. - Provider.provider_name == ModelProviderID(model_instance.provider).provider_name, - Provider.provider_type == ProviderType.SYSTEM.value, - Provider.quota_type == system_configuration.current_quota_type, - Provider.quota_limit > Provider.quota_used, - ) - .values( - quota_used=Provider.quota_used + used_quota, - last_used=naive_utc_now(), - ) + CreditPoolService.check_and_deduct_credits( + tenant_id=tenant_id, + credits_required=used_quota, ) - session.execute(stmt) + case ProviderQuotaType.PAID: + from services.credit_pool_service import CreditPoolService + + CreditPoolService.check_and_deduct_credits( + tenant_id=tenant_id, + credits_required=used_quota, + pool_type="paid", + ) + case ProviderQuotaType.FREE: + with sessionmaker(bind=db.engine).begin() as session: + stmt = ( + update(Provider) + .where( + Provider.tenant_id == tenant_id, + # TODO: Use provider name with prefix after the data migration. + Provider.provider_name == ModelProviderID(model_instance.provider).provider_name, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == system_configuration.current_quota_type, + Provider.quota_limit > Provider.quota_used, + ) + .values( + quota_used=Provider.quota_used + used_quota, + last_used=naive_utc_now(), + ) + ) + session.execute(stmt) diff --git a/api/core/app/task_pipeline/message_file_utils.py b/api/core/app/task_pipeline/message_file_utils.py index b23a33923b..77310baf74 100644 --- a/api/core/app/task_pipeline/message_file_utils.py +++ b/api/core/app/task_pipeline/message_file_utils.py @@ -40,41 +40,44 @@ def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, Upl size = 0 extension = "" - if message_file.transfer_method == FileTransferMethod.REMOTE_URL: - url = message_file.url - if message_file.url: - filename = message_file.url.split("/")[-1].split("?")[0] - if "." in filename: - extension = "." + filename.rsplit(".", 1)[1] - elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE: - if upload_file: - url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id)) - filename = upload_file.name - mime_type = upload_file.mime_type or "application/octet-stream" - size = upload_file.size or 0 - extension = f".{upload_file.extension}" if upload_file.extension else "" - elif message_file.upload_file_id: - url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id)) - elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url: - if message_file.url.startswith(("http://", "https://")): + match message_file.transfer_method: + case FileTransferMethod.REMOTE_URL: url = message_file.url - filename = message_file.url.split("/")[-1].split("?")[0] - if "." in filename: - extension = "." + filename.rsplit(".", 1)[1] - else: - url_parts = message_file.url.split("/") - if url_parts: - file_part = url_parts[-1].split("?")[0] - if "." in file_part: - tool_file_id, ext = file_part.rsplit(".", 1) - extension = f".{ext}" - if len(extension) > MAX_TOOL_FILE_EXTENSION_LENGTH: + if message_file.url: + filename = message_file.url.split("/")[-1].split("?")[0] + if "." in filename: + extension = "." + filename.rsplit(".", 1)[1] + case FileTransferMethod.LOCAL_FILE: + if upload_file: + url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id)) + filename = upload_file.name + mime_type = upload_file.mime_type or "application/octet-stream" + size = upload_file.size or 0 + extension = f".{upload_file.extension}" if upload_file.extension else "" + elif message_file.upload_file_id: + url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id)) + case FileTransferMethod.TOOL_FILE if message_file.url: + if message_file.url.startswith(("http://", "https://")): + url = message_file.url + filename = message_file.url.split("/")[-1].split("?")[0] + if "." in filename: + extension = "." + filename.rsplit(".", 1)[1] + else: + url_parts = message_file.url.split("/") + if url_parts: + file_part = url_parts[-1].split("?")[0] + if "." in file_part: + tool_file_id, ext = file_part.rsplit(".", 1) + extension = f".{ext}" + if len(extension) > MAX_TOOL_FILE_EXTENSION_LENGTH: + extension = ".bin" + else: + tool_file_id = file_part extension = ".bin" - else: - tool_file_id = file_part - extension = ".bin" - url = sign_tool_file(tool_file_id=tool_file_id, extension=extension) - filename = file_part + url = sign_tool_file(tool_file_id=tool_file_id, extension=extension) + filename = file_part + case FileTransferMethod.TOOL_FILE | FileTransferMethod.DATASOURCE_FILE: + pass transfer_method_value = message_file.transfer_method.value remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else "" diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py index 8de002ae55..72171d1536 100644 --- a/api/core/mcp/server/streamable_http.py +++ b/api/core/mcp/server/streamable_http.py @@ -187,15 +187,16 @@ def build_parameter_schema( def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> ToolArgumentsDict: """Prepare arguments based on app mode""" - if app.mode == AppMode.WORKFLOW: - return {"inputs": arguments} - elif app.mode == AppMode.COMPLETION: - return {"query": "", "inputs": arguments} - else: - # Chat modes - create a copy to avoid modifying original dict - args_copy = arguments.copy() - query = args_copy.pop("query", "") - return {"query": query, "inputs": args_copy} + match app.mode: + case AppMode.WORKFLOW: + return {"inputs": arguments} + case AppMode.COMPLETION: + return {"query": "", "inputs": arguments} + case _: + # Chat modes - create a copy to avoid modifying original dict + args_copy = arguments.copy() + query = args_copy.pop("query", "") + return {"query": query, "inputs": args_copy} def extract_answer_from_response(app: App, response: Any) -> str: @@ -229,17 +230,13 @@ def process_streaming_response(response: RateLimitGenerator) -> str: def process_mapping_response(app: App, response: Mapping) -> str: """Process mapping response based on app mode""" - if app.mode in { - AppMode.ADVANCED_CHAT, - AppMode.COMPLETION, - AppMode.CHAT, - AppMode.AGENT_CHAT, - }: - return response.get("answer", "") - elif app.mode == AppMode.WORKFLOW: - return json.dumps(response["data"]["outputs"], ensure_ascii=False) - else: - raise ValueError("Invalid app mode: " + str(app.mode)) + match app.mode: + case AppMode.ADVANCED_CHAT | AppMode.COMPLETION | AppMode.CHAT | AppMode.AGENT_CHAT: + return response.get("answer", "") + case AppMode.WORKFLOW: + return json.dumps(response["data"]["outputs"], ensure_ascii=False) + case _: + raise ValueError("Invalid app mode: " + str(app.mode)) def convert_input_form_to_parameters( diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index be11d2223c..e2d2be92cb 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -72,17 +72,18 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): conversation_id = conversation_id or "" - if app.mode in {AppMode.ADVANCED_CHAT, AppMode.AGENT_CHAT, AppMode.CHAT}: - if not query: - raise ValueError("missing query") + match app.mode: + case AppMode.ADVANCED_CHAT | AppMode.AGENT_CHAT | AppMode.CHAT: + if not query: + raise ValueError("missing query") - return cls.invoke_chat_app(app, user, conversation_id, query, stream, inputs, files) - elif app.mode == AppMode.WORKFLOW: - return cls.invoke_workflow_app(app, user, stream, inputs, files) - elif app.mode == AppMode.COMPLETION: - return cls.invoke_completion_app(app, user, stream, inputs, files) - - raise ValueError("unexpected app type") + return cls.invoke_chat_app(app, user, conversation_id, query, stream, inputs, files) + case AppMode.WORKFLOW: + return cls.invoke_workflow_app(app, user, stream, inputs, files) + case AppMode.COMPLETION: + return cls.invoke_completion_app(app, user, stream, inputs, files) + case _: + raise ValueError("unexpected app type") @classmethod def invoke_chat_app( @@ -98,60 +99,61 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): """ invoke chat app """ - if app.mode == AppMode.ADVANCED_CHAT: - workflow = app.workflow - if not workflow: + match app.mode: + case AppMode.ADVANCED_CHAT: + workflow = app.workflow + if not workflow: + raise ValueError("unexpected app type") + + pause_config = PauseStateLayerConfig( + session_factory=db.engine, + state_owner_user_id=workflow.created_by, + ) + + return AdvancedChatAppGenerator().generate( + app_model=app, + workflow=workflow, + user=user, + args={ + "inputs": inputs, + "query": query, + "files": files, + "conversation_id": conversation_id, + }, + invoke_from=InvokeFrom.SERVICE_API, + workflow_run_id=str(uuid.uuid4()), + streaming=stream, + pause_state_config=pause_config, + ) + case AppMode.AGENT_CHAT: + return AgentChatAppGenerator().generate( + app_model=app, + user=user, + args={ + "inputs": inputs, + "query": query, + "files": files, + "conversation_id": conversation_id, + }, + invoke_from=InvokeFrom.SERVICE_API, + streaming=stream, + ) + case AppMode.CHAT: + return ChatAppGenerator().generate( + app_model=app, + user=user, + args={ + "inputs": inputs, + "query": query, + "files": files, + "conversation_id": conversation_id, + }, + invoke_from=InvokeFrom.SERVICE_API, + streaming=stream, + ) + case _: raise ValueError("unexpected app type") - pause_config = PauseStateLayerConfig( - session_factory=db.engine, - state_owner_user_id=workflow.created_by, - ) - - return AdvancedChatAppGenerator().generate( - app_model=app, - workflow=workflow, - user=user, - args={ - "inputs": inputs, - "query": query, - "files": files, - "conversation_id": conversation_id, - }, - invoke_from=InvokeFrom.SERVICE_API, - workflow_run_id=str(uuid.uuid4()), - streaming=stream, - pause_state_config=pause_config, - ) - elif app.mode == AppMode.AGENT_CHAT: - return AgentChatAppGenerator().generate( - app_model=app, - user=user, - args={ - "inputs": inputs, - "query": query, - "files": files, - "conversation_id": conversation_id, - }, - invoke_from=InvokeFrom.SERVICE_API, - streaming=stream, - ) - elif app.mode == AppMode.CHAT: - return ChatAppGenerator().generate( - app_model=app, - user=user, - args={ - "inputs": inputs, - "query": query, - "files": files, - "conversation_id": conversation_id, - }, - invoke_from=InvokeFrom.SERVICE_API, - streaming=stream, - ) - else: - raise ValueError("unexpected app type") - @classmethod def invoke_workflow_app( cls, diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 552de66f8b..e3b3f83c20 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -961,36 +961,37 @@ class ProviderManager: raise ValueError("quota_used is None") if provider_record.quota_limit is None: raise ValueError("quota_limit is None") - if provider_quota.quota_type == ProviderQuotaType.TRIAL and trail_pool is not None: - quota_configuration = QuotaConfiguration( - quota_type=provider_quota.quota_type, - quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, - quota_used=trail_pool.quota_used, - quota_limit=trail_pool.quota_limit, - is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1, - restrict_models=provider_quota.restrict_models, - ) + match provider_quota.quota_type: + case ProviderQuotaType.TRIAL if trail_pool is not None: + quota_configuration = QuotaConfiguration( + quota_type=provider_quota.quota_type, + quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, + quota_used=trail_pool.quota_used, + quota_limit=trail_pool.quota_limit, + is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1, + restrict_models=provider_quota.restrict_models, + ) - elif provider_quota.quota_type == ProviderQuotaType.PAID and paid_pool is not None: - quota_configuration = QuotaConfiguration( - quota_type=provider_quota.quota_type, - quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, - quota_used=paid_pool.quota_used, - quota_limit=paid_pool.quota_limit, - is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1, - restrict_models=provider_quota.restrict_models, - ) + case ProviderQuotaType.PAID if paid_pool is not None: + quota_configuration = QuotaConfiguration( + quota_type=provider_quota.quota_type, + quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, + quota_used=paid_pool.quota_used, + quota_limit=paid_pool.quota_limit, + is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1, + restrict_models=provider_quota.restrict_models, + ) - else: - quota_configuration = QuotaConfiguration( - quota_type=provider_quota.quota_type, - quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, - quota_used=provider_record.quota_used, - quota_limit=provider_record.quota_limit, - is_valid=provider_record.quota_limit > provider_record.quota_used - or provider_record.quota_limit == -1, - restrict_models=provider_quota.restrict_models, - ) + case _: + quota_configuration = QuotaConfiguration( + quota_type=provider_quota.quota_type, + quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, + quota_used=provider_record.quota_used, + quota_limit=provider_record.quota_limit, + is_valid=provider_record.quota_limit > provider_record.quota_used + or provider_record.quota_limit == -1, + restrict_models=provider_quota.restrict_models, + ) quota_configurations.append(quota_configuration) diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index c72bdf02ed..03e3c5918d 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -7,14 +7,13 @@ from sqlalchemy import select from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.model_manager import ModelManager -from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService from core.rag.entities import RetrievalSourceMetadata from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document as RagDocument from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool -from core.tools.utils.dataset_retriever.dataset_retriever_tool import DefaultRetrievalModelDict from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index a346eb53c4..6a189fa6aa 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -1,11 +1,10 @@ -from typing import NotRequired, TypedDict, cast +from typing import cast from pydantic import BaseModel, Field from sqlalchemy import select from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig -from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict -from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService from core.rag.entities import DocumentContext, RetrievalSourceMetadata from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document as RetrievalDocument @@ -17,18 +16,6 @@ from models.dataset import Dataset from models.dataset import Document as DatasetDocument from services.external_knowledge_service import ExternalDatasetService - -class DefaultRetrievalModelDict(TypedDict): - search_method: RetrievalMethod - reranking_enable: bool - reranking_model: RerankingModelDict - reranking_mode: NotRequired[str] - weights: NotRequired[WeightsDict | None] - score_threshold: NotRequired[float] - top_k: int - score_threshold_enabled: bool - - default_retrieval_model: DefaultRetrievalModelDict = { "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, diff --git a/api/core/workflow/nodes/trigger_webhook/node.py b/api/core/workflow/nodes/trigger_webhook/node.py index ebaac93934..6a0d633627 100644 --- a/api/core/workflow/nodes/trigger_webhook/node.py +++ b/api/core/workflow/nodes/trigger_webhook/node.py @@ -155,24 +155,25 @@ class TriggerWebhookNode(Node[WebhookData]): outputs[param_name] = raw_data continue - if param_type == SegmentType.FILE: - # Get File object (already processed by webhook controller) - files = webhook_data.get("files", {}) - if files and isinstance(files, dict): - file = files.get(param_name) - if file and isinstance(file, dict): - file_var = self.generate_file_var(param_name, file) - if file_var: - outputs[param_name] = file_var + match param_type: + case SegmentType.FILE: + # Get File object (already processed by webhook controller) + files = webhook_data.get("files", {}) + if files and isinstance(files, dict): + file = files.get(param_name) + if file and isinstance(file, dict): + file_var = self.generate_file_var(param_name, file) + if file_var: + outputs[param_name] = file_var + else: + outputs[param_name] = files else: outputs[param_name] = files else: outputs[param_name] = files - else: - outputs[param_name] = files - else: - # Get regular body parameter - outputs[param_name] = webhook_data.get("body", {}).get(param_name) + case _: + # Get regular body parameter + outputs[param_name] = webhook_data.get("body", {}).get(param_name) # Include raw webhook data for debugging/advanced use outputs["_webhook_raw"] = webhook_data diff --git a/api/enterprise/telemetry/metric_handler.py b/api/enterprise/telemetry/metric_handler.py index ffd9a7e2b5..9cda0bf90a 100644 --- a/api/enterprise/telemetry/metric_handler.py +++ b/api/enterprise/telemetry/metric_handler.py @@ -68,46 +68,49 @@ class EnterpriseMetricHandler: # Route to appropriate handler based on case case = envelope.case - if case == TelemetryCase.APP_CREATED: - self._on_app_created(envelope) - self._increment_diagnostic_counter("processed_total", {"case": "app_created"}) - elif case == TelemetryCase.APP_UPDATED: - self._on_app_updated(envelope) - self._increment_diagnostic_counter("processed_total", {"case": "app_updated"}) - elif case == TelemetryCase.APP_DELETED: - self._on_app_deleted(envelope) - self._increment_diagnostic_counter("processed_total", {"case": "app_deleted"}) - elif case == TelemetryCase.FEEDBACK_CREATED: - self._on_feedback_created(envelope) - self._increment_diagnostic_counter("processed_total", {"case": "feedback_created"}) - elif case == TelemetryCase.MESSAGE_RUN: - self._on_message_run(envelope) - self._increment_diagnostic_counter("processed_total", {"case": "message_run"}) - elif case == TelemetryCase.TOOL_EXECUTION: - self._on_tool_execution(envelope) - self._increment_diagnostic_counter("processed_total", {"case": "tool_execution"}) - elif case == TelemetryCase.MODERATION_CHECK: - self._on_moderation_check(envelope) - self._increment_diagnostic_counter("processed_total", {"case": "moderation_check"}) - elif case == TelemetryCase.SUGGESTED_QUESTION: - self._on_suggested_question(envelope) - self._increment_diagnostic_counter("processed_total", {"case": "suggested_question"}) - elif case == TelemetryCase.DATASET_RETRIEVAL: - self._on_dataset_retrieval(envelope) - self._increment_diagnostic_counter("processed_total", {"case": "dataset_retrieval"}) - elif case == TelemetryCase.GENERATE_NAME: - self._on_generate_name(envelope) - self._increment_diagnostic_counter("processed_total", {"case": "generate_name"}) - elif case == TelemetryCase.PROMPT_GENERATION: - self._on_prompt_generation(envelope) - self._increment_diagnostic_counter("processed_total", {"case": "prompt_generation"}) - else: - logger.warning( - "Unknown telemetry case: %s (tenant_id=%s, event_id=%s)", - case, - envelope.tenant_id, - envelope.event_id, - ) + match case: + case TelemetryCase.APP_CREATED: + self._on_app_created(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "app_created"}) + case TelemetryCase.APP_UPDATED: + self._on_app_updated(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "app_updated"}) + case TelemetryCase.APP_DELETED: + self._on_app_deleted(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "app_deleted"}) + case TelemetryCase.FEEDBACK_CREATED: + self._on_feedback_created(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "feedback_created"}) + case TelemetryCase.MESSAGE_RUN: + self._on_message_run(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "message_run"}) + case TelemetryCase.TOOL_EXECUTION: + self._on_tool_execution(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "tool_execution"}) + case TelemetryCase.MODERATION_CHECK: + self._on_moderation_check(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "moderation_check"}) + case TelemetryCase.SUGGESTED_QUESTION: + self._on_suggested_question(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "suggested_question"}) + case TelemetryCase.DATASET_RETRIEVAL: + self._on_dataset_retrieval(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "dataset_retrieval"}) + case TelemetryCase.GENERATE_NAME: + self._on_generate_name(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "generate_name"}) + case TelemetryCase.PROMPT_GENERATION: + self._on_prompt_generation(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "prompt_generation"}) + case TelemetryCase.WORKFLOW_RUN | TelemetryCase.NODE_EXECUTION | TelemetryCase.DRAFT_NODE_EXECUTION: + pass + case _: + logger.warning( + "Unknown telemetry case: %s (tenant_id=%s, event_id=%s)", + case, + envelope.tenant_id, + envelope.event_id, + ) def _is_duplicate(self, envelope: TelemetryEnvelope) -> bool: """Check if this event has already been processed. diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index 5f528dbf9e..b9e592cadb 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -7,10 +7,12 @@ from typing import TYPE_CHECKING, Any, Union import redis from redis import RedisError +from redis.backoff import ExponentialWithJitterBackoff # type: ignore from redis.cache import CacheConfig from redis.client import PubSub from redis.cluster import ClusterNode, RedisCluster from redis.connection import Connection, SSLConnection +from redis.retry import Retry from redis.sentinel import Sentinel from configs import dify_config @@ -158,8 +160,41 @@ def _get_cache_configuration() -> CacheConfig | None: return CacheConfig() +def _get_retry_policy() -> Retry: + """Build the shared retry policy for Redis connections.""" + return Retry( + backoff=ExponentialWithJitterBackoff( + base=dify_config.REDIS_RETRY_BACKOFF_BASE, + cap=dify_config.REDIS_RETRY_BACKOFF_CAP, + ), + retries=dify_config.REDIS_RETRY_RETRIES, + ) + + +def _get_connection_health_params() -> dict[str, Any]: + """Get connection health and retry parameters for standalone and Sentinel Redis clients.""" + return { + "retry": _get_retry_policy(), + "socket_timeout": dify_config.REDIS_SOCKET_TIMEOUT, + "socket_connect_timeout": dify_config.REDIS_SOCKET_CONNECT_TIMEOUT, + "health_check_interval": dify_config.REDIS_HEALTH_CHECK_INTERVAL, + } + + +def _get_cluster_connection_health_params() -> dict[str, Any]: + """Get retry and timeout parameters for Redis Cluster clients. + + RedisCluster does not support ``health_check_interval`` as a constructor + keyword (it is silently stripped by ``cleanup_kwargs``), so it is excluded + here. Only ``retry``, ``socket_timeout``, and ``socket_connect_timeout`` + are passed through. + """ + params = _get_connection_health_params() + return {k: v for k, v in params.items() if k != "health_check_interval"} + + def _get_base_redis_params() -> dict[str, Any]: - """Get base Redis connection parameters.""" + """Get base Redis connection parameters including retry and health policy.""" return { "username": dify_config.REDIS_USERNAME, "password": dify_config.REDIS_PASSWORD or None, @@ -169,6 +204,7 @@ def _get_base_redis_params() -> dict[str, Any]: "decode_responses": False, "protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL, "cache_config": _get_cache_configuration(), + **_get_connection_health_params(), } @@ -215,6 +251,7 @@ def _create_cluster_client() -> Union[redis.Redis, RedisCluster]: "password": dify_config.REDIS_CLUSTERS_PASSWORD, "protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL, "cache_config": _get_cache_configuration(), + **_get_cluster_connection_health_params(), } if dify_config.REDIS_MAX_CONNECTIONS: cluster_kwargs["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS @@ -226,7 +263,8 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis """Create standalone Redis client.""" connection_class, ssl_kwargs = _get_ssl_configuration() - redis_params.update( + params = {**redis_params} + params.update( { "host": dify_config.REDIS_HOST, "port": dify_config.REDIS_PORT, @@ -235,28 +273,31 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis ) if dify_config.REDIS_MAX_CONNECTIONS: - redis_params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS + params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS if ssl_kwargs: - redis_params.update(ssl_kwargs) + params.update(ssl_kwargs) - pool = redis.ConnectionPool(**redis_params) + pool = redis.ConnectionPool(**params) client: redis.Redis = redis.Redis(connection_pool=pool) return client def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> redis.Redis | RedisCluster: max_conns = dify_config.REDIS_MAX_CONNECTIONS - if use_clusters: - if max_conns: - return RedisCluster.from_url(pubsub_url, max_connections=max_conns) - else: - return RedisCluster.from_url(pubsub_url) + if use_clusters: + health_params = _get_cluster_connection_health_params() + kwargs: dict[str, Any] = {**health_params} + if max_conns: + kwargs["max_connections"] = max_conns + return RedisCluster.from_url(pubsub_url, **kwargs) + + health_params = _get_connection_health_params() + kwargs = {**health_params} if max_conns: - return redis.Redis.from_url(pubsub_url, max_connections=max_conns) - else: - return redis.Redis.from_url(pubsub_url) + kwargs["max_connections"] = max_conns + return redis.Redis.from_url(pubsub_url, **kwargs) def init_app(app: DifyApp): diff --git a/api/pyproject.toml b/api/pyproject.toml index dab420fc87..2f6581a199 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ "numpy~=1.26.4", "openpyxl~=3.1.5", "opik~=1.10.37", - "litellm==1.82.6", # Pinned to avoid madoka dependency issue + "litellm==1.83.0", # Pinned to avoid madoka dependency issue "opentelemetry-api==1.40.0", "opentelemetry-distro==0.61b0", "opentelemetry-exporter-otlp==1.40.0", diff --git a/api/tests/unit_tests/extensions/test_redis.py b/api/tests/unit_tests/extensions/test_redis.py index 933fa32894..5e9be4ab9b 100644 --- a/api/tests/unit_tests/extensions/test_redis.py +++ b/api/tests/unit_tests/extensions/test_redis.py @@ -1,53 +1,125 @@ +from unittest.mock import patch + from redis import RedisError +from redis.retry import Retry -from extensions.ext_redis import redis_fallback +from extensions.ext_redis import ( + _get_base_redis_params, + _get_cluster_connection_health_params, + _get_connection_health_params, + redis_fallback, +) -def test_redis_fallback_success(): - @redis_fallback(default_return=None) - def test_func(): - return "success" +class TestGetConnectionHealthParams: + @patch("extensions.ext_redis.dify_config") + def test_includes_all_health_params(self, mock_config): + mock_config.REDIS_RETRY_RETRIES = 3 + mock_config.REDIS_RETRY_BACKOFF_BASE = 1.0 + mock_config.REDIS_RETRY_BACKOFF_CAP = 10.0 + mock_config.REDIS_SOCKET_TIMEOUT = 5.0 + mock_config.REDIS_SOCKET_CONNECT_TIMEOUT = 5.0 + mock_config.REDIS_HEALTH_CHECK_INTERVAL = 30 - assert test_func() == "success" + params = _get_connection_health_params() + + assert "retry" in params + assert "socket_timeout" in params + assert "socket_connect_timeout" in params + assert "health_check_interval" in params + assert isinstance(params["retry"], Retry) + assert params["retry"]._retries == 3 + assert params["socket_timeout"] == 5.0 + assert params["socket_connect_timeout"] == 5.0 + assert params["health_check_interval"] == 30 -def test_redis_fallback_error(): - @redis_fallback(default_return="fallback") - def test_func(): - raise RedisError("Redis error") +class TestGetClusterConnectionHealthParams: + @patch("extensions.ext_redis.dify_config") + def test_excludes_health_check_interval(self, mock_config): + mock_config.REDIS_RETRY_RETRIES = 3 + mock_config.REDIS_RETRY_BACKOFF_BASE = 1.0 + mock_config.REDIS_RETRY_BACKOFF_CAP = 10.0 + mock_config.REDIS_SOCKET_TIMEOUT = 5.0 + mock_config.REDIS_SOCKET_CONNECT_TIMEOUT = 5.0 + mock_config.REDIS_HEALTH_CHECK_INTERVAL = 30 - assert test_func() == "fallback" + params = _get_cluster_connection_health_params() + + assert "retry" in params + assert "socket_timeout" in params + assert "socket_connect_timeout" in params + assert "health_check_interval" not in params -def test_redis_fallback_none_default(): - @redis_fallback() - def test_func(): - raise RedisError("Redis error") +class TestGetBaseRedisParams: + @patch("extensions.ext_redis.dify_config") + def test_includes_retry_and_health_params(self, mock_config): + mock_config.REDIS_USERNAME = None + mock_config.REDIS_PASSWORD = None + mock_config.REDIS_DB = 0 + mock_config.REDIS_SERIALIZATION_PROTOCOL = 3 + mock_config.REDIS_ENABLE_CLIENT_SIDE_CACHE = False + mock_config.REDIS_RETRY_RETRIES = 3 + mock_config.REDIS_RETRY_BACKOFF_BASE = 1.0 + mock_config.REDIS_RETRY_BACKOFF_CAP = 10.0 + mock_config.REDIS_SOCKET_TIMEOUT = 5.0 + mock_config.REDIS_SOCKET_CONNECT_TIMEOUT = 5.0 + mock_config.REDIS_HEALTH_CHECK_INTERVAL = 30 - assert test_func() is None + params = _get_base_redis_params() + + assert "retry" in params + assert isinstance(params["retry"], Retry) + assert params["socket_timeout"] == 5.0 + assert params["socket_connect_timeout"] == 5.0 + assert params["health_check_interval"] == 30 + # Existing params still present + assert params["db"] == 0 + assert params["encoding"] == "utf-8" -def test_redis_fallback_with_args(): - @redis_fallback(default_return=0) - def test_func(x, y): - raise RedisError("Redis error") +class TestRedisFallback: + def test_redis_fallback_success(self): + @redis_fallback(default_return=None) + def test_func(): + return "success" - assert test_func(1, 2) == 0 + assert test_func() == "success" + def test_redis_fallback_error(self): + @redis_fallback(default_return="fallback") + def test_func(): + raise RedisError("Redis error") -def test_redis_fallback_with_kwargs(): - @redis_fallback(default_return={}) - def test_func(x=None, y=None): - raise RedisError("Redis error") + assert test_func() == "fallback" - assert test_func(x=1, y=2) == {} + def test_redis_fallback_none_default(self): + @redis_fallback() + def test_func(): + raise RedisError("Redis error") + assert test_func() is None -def test_redis_fallback_preserves_function_metadata(): - @redis_fallback(default_return=None) - def test_func(): - """Test function docstring""" - pass + def test_redis_fallback_with_args(self): + @redis_fallback(default_return=0) + def test_func(x, y): + raise RedisError("Redis error") - assert test_func.__name__ == "test_func" - assert test_func.__doc__ == "Test function docstring" + assert test_func(1, 2) == 0 + + def test_redis_fallback_with_kwargs(self): + @redis_fallback(default_return={}) + def test_func(x=None, y=None): + raise RedisError("Redis error") + + assert test_func(x=1, y=2) == {} + + def test_redis_fallback_preserves_function_metadata(self): + @redis_fallback(default_return=None) + def test_func(): + """Test function docstring""" + pass + + assert test_func.__name__ == "test_func" + assert test_func.__doc__ == "Test function docstring" diff --git a/api/uv.lock b/api/uv.lock index 5015f76224..b1145eac56 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1507,7 +1507,7 @@ requires-dist = [ { name = "json-repair", specifier = ">=0.55.1" }, { name = "langfuse", specifier = ">=3.0.0,<5.0.0" }, { name = "langsmith", specifier = "~=0.7.16" }, - { name = "litellm", specifier = "==1.82.6" }, + { name = "litellm", specifier = "==1.83.0" }, { name = "markdown", specifier = "~=3.10.2" }, { name = "mlflow-skinny", specifier = ">=3.0.0" }, { name = "numpy", specifier = "~=1.26.4" }, @@ -3121,7 +3121,7 @@ wheels = [ [[package]] name = "litellm" -version = "1.82.6" +version = "1.83.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -3137,9 +3137,9 @@ dependencies = [ { name = "tiktoken" }, { name = "tokenizers" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/29/75/1c537aa458426a9127a92bc2273787b2f987f4e5044e21f01f2eed5244fd/litellm-1.82.6.tar.gz", hash = "sha256:2aa1c2da21fe940c33613aa447119674a3ad4d2ad5eb064e4d5ce5ee42420136", size = 17414147, upload-time = "2026-03-22T06:36:00.452Z" } +sdist = { url = "https://files.pythonhosted.org/packages/22/92/6ce9737554994ca8e536e5f4f6a87cc7c4774b656c9eb9add071caf7d54b/litellm-1.83.0.tar.gz", hash = "sha256:860bebc76c4bb27b4cf90b4a77acd66dba25aced37e3db98750de8a1766bfb7a", size = 17333062, upload-time = "2026-03-31T05:08:25.331Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/02/6c/5327667e6dbe9e98cbfbd4261c8e91386a52e38f41419575854248bbab6a/litellm-1.82.6-py3-none-any.whl", hash = "sha256:164a3ef3e19f309e3cabc199bef3d2045212712fefdfa25fc7f75884a5b5b205", size = 15591595, upload-time = "2026-03-22T06:35:56.795Z" }, + { url = "https://files.pythonhosted.org/packages/19/2c/a670cc050fcd6f45c6199eb99e259c73aea92edba8d5c2fc1b3686d36217/litellm-1.83.0-py3-none-any.whl", hash = "sha256:88c536d339248f3987571493015784671ba3f193a328e1ea6780dbebaa2094a8", size = 15610306, upload-time = "2026-03-31T05:08:21.987Z" }, ] [[package]] diff --git a/docker/.env.example b/docker/.env.example index c046f6d378..f6da6c568d 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -373,6 +373,20 @@ REDIS_USE_CLUSTERS=false REDIS_CLUSTERS= REDIS_CLUSTERS_PASSWORD= +# Redis connection and retry configuration +# max redis retry +REDIS_RETRY_RETRIES=3 +# Base delay (in seconds) for exponential backoff on retries +REDIS_RETRY_BACKOFF_BASE=1.0 +# Cap (in seconds) for exponential backoff on retries +REDIS_RETRY_BACKOFF_CAP=10.0 +# Timeout (in seconds) for Redis socket operations +REDIS_SOCKET_TIMEOUT=5.0 +# Timeout (in seconds) for establishing a Redis connection +REDIS_SOCKET_CONNECT_TIMEOUT=5.0 +# Interval (in seconds) for Redis health checks +REDIS_HEALTH_CHECK_INTERVAL=30 + # ------------------------------ # Celery Configuration # ------------------------------ diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 3f6a13e78e..dbadc58f89 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -100,6 +100,12 @@ x-shared-env: &shared-api-worker-env REDIS_USE_CLUSTERS: ${REDIS_USE_CLUSTERS:-false} REDIS_CLUSTERS: ${REDIS_CLUSTERS:-} REDIS_CLUSTERS_PASSWORD: ${REDIS_CLUSTERS_PASSWORD:-} + REDIS_RETRY_RETRIES: ${REDIS_RETRY_RETRIES:-3} + REDIS_RETRY_BACKOFF_BASE: ${REDIS_RETRY_BACKOFF_BASE:-1.0} + REDIS_RETRY_BACKOFF_CAP: ${REDIS_RETRY_BACKOFF_CAP:-10.0} + REDIS_SOCKET_TIMEOUT: ${REDIS_SOCKET_TIMEOUT:-5.0} + REDIS_SOCKET_CONNECT_TIMEOUT: ${REDIS_SOCKET_CONNECT_TIMEOUT:-5.0} + REDIS_HEALTH_CHECK_INTERVAL: ${REDIS_HEALTH_CHECK_INTERVAL:-30} CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1} CELERY_BACKEND: ${CELERY_BACKEND:-redis} BROKER_USE_SSL: ${BROKER_USE_SSL:-false} diff --git a/e2e/features/apps/create-chatbot-app.feature b/e2e/features/apps/create-chatbot-app.feature new file mode 100644 index 0000000000..4f506e4f40 --- /dev/null +++ b/e2e/features/apps/create-chatbot-app.feature @@ -0,0 +1,11 @@ +@apps @authenticated +Feature: Create Chatbot app + Scenario: Create a new Chatbot app and redirect to the configuration page + Given I am signed in as the default E2E admin + When I open the apps console + And I start creating a blank app + And I expand the beginner app types + And I select the "Chatbot" app type + And I enter a unique E2E app name + And I confirm app creation + Then I should land on the app configuration page diff --git a/e2e/features/apps/create-workflow-app.feature b/e2e/features/apps/create-workflow-app.feature new file mode 100644 index 0000000000..b88d94d899 --- /dev/null +++ b/e2e/features/apps/create-workflow-app.feature @@ -0,0 +1,10 @@ +@apps @authenticated +Feature: Create Workflow app + Scenario: Create a new Workflow app and redirect to the workflow editor + Given I am signed in as the default E2E admin + When I open the apps console + And I start creating a blank app + And I select the "Workflow" app type + And I enter a unique E2E app name + And I confirm app creation + Then I should land on the workflow editor diff --git a/e2e/features/auth/sign-out.feature b/e2e/features/auth/sign-out.feature new file mode 100644 index 0000000000..0f377ea133 --- /dev/null +++ b/e2e/features/auth/sign-out.feature @@ -0,0 +1,8 @@ +@auth @authenticated +Feature: Sign out + Scenario: Sign out from the apps console + Given I am signed in as the default E2E admin + When I open the apps console + And I open the account menu + And I sign out + Then I should be on the sign-in page diff --git a/e2e/features/step-definitions/apps/create-app.steps.ts b/e2e/features/step-definitions/apps/create-app.steps.ts index b8e76c6f06..6bc9ae30b6 100644 --- a/e2e/features/step-definitions/apps/create-app.steps.ts +++ b/e2e/features/step-definitions/apps/create-app.steps.ts @@ -24,6 +24,30 @@ When('I confirm app creation', async function (this: DifyWorld) { await createButton.click() }) +When('I select the {string} app type', async function (this: DifyWorld, appType: string) { + const dialog = this.getPage().getByRole('dialog') + const appTypeTitle = dialog.getByText(appType, { exact: true }) + + await expect(appTypeTitle).toBeVisible() + await appTypeTitle.click() +}) + +When('I expand the beginner app types', async function (this: DifyWorld) { + const page = this.getPage() + const toggle = page.getByRole('button', { name: 'More basic app types' }) + + await expect(toggle).toBeVisible() + await toggle.click() +}) + Then('I should land on the app editor', async function (this: DifyWorld) { await expect(this.getPage()).toHaveURL(/\/app\/[^/]+\/(workflow|configuration)(?:\?.*)?$/) }) + +Then('I should land on the workflow editor', async function (this: DifyWorld) { + await expect(this.getPage()).toHaveURL(/\/app\/[^/]+\/workflow(?:\?.*)?$/) +}) + +Then('I should land on the app configuration page', async function (this: DifyWorld) { + await expect(this.getPage()).toHaveURL(/\/app\/[^/]+\/configuration(?:\?.*)?$/) +}) diff --git a/e2e/features/step-definitions/auth/sign-out.steps.ts b/e2e/features/step-definitions/auth/sign-out.steps.ts new file mode 100644 index 0000000000..935b73c3af --- /dev/null +++ b/e2e/features/step-definitions/auth/sign-out.steps.ts @@ -0,0 +1,25 @@ +import { Then, When } from '@cucumber/cucumber' +import { expect } from '@playwright/test' +import type { DifyWorld } from '../../support/world' + +When('I open the account menu', async function (this: DifyWorld) { + const page = this.getPage() + const trigger = page.getByRole('button', { name: 'Account' }) + + await expect(trigger).toBeVisible() + await trigger.click() +}) + +When('I sign out', async function (this: DifyWorld) { + const page = this.getPage() + + await expect(page.getByText('Log out')).toBeVisible() + await page.getByText('Log out').click() +}) + +Then('I should be on the sign-in page', async function (this: DifyWorld) { + await expect(this.getPage()).toHaveURL(/\/signin/) + await expect(this.getPage().getByRole('button', { name: /^Sign in$/i })).toBeVisible({ + timeout: 30_000, + }) +}) diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 8e8c0970a6..ee3794d88d 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -249,6 +249,9 @@ catalogs: class-variance-authority: specifier: 0.7.1 version: 0.7.1 + client-only: + specifier: 0.0.1 + version: 0.0.1 clsx: specifier: 2.1.1 version: 2.1.1 @@ -324,9 +327,6 @@ catalogs: fast-deep-equal: specifier: 3.1.3 version: 3.1.3 - foxact: - specifier: 0.3.0 - version: 0.3.0 happy-dom: specifier: 20.8.9 version: 20.8.9 @@ -736,6 +736,9 @@ importers: class-variance-authority: specifier: 'catalog:' version: 0.7.1 + client-only: + specifier: 'catalog:' + version: 0.0.1 clsx: specifier: 'catalog:' version: 2.1.1 @@ -781,9 +784,6 @@ importers: fast-deep-equal: specifier: 'catalog:' version: 3.1.3 - foxact: - specifier: 'catalog:' - version: 0.3.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4) hast-util-to-jsx-runtime: specifier: 'catalog:' version: 2.3.6 @@ -5871,9 +5871,6 @@ packages: resolution: {integrity: sha512-kVscqXk4OCp68SZ0dkgEKVi6/8ij300KBWTJq32P/dYeWTSwK41WyTxalN1eRmA5Z9UU/LX9D7FWSmV9SAYx6g==} engines: {node: '>=0.10.0'} - event-target-bus@1.0.0: - resolution: {integrity: sha512-uPcWKbj/BJU3Tbw9XqhHqET4/LBOhvv3/SJWr7NksxA6TC5YqBpaZgawE9R+WpYFCBFSAE4Vun+xQS6w4ABdlA==} - events@3.3.0: resolution: {integrity: sha512-mQw+2fkQbALzQ7V0MY0IqdnXNOeTtP4r0lN9z7AAawCXgqea7bDii20AYrIBrFd/Hx0M2Ocz6S111CaFkUcb0Q==} engines: {node: '>=0.8.x'} @@ -5986,17 +5983,6 @@ packages: engines: {node: '>=18.3.0'} hasBin: true - foxact@0.3.0: - resolution: {integrity: sha512-CSlMlC0KlKQQEO83iLeQCLuT1V0OqnMWj7mjLstIDV8baMe1w4F7z3cz3/T+6Z8W12jqkQj07rwlw4Gi39knGg==} - peerDependencies: - react: '*' - react-dom: '*' - peerDependenciesMeta: - react: - optional: true - react-dom: - optional: true - fs-constants@1.0.0: resolution: {integrity: sha512-y6OAwoSIf7FyjMIv94u+b5rdheZEjzR63GTyZJm5qh4Bi+2YgwLCcI/fPFZkL5PSixOt6ZNKm+w+Hfp/Bciwow==} @@ -7710,9 +7696,6 @@ packages: resolution: {integrity: sha512-OwrZRZAfhHww0WEnKHDY8OM0U/Qs8OTfIDWhUD4BLpNJUfXK4cGmjiagGze086m+mhI+V2nD0gfbHEnJjb9STA==} engines: {node: '>=10'} - server-only@0.0.1: - resolution: {integrity: sha512-qepMx2JxAa5jjfzxG79yPPq+8BuFToHd1hm7kI+Z4zAq1ftQiP7HcxMhDDItrbtwVeLg/cY2JnKnrcFkmiswNA==} - sharp@0.34.5: resolution: {integrity: sha512-Ou9I5Ft9WNcCbXrU9cMgPBcCK8LiwLqcbywW3t4oDV37n1pzpuNLsYiAV8eODnjbtQlSDwZ2cUEeQz4E54Hltg==} engines: {node: ^18.17.0 || ^20.3.0 || >=21.0.0} @@ -13552,8 +13535,6 @@ snapshots: esutils@2.0.3: {} - event-target-bus@1.0.0: {} - events@3.3.0: {} expand-template@2.0.3: @@ -13661,15 +13642,6 @@ snapshots: dependencies: fd-package-json: 2.0.0 - foxact@0.3.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4): - dependencies: - client-only: 0.0.1 - event-target-bus: 1.0.0 - server-only: 0.0.1 - optionalDependencies: - react: 19.2.4 - react-dom: 19.2.4(react@19.2.4) - fs-constants@1.0.0: optional: true @@ -15905,8 +15877,6 @@ snapshots: seroval@1.5.1: {} - server-only@0.0.1: {} - sharp@0.34.5: dependencies: '@img/colour': 1.1.0 diff --git a/pnpm-workspace.yaml b/pnpm-workspace.yaml index 6fe023066a..b7918fff1b 100644 --- a/pnpm-workspace.yaml +++ b/pnpm-workspace.yaml @@ -129,6 +129,7 @@ catalog: ahooks: 3.9.7 autoprefixer: 10.4.27 class-variance-authority: 0.7.1 + client-only: 0.0.1 clsx: 2.1.1 cmdk: 1.1.1 code-inspector-plugin: 1.5.1 @@ -154,7 +155,6 @@ catalog: eslint-plugin-sonarjs: 4.0.2 eslint-plugin-storybook: 10.3.5 fast-deep-equal: 3.1.3 - foxact: 0.3.0 happy-dom: 20.8.9 hast-util-to-jsx-runtime: 2.3.6 hono: 4.12.12 diff --git a/web/app/components/base/copy-feedback/__tests__/index.spec.tsx b/web/app/components/base/copy-feedback/__tests__/index.spec.tsx index 322a9970af..8cc22693b6 100644 --- a/web/app/components/base/copy-feedback/__tests__/index.spec.tsx +++ b/web/app/components/base/copy-feedback/__tests__/index.spec.tsx @@ -5,7 +5,7 @@ const mockCopy = vi.fn() const mockReset = vi.fn() let mockCopied = false -vi.mock('foxact/use-clipboard', () => ({ +vi.mock('@/hooks/use-clipboard', () => ({ useClipboard: () => ({ copy: mockCopy, reset: mockReset, diff --git a/web/app/components/base/copy-feedback/index.tsx b/web/app/components/base/copy-feedback/index.tsx index 80b35eb3a8..5210066670 100644 --- a/web/app/components/base/copy-feedback/index.tsx +++ b/web/app/components/base/copy-feedback/index.tsx @@ -3,11 +3,11 @@ import { RiClipboardFill, RiClipboardLine, } from '@remixicon/react' -import { useClipboard } from 'foxact/use-clipboard' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' import ActionButton from '@/app/components/base/action-button' import Tooltip from '@/app/components/base/tooltip' +import { useClipboard } from '@/hooks/use-clipboard' import copyStyle from './style.module.css' type Props = { diff --git a/web/app/components/base/copy-icon/__tests__/index.spec.tsx b/web/app/components/base/copy-icon/__tests__/index.spec.tsx index 3db76ef606..1ce9e6dbf5 100644 --- a/web/app/components/base/copy-icon/__tests__/index.spec.tsx +++ b/web/app/components/base/copy-icon/__tests__/index.spec.tsx @@ -5,7 +5,7 @@ const copy = vi.fn() const reset = vi.fn() let copied = false -vi.mock('foxact/use-clipboard', () => ({ +vi.mock('@/hooks/use-clipboard', () => ({ useClipboard: () => ({ copy, reset, diff --git a/web/app/components/base/copy-icon/index.tsx b/web/app/components/base/copy-icon/index.tsx index 78c0fcb8c3..15332592d0 100644 --- a/web/app/components/base/copy-icon/index.tsx +++ b/web/app/components/base/copy-icon/index.tsx @@ -1,7 +1,7 @@ 'use client' -import { useClipboard } from 'foxact/use-clipboard' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' +import { useClipboard } from '@/hooks/use-clipboard' import Tooltip from '../tooltip' type Props = { diff --git a/web/app/components/base/input-with-copy/__tests__/index.spec.tsx b/web/app/components/base/input-with-copy/__tests__/index.spec.tsx index 201c419444..33ebec5cbc 100644 --- a/web/app/components/base/input-with-copy/__tests__/index.spec.tsx +++ b/web/app/components/base/input-with-copy/__tests__/index.spec.tsx @@ -6,7 +6,7 @@ const mockCopy = vi.fn() let mockCopied = false const mockReset = vi.fn() -vi.mock('foxact/use-clipboard', () => ({ +vi.mock('@/hooks/use-clipboard', () => ({ useClipboard: () => ({ copy: mockCopy, copied: mockCopied, diff --git a/web/app/components/base/input-with-copy/index.tsx b/web/app/components/base/input-with-copy/index.tsx index e85a7bd6f4..33db47baaa 100644 --- a/web/app/components/base/input-with-copy/index.tsx +++ b/web/app/components/base/input-with-copy/index.tsx @@ -1,8 +1,8 @@ 'use client' import type { InputProps } from '../input' -import { useClipboard } from 'foxact/use-clipboard' import * as React from 'react' import { useTranslation } from 'react-i18next' +import { useClipboard } from '@/hooks/use-clipboard' import { cn } from '@/utils/classnames' import ActionButton from '../action-button' import Tooltip from '../tooltip' diff --git a/web/hooks/noop.ts b/web/hooks/noop.ts new file mode 100644 index 0000000000..9cf6f968dc --- /dev/null +++ b/web/hooks/noop.ts @@ -0,0 +1,7 @@ +type Noop = { + // eslint-disable-next-line ts/no-explicit-any + (...args: any[]): any +} + +/** @see https://foxact.skk.moe/noop */ +export const noop: Noop = () => { /* noop */ } diff --git a/web/hooks/use-clipboard.ts b/web/hooks/use-clipboard.ts new file mode 100644 index 0000000000..6d24c04027 --- /dev/null +++ b/web/hooks/use-clipboard.ts @@ -0,0 +1,72 @@ +import { useRef, useState } from 'react' +import { writeTextToClipboard } from '@/utils/clipboard' +import { noop } from './noop' +import { useStableHandler } from './use-stable-handler-only-when-you-know-what-you-are-doing-or-you-will-be-fired' +import { useCallback } from './use-typescript-happy-callback' +import 'client-only' + +type UseClipboardOption = { + timeout?: number + usePromptAsFallback?: boolean + promptFallbackText?: string + onCopyError?: (error: Error) => void +} + +/** @see https://foxact.skk.moe/use-clipboard */ +export function useClipboard({ + timeout = 1000, + usePromptAsFallback = false, + promptFallbackText = 'Failed to copy to clipboard automatically, please manually copy the text below.', + onCopyError, +}: UseClipboardOption = {}) { + const [error, setError] = useState(null) + const [copied, setCopied] = useState(false) + const copyTimeoutRef = useRef(null) + + const stablizedOnCopyError = useStableHandler<[e: Error], void>(onCopyError || noop) + + const handleCopyResult = useCallback((isCopied: boolean) => { + if (copyTimeoutRef.current) { + clearTimeout(copyTimeoutRef.current) + } + if (isCopied) { + copyTimeoutRef.current = window.setTimeout(() => setCopied(false), timeout) + } + setCopied(isCopied) + }, [timeout]) + + const handleCopyError = useCallback((e: Error) => { + setError(e) + stablizedOnCopyError(e) + }, [stablizedOnCopyError]) + + const copy = useCallback(async (valueToCopy: string) => { + try { + await writeTextToClipboard(valueToCopy) + } + catch (e) { + if (usePromptAsFallback) { + try { + // eslint-disable-next-line no-alert -- prompt as fallback in case of copy error + window.prompt(promptFallbackText, valueToCopy) + } + catch (e2) { + handleCopyError(e2 as Error) + } + } + else { + handleCopyError(e as Error) + } + } + }, [handleCopyResult, promptFallbackText, handleCopyError, usePromptAsFallback]) + + const reset = useCallback(() => { + setCopied(false) + setError(null) + if (copyTimeoutRef.current) { + clearTimeout(copyTimeoutRef.current) + } + }, []) + + return { copy, reset, error, copied } +} diff --git a/web/hooks/use-stable-handler-only-when-you-know-what-you-are-doing-or-you-will-be-fired.ts b/web/hooks/use-stable-handler-only-when-you-know-what-you-are-doing-or-you-will-be-fired.ts new file mode 100644 index 0000000000..227f4fd1fb --- /dev/null +++ b/web/hooks/use-stable-handler-only-when-you-know-what-you-are-doing-or-you-will-be-fired.ts @@ -0,0 +1,44 @@ +import * as reactExports from 'react' +import { useCallback, useEffect, useLayoutEffect, useRef } from 'react' + +// useIsomorphicInsertionEffect +const useInsertionEffect + = typeof window === 'undefined' + // useInsertionEffect is only available in React 18+ + + ? useEffect + : reactExports.useInsertionEffect || useLayoutEffect + +/** + * @see https://foxact.skk.moe/use-stable-handler-only-when-you-know-what-you-are-doing-or-you-will-be-fired + * Similar to useCallback, with a few subtle differences: + * - The returned function is a stable reference, and will always be the same between renders + * - No dependency lists required + * - Properties or state accessed within the callback will always be "current" + */ +// eslint-disable-next-line ts/no-explicit-any +export function useStableHandler( + callback: (...args: Args) => Result, +): typeof callback { + // Keep track of the latest callback: + // eslint-disable-next-line ts/no-explicit-any + const latestRef = useRef(shouldNotBeInvokedBeforeMount as any) + useInsertionEffect(() => { + latestRef.current = callback + }, [callback]) + + return useCallback((...args) => { + const fn = latestRef.current + return fn(...args) + }, []) +} + +/** + * Render methods should be pure, especially when concurrency is used, + * so we will throw this error if the callback is called while rendering. + */ +function shouldNotBeInvokedBeforeMount() { + throw new Error( + 'foxact: the stablized handler cannot be invoked before the component has mounted.', + ) +} diff --git a/web/hooks/use-typescript-happy-callback.ts b/web/hooks/use-typescript-happy-callback.ts new file mode 100644 index 0000000000..db3ba372c0 --- /dev/null +++ b/web/hooks/use-typescript-happy-callback.ts @@ -0,0 +1,10 @@ +import { useCallback as useCallbackFromReact } from 'react' + +/** @see https://foxact.skk.moe/use-typescript-happy-callback */ +const useTypeScriptHappyCallback: ( + fn: (...args: Args) => R, + deps: React.DependencyList, +) => (...args: Args) => R = useCallbackFromReact + +/** @see https://foxact.skk.moe/use-typescript-happy-callback */ +export const useCallback = useTypeScriptHappyCallback diff --git a/web/package.json b/web/package.json index d2a9e88f4a..8bc31dce31 100644 --- a/web/package.json +++ b/web/package.json @@ -85,6 +85,7 @@ "abcjs": "catalog:", "ahooks": "catalog:", "class-variance-authority": "catalog:", + "client-only": "catalog:", "clsx": "catalog:", "cmdk": "catalog:", "copy-to-clipboard": "catalog:", @@ -100,7 +101,6 @@ "emoji-mart": "catalog:", "es-toolkit": "catalog:", "fast-deep-equal": "catalog:", - "foxact": "catalog:", "hast-util-to-jsx-runtime": "catalog:", "html-entities": "catalog:", "html-to-image": "catalog:", diff --git a/web/vitest.setup.ts b/web/vitest.setup.ts index b17a59bab6..b945f675f7 100644 --- a/web/vitest.setup.ts +++ b/web/vitest.setup.ts @@ -83,11 +83,12 @@ afterEach(async () => { }) }) -// mock foxact/use-clipboard - not available in test environment -vi.mock('foxact/use-clipboard', () => ({ +// mock custom clipboard hook - wraps writeTextToClipboard with fallback +vi.mock('@/hooks/use-clipboard', () => ({ useClipboard: () => ({ copy: vi.fn(), copied: false, + reset: vi.fn(), }), }))