From a297b06aacc6c7c6ed176fb29a5db40c8074de9e Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Fri, 6 Feb 2026 14:38:15 +0800 Subject: [PATCH 1/8] fix: fix tool type is miss (#32042) --- api/.importlinter | 2 - api/core/workflow/nodes/agent/agent_node.py | 42 +--- .../workflow/nodes/agent/test_agent_node.py | 197 ------------------ .../config/agent/agent-tools/index.tsx | 1 + .../hooks/use-tool-selector-state.ts | 1 + .../workflow/block-selector/types.ts | 1 + 6 files changed, 5 insertions(+), 239 deletions(-) delete mode 100644 api/tests/unit_tests/core/workflow/nodes/agent/test_agent_node.py diff --git a/api/.importlinter b/api/.importlinter index fb66df7334..2a6bb66a95 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -102,8 +102,6 @@ forbidden_modules = core.trigger core.variables ignore_imports = - core.workflow.nodes.agent.agent_node -> core.db.session_factory - core.workflow.nodes.agent.agent_node -> models.tools core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis core.workflow.workflow_entry -> core.app.workflow.layers.observability diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index e64a83034c..e195aebe6d 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -2,7 +2,7 @@ from __future__ import annotations import json from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Union, cast +from typing import TYPE_CHECKING, Any, cast from packaging.version import Version from pydantic import ValidationError @@ -11,7 +11,6 @@ from sqlalchemy.orm import Session from core.agent.entities import AgentToolEntity from core.agent.plugin_entities import AgentStrategyParameter -from core.db.session_factory import session_factory from core.file import File, FileTransferMethod from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager @@ -50,12 +49,6 @@ from factories import file_factory from factories.agent_factory import get_plugin_agent_strategy from models import ToolFile from models.model import Conversation -from models.tools import ( - ApiToolProvider, - BuiltinToolProvider, - MCPToolProvider, - WorkflowToolProvider, -) from services.tools.builtin_tools_manage_service import BuiltinToolManageService from .exc import ( @@ -266,7 +259,7 @@ class AgentNode(Node[AgentNodeData]): value = cast(list[dict[str, Any]], value) tool_value = [] for tool in value: - provider_type = self._infer_tool_provider_type(tool, self.tenant_id) + provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN)) setting_params = tool.get("settings", {}) parameters = tool.get("parameters", {}) manual_input_params = [key for key, value in parameters.items() if value is not None] @@ -755,34 +748,3 @@ class AgentNode(Node[AgentNodeData]): llm_usage=llm_usage, ) ) - - @staticmethod - def _infer_tool_provider_type(tool_config: dict[str, Any], tenant_id: str) -> ToolProviderType: - provider_type_str = tool_config.get("type") - if provider_type_str: - return ToolProviderType(provider_type_str) - - provider_id = tool_config.get("provider_name") - if not provider_id: - return ToolProviderType.BUILT_IN - - with session_factory.create_session() as session: - provider_map: dict[ - type[Union[WorkflowToolProvider, MCPToolProvider, ApiToolProvider, BuiltinToolProvider]], - ToolProviderType, - ] = { - WorkflowToolProvider: ToolProviderType.WORKFLOW, - MCPToolProvider: ToolProviderType.MCP, - ApiToolProvider: ToolProviderType.API, - BuiltinToolProvider: ToolProviderType.BUILT_IN, - } - - for provider_model, provider_type in provider_map.items(): - stmt = select(provider_model).where( - provider_model.id == provider_id, - provider_model.tenant_id == tenant_id, - ) - if session.scalar(stmt): - return provider_type - - raise AgentNodeError(f"Tool provider with ID '{provider_id}' not found.") diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/test_agent_node.py b/api/tests/unit_tests/core/workflow/nodes/agent/test_agent_node.py deleted file mode 100644 index a95892d0b6..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/agent/test_agent_node.py +++ /dev/null @@ -1,197 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest - -from core.tools.entities.tool_entities import ToolProviderType -from core.workflow.nodes.agent.agent_node import AgentNode - - -class TestInferToolProviderType: - """Test cases for AgentNode._infer_tool_provider_type method.""" - - def test_infer_type_from_config_workflow(self): - """Test inferring workflow provider type from config.""" - tool_config = { - "type": "workflow", - "provider_name": "workflow-provider-id", - } - tenant_id = "test-tenant" - - result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - assert result == ToolProviderType.WORKFLOW - - def test_infer_type_from_config_builtin(self): - """Test inferring builtin provider type from config.""" - tool_config = { - "type": "builtin", - "provider_name": "builtin-provider-id", - } - tenant_id = "test-tenant" - - result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - assert result == ToolProviderType.BUILT_IN - - def test_infer_type_from_config_api(self): - """Test inferring API provider type from config.""" - tool_config = { - "type": "api", - "provider_name": "api-provider-id", - } - tenant_id = "test-tenant" - - result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - assert result == ToolProviderType.API - - def test_infer_type_from_config_mcp(self): - """Test inferring MCP provider type from config.""" - tool_config = { - "type": "mcp", - "provider_name": "mcp-provider-id", - } - tenant_id = "test-tenant" - - result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - assert result == ToolProviderType.MCP - - def test_infer_type_invalid_config_value_raises_error(self): - """Test that invalid type value in config raises ValueError.""" - tool_config = { - "type": "invalid-type", - "provider_name": "workflow-provider-id", - } - tenant_id = "test-tenant" - - with pytest.raises(ValueError): - AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - def test_infer_workflow_type_from_database(self): - """Test inferring workflow provider type from database.""" - tool_config = { - "provider_name": "workflow-provider-id", - } - tenant_id = "test-tenant" - - with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: - mock_session = MagicMock() - mock_create_session.return_value.__enter__.return_value = mock_session - - # First query (WorkflowToolProvider) returns a result - mock_session.scalar.return_value = True - - result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - assert result == ToolProviderType.WORKFLOW - # Should only query once (after finding WorkflowToolProvider) - assert mock_session.scalar.call_count == 1 - - def test_infer_mcp_type_from_database(self): - """Test inferring MCP provider type from database.""" - tool_config = { - "provider_name": "mcp-provider-id", - } - tenant_id = "test-tenant" - - with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: - mock_session = MagicMock() - mock_create_session.return_value.__enter__.return_value = mock_session - - # First query (WorkflowToolProvider) returns None - # Second query (MCPToolProvider) returns a result - mock_session.scalar.side_effect = [None, True] - - result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - assert result == ToolProviderType.MCP - assert mock_session.scalar.call_count == 2 - - def test_infer_api_type_from_database(self): - """Test inferring API provider type from database.""" - tool_config = { - "provider_name": "api-provider-id", - } - tenant_id = "test-tenant" - - with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: - mock_session = MagicMock() - mock_create_session.return_value.__enter__.return_value = mock_session - - # First query (WorkflowToolProvider) returns None - # Second query (MCPToolProvider) returns None - # Third query (ApiToolProvider) returns a result - mock_session.scalar.side_effect = [None, None, True] - - result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - assert result == ToolProviderType.API - assert mock_session.scalar.call_count == 3 - - def test_infer_builtin_type_from_database(self): - """Test inferring builtin provider type from database.""" - tool_config = { - "provider_name": "builtin-provider-id", - } - tenant_id = "test-tenant" - - with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: - mock_session = MagicMock() - mock_create_session.return_value.__enter__.return_value = mock_session - - # First three queries return None - # Fourth query (BuiltinToolProvider) returns a result - mock_session.scalar.side_effect = [None, None, None, True] - - result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - assert result == ToolProviderType.BUILT_IN - assert mock_session.scalar.call_count == 4 - - def test_infer_type_default_when_not_found(self): - """Test raising AgentNodeError when provider is not found in database.""" - tool_config = { - "provider_name": "unknown-provider-id", - } - tenant_id = "test-tenant" - - with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: - mock_session = MagicMock() - mock_create_session.return_value.__enter__.return_value = mock_session - - # All queries return None - mock_session.scalar.return_value = None - - # Current implementation raises AgentNodeError when provider not found - from core.workflow.nodes.agent.exc import AgentNodeError - - with pytest.raises(AgentNodeError, match="Tool provider with ID 'unknown-provider-id' not found"): - AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - def test_infer_type_default_when_no_provider_name(self): - """Test defaulting to BUILT_IN when provider_name is missing.""" - tool_config = {} - tenant_id = "test-tenant" - - result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - assert result == ToolProviderType.BUILT_IN - - def test_infer_type_database_exception_propagates(self): - """Test that database exception propagates (current implementation doesn't catch it).""" - tool_config = { - "provider_name": "provider-id", - } - tenant_id = "test-tenant" - - with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: - mock_session = MagicMock() - mock_create_session.return_value.__enter__.return_value = mock_session - - # Database query raises exception - mock_session.scalar.side_effect = Exception("Database error") - - # Current implementation doesn't catch exceptions, so it propagates - with pytest.raises(Exception, match="Database error"): - AgentNode._infer_tool_provider_type(tool_config, tenant_id) diff --git a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx index 486c0a8ac9..b97aa6e775 100644 --- a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx +++ b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx @@ -109,6 +109,7 @@ const AgentTools: FC = () => { tool_parameters: paramsWithDefaultValue, notAuthor: !tool.is_team_authorization, enabled: true, + type: tool.provider_type as CollectionType, } } const handleSelectTool = (tool: ToolDefaultValue) => { diff --git a/web/app/components/plugins/plugin-detail-panel/tool-selector/hooks/use-tool-selector-state.ts b/web/app/components/plugins/plugin-detail-panel/tool-selector/hooks/use-tool-selector-state.ts index 44d0ff864e..e5edea5679 100644 --- a/web/app/components/plugins/plugin-detail-panel/tool-selector/hooks/use-tool-selector-state.ts +++ b/web/app/components/plugins/plugin-detail-panel/tool-selector/hooks/use-tool-selector-state.ts @@ -129,6 +129,7 @@ export const useToolSelectorState = ({ extra: { description: tool.tool_description, }, + type: tool.provider_type, } }, []) diff --git a/web/app/components/workflow/block-selector/types.ts b/web/app/components/workflow/block-selector/types.ts index 07efb0d02f..39e7b033bd 100644 --- a/web/app/components/workflow/block-selector/types.ts +++ b/web/app/components/workflow/block-selector/types.ts @@ -87,6 +87,7 @@ export type ToolValue = { enabled?: boolean extra?: { description?: string } & Record credential_id?: string + type?: string } export type DataSourceItem = { From 4971e11734e668800a7f5c9c1b879fee22061e28 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Fri, 6 Feb 2026 15:12:32 +0800 Subject: [PATCH 2/8] perf: use batch delete method instead of single delete (#32036) Co-authored-by: fatelei Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: FFXN --- api/tasks/batch_clean_document_task.py | 211 ++++++++++++++---- api/tasks/delete_segment_from_index_task.py | 11 +- api/tasks/document_indexing_sync_task.py | 4 +- .../tasks/test_document_indexing_sync_task.py | 15 ++ 4 files changed, 190 insertions(+), 51 deletions(-) diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py index d388284980..747106d373 100644 --- a/api/tasks/batch_clean_document_task.py +++ b/api/tasks/batch_clean_document_task.py @@ -14,6 +14,9 @@ from models.model import UploadFile logger = logging.getLogger(__name__) +# Batch size for database operations to keep transactions short +BATCH_SIZE = 1000 + @shared_task(queue="dataset") def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str | None, file_ids: list[str]): @@ -31,63 +34,179 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form if not doc_form: raise ValueError("doc_form is required") - with session_factory.create_session() as session: - try: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() - - if not dataset: - raise Exception("Document has no dataset") - - session.query(DatasetMetadataBinding).where( - DatasetMetadataBinding.dataset_id == dataset_id, - DatasetMetadataBinding.document_id.in_(document_ids), - ).delete(synchronize_session=False) + storage_keys_to_delete: list[str] = [] + index_node_ids: list[str] = [] + segment_ids: list[str] = [] + total_image_upload_file_ids: list[str] = [] + try: + # ============ Step 1: Query segment and file data (short read-only transaction) ============ + with session_factory.create_session() as session: + # Get segments info segments = session.scalars( select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)) ).all() - # check segment is exist + if segments: index_node_ids = [segment.index_node_id for segment in segments] - index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.clean( - dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True - ) + segment_ids = [segment.id for segment in segments] + # Collect image file IDs from segment content for segment in segments: image_upload_file_ids = get_image_upload_file_ids(segment.content) - image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all() - for image_file in image_files: - try: - if image_file and image_file.key: - storage.delete(image_file.key) - except Exception: - logger.exception( - "Delete image_files failed when storage deleted, \ - image_upload_file_is: %s", - image_file.id, - ) - stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)) - session.execute(stmt) - session.delete(segment) + total_image_upload_file_ids.extend(image_upload_file_ids) + + # Query storage keys for image files + if total_image_upload_file_ids: + image_files = session.scalars( + select(UploadFile).where(UploadFile.id.in_(total_image_upload_file_ids)) + ).all() + storage_keys_to_delete.extend([f.key for f in image_files if f and f.key]) + + # Query storage keys for document files if file_ids: files = session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all() - for file in files: - try: - storage.delete(file.key) - except Exception: - logger.exception("Delete file failed when document deleted, file_id: %s", file.id) - stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids)) - session.execute(stmt) + storage_keys_to_delete.extend([f.key for f in files if f and f.key]) - session.commit() - - end_at = time.perf_counter() - logger.info( - click.style( - f"Cleaned documents when documents deleted latency: {end_at - start_at}", - fg="green", + # ============ Step 2: Clean vector index (external service, fresh session for dataset) ============ + if index_node_ids: + try: + # Fetch dataset in a fresh session to avoid DetachedInstanceError + with session_factory.create_session() as session: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if not dataset: + logger.warning("Dataset not found for vector index cleanup, dataset_id: %s", dataset_id) + else: + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.clean( + dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True + ) + except Exception: + logger.exception( + "Failed to clean vector index for dataset_id: %s, document_ids: %s, index_node_ids count: %d", + dataset_id, + document_ids, + len(index_node_ids), ) - ) + + # ============ Step 3: Delete metadata binding (separate short transaction) ============ + try: + with session_factory.create_session() as session: + deleted_count = ( + session.query(DatasetMetadataBinding) + .where( + DatasetMetadataBinding.dataset_id == dataset_id, + DatasetMetadataBinding.document_id.in_(document_ids), + ) + .delete(synchronize_session=False) + ) + session.commit() + logger.debug("Deleted %d metadata bindings for dataset_id: %s", deleted_count, dataset_id) except Exception: - logger.exception("Cleaned documents when documents deleted failed") + logger.exception( + "Failed to delete metadata bindings for dataset_id: %s, document_ids: %s", + dataset_id, + document_ids, + ) + + # ============ Step 4: Batch delete UploadFile records (multiple short transactions) ============ + if total_image_upload_file_ids: + failed_batches = 0 + total_batches = (len(total_image_upload_file_ids) + BATCH_SIZE - 1) // BATCH_SIZE + for i in range(0, len(total_image_upload_file_ids), BATCH_SIZE): + batch = total_image_upload_file_ids[i : i + BATCH_SIZE] + try: + with session_factory.create_session() as session: + stmt = delete(UploadFile).where(UploadFile.id.in_(batch)) + session.execute(stmt) + session.commit() + except Exception: + failed_batches += 1 + logger.exception( + "Failed to delete image UploadFile batch %d-%d for dataset_id: %s", + i, + i + len(batch), + dataset_id, + ) + if failed_batches > 0: + logger.warning( + "Image UploadFile deletion: %d/%d batches failed for dataset_id: %s", + failed_batches, + total_batches, + dataset_id, + ) + + # ============ Step 5: Batch delete DocumentSegment records (multiple short transactions) ============ + if segment_ids: + failed_batches = 0 + total_batches = (len(segment_ids) + BATCH_SIZE - 1) // BATCH_SIZE + for i in range(0, len(segment_ids), BATCH_SIZE): + batch = segment_ids[i : i + BATCH_SIZE] + try: + with session_factory.create_session() as session: + segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(batch)) + session.execute(segment_delete_stmt) + session.commit() + except Exception: + failed_batches += 1 + logger.exception( + "Failed to delete DocumentSegment batch %d-%d for dataset_id: %s, document_ids: %s", + i, + i + len(batch), + dataset_id, + document_ids, + ) + if failed_batches > 0: + logger.warning( + "DocumentSegment deletion: %d/%d batches failed, document_ids: %s", + failed_batches, + total_batches, + document_ids, + ) + + # ============ Step 6: Delete document-associated files (separate short transaction) ============ + if file_ids: + try: + with session_factory.create_session() as session: + stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids)) + session.execute(stmt) + session.commit() + except Exception: + logger.exception( + "Failed to delete document UploadFile records for dataset_id: %s, file_ids: %s", + dataset_id, + file_ids, + ) + + # ============ Step 7: Delete storage files (I/O operations, no DB transaction) ============ + storage_delete_failures = 0 + for storage_key in storage_keys_to_delete: + try: + storage.delete(storage_key) + except Exception: + storage_delete_failures += 1 + logger.exception("Failed to delete file from storage, key: %s", storage_key) + if storage_delete_failures > 0: + logger.warning( + "Storage file deletion completed with %d failures out of %d total files for dataset_id: %s", + storage_delete_failures, + len(storage_keys_to_delete), + dataset_id, + ) + + end_at = time.perf_counter() + logger.info( + click.style( + f"Cleaned documents when documents deleted latency: {end_at - start_at:.2f}s, " + f"dataset_id: {dataset_id}, document_ids: {document_ids}, " + f"segments: {len(segment_ids)}, image_files: {len(total_image_upload_file_ids)}, " + f"storage_files: {len(storage_keys_to_delete)}", + fg="green", + ) + ) + except Exception: + logger.exception( + "Batch clean documents failed for dataset_id: %s, document_ids: %s", + dataset_id, + document_ids, + ) diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index 764c635d83..a6a2dcebc8 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import delete from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -67,8 +68,14 @@ def delete_segment_from_index_task( if segment_attachment_bindings: attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings] index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False) - for binding in segment_attachment_bindings: - session.delete(binding) + segment_attachment_bind_ids = [i.id for i in segment_attachment_bindings] + + for i in range(0, len(segment_attachment_bind_ids), 1000): + segment_attachment_bind_delete_stmt = delete(SegmentAttachmentBinding).where( + SegmentAttachmentBinding.id.in_(segment_attachment_bind_ids[i : i + 1000]) + ) + session.execute(segment_attachment_bind_delete_stmt) + # delete upload file session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False) session.commit() diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 149185f6e2..8fa5faa796 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -28,7 +28,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): logger.info(click.style(f"Start sync document: {document_id}", fg="green")) start_at = time.perf_counter() - with session_factory.create_session() as session: + with session_factory.create_session() as session, session.begin(): document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() if not document: @@ -68,7 +68,6 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): document.indexing_status = "error" document.error = "Datasource credential not found. Please reconnect your Notion workspace." document.stopped_at = naive_utc_now() - session.commit() return loader = NotionExtractor( @@ -85,7 +84,6 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): if last_edited_time != page_edited_time: document.indexing_status = "parsing" document.processing_started_at = naive_utc_now() - session.commit() # delete all document segment and index try: diff --git a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py index fa33034f40..24e0bc76cf 100644 --- a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py @@ -114,6 +114,21 @@ def mock_db_session(): session = MagicMock() # Ensure tests can observe session.close() via context manager teardown session.close = MagicMock() + session.commit = MagicMock() + + # Mock session.begin() context manager to auto-commit on exit + begin_cm = MagicMock() + begin_cm.__enter__.return_value = session + + def _begin_exit_side_effect(*args, **kwargs): + # session.begin().__exit__() should commit if no exception + if args[0] is None: # No exception + session.commit() + + begin_cm.__exit__.side_effect = _begin_exit_side_effect + session.begin.return_value = begin_cm + + # Mock create_session() context manager cm = MagicMock() cm.__enter__.return_value = session From c2a3f459c73b2afabd6e59e4613a569cbacc0ca0 Mon Sep 17 00:00:00 2001 From: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com> Date: Fri, 6 Feb 2026 15:32:52 +0800 Subject: [PATCH 3/8] fix(api): return proper HTTP 204 status code in DELETE endpoints (#32012) Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/tag/tags.py | 2 +- api/controllers/service_api/dataset/dataset.py | 8 ++++---- api/controllers/service_api/dataset/document.py | 2 +- api/controllers/service_api/dataset/metadata.py | 2 +- api/controllers/service_api/dataset/segment.py | 4 ++-- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index bc0776f658..7511c970a3 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -120,7 +120,7 @@ class TagUpdateDeleteApi(Resource): TagService.delete_tag(tag_id) - return 204 + return "", 204 @console_ns.route("/tag-bindings/create") diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index db5cabe8aa..c06b81b775 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -396,7 +396,7 @@ class DatasetApi(DatasetApiResource): try: if DatasetService.delete_dataset(dataset_id_str, current_user): DatasetPermissionService.clear_partial_member_list(dataset_id_str) - return 204 + return "", 204 else: raise NotFound("Dataset not found.") except services.errors.dataset.DatasetInUseError: @@ -557,7 +557,7 @@ class DatasetTagsApi(DatasetApiResource): payload = TagDeletePayload.model_validate(service_api_ns.payload or {}) TagService.delete_tag(payload.tag_id) - return 204 + return "", 204 @service_api_ns.route("/datasets/tags/binding") @@ -581,7 +581,7 @@ class DatasetTagBindingApi(DatasetApiResource): payload = TagBindingPayload.model_validate(service_api_ns.payload or {}) TagService.save_tag_binding({"tag_ids": payload.tag_ids, "target_id": payload.target_id, "type": "knowledge"}) - return 204 + return "", 204 @service_api_ns.route("/datasets/tags/unbinding") @@ -605,7 +605,7 @@ class DatasetTagUnbindingApi(DatasetApiResource): payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {}) TagService.delete_tag_binding({"tag_id": payload.tag_id, "target_id": payload.target_id, "type": "knowledge"}) - return 204 + return "", 204 @service_api_ns.route("/datasets//tags") diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index a01524f1bc..0aeb4a2d36 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -746,4 +746,4 @@ class DocumentApi(DatasetApiResource): except services.errors.document.DocumentIndexingError: raise DocumentIndexingError("Cannot delete document during indexing.") - return 204 + return "", 204 diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index 692342a38a..52166f7fcc 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -128,7 +128,7 @@ class DatasetMetadataServiceApi(DatasetApiResource): DatasetService.check_dataset_permission(dataset, current_user) MetadataService.delete_metadata(dataset_id_str, metadata_id_str) - return 204 + return "", 204 @service_api_ns.route("/datasets//metadata/built-in") diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 95679e6fcb..4eb4fed29a 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -233,7 +233,7 @@ class DatasetSegmentApi(DatasetApiResource): if not segment: raise NotFound("Segment not found.") SegmentService.delete_segment(segment, document, dataset) - return 204 + return "", 204 @service_api_ns.expect(service_api_ns.models[SegmentUpdatePayload.__name__]) @service_api_ns.doc("update_segment") @@ -499,7 +499,7 @@ class DatasetChildChunkApi(DatasetApiResource): except ChildChunkDeleteIndexServiceError as e: raise ChildChunkDeleteIndexError(str(e)) - return 204 + return "", 204 @service_api_ns.expect(service_api_ns.models[ChildChunkUpdatePayload.__name__]) @service_api_ns.doc("update_child_chunk") From d5b9a7b2f81b2e701282d22a48709be68df4b793 Mon Sep 17 00:00:00 2001 From: Stephen Zhou <38493346+hyoban@users.noreply.github.com> Date: Fri, 6 Feb 2026 16:12:28 +0800 Subject: [PATCH 4/8] test: only remove text coverage in CI (#32043) --- web/package.json | 5 ++--- web/pnpm-lock.yaml | 15 --------------- web/vitest.config.ts | 4 +++- 3 files changed, 5 insertions(+), 19 deletions(-) diff --git a/web/package.json b/web/package.json index bde1135bdb..2e44328f53 100644 --- a/web/package.json +++ b/web/package.json @@ -47,7 +47,7 @@ "i18n:check": "tsx ./scripts/check-i18n.js", "test": "vitest run", "test:coverage": "vitest run --coverage", - "test:ci": "vitest run --coverage --reporter vitest-tiny-reporter --silent=passed-only", + "test:ci": "vitest run --coverage --silent=passed-only", "test:watch": "vitest --watch", "analyze-component": "node ./scripts/analyze-component.js", "refactor-component": "node ./scripts/refactor-component.js", @@ -236,8 +236,7 @@ "vite": "7.3.1", "vite-tsconfig-paths": "6.0.4", "vitest": "4.0.17", - "vitest-canvas-mock": "1.1.3", - "vitest-tiny-reporter": "1.3.1" + "vitest-canvas-mock": "1.1.3" }, "pnpm": { "overrides": { diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index 7024688ade..169428bfbd 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -585,9 +585,6 @@ importers: vitest-canvas-mock: specifier: 1.1.3 version: 1.1.3(vitest@4.0.17) - vitest-tiny-reporter: - specifier: 1.3.1 - version: 1.3.1(@vitest/runner@4.0.17)(vitest@4.0.17) packages: @@ -7294,12 +7291,6 @@ packages: peerDependencies: vitest: ^3.0.0 || ^4.0.0 - vitest-tiny-reporter@1.3.1: - resolution: {integrity: sha512-9WfLruQBbxm4EqMIS0jDZmQjvMgsWgHUso9mHQWgjA6hM3tEVhjdG8wYo7ePFh1XbwEFzEo3XUQqkGoKZ/Td2Q==} - peerDependencies: - '@vitest/runner': ^2.0.0 || ^3.0.2 || ^4.0.0 - vitest: ^2.0.0 || ^3.0.2 || ^4.0.0 - vitest@4.0.17: resolution: {integrity: sha512-FQMeF0DJdWY0iOnbv466n/0BudNdKj1l5jYgl5JVTwjSsZSlqyXFt/9+1sEyhR6CLowbZpV7O1sCHrzBhucKKg==} engines: {node: ^20.0.0 || ^22.0.0 || >=24.0.0} @@ -15351,12 +15342,6 @@ snapshots: moo-color: 1.0.3 vitest: 4.0.17(@types/node@18.15.0)(@vitest/browser-playwright@4.0.17)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) - vitest-tiny-reporter@1.3.1(@vitest/runner@4.0.17)(vitest@4.0.17): - dependencies: - '@vitest/runner': 4.0.17 - tinyrainbow: 3.0.3 - vitest: 4.0.17(@types/node@18.15.0)(@vitest/browser-playwright@4.0.17)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) - vitest@4.0.17(@types/node@18.15.0)(@vitest/browser-playwright@4.0.17)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2): dependencies: '@vitest/expect': 4.0.17 diff --git a/web/vitest.config.ts b/web/vitest.config.ts index 370bc74904..79486b6b4b 100644 --- a/web/vitest.config.ts +++ b/web/vitest.config.ts @@ -1,6 +1,8 @@ import { defineConfig, mergeConfig } from 'vitest/config' import viteConfig from './vite.config' +const isCI = !!process.env.CI + export default mergeConfig(viteConfig, defineConfig({ test: { environment: 'jsdom', @@ -8,7 +10,7 @@ export default mergeConfig(viteConfig, defineConfig({ setupFiles: ['./vitest.setup.ts'], coverage: { provider: 'v8', - reporter: ['json', 'json-summary'], + reporter: isCI ? ['json', 'json-summary'] : ['text', 'json', 'json-summary'], }, }, })) From 552ee369b2be7bf2084233fad4e3bc47137d67b8 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Fri, 6 Feb 2026 16:14:05 +0800 Subject: [PATCH 5/8] chore: update deploy branches for deploy-hitl.yaml (#32051) --- .github/workflows/deploy-hitl.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/deploy-hitl.yml b/.github/workflows/deploy-hitl.yml index 6a4702d2da..a3fd52afc6 100644 --- a/.github/workflows/deploy-hitl.yml +++ b/.github/workflows/deploy-hitl.yml @@ -4,7 +4,7 @@ on: workflow_run: workflows: ["Build and Push API & Web"] branches: - - "feat/hitl" + - "build/feat/hitl" types: - completed @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest if: | github.event.workflow_run.conclusion == 'success' && - github.event.workflow_run.head_branch == 'feat/hitl' + github.event.workflow_run.head_branch == 'build/feat/hitl' steps: - name: Deploy to server uses: appleboy/ssh-action@v1 From 2c9430313dba35ce948cf342a28f1bed98b93b1e Mon Sep 17 00:00:00 2001 From: zyssyz123 <916125788@qq.com> Date: Fri, 6 Feb 2026 16:25:27 +0800 Subject: [PATCH 6/8] fix: redis for api token (#31861) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: hj24 --- api/README.md | 2 +- api/configs/feature/__init__.py | 10 + api/controllers/console/apikey.py | 6 + api/controllers/console/datasets/datasets.py | 6 + api/controllers/service_api/wraps.py | 49 +-- api/docker/entrypoint.sh | 4 +- api/extensions/ext_celery.py | 8 + .../update_api_token_last_used_task.py | 114 ++++++ api/services/api_token_service.py | 330 +++++++++++++++ api/tasks/remove_app_and_related_data_task.py | 7 + .../libs/test_api_token_cache_integration.py | 375 ++++++++++++++++++ .../unit_tests/extensions/test_celery_ssl.py | 2 + .../unit_tests/libs/test_api_token_cache.py | 250 ++++++++++++ 13 files changed, 1132 insertions(+), 31 deletions(-) create mode 100644 api/schedule/update_api_token_last_used_task.py create mode 100644 api/services/api_token_service.py create mode 100644 api/tests/integration_tests/libs/test_api_token_cache_integration.py create mode 100644 api/tests/unit_tests/libs/test_api_token_cache.py diff --git a/api/README.md b/api/README.md index 9d89b490b0..b23edeab72 100644 --- a/api/README.md +++ b/api/README.md @@ -122,7 +122,7 @@ These commands assume you start from the repository root. ```bash cd api - uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention + uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention ``` 1. Optional: start Celery Beat (scheduled tasks, in a new terminal). diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index d97e9a0440..e8c1b522de 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -1155,6 +1155,16 @@ class CeleryScheduleTasksConfig(BaseSettings): default=0, ) + # API token last_used_at batch update + ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK: bool = Field( + description="Enable periodic batch update of API token last_used_at timestamps", + default=True, + ) + API_TOKEN_LAST_USED_UPDATE_INTERVAL: int = Field( + description="Interval in minutes for batch updating API token last_used_at (default 30)", + default=30, + ) + # Trigger provider refresh (simple version) ENABLE_TRIGGER_PROVIDER_REFRESH_TASK: bool = Field( description="Enable trigger provider refresh poller", diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index c81709e985..b6d1df319e 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -10,6 +10,7 @@ from libs.helper import TimestampField from libs.login import current_account_with_tenant, login_required from models.dataset import Dataset from models.model import ApiToken, App +from services.api_token_service import ApiTokenCache from . import console_ns from .wraps import account_initialization_required, edit_permission_required, setup_required @@ -131,6 +132,11 @@ class BaseApiKeyResource(Resource): if key is None: flask_restx.abort(HTTPStatus.NOT_FOUND, message="API key not found") + # Invalidate cache before deleting from database + # Type assertion: key is guaranteed to be non-None here because abort() raises + assert key is not None # nosec - for type checker only + ApiTokenCache.delete(key.token, key.type) + db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() db.session.commit() diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 30e4ed1119..a06b872846 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -55,6 +55,7 @@ from libs.login import current_account_with_tenant, login_required from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile from models.dataset import DatasetPermissionEnum from models.provider_ids import ModelProviderID +from services.api_token_service import ApiTokenCache from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService # Register models for flask_restx to avoid dict type issues in Swagger @@ -820,6 +821,11 @@ class DatasetApiDeleteApi(Resource): if key is None: console_ns.abort(404, message="API key not found") + # Invalidate cache before deleting from database + # Type assertion: key is guaranteed to be non-None here because abort() raises + assert key is not None # nosec - for type checker only + ApiTokenCache.delete(key.token, key.type) + db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() db.session.commit() diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index e597a72fc0..b80735914d 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -1,27 +1,24 @@ import logging import time from collections.abc import Callable -from datetime import timedelta from enum import StrEnum, auto from functools import wraps -from typing import Concatenate, ParamSpec, TypeVar +from typing import Concatenate, ParamSpec, TypeVar, cast from flask import current_app, request from flask_login import user_logged_in from flask_restx import Resource from pydantic import BaseModel -from sqlalchemy import select, update -from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, NotFound, Unauthorized from enums.cloud_plan import CloudPlan from extensions.ext_database import db from extensions.ext_redis import redis_client -from libs.datetime_utils import naive_utc_now from libs.login import current_user from models import Account, Tenant, TenantAccountJoin, TenantStatus from models.dataset import Dataset, RateLimitLog from models.model import ApiToken, App +from services.api_token_service import ApiTokenCache, fetch_token_with_single_flight, record_token_usage from services.end_user_service import EndUserService from services.feature_service import FeatureService @@ -296,7 +293,14 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None): def validate_and_get_api_token(scope: str | None = None): """ - Validate and get API token. + Validate and get API token with Redis caching. + + This function uses a two-tier approach: + 1. First checks Redis cache for the token + 2. If not cached, queries database and caches the result + + The last_used_at field is updated asynchronously via Celery task + to avoid blocking the request. """ auth_header = request.headers.get("Authorization") if auth_header is None or " " not in auth_header: @@ -308,29 +312,18 @@ def validate_and_get_api_token(scope: str | None = None): if auth_scheme != "bearer": raise Unauthorized("Authorization scheme must be 'Bearer'") - current_time = naive_utc_now() - cutoff_time = current_time - timedelta(minutes=1) - with Session(db.engine, expire_on_commit=False) as session: - update_stmt = ( - update(ApiToken) - .where( - ApiToken.token == auth_token, - (ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < cutoff_time)), - ApiToken.type == scope, - ) - .values(last_used_at=current_time) - ) - stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope) - result = session.execute(update_stmt) - api_token = session.scalar(stmt) + # Try to get token from cache first + # Returns a CachedApiToken (plain Python object), not a SQLAlchemy model + cached_token = ApiTokenCache.get(auth_token, scope) + if cached_token is not None: + logger.debug("Token validation served from cache for scope: %s", scope) + # Record usage in Redis for later batch update (no Celery task per request) + record_token_usage(auth_token, scope) + return cast(ApiToken, cached_token) - if hasattr(result, "rowcount") and result.rowcount > 0: - session.commit() - - if not api_token: - raise Unauthorized("Access token is invalid") - - return api_token + # Cache miss - use Redis lock for single-flight mode + # This ensures only one request queries DB for the same token concurrently + return fetch_token_with_single_flight(auth_token, scope) class DatasetApiResource(Resource): diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index c0279f893b..b0863f0a2c 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -35,10 +35,10 @@ if [[ "${MODE}" == "worker" ]]; then if [[ -z "${CELERY_QUEUES}" ]]; then if [[ "${EDITION}" == "CLOUD" ]]; then # Cloud edition: separate queues for dataset and trigger tasks - DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention" + DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention" else # Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues - DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention" + DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention" fi else DEFAULT_QUEUES="${CELERY_QUEUES}" diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index af983f6d87..dea214163c 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -184,6 +184,14 @@ def init_app(app: DifyApp) -> Celery: "task": "schedule.trigger_provider_refresh_task.trigger_provider_refresh", "schedule": timedelta(minutes=dify_config.TRIGGER_PROVIDER_REFRESH_INTERVAL), } + + if dify_config.ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK: + imports.append("schedule.update_api_token_last_used_task") + beat_schedule["batch_update_api_token_last_used"] = { + "task": "schedule.update_api_token_last_used_task.batch_update_api_token_last_used", + "schedule": timedelta(minutes=dify_config.API_TOKEN_LAST_USED_UPDATE_INTERVAL), + } + celery_app.conf.update(beat_schedule=beat_schedule, imports=imports) return celery_app diff --git a/api/schedule/update_api_token_last_used_task.py b/api/schedule/update_api_token_last_used_task.py new file mode 100644 index 0000000000..f0f304a671 --- /dev/null +++ b/api/schedule/update_api_token_last_used_task.py @@ -0,0 +1,114 @@ +""" +Scheduled task to batch-update API token last_used_at timestamps. + +Instead of updating the database on every request, token usage is recorded +in Redis as lightweight SET keys (api_token_active:{scope}:{token}). +This task runs periodically (default every 30 minutes) to flush those +records into the database in a single batch operation. +""" + +import logging +import time +from datetime import datetime + +import click +from sqlalchemy import update +from sqlalchemy.orm import Session + +import app +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.model import ApiToken +from services.api_token_service import ACTIVE_TOKEN_KEY_PREFIX + +logger = logging.getLogger(__name__) + + +@app.celery.task(queue="api_token") +def batch_update_api_token_last_used(): + """ + Batch update last_used_at for all recently active API tokens. + + Scans Redis for api_token_active:* keys, parses the token and scope + from each key, and performs a batch database update. + """ + click.echo(click.style("batch_update_api_token_last_used: start.", fg="green")) + start_at = time.perf_counter() + + updated_count = 0 + scanned_count = 0 + + try: + # Collect all active token keys and their values (the actual usage timestamps) + token_entries: list[tuple[str, str | None, datetime]] = [] # (token, scope, usage_time) + keys_to_delete: list[str | bytes] = [] + + for key in redis_client.scan_iter(match=f"{ACTIVE_TOKEN_KEY_PREFIX}*", count=200): + if isinstance(key, bytes): + key = key.decode("utf-8") + scanned_count += 1 + + # Read the value (ISO timestamp recorded at actual request time) + value = redis_client.get(key) + if not value: + keys_to_delete.append(key) + continue + + if isinstance(value, bytes): + value = value.decode("utf-8") + + try: + usage_time = datetime.fromisoformat(value) + except (ValueError, TypeError): + logger.warning("Invalid timestamp in key %s: %s", key, value) + keys_to_delete.append(key) + continue + + # Parse token info from key: api_token_active:{scope}:{token} + suffix = key[len(ACTIVE_TOKEN_KEY_PREFIX) :] + parts = suffix.split(":", 1) + if len(parts) == 2: + scope_str, token = parts + scope = None if scope_str == "None" else scope_str + token_entries.append((token, scope, usage_time)) + keys_to_delete.append(key) + + if not token_entries: + click.echo(click.style("batch_update_api_token_last_used: no active tokens found.", fg="yellow")) + # Still clean up any invalid keys + if keys_to_delete: + redis_client.delete(*keys_to_delete) + return + + # Update each token in its own short transaction to avoid long transactions + for token, scope, usage_time in token_entries: + with Session(db.engine, expire_on_commit=False) as session, session.begin(): + stmt = ( + update(ApiToken) + .where( + ApiToken.token == token, + ApiToken.type == scope, + (ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < usage_time)), + ) + .values(last_used_at=usage_time) + ) + result = session.execute(stmt) + rowcount = getattr(result, "rowcount", 0) + if rowcount > 0: + updated_count += 1 + + # Delete processed keys from Redis + if keys_to_delete: + redis_client.delete(*keys_to_delete) + + except Exception: + logger.exception("batch_update_api_token_last_used failed") + + elapsed = time.perf_counter() - start_at + click.echo( + click.style( + f"batch_update_api_token_last_used: done. " + f"scanned={scanned_count}, updated={updated_count}, elapsed={elapsed:.2f}s", + fg="green", + ) + ) diff --git a/api/services/api_token_service.py b/api/services/api_token_service.py new file mode 100644 index 0000000000..98cb5c0620 --- /dev/null +++ b/api/services/api_token_service.py @@ -0,0 +1,330 @@ +""" +API Token Service + +Handles all API token caching, validation, and usage recording. +Includes Redis cache operations, database queries, and single-flight concurrency control. +""" + +import logging +from datetime import datetime +from typing import Any + +from pydantic import BaseModel +from sqlalchemy import select +from sqlalchemy.orm import Session +from werkzeug.exceptions import Unauthorized + +from extensions.ext_database import db +from extensions.ext_redis import redis_client, redis_fallback +from libs.datetime_utils import naive_utc_now +from models.model import ApiToken + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------- +# Pydantic DTO +# --------------------------------------------------------------------- + + +class CachedApiToken(BaseModel): + """ + Pydantic model for cached API token data. + + This is NOT a SQLAlchemy model instance, but a plain Pydantic model + that mimics the ApiToken model interface for read-only access. + """ + + id: str + app_id: str | None + tenant_id: str | None + type: str + token: str + last_used_at: datetime | None + created_at: datetime | None + + def __repr__(self) -> str: + return f"" + + +# --------------------------------------------------------------------- +# Cache configuration +# --------------------------------------------------------------------- + +CACHE_KEY_PREFIX = "api_token" +CACHE_TTL_SECONDS = 600 # 10 minutes +CACHE_NULL_TTL_SECONDS = 60 # 1 minute for non-existent tokens +ACTIVE_TOKEN_KEY_PREFIX = "api_token_active:" + + +# --------------------------------------------------------------------- +# Cache class +# --------------------------------------------------------------------- + + +class ApiTokenCache: + """ + Redis cache wrapper for API tokens. + Handles serialization, deserialization, and cache invalidation. + """ + + @staticmethod + def make_active_key(token: str, scope: str | None = None) -> str: + """Generate Redis key for recording token usage.""" + return f"{ACTIVE_TOKEN_KEY_PREFIX}{scope}:{token}" + + @staticmethod + def _make_tenant_index_key(tenant_id: str) -> str: + """Generate Redis key for tenant token index.""" + return f"tenant_tokens:{tenant_id}" + + @staticmethod + def _make_cache_key(token: str, scope: str | None = None) -> str: + """Generate cache key for the given token and scope.""" + scope_str = scope or "any" + return f"{CACHE_KEY_PREFIX}:{scope_str}:{token}" + + @staticmethod + def _serialize_token(api_token: Any) -> bytes: + """Serialize ApiToken object to JSON bytes.""" + if isinstance(api_token, CachedApiToken): + return api_token.model_dump_json().encode("utf-8") + + cached = CachedApiToken( + id=str(api_token.id), + app_id=str(api_token.app_id) if api_token.app_id else None, + tenant_id=str(api_token.tenant_id) if api_token.tenant_id else None, + type=api_token.type, + token=api_token.token, + last_used_at=api_token.last_used_at, + created_at=api_token.created_at, + ) + return cached.model_dump_json().encode("utf-8") + + @staticmethod + def _deserialize_token(cached_data: bytes | str) -> Any: + """Deserialize JSON bytes/string back to a CachedApiToken Pydantic model.""" + if cached_data in {b"null", "null"}: + return None + + try: + if isinstance(cached_data, bytes): + cached_data = cached_data.decode("utf-8") + return CachedApiToken.model_validate_json(cached_data) + except (ValueError, Exception) as e: + logger.warning("Failed to deserialize token from cache: %s", e) + return None + + @staticmethod + @redis_fallback(default_return=None) + def get(token: str, scope: str | None) -> Any | None: + """Get API token from cache.""" + cache_key = ApiTokenCache._make_cache_key(token, scope) + cached_data = redis_client.get(cache_key) + + if cached_data is None: + logger.debug("Cache miss for token key: %s", cache_key) + return None + + logger.debug("Cache hit for token key: %s", cache_key) + return ApiTokenCache._deserialize_token(cached_data) + + @staticmethod + def _add_to_tenant_index(tenant_id: str | None, cache_key: str) -> None: + """Add cache key to tenant index for efficient invalidation.""" + if not tenant_id: + return + + try: + index_key = ApiTokenCache._make_tenant_index_key(tenant_id) + redis_client.sadd(index_key, cache_key) + redis_client.expire(index_key, CACHE_TTL_SECONDS + 60) + except Exception as e: + logger.warning("Failed to update tenant index: %s", e) + + @staticmethod + def _remove_from_tenant_index(tenant_id: str | None, cache_key: str) -> None: + """Remove cache key from tenant index.""" + if not tenant_id: + return + + try: + index_key = ApiTokenCache._make_tenant_index_key(tenant_id) + redis_client.srem(index_key, cache_key) + except Exception as e: + logger.warning("Failed to remove from tenant index: %s", e) + + @staticmethod + @redis_fallback(default_return=False) + def set(token: str, scope: str | None, api_token: Any | None, ttl: int = CACHE_TTL_SECONDS) -> bool: + """Set API token in cache.""" + cache_key = ApiTokenCache._make_cache_key(token, scope) + + if api_token is None: + cached_value = b"null" + ttl = CACHE_NULL_TTL_SECONDS + else: + cached_value = ApiTokenCache._serialize_token(api_token) + + try: + redis_client.setex(cache_key, ttl, cached_value) + + if api_token is not None and hasattr(api_token, "tenant_id"): + ApiTokenCache._add_to_tenant_index(api_token.tenant_id, cache_key) + + logger.debug("Cached token with key: %s, ttl: %ss", cache_key, ttl) + return True + except Exception as e: + logger.warning("Failed to cache token: %s", e) + return False + + @staticmethod + @redis_fallback(default_return=False) + def delete(token: str, scope: str | None = None) -> bool: + """Delete API token from cache.""" + if scope is None: + pattern = f"{CACHE_KEY_PREFIX}:*:{token}" + try: + keys_to_delete = list(redis_client.scan_iter(match=pattern)) + if keys_to_delete: + redis_client.delete(*keys_to_delete) + logger.info("Deleted %d cache entries for token", len(keys_to_delete)) + return True + except Exception as e: + logger.warning("Failed to delete token cache with pattern: %s", e) + return False + else: + cache_key = ApiTokenCache._make_cache_key(token, scope) + try: + tenant_id = None + try: + cached_data = redis_client.get(cache_key) + if cached_data and cached_data != b"null": + cached_token = ApiTokenCache._deserialize_token(cached_data) + if cached_token: + tenant_id = cached_token.tenant_id + except Exception as e: + logger.debug("Failed to get tenant_id for cache cleanup: %s", e) + + redis_client.delete(cache_key) + + if tenant_id: + ApiTokenCache._remove_from_tenant_index(tenant_id, cache_key) + + logger.info("Deleted cache for key: %s", cache_key) + return True + except Exception as e: + logger.warning("Failed to delete token cache: %s", e) + return False + + @staticmethod + @redis_fallback(default_return=False) + def invalidate_by_tenant(tenant_id: str) -> bool: + """Invalidate all API token caches for a specific tenant via tenant index.""" + try: + index_key = ApiTokenCache._make_tenant_index_key(tenant_id) + cache_keys = redis_client.smembers(index_key) + + if cache_keys: + deleted_count = 0 + for cache_key in cache_keys: + if isinstance(cache_key, bytes): + cache_key = cache_key.decode("utf-8") + redis_client.delete(cache_key) + deleted_count += 1 + + redis_client.delete(index_key) + + logger.info( + "Invalidated %d token cache entries for tenant: %s", + deleted_count, + tenant_id, + ) + else: + logger.info( + "No tenant index found for %s, relying on TTL expiration", + tenant_id, + ) + + return True + + except Exception as e: + logger.warning("Failed to invalidate tenant token cache: %s", e) + return False + + +# --------------------------------------------------------------------- +# Token usage recording (for batch update) +# --------------------------------------------------------------------- + + +def record_token_usage(auth_token: str, scope: str | None) -> None: + """ + Record token usage in Redis for later batch update by a scheduled job. + + Instead of dispatching a Celery task per request, we simply SET a key in Redis. + A Celery Beat scheduled task will periodically scan these keys and batch-update + last_used_at in the database. + """ + try: + key = ApiTokenCache.make_active_key(auth_token, scope) + redis_client.set(key, naive_utc_now().isoformat(), ex=3600) + except Exception as e: + logger.warning("Failed to record token usage: %s", e) + + +# --------------------------------------------------------------------- +# Database query + single-flight +# --------------------------------------------------------------------- + + +def query_token_from_db(auth_token: str, scope: str | None) -> ApiToken: + """ + Query API token from database and cache the result. + + Raises Unauthorized if token is invalid. + """ + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope) + api_token = session.scalar(stmt) + + if not api_token: + ApiTokenCache.set(auth_token, scope, None) + raise Unauthorized("Access token is invalid") + + ApiTokenCache.set(auth_token, scope, api_token) + record_token_usage(auth_token, scope) + return api_token + + +def fetch_token_with_single_flight(auth_token: str, scope: str | None) -> ApiToken | Any: + """ + Fetch token from DB with single-flight pattern using Redis lock. + + Ensures only one concurrent request queries the database for the same token. + Falls back to direct query if lock acquisition fails. + """ + logger.debug("Token cache miss, attempting to acquire query lock for scope: %s", scope) + + lock_key = f"api_token_query_lock:{scope}:{auth_token}" + lock = redis_client.lock(lock_key, timeout=10, blocking_timeout=5) + + try: + if lock.acquire(blocking=True): + try: + cached_token = ApiTokenCache.get(auth_token, scope) + if cached_token is not None: + logger.debug("Token cached by concurrent request, using cached version") + return cached_token + + return query_token_from_db(auth_token, scope) + finally: + lock.release() + else: + logger.warning("Lock timeout for token: %s, proceeding with direct query", auth_token[:10]) + return query_token_from_db(auth_token, scope) + except Unauthorized: + raise + except Exception as e: + logger.warning("Redis lock failed for token query: %s, proceeding anyway", e) + return query_token_from_db(auth_token, scope) diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 6240f2200f..b1840662ff 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -48,6 +48,7 @@ from models.workflow import ( WorkflowArchiveLog, ) from repositories.factory import DifyAPIRepositoryFactory +from services.api_token_service import ApiTokenCache logger = logging.getLogger(__name__) @@ -134,6 +135,12 @@ def _delete_app_mcp_servers(tenant_id: str, app_id: str): def _delete_app_api_tokens(tenant_id: str, app_id: str): def del_api_token(session, api_token_id: str): + # Fetch token details for cache invalidation + token_obj = session.query(ApiToken).where(ApiToken.id == api_token_id).first() + if token_obj: + # Invalidate cache before deletion + ApiTokenCache.delete(token_obj.token, token_obj.type) + session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False) _delete_records( diff --git a/api/tests/integration_tests/libs/test_api_token_cache_integration.py b/api/tests/integration_tests/libs/test_api_token_cache_integration.py new file mode 100644 index 0000000000..166fcb515f --- /dev/null +++ b/api/tests/integration_tests/libs/test_api_token_cache_integration.py @@ -0,0 +1,375 @@ +""" +Integration tests for API Token Cache with Redis. + +These tests require: +- Redis server running +- Test database configured +""" + +import time +from datetime import datetime, timedelta +from unittest.mock import patch + +import pytest + +from extensions.ext_redis import redis_client +from models.model import ApiToken +from services.api_token_service import ApiTokenCache, CachedApiToken + + +class TestApiTokenCacheRedisIntegration: + """Integration tests with real Redis.""" + + def setup_method(self): + """Setup test fixtures and clean Redis.""" + self.test_token = "test-integration-token-123" + self.test_scope = "app" + self.cache_key = f"api_token:{self.test_scope}:{self.test_token}" + + # Clean up any existing test data + self._cleanup() + + def teardown_method(self): + """Cleanup test data from Redis.""" + self._cleanup() + + def _cleanup(self): + """Remove test data from Redis.""" + try: + redis_client.delete(self.cache_key) + redis_client.delete(ApiTokenCache._make_tenant_index_key("test-tenant-id")) + redis_client.delete(ApiTokenCache.make_active_key(self.test_token, self.test_scope)) + except Exception: + pass # Ignore cleanup errors + + def test_cache_set_and_get_with_real_redis(self): + """Test cache set and get operations with real Redis.""" + from unittest.mock import MagicMock + + mock_token = MagicMock() + mock_token.id = "test-id-123" + mock_token.app_id = "test-app-456" + mock_token.tenant_id = "test-tenant-789" + mock_token.type = "app" + mock_token.token = self.test_token + mock_token.last_used_at = datetime.now() + mock_token.created_at = datetime.now() - timedelta(days=30) + + # Set in cache + result = ApiTokenCache.set(self.test_token, self.test_scope, mock_token) + assert result is True + + # Verify in Redis + cached_data = redis_client.get(self.cache_key) + assert cached_data is not None + + # Get from cache + cached_token = ApiTokenCache.get(self.test_token, self.test_scope) + assert cached_token is not None + assert isinstance(cached_token, CachedApiToken) + assert cached_token.id == "test-id-123" + assert cached_token.app_id == "test-app-456" + assert cached_token.tenant_id == "test-tenant-789" + assert cached_token.type == "app" + assert cached_token.token == self.test_token + + def test_cache_ttl_with_real_redis(self): + """Test cache TTL is set correctly.""" + from unittest.mock import MagicMock + + mock_token = MagicMock() + mock_token.id = "test-id" + mock_token.app_id = "test-app" + mock_token.tenant_id = "test-tenant" + mock_token.type = "app" + mock_token.token = self.test_token + mock_token.last_used_at = None + mock_token.created_at = datetime.now() + + ApiTokenCache.set(self.test_token, self.test_scope, mock_token) + + ttl = redis_client.ttl(self.cache_key) + assert 595 <= ttl <= 600 # Should be around 600 seconds (10 minutes) + + def test_cache_null_value_for_invalid_token(self): + """Test caching null value for invalid tokens.""" + result = ApiTokenCache.set(self.test_token, self.test_scope, None) + assert result is True + + cached_data = redis_client.get(self.cache_key) + assert cached_data == b"null" + + cached_token = ApiTokenCache.get(self.test_token, self.test_scope) + assert cached_token is None + + ttl = redis_client.ttl(self.cache_key) + assert 55 <= ttl <= 60 + + def test_cache_delete_with_real_redis(self): + """Test cache deletion with real Redis.""" + from unittest.mock import MagicMock + + mock_token = MagicMock() + mock_token.id = "test-id" + mock_token.app_id = "test-app" + mock_token.tenant_id = "test-tenant" + mock_token.type = "app" + mock_token.token = self.test_token + mock_token.last_used_at = None + mock_token.created_at = datetime.now() + + ApiTokenCache.set(self.test_token, self.test_scope, mock_token) + assert redis_client.exists(self.cache_key) == 1 + + result = ApiTokenCache.delete(self.test_token, self.test_scope) + assert result is True + assert redis_client.exists(self.cache_key) == 0 + + def test_tenant_index_creation(self): + """Test tenant index is created when caching token.""" + from unittest.mock import MagicMock + + tenant_id = "test-tenant-id" + mock_token = MagicMock() + mock_token.id = "test-id" + mock_token.app_id = "test-app" + mock_token.tenant_id = tenant_id + mock_token.type = "app" + mock_token.token = self.test_token + mock_token.last_used_at = None + mock_token.created_at = datetime.now() + + ApiTokenCache.set(self.test_token, self.test_scope, mock_token) + + index_key = ApiTokenCache._make_tenant_index_key(tenant_id) + assert redis_client.exists(index_key) == 1 + + members = redis_client.smembers(index_key) + cache_keys = [m.decode("utf-8") if isinstance(m, bytes) else m for m in members] + assert self.cache_key in cache_keys + + def test_invalidate_by_tenant_via_index(self): + """Test tenant-wide cache invalidation using index (fast path).""" + from unittest.mock import MagicMock + + tenant_id = "test-tenant-id" + + for i in range(3): + token_value = f"test-token-{i}" + mock_token = MagicMock() + mock_token.id = f"test-id-{i}" + mock_token.app_id = "test-app" + mock_token.tenant_id = tenant_id + mock_token.type = "app" + mock_token.token = token_value + mock_token.last_used_at = None + mock_token.created_at = datetime.now() + + ApiTokenCache.set(token_value, "app", mock_token) + + for i in range(3): + key = f"api_token:app:test-token-{i}" + assert redis_client.exists(key) == 1 + + result = ApiTokenCache.invalidate_by_tenant(tenant_id) + assert result is True + + for i in range(3): + key = f"api_token:app:test-token-{i}" + assert redis_client.exists(key) == 0 + + assert redis_client.exists(ApiTokenCache._make_tenant_index_key(tenant_id)) == 0 + + def test_concurrent_cache_access(self): + """Test concurrent cache access doesn't cause issues.""" + import concurrent.futures + from unittest.mock import MagicMock + + mock_token = MagicMock() + mock_token.id = "test-id" + mock_token.app_id = "test-app" + mock_token.tenant_id = "test-tenant" + mock_token.type = "app" + mock_token.token = self.test_token + mock_token.last_used_at = None + mock_token.created_at = datetime.now() + + ApiTokenCache.set(self.test_token, self.test_scope, mock_token) + + def get_from_cache(): + return ApiTokenCache.get(self.test_token, self.test_scope) + + with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(get_from_cache) for _ in range(50)] + results = [f.result() for f in concurrent.futures.as_completed(futures)] + + assert len(results) == 50 + assert all(r is not None for r in results) + assert all(isinstance(r, CachedApiToken) for r in results) + + +class TestTokenUsageRecording: + """Tests for recording token usage in Redis (batch update approach).""" + + def setup_method(self): + self.test_token = "test-usage-token" + self.test_scope = "app" + self.active_key = ApiTokenCache.make_active_key(self.test_token, self.test_scope) + + def teardown_method(self): + try: + redis_client.delete(self.active_key) + except Exception: + pass + + def test_record_token_usage_sets_redis_key(self): + """Test that record_token_usage writes an active key to Redis.""" + from services.api_token_service import record_token_usage + + record_token_usage(self.test_token, self.test_scope) + + # Key should exist + assert redis_client.exists(self.active_key) == 1 + + # Value should be an ISO timestamp + value = redis_client.get(self.active_key) + if isinstance(value, bytes): + value = value.decode("utf-8") + datetime.fromisoformat(value) # Should not raise + + def test_record_token_usage_has_ttl(self): + """Test that active keys have a TTL as safety net.""" + from services.api_token_service import record_token_usage + + record_token_usage(self.test_token, self.test_scope) + + ttl = redis_client.ttl(self.active_key) + assert 3595 <= ttl <= 3600 # ~1 hour + + def test_record_token_usage_overwrites(self): + """Test that repeated calls overwrite the same key (no accumulation).""" + from services.api_token_service import record_token_usage + + record_token_usage(self.test_token, self.test_scope) + first_value = redis_client.get(self.active_key) + + time.sleep(0.01) # Tiny delay so timestamp differs + + record_token_usage(self.test_token, self.test_scope) + second_value = redis_client.get(self.active_key) + + # Key count should still be 1 (overwritten, not accumulated) + assert redis_client.exists(self.active_key) == 1 + + +class TestEndToEndCacheFlow: + """End-to-end integration test for complete cache flow.""" + + @pytest.mark.usefixtures("db_session") + def test_complete_flow_cache_miss_then_hit(self, db_session): + """ + Test complete flow: + 1. First request (cache miss) -> query DB -> cache result + 2. Second request (cache hit) -> return from cache + 3. Verify Redis state + """ + test_token_value = "test-e2e-token" + test_scope = "app" + + test_token = ApiToken() + test_token.id = "test-e2e-id" + test_token.token = test_token_value + test_token.type = test_scope + test_token.app_id = "test-app" + test_token.tenant_id = "test-tenant" + test_token.last_used_at = None + test_token.created_at = datetime.now() + + db_session.add(test_token) + db_session.commit() + + try: + # Step 1: Cache miss - set token in cache + ApiTokenCache.set(test_token_value, test_scope, test_token) + + cache_key = f"api_token:{test_scope}:{test_token_value}" + assert redis_client.exists(cache_key) == 1 + + # Step 2: Cache hit - get from cache + cached_token = ApiTokenCache.get(test_token_value, test_scope) + assert cached_token is not None + assert cached_token.id == test_token.id + assert cached_token.token == test_token_value + + # Step 3: Verify tenant index + index_key = ApiTokenCache._make_tenant_index_key(test_token.tenant_id) + assert redis_client.exists(index_key) == 1 + assert cache_key.encode() in redis_client.smembers(index_key) + + # Step 4: Delete and verify cleanup + ApiTokenCache.delete(test_token_value, test_scope) + assert redis_client.exists(cache_key) == 0 + assert cache_key.encode() not in redis_client.smembers(index_key) + + finally: + db_session.delete(test_token) + db_session.commit() + redis_client.delete(f"api_token:{test_scope}:{test_token_value}") + redis_client.delete(ApiTokenCache._make_tenant_index_key(test_token.tenant_id)) + + def test_high_concurrency_simulation(self): + """Simulate high concurrency access to cache.""" + import concurrent.futures + from unittest.mock import MagicMock + + test_token_value = "test-concurrent-token" + test_scope = "app" + + mock_token = MagicMock() + mock_token.id = "concurrent-id" + mock_token.app_id = "test-app" + mock_token.tenant_id = "test-tenant" + mock_token.type = test_scope + mock_token.token = test_token_value + mock_token.last_used_at = datetime.now() + mock_token.created_at = datetime.now() + + ApiTokenCache.set(test_token_value, test_scope, mock_token) + + try: + + def read_cache(): + return ApiTokenCache.get(test_token_value, test_scope) + + start_time = time.time() + with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor: + futures = [executor.submit(read_cache) for _ in range(100)] + results = [f.result() for f in concurrent.futures.as_completed(futures)] + elapsed = time.time() - start_time + + assert len(results) == 100 + assert all(r is not None for r in results) + + assert elapsed < 1.0, f"Too slow: {elapsed}s for 100 cache reads" + + finally: + ApiTokenCache.delete(test_token_value, test_scope) + redis_client.delete(ApiTokenCache._make_tenant_index_key(mock_token.tenant_id)) + + +class TestRedisFailover: + """Test behavior when Redis is unavailable.""" + + @patch("services.api_token_service.redis_client") + def test_graceful_degradation_when_redis_fails(self, mock_redis): + """Test system degrades gracefully when Redis is unavailable.""" + from redis import RedisError + + mock_redis.get.side_effect = RedisError("Connection failed") + mock_redis.setex.side_effect = RedisError("Connection failed") + + result_get = ApiTokenCache.get("test-token", "app") + assert result_get is None + + result_set = ApiTokenCache.set("test-token", "app", None) + assert result_set is False diff --git a/api/tests/unit_tests/extensions/test_celery_ssl.py b/api/tests/unit_tests/extensions/test_celery_ssl.py index d3a4d69f07..34d48fa94e 100644 --- a/api/tests/unit_tests/extensions/test_celery_ssl.py +++ b/api/tests/unit_tests/extensions/test_celery_ssl.py @@ -132,6 +132,8 @@ class TestCelerySSLConfiguration: mock_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK = 0 mock_config.ENABLE_TRIGGER_PROVIDER_REFRESH_TASK = False mock_config.TRIGGER_PROVIDER_REFRESH_INTERVAL = 15 + mock_config.ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK = False + mock_config.API_TOKEN_LAST_USED_UPDATE_INTERVAL = 30 with patch("extensions.ext_celery.dify_config", mock_config): from dify_app import DifyApp diff --git a/api/tests/unit_tests/libs/test_api_token_cache.py b/api/tests/unit_tests/libs/test_api_token_cache.py new file mode 100644 index 0000000000..fa4c5e77a7 --- /dev/null +++ b/api/tests/unit_tests/libs/test_api_token_cache.py @@ -0,0 +1,250 @@ +""" +Unit tests for API Token Cache module. +""" + +import json +from datetime import datetime +from unittest.mock import MagicMock, patch + +from services.api_token_service import ( + CACHE_KEY_PREFIX, + CACHE_NULL_TTL_SECONDS, + CACHE_TTL_SECONDS, + ApiTokenCache, + CachedApiToken, +) + + +class TestApiTokenCache: + """Test cases for ApiTokenCache class.""" + + def setup_method(self): + """Setup test fixtures.""" + self.mock_token = MagicMock() + self.mock_token.id = "test-token-id-123" + self.mock_token.app_id = "test-app-id-456" + self.mock_token.tenant_id = "test-tenant-id-789" + self.mock_token.type = "app" + self.mock_token.token = "test-token-value-abc" + self.mock_token.last_used_at = datetime(2026, 2, 3, 10, 0, 0) + self.mock_token.created_at = datetime(2026, 1, 1, 0, 0, 0) + + def test_make_cache_key(self): + """Test cache key generation.""" + # Test with scope + key = ApiTokenCache._make_cache_key("my-token", "app") + assert key == f"{CACHE_KEY_PREFIX}:app:my-token" + + # Test without scope + key = ApiTokenCache._make_cache_key("my-token", None) + assert key == f"{CACHE_KEY_PREFIX}:any:my-token" + + def test_serialize_token(self): + """Test token serialization.""" + serialized = ApiTokenCache._serialize_token(self.mock_token) + data = json.loads(serialized) + + assert data["id"] == "test-token-id-123" + assert data["app_id"] == "test-app-id-456" + assert data["tenant_id"] == "test-tenant-id-789" + assert data["type"] == "app" + assert data["token"] == "test-token-value-abc" + assert data["last_used_at"] == "2026-02-03T10:00:00" + assert data["created_at"] == "2026-01-01T00:00:00" + + def test_serialize_token_with_nulls(self): + """Test token serialization with None values.""" + mock_token = MagicMock() + mock_token.id = "test-id" + mock_token.app_id = None + mock_token.tenant_id = None + mock_token.type = "dataset" + mock_token.token = "test-token" + mock_token.last_used_at = None + mock_token.created_at = datetime(2026, 1, 1, 0, 0, 0) + + serialized = ApiTokenCache._serialize_token(mock_token) + data = json.loads(serialized) + + assert data["app_id"] is None + assert data["tenant_id"] is None + assert data["last_used_at"] is None + + def test_deserialize_token(self): + """Test token deserialization.""" + cached_data = json.dumps( + { + "id": "test-id", + "app_id": "test-app", + "tenant_id": "test-tenant", + "type": "app", + "token": "test-token", + "last_used_at": "2026-02-03T10:00:00", + "created_at": "2026-01-01T00:00:00", + } + ) + + result = ApiTokenCache._deserialize_token(cached_data) + + assert isinstance(result, CachedApiToken) + assert result.id == "test-id" + assert result.app_id == "test-app" + assert result.tenant_id == "test-tenant" + assert result.type == "app" + assert result.token == "test-token" + assert result.last_used_at == datetime(2026, 2, 3, 10, 0, 0) + assert result.created_at == datetime(2026, 1, 1, 0, 0, 0) + + def test_deserialize_null_token(self): + """Test deserialization of null token (cached miss).""" + result = ApiTokenCache._deserialize_token("null") + assert result is None + + def test_deserialize_invalid_json(self): + """Test deserialization with invalid JSON.""" + result = ApiTokenCache._deserialize_token("invalid-json{") + assert result is None + + @patch("services.api_token_service.redis_client") + def test_get_cache_hit(self, mock_redis): + """Test cache hit scenario.""" + cached_data = json.dumps( + { + "id": "test-id", + "app_id": "test-app", + "tenant_id": "test-tenant", + "type": "app", + "token": "test-token", + "last_used_at": "2026-02-03T10:00:00", + "created_at": "2026-01-01T00:00:00", + } + ).encode("utf-8") + mock_redis.get.return_value = cached_data + + result = ApiTokenCache.get("test-token", "app") + + assert result is not None + assert isinstance(result, CachedApiToken) + assert result.app_id == "test-app" + mock_redis.get.assert_called_once_with(f"{CACHE_KEY_PREFIX}:app:test-token") + + @patch("services.api_token_service.redis_client") + def test_get_cache_miss(self, mock_redis): + """Test cache miss scenario.""" + mock_redis.get.return_value = None + + result = ApiTokenCache.get("test-token", "app") + + assert result is None + mock_redis.get.assert_called_once() + + @patch("services.api_token_service.redis_client") + def test_set_valid_token(self, mock_redis): + """Test setting a valid token in cache.""" + result = ApiTokenCache.set("test-token", "app", self.mock_token) + + assert result is True + mock_redis.setex.assert_called_once() + args = mock_redis.setex.call_args[0] + assert args[0] == f"{CACHE_KEY_PREFIX}:app:test-token" + assert args[1] == CACHE_TTL_SECONDS + + @patch("services.api_token_service.redis_client") + def test_set_null_token(self, mock_redis): + """Test setting a null token (cache penetration prevention).""" + result = ApiTokenCache.set("invalid-token", "app", None) + + assert result is True + mock_redis.setex.assert_called_once() + args = mock_redis.setex.call_args[0] + assert args[0] == f"{CACHE_KEY_PREFIX}:app:invalid-token" + assert args[1] == CACHE_NULL_TTL_SECONDS + assert args[2] == b"null" + + @patch("services.api_token_service.redis_client") + def test_delete_with_scope(self, mock_redis): + """Test deleting token cache with specific scope.""" + result = ApiTokenCache.delete("test-token", "app") + + assert result is True + mock_redis.delete.assert_called_once_with(f"{CACHE_KEY_PREFIX}:app:test-token") + + @patch("services.api_token_service.redis_client") + def test_delete_without_scope(self, mock_redis): + """Test deleting token cache without scope (delete all).""" + # Mock scan_iter to return an iterator of keys + mock_redis.scan_iter.return_value = iter( + [ + b"api_token:app:test-token", + b"api_token:dataset:test-token", + ] + ) + + result = ApiTokenCache.delete("test-token", None) + + assert result is True + # Verify scan_iter was called with the correct pattern + mock_redis.scan_iter.assert_called_once() + call_args = mock_redis.scan_iter.call_args + assert call_args[1]["match"] == f"{CACHE_KEY_PREFIX}:*:test-token" + + # Verify delete was called with all matched keys + mock_redis.delete.assert_called_once_with( + b"api_token:app:test-token", + b"api_token:dataset:test-token", + ) + + @patch("services.api_token_service.redis_client") + def test_redis_fallback_on_exception(self, mock_redis): + """Test Redis fallback when Redis is unavailable.""" + from redis import RedisError + + mock_redis.get.side_effect = RedisError("Connection failed") + + result = ApiTokenCache.get("test-token", "app") + + # Should return None (fallback) instead of raising exception + assert result is None + + +class TestApiTokenCacheIntegration: + """Integration test scenarios.""" + + @patch("services.api_token_service.redis_client") + def test_full_cache_lifecycle(self, mock_redis): + """Test complete cache lifecycle: set -> get -> delete.""" + # Setup mock token + mock_token = MagicMock() + mock_token.id = "id-123" + mock_token.app_id = "app-456" + mock_token.tenant_id = "tenant-789" + mock_token.type = "app" + mock_token.token = "token-abc" + mock_token.last_used_at = datetime(2026, 2, 3, 10, 0, 0) + mock_token.created_at = datetime(2026, 1, 1, 0, 0, 0) + + # 1. Set token in cache + ApiTokenCache.set("token-abc", "app", mock_token) + assert mock_redis.setex.called + + # 2. Simulate cache hit + cached_data = ApiTokenCache._serialize_token(mock_token) + mock_redis.get.return_value = cached_data # bytes from model_dump_json().encode() + + retrieved = ApiTokenCache.get("token-abc", "app") + assert retrieved is not None + assert isinstance(retrieved, CachedApiToken) + + # 3. Delete from cache + ApiTokenCache.delete("token-abc", "app") + assert mock_redis.delete.called + + @patch("services.api_token_service.redis_client") + def test_cache_penetration_prevention(self, mock_redis): + """Test that non-existent tokens are cached as null.""" + # Set null token (cache miss) + ApiTokenCache.set("non-existent-token", "app", None) + + args = mock_redis.setex.call_args[0] + assert args[2] == b"null" + assert args[1] == CACHE_NULL_TTL_SECONDS # Shorter TTL for null values From 4430a1b3da4924e42974708e2b30571e5d473af6 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Fri, 6 Feb 2026 18:02:14 +0800 Subject: [PATCH 7/8] fix: batch delete document db session block (#32062) --- api/services/dataset_service.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 1ea6c4e1c3..b208e394b0 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -1696,13 +1696,18 @@ class DocumentService: for document in documents if document.data_source_type == "upload_file" and document.data_source_info_dict ] - if dataset.doc_form is not None: - batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids) + # Delete documents first, then dispatch cleanup task after commit + # to avoid deadlock between main transaction and async task for document in documents: db.session.delete(document) db.session.commit() + # Dispatch cleanup task after commit to avoid lock contention + # Task cleans up segments, files, and vector indexes + if dataset.doc_form is not None: + batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids) + @staticmethod def rename_document(dataset_id: str, document_id: str, name: str) -> Document: assert isinstance(current_user, Account) From c185a51bad8ca441eac68b60c18e39706b297264 Mon Sep 17 00:00:00 2001 From: Crazywoola <100913391+crazywoola@users.noreply.github.com> Date: Sat, 7 Feb 2026 13:23:01 +0800 Subject: [PATCH 8/8] fix: remove unexpected scrollbar in KB Retrieval settings (#32082) --- .../dataset-config/params-config/config-content.tsx | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx index 7fbb1bc7c4..69032b4743 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx @@ -194,11 +194,11 @@ const ConfigContent: FC = ({ {type === RETRIEVE_TYPE.multiWay && ( <> -
-
+
+
{t('rerankSettings', { ns: 'dataset' })}
- +
{ selectedDatasetsMode.inconsistentEmbeddingModel