diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index d34b4124ae..2c094aa3e6 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -6,7 +6,7 @@ from uuid import UUID from flask import request, send_file from flask_restx import marshal from pydantic import BaseModel, Field, field_validator, model_validator -from sqlalchemy import desc, select +from sqlalchemy import desc, func, select from werkzeug.exceptions import Forbidden, NotFound import services @@ -155,7 +155,9 @@ class DocumentAddByTextApi(DatasetApiResource): dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise ValueError("Dataset does not exist.") @@ -238,7 +240,9 @@ class DocumentUpdateByTextApi(DatasetApiResource): def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID): """Update document by text.""" payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {}) - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).limit(1) + ) args = payload.model_dump(exclude_none=True) if not dataset: raise ValueError("Dataset does not exist.") @@ -315,7 +319,9 @@ class DocumentAddByFileApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): """Create document by upload file.""" - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise ValueError("Dataset does not exist.") @@ -425,7 +431,9 @@ class DocumentUpdateByFileApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id, document_id): """Update document by upload file.""" - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise ValueError("Dataset does not exist.") @@ -515,7 +523,9 @@ class DocumentListApi(DatasetApiResource): dataset_id = str(dataset_id) tenant_id = str(tenant_id) query_params = DocumentListQuery.model_validate(request.args.to_dict()) - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") @@ -609,7 +619,9 @@ class DocumentIndexingStatusApi(DatasetApiResource): batch = str(batch) tenant_id = str(tenant_id) # get dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # get documents @@ -619,20 +631,23 @@ class DocumentIndexingStatusApi(DatasetApiResource): documents_status = [] for document in documents: completed_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != SegmentStatus.RE_SEGMENT, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) total_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) # Create a dictionary with document attributes and additional fields document_dict = { @@ -822,7 +837,9 @@ class DocumentApi(DatasetApiResource): tenant_id = str(tenant_id) # get dataset info - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise ValueError("Dataset does not exist.") diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index f8c6b251eb..28fa915117 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -3,6 +3,7 @@ from typing import Any from flask import request from flask_restx import marshal from pydantic import BaseModel, Field +from sqlalchemy import select from werkzeug.exceptions import NotFound from configs import dify_config @@ -92,7 +93,9 @@ class SegmentApi(DatasetApiResource): _, current_tenant_id = current_account_with_tenant() """Create single segment.""" # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # check document @@ -150,7 +153,9 @@ class SegmentApi(DatasetApiResource): # check dataset page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # check document @@ -220,7 +225,9 @@ class DatasetSegmentApi(DatasetApiResource): def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): _, current_tenant_id = current_account_with_tenant() # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # check user's model setting @@ -254,7 +261,9 @@ class DatasetSegmentApi(DatasetApiResource): def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): _, current_tenant_id = current_account_with_tenant() # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # check user's model setting @@ -301,7 +310,9 @@ class DatasetSegmentApi(DatasetApiResource): def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): _, current_tenant_id = current_account_with_tenant() # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # check user's model setting @@ -344,7 +355,9 @@ class ChildChunkApi(DatasetApiResource): _, current_tenant_id = current_account_with_tenant() """Create child chunk.""" # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") @@ -402,7 +415,9 @@ class ChildChunkApi(DatasetApiResource): _, current_tenant_id = current_account_with_tenant() """Get child chunks.""" # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") @@ -468,7 +483,9 @@ class DatasetChildChunkApi(DatasetApiResource): _, current_tenant_id = current_account_with_tenant() """Delete child chunk.""" # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") @@ -527,7 +544,9 @@ class DatasetChildChunkApi(DatasetApiResource): _, current_tenant_id = current_account_with_tenant() """Update child chunk.""" # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py index 73a87761d5..7f5d6b0839 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py @@ -788,7 +788,7 @@ class TestSegmentApiGet: """Test successful segment list retrieval.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX) mock_seg_svc.get_segments.return_value = ([mock_segment], 1) mock_marshal.return_value = [{"id": mock_segment.id}] @@ -813,7 +813,7 @@ class TestSegmentApiGet: """Test 404 when dataset not found.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -833,7 +833,7 @@ class TestSegmentApiGet: """Test 404 when document not found.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = None # Act & Assert @@ -899,7 +899,7 @@ class TestSegmentApiPost: mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_dataset.indexing_technique = "economy" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc = Mock() mock_doc.indexing_status = "completed" @@ -950,7 +950,7 @@ class TestSegmentApiPost: mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_dataset.indexing_technique = "economy" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc = Mock() mock_doc.indexing_status = "completed" @@ -992,7 +992,7 @@ class TestSegmentApiPost: self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc = Mock() mock_doc.indexing_status = "indexing" # Not completed @@ -1043,7 +1043,7 @@ class TestDatasetSegmentApiDelete: """Test successful segment deletion.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc = Mock() @@ -1087,7 +1087,7 @@ class TestDatasetSegmentApiDelete: """Test 404 when segment not found.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc = Mock() mock_doc.indexing_status = "completed" @@ -1129,7 +1129,7 @@ class TestDatasetSegmentApiDelete: """Test 404 when dataset not found for delete.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -1163,7 +1163,7 @@ class TestDatasetSegmentApiDelete: """Test 404 when document not found for delete.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc_svc.get_document.return_value = None @@ -1233,7 +1233,7 @@ class TestDatasetSegmentApiUpdate: self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_dataset.indexing_technique = "economy" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = mock_segment @@ -1280,7 +1280,7 @@ class TestDatasetSegmentApiUpdate: """Test 404 when dataset not found for update.""" self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with app.test_request_context( f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id", @@ -1321,7 +1321,7 @@ class TestDatasetSegmentApiUpdate: self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_dataset.indexing_technique = "economy" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = None @@ -1370,7 +1370,7 @@ class TestDatasetSegmentApiGetSingle: ): """Test successful single segment retrieval.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX) mock_doc_svc.get_document.return_value = mock_doc @@ -1405,7 +1405,7 @@ class TestDatasetSegmentApiGetSingle: ): """Test 404 when dataset not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with app.test_request_context( f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id", @@ -1436,7 +1436,7 @@ class TestDatasetSegmentApiGetSingle: ): """Test 404 when document not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc_svc.get_document.return_value = None @@ -1471,7 +1471,7 @@ class TestDatasetSegmentApiGetSingle: ): """Test 404 when segment not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = None @@ -1515,7 +1515,7 @@ class TestChildChunkApiGet: ): """Test successful child chunk list retrieval.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = Mock() @@ -1554,7 +1554,7 @@ class TestChildChunkApiGet: ): """Test 404 when dataset not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with app.test_request_context( f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id/child_chunks", @@ -1583,7 +1583,7 @@ class TestChildChunkApiGet: ): """Test 404 when document not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = None with app.test_request_context( @@ -1615,7 +1615,7 @@ class TestChildChunkApiGet: ): """Test 404 when segment not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = None @@ -1676,7 +1676,7 @@ class TestChildChunkApiPost: self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_dataset.indexing_technique = "economy" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = Mock() mock_child = Mock() @@ -1717,7 +1717,7 @@ class TestChildChunkApiPost: """Test 404 when dataset not found.""" self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with app.test_request_context( f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id/child_chunks", @@ -1755,7 +1755,7 @@ class TestChildChunkApiPost: """Test 404 when segment not found.""" self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = None @@ -1808,7 +1808,7 @@ class TestDatasetChildChunkApiDelete: ): """Test successful child chunk deletion.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc = Mock() mock_doc_svc.get_document.return_value = mock_doc @@ -1858,7 +1858,7 @@ class TestDatasetChildChunkApiDelete: ): """Test 404 when child chunk not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() segment_id = str(uuid.uuid4()) @@ -1899,7 +1899,7 @@ class TestDatasetChildChunkApiDelete: ): """Test 404 when segment does not belong to the document.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() segment_id = str(uuid.uuid4()) @@ -1939,7 +1939,7 @@ class TestDatasetChildChunkApiDelete: ): """Test 404 when child chunk does not belong to the segment.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() segment_id = str(uuid.uuid4()) diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py index 7f77e61ee4..12d5e7345d 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py @@ -717,7 +717,7 @@ class TestDocumentApiDelete: dataset_id = str(uuid.uuid4()) mock_dataset = Mock() mock_dataset.id = dataset_id - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = mock_document mock_doc_svc.check_archived.return_value = False @@ -746,7 +746,7 @@ class TestDocumentApiDelete: document_id = str(uuid.uuid4()) mock_dataset = Mock() mock_dataset.id = dataset_id - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = None @@ -767,7 +767,7 @@ class TestDocumentApiDelete: dataset_id = str(uuid.uuid4()) mock_dataset = Mock() mock_dataset.id = dataset_id - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = mock_document mock_doc_svc.check_archived.return_value = True @@ -788,7 +788,7 @@ class TestDocumentApiDelete: # Arrange dataset_id = str(uuid.uuid4()) document_id = str(uuid.uuid4()) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -809,7 +809,7 @@ class TestDocumentListApi: def test_list_documents_success(self, mock_db, mock_doc_svc, mock_marshal, app, mock_tenant, mock_dataset): """Test successful document list retrieval.""" # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_pagination = Mock() mock_pagination.items = [Mock(), Mock()] @@ -838,7 +838,7 @@ class TestDocumentListApi: def test_list_documents_dataset_not_found(self, mock_db, app, mock_tenant, mock_dataset): """Test 404 when dataset not found.""" # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -860,8 +860,6 @@ class TestDocumentIndexingStatusApi: """Test successful indexing status retrieval.""" # Arrange batch_id = "batch_123" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset - mock_doc = Mock() mock_doc.id = str(uuid.uuid4()) mock_doc.is_paused = False @@ -877,8 +875,8 @@ class TestDocumentIndexingStatusApi: mock_doc_svc.get_batch_documents.return_value = [mock_doc] - # Mock segment count queries - mock_db.session.query.return_value.where.return_value.where.return_value.count.return_value = 5 + # scalar() called 3 times: dataset lookup, completed_segments count, total_segments count + mock_db.session.scalar.side_effect = [mock_dataset, 5, 5] mock_marshal.return_value = {"id": mock_doc.id, "indexing_status": "completed"} # Act @@ -898,7 +896,7 @@ class TestDocumentIndexingStatusApi: """Test 404 when dataset not found.""" # Arrange batch_id = "batch_123" - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -915,7 +913,7 @@ class TestDocumentIndexingStatusApi: """Test 404 when no documents found for batch.""" # Arrange batch_id = "batch_empty" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_batch_documents.return_value = [] # Act & Assert @@ -986,7 +984,7 @@ class TestDocumentAddByTextApi: # Arrange — neutralise billing decorators self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset.indexing_technique = "economy" mock_current_user.id = str(uuid.uuid4()) @@ -1035,7 +1033,7 @@ class TestDocumentAddByTextApi: # Arrange — neutralise billing decorators self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -1064,7 +1062,7 @@ class TestDocumentAddByTextApi: self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_dataset.indexing_technique = None - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset # Act & Assert with app.test_request_context( @@ -1150,7 +1148,7 @@ class TestDocumentUpdateByTextApiPost: _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_dataset.indexing_technique = "economy" mock_dataset.latest_process_rule = Mock() - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_current_user.id = "user-1" mock_upload = Mock() @@ -1193,7 +1191,7 @@ class TestDocumentUpdateByTextApiPost: ): """Test ValueError when dataset not found.""" _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None doc_id = str(uuid.uuid4()) with app.test_request_context( @@ -1232,7 +1230,7 @@ class TestDocumentAddByFileApiPost: ): """Test ValueError when dataset not found.""" _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None from io import BytesIO @@ -1263,7 +1261,7 @@ class TestDocumentAddByFileApiPost: """Test ValueError when dataset is external.""" _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_dataset.provider = "external" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset from io import BytesIO @@ -1298,7 +1296,7 @@ class TestDocumentAddByFileApiPost: mock_dataset.provider = "vendor" mock_dataset.indexing_technique = "economy" mock_dataset.chunk_structure = None - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset with app.test_request_context( f"/datasets/{mock_dataset.id}/document/create_by_file", @@ -1328,7 +1326,7 @@ class TestDocumentAddByFileApiPost: mock_dataset.provider = "vendor" mock_dataset.indexing_technique = None mock_dataset.chunk_structure = None - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset from io import BytesIO @@ -1366,7 +1364,7 @@ class TestDocumentUpdateByFileApiPost: ): """Test ValueError when dataset not found.""" _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None from io import BytesIO @@ -1402,7 +1400,7 @@ class TestDocumentUpdateByFileApiPost: """Test ValueError when dataset is external.""" _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_dataset.provider = "external" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset from io import BytesIO @@ -1450,7 +1448,7 @@ class TestDocumentUpdateByFileApiPost: mock_dataset.chunk_structure = None mock_dataset.latest_process_rule = Mock() mock_dataset.created_by_account = Mock() - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_current_user.id = "user-1" mock_upload = Mock() diff --git a/api/tests/unit_tests/services/dataset_service_test_helpers.py b/api/tests/unit_tests/services/dataset_service_test_helpers.py index 602488d6ea..542179e2a3 100644 --- a/api/tests/unit_tests/services/dataset_service_test_helpers.py +++ b/api/tests/unit_tests/services/dataset_service_test_helpers.py @@ -10,13 +10,13 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, create_autospec, patch import pytest +from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType from werkzeug.exceptions import Forbidden, NotFound from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.rag.index_processor.constant.built_in_field import BuiltInField from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.retrieval.retrieval_methods import RetrievalMethod -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType from enums.cloud_plan import CloudPlan from models import Account, TenantAccountRole from models.dataset import ( diff --git a/api/uv.lock b/api/uv.lock index 2faa0d38c1..fb08594fb3 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -5340,11 +5340,11 @@ wheels = [ [[package]] name = "pypdf" -version = "6.9.1" +version = "6.9.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f9/fb/dc2e8cb006e80b0020ed20d8649106fe4274e82d8e756ad3e24ade19c0df/pypdf-6.9.1.tar.gz", hash = "sha256:ae052407d33d34de0c86c5c729be6d51010bf36e03035a8f23ab449bca52377d", size = 5311551, upload-time = "2026-03-17T10:46:07.876Z" } +sdist = { url = "https://files.pythonhosted.org/packages/31/83/691bdb309306232362503083cb15777491045dd54f45393a317dc7d8082f/pypdf-6.9.2.tar.gz", hash = "sha256:7f850faf2b0d4ab936582c05da32c52214c2b089d61a316627b5bfb5b0dab46c", size = 5311837, upload-time = "2026-03-23T14:53:27.983Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f9/f4/75543fa802b86e72f87e9395440fe1a89a6d149887e3e55745715c3352ac/pypdf-6.9.1-py3-none-any.whl", hash = "sha256:f35a6a022348fae47e092a908339a8f3dc993510c026bb39a96718fc7185e89f", size = 333661, upload-time = "2026-03-17T10:46:06.286Z" }, + { url = "https://files.pythonhosted.org/packages/a5/7e/c85f41243086a8fe5d1baeba527cb26a1918158a565932b41e0f7c0b32e9/pypdf-6.9.2-py3-none-any.whl", hash = "sha256:662cf29bcb419a36a1365232449624ab40b7c2d0cfc28e54f42eeecd1fd7e844", size = 333744, upload-time = "2026-03-23T14:53:26.573Z" }, ] [[package]] diff --git a/web/app/components/base/prompt-editor/plugins/__tests__/on-blur-or-focus-block.spec.tsx b/web/app/components/base/prompt-editor/plugins/__tests__/on-blur-or-focus-block.spec.tsx index dd2f74f7e5..a16ae9d823 100644 --- a/web/app/components/base/prompt-editor/plugins/__tests__/on-blur-or-focus-block.spec.tsx +++ b/web/app/components/base/prompt-editor/plugins/__tests__/on-blur-or-focus-block.spec.tsx @@ -3,13 +3,10 @@ import { LexicalComposer } from '@lexical/react/LexicalComposer' import { act, render, waitFor } from '@testing-library/react' import { BLUR_COMMAND, - COMMAND_PRIORITY_EDITOR, FOCUS_COMMAND, - KEY_ESCAPE_COMMAND, } from 'lexical' import OnBlurBlock from '../on-blur-or-focus-block' import { CaptureEditorPlugin } from '../test-utils' -import { CLEAR_HIDE_MENU_TIMEOUT } from '../workflow-variable-block' const renderOnBlurBlock = (props?: { onBlur?: () => void @@ -75,7 +72,7 @@ describe('OnBlurBlock', () => { expect(onFocus).toHaveBeenCalledTimes(1) }) - it('should call onBlur and dispatch escape after delay when blur target is not var-search-input', async () => { + it('should call onBlur when blur target is not var-search-input', async () => { const onBlur = vi.fn() const { getEditor } = renderOnBlurBlock({ onBlur }) @@ -85,14 +82,6 @@ describe('OnBlurBlock', () => { const editor = getEditor() expect(editor).not.toBeNull() - vi.useFakeTimers() - - const onEscape = vi.fn(() => true) - const unregister = editor!.registerCommand( - KEY_ESCAPE_COMMAND, - onEscape, - COMMAND_PRIORITY_EDITOR, - ) let handled = false act(() => { @@ -101,18 +90,9 @@ describe('OnBlurBlock', () => { expect(handled).toBe(true) expect(onBlur).toHaveBeenCalledTimes(1) - expect(onEscape).not.toHaveBeenCalled() - - act(() => { - vi.advanceTimersByTime(200) - }) - - expect(onEscape).toHaveBeenCalledTimes(1) - unregister() - vi.useRealTimers() }) - it('should dispatch delayed escape when onBlur callback is not provided', async () => { + it('should handle blur when onBlur callback is not provided', async () => { const { getEditor } = renderOnBlurBlock() await waitFor(() => { @@ -121,28 +101,16 @@ describe('OnBlurBlock', () => { const editor = getEditor() expect(editor).not.toBeNull() - vi.useFakeTimers() - - const onEscape = vi.fn(() => true) - const unregister = editor!.registerCommand( - KEY_ESCAPE_COMMAND, - onEscape, - COMMAND_PRIORITY_EDITOR, - ) + let handled = false act(() => { - editor!.dispatchCommand(BLUR_COMMAND, createBlurEvent(document.createElement('div'))) - }) - act(() => { - vi.advanceTimersByTime(200) + handled = editor!.dispatchCommand(BLUR_COMMAND, createBlurEvent(document.createElement('div'))) }) - expect(onEscape).toHaveBeenCalledTimes(1) - unregister() - vi.useRealTimers() + expect(handled).toBe(true) }) - it('should skip onBlur and delayed escape when blur target is var-search-input', async () => { + it('should skip onBlur when blur target is var-search-input', async () => { const onBlur = vi.fn() const { getEditor } = renderOnBlurBlock({ onBlur }) @@ -152,31 +120,17 @@ describe('OnBlurBlock', () => { const editor = getEditor() expect(editor).not.toBeNull() - vi.useFakeTimers() const target = document.createElement('input') target.classList.add('var-search-input') - const onEscape = vi.fn(() => true) - const unregister = editor!.registerCommand( - KEY_ESCAPE_COMMAND, - onEscape, - COMMAND_PRIORITY_EDITOR, - ) - let handled = false act(() => { handled = editor!.dispatchCommand(BLUR_COMMAND, createBlurEvent(target)) }) - act(() => { - vi.advanceTimersByTime(200) - }) expect(handled).toBe(true) expect(onBlur).not.toHaveBeenCalled() - expect(onEscape).not.toHaveBeenCalled() - unregister() - vi.useRealTimers() }) it('should handle focus command when onFocus callback is not provided', async () => { @@ -198,59 +152,6 @@ describe('OnBlurBlock', () => { }) }) - describe('Clear timeout command', () => { - it('should clear scheduled escape timeout when clear command is dispatched', async () => { - const { getEditor } = renderOnBlurBlock({ onBlur: vi.fn() }) - - await waitFor(() => { - expect(getEditor()).not.toBeNull() - }) - - const editor = getEditor() - expect(editor).not.toBeNull() - vi.useFakeTimers() - - const onEscape = vi.fn(() => true) - const unregister = editor!.registerCommand( - KEY_ESCAPE_COMMAND, - onEscape, - COMMAND_PRIORITY_EDITOR, - ) - - act(() => { - editor!.dispatchCommand(BLUR_COMMAND, createBlurEvent(document.createElement('div'))) - }) - act(() => { - editor!.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined) - }) - act(() => { - vi.advanceTimersByTime(200) - }) - - expect(onEscape).not.toHaveBeenCalled() - unregister() - vi.useRealTimers() - }) - - it('should handle clear command when no timeout is scheduled', async () => { - const { getEditor } = renderOnBlurBlock() - - await waitFor(() => { - expect(getEditor()).not.toBeNull() - }) - - const editor = getEditor() - expect(editor).not.toBeNull() - - let handled = false - act(() => { - handled = editor!.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined) - }) - - expect(handled).toBe(true) - }) - }) - describe('Lifecycle cleanup', () => { it('should unregister commands when component unmounts', async () => { const { getEditor, unmount } = renderOnBlurBlock() @@ -266,16 +167,13 @@ describe('OnBlurBlock', () => { let blurHandled = true let focusHandled = true - let clearHandled = true act(() => { blurHandled = editor!.dispatchCommand(BLUR_COMMAND, createBlurEvent(document.createElement('div'))) focusHandled = editor!.dispatchCommand(FOCUS_COMMAND, createFocusEvent()) - clearHandled = editor!.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined) }) expect(blurHandled).toBe(false) expect(focusHandled).toBe(false) - expect(clearHandled).toBe(false) }) }) }) diff --git a/web/app/components/base/prompt-editor/plugins/__tests__/update-block.spec.tsx b/web/app/components/base/prompt-editor/plugins/__tests__/update-block.spec.tsx index 8f6a72a7de..4283910c31 100644 --- a/web/app/components/base/prompt-editor/plugins/__tests__/update-block.spec.tsx +++ b/web/app/components/base/prompt-editor/plugins/__tests__/update-block.spec.tsx @@ -1,14 +1,13 @@ import type { LexicalEditor } from 'lexical' import { LexicalComposer } from '@lexical/react/LexicalComposer' import { act, render, waitFor } from '@testing-library/react' -import { $getRoot, COMMAND_PRIORITY_EDITOR } from 'lexical' +import { $getRoot } from 'lexical' import { CustomTextNode } from '../custom-text/node' import { CaptureEditorPlugin } from '../test-utils' import UpdateBlock, { PROMPT_EDITOR_INSERT_QUICKLY, PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER, } from '../update-block' -import { CLEAR_HIDE_MENU_TIMEOUT } from '../workflow-variable-block' const { mockUseEventEmitterContextContext } = vi.hoisted(() => ({ mockUseEventEmitterContextContext: vi.fn(), @@ -157,7 +156,7 @@ describe('UpdateBlock', () => { }) describe('Quick insert event', () => { - it('should insert slash and dispatch clear command when quick insert event matches instance id', async () => { + it('should insert slash when quick insert event matches instance id', async () => { const { emit, getEditor } = setup({ instanceId: 'instance-1' }) await waitFor(() => { @@ -168,13 +167,6 @@ describe('UpdateBlock', () => { selectRootEnd(editor!) - const clearCommandHandler = vi.fn(() => true) - const unregister = editor!.registerCommand( - CLEAR_HIDE_MENU_TIMEOUT, - clearCommandHandler, - COMMAND_PRIORITY_EDITOR, - ) - emit({ type: PROMPT_EDITOR_INSERT_QUICKLY, instanceId: 'instance-1', @@ -183,9 +175,6 @@ describe('UpdateBlock', () => { await waitFor(() => { expect(readEditorText(editor!)).toBe('/') }) - expect(clearCommandHandler).toHaveBeenCalledTimes(1) - - unregister() }) it('should ignore quick insert event when instance id does not match', async () => { diff --git a/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx b/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx index 6cc6c3a67f..51b14b76c8 100644 --- a/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx +++ b/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx @@ -23,6 +23,8 @@ import { $createTextNode, $getRoot, $setSelection, + BLUR_COMMAND, + FOCUS_COMMAND, KEY_ESCAPE_COMMAND, } from 'lexical' import * as React from 'react' @@ -631,4 +633,180 @@ describe('ComponentPicker (component-picker-block/index.tsx)', () => { // With a single option group, the only divider should be the workflow-var/options separator. expect(document.querySelectorAll('.bg-divider-subtle')).toHaveLength(1) }) + + describe('blur/focus menu visibility', () => { + it('hides the menu after a 200ms delay when blur command is dispatched', async () => { + const captures: Captures = { editor: null, eventEmitter: null } + + render(( + + )) + + const editor = await waitForEditor(captures) + await setEditorText(editor, '{', true) + expect(await screen.findByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + vi.useFakeTimers() + + act(() => { + editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: document.createElement('button') })) + }) + + expect(screen.queryByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + act(() => { + vi.advanceTimersByTime(200) + }) + + expect(screen.queryByText('common.promptEditor.context.item.title')).not.toBeInTheDocument() + + vi.useRealTimers() + }) + + it('restores menu visibility when focus command is dispatched after blur hides it', async () => { + const captures: Captures = { editor: null, eventEmitter: null } + + render(( + + )) + + const editor = await waitForEditor(captures) + await setEditorText(editor, '{', true) + expect(await screen.findByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + vi.useFakeTimers() + + act(() => { + editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: document.createElement('button') })) + }) + act(() => { + vi.advanceTimersByTime(200) + }) + + expect(screen.queryByText('common.promptEditor.context.item.title')).not.toBeInTheDocument() + + act(() => { + editor.dispatchCommand(FOCUS_COMMAND, new FocusEvent('focus')) + }) + + vi.useRealTimers() + + await setEditorText(editor, '{', true) + await waitFor(() => { + expect(screen.queryByText('common.promptEditor.context.item.title')).toBeInTheDocument() + }) + }) + + it('cancels the blur timer when focus arrives before the 200ms timeout', async () => { + const captures: Captures = { editor: null, eventEmitter: null } + + render(( + + )) + + const editor = await waitForEditor(captures) + await setEditorText(editor, '{', true) + expect(await screen.findByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + vi.useFakeTimers() + + act(() => { + editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: document.createElement('button') })) + }) + + act(() => { + editor.dispatchCommand(FOCUS_COMMAND, new FocusEvent('focus')) + }) + + act(() => { + vi.advanceTimersByTime(200) + }) + + expect(screen.queryByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + vi.useRealTimers() + }) + + it('cancels a pending blur timer when a subsequent blur targets var-search-input', async () => { + const captures: Captures = { editor: null, eventEmitter: null } + + render(( + + )) + + const editor = await waitForEditor(captures) + await setEditorText(editor, '{', true) + expect(await screen.findByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + vi.useFakeTimers() + + act(() => { + editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: document.createElement('button') })) + }) + + const varInput = document.createElement('input') + varInput.classList.add('var-search-input') + + act(() => { + editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: varInput })) + }) + + act(() => { + vi.advanceTimersByTime(200) + }) + + expect(screen.queryByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + vi.useRealTimers() + }) + + it('does not hide the menu when blur target is var-search-input', async () => { + const captures: Captures = { editor: null, eventEmitter: null } + + render(( + + )) + + const editor = await waitForEditor(captures) + await setEditorText(editor, '{', true) + expect(await screen.findByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + vi.useFakeTimers() + + const target = document.createElement('input') + target.classList.add('var-search-input') + + act(() => { + editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: target })) + }) + + act(() => { + vi.advanceTimersByTime(200) + }) + + expect(screen.queryByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + vi.useRealTimers() + }) + }) }) diff --git a/web/app/components/base/prompt-editor/plugins/component-picker-block/index.tsx b/web/app/components/base/prompt-editor/plugins/component-picker-block/index.tsx index 8001a2755b..bebc1b59af 100644 --- a/web/app/components/base/prompt-editor/plugins/component-picker-block/index.tsx +++ b/web/app/components/base/prompt-editor/plugins/component-picker-block/index.tsx @@ -21,11 +21,19 @@ import { } from '@floating-ui/react' import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext' import { LexicalTypeaheadMenuPlugin } from '@lexical/react/LexicalTypeaheadMenuPlugin' -import { KEY_ESCAPE_COMMAND } from 'lexical' +import { mergeRegister } from '@lexical/utils' +import { + BLUR_COMMAND, + COMMAND_PRIORITY_EDITOR, + FOCUS_COMMAND, + KEY_ESCAPE_COMMAND, +} from 'lexical' import { Fragment, memo, useCallback, + useEffect, + useRef, useState, } from 'react' import ReactDOM from 'react-dom' @@ -87,6 +95,46 @@ const ComponentPicker = ({ }) const [queryString, setQueryString] = useState(null) + const [blurHidden, setBlurHidden] = useState(false) + const blurTimerRef = useRef | null>(null) + + const clearBlurTimer = useCallback(() => { + if (blurTimerRef.current) { + clearTimeout(blurTimerRef.current) + blurTimerRef.current = null + } + }, []) + + useEffect(() => { + const unregister = mergeRegister( + editor.registerCommand( + BLUR_COMMAND, + (event) => { + clearBlurTimer() + const target = event?.relatedTarget as HTMLElement + if (!target?.classList?.contains('var-search-input')) + blurTimerRef.current = setTimeout(() => setBlurHidden(true), 200) + return false + }, + COMMAND_PRIORITY_EDITOR, + ), + editor.registerCommand( + FOCUS_COMMAND, + () => { + clearBlurTimer() + setBlurHidden(false) + return false + }, + COMMAND_PRIORITY_EDITOR, + ), + ) + + return () => { + if (blurTimerRef.current) + clearTimeout(blurTimerRef.current) + unregister() + } + }, [editor, clearBlurTimer]) eventEmitter?.useSubscription((v: any) => { if (v.type === INSERT_VARIABLE_VALUE_BLOCK_COMMAND) @@ -159,6 +207,8 @@ const ComponentPicker = ({ anchorElementRef, { options, selectedIndex, selectOptionAndCleanUp, setHighlightedIndex }, ) => { + if (blurHidden) + return null if (!(anchorElementRef.current && (allFlattenOptions.length || workflowVariableBlock?.show))) return null @@ -240,7 +290,7 @@ const ComponentPicker = ({ } ) - }, [allFlattenOptions.length, workflowVariableBlock?.show, floatingStyles, isPositioned, refs, workflowVariableOptions, isSupportFileVar, handleClose, currentBlock?.generatorType, handleSelectWorkflowVariable, queryString, workflowVariableBlock?.showManageInputField, workflowVariableBlock?.onManageInputField]) + }, [blurHidden, allFlattenOptions.length, workflowVariableBlock?.show, floatingStyles, isPositioned, refs, workflowVariableOptions, isSupportFileVar, handleClose, currentBlock?.generatorType, handleSelectWorkflowVariable, queryString, workflowVariableBlock?.showManageInputField, workflowVariableBlock?.onManageInputField]) return ( void @@ -20,35 +18,13 @@ const OnBlurBlock: FC = ({ }) => { const [editor] = useLexicalComposerContext() - const ref = useRef | null>(null) - useEffect(() => { - const clearHideMenuTimeout = () => { - if (ref.current) { - clearTimeout(ref.current) - ref.current = null - } - } - - const unregister = mergeRegister( - editor.registerCommand( - CLEAR_HIDE_MENU_TIMEOUT, - () => { - clearHideMenuTimeout() - return true - }, - COMMAND_PRIORITY_EDITOR, - ), + return mergeRegister( editor.registerCommand( BLUR_COMMAND, (event) => { - // Check if the clicked target element is var-search-input const target = event?.relatedTarget as HTMLElement if (!target?.classList?.contains('var-search-input')) { - clearHideMenuTimeout() - ref.current = setTimeout(() => { - editor.dispatchCommand(KEY_ESCAPE_COMMAND, new KeyboardEvent('keydown', { key: 'Escape' })) - }, 200) if (onBlur) onBlur() } @@ -66,11 +42,6 @@ const OnBlurBlock: FC = ({ COMMAND_PRIORITY_EDITOR, ), ) - - return () => { - clearHideMenuTimeout() - unregister() - } }, [editor, onBlur, onFocus]) return null diff --git a/web/app/components/base/prompt-editor/plugins/update-block.tsx b/web/app/components/base/prompt-editor/plugins/update-block.tsx index bf89a259af..2d83573b1f 100644 --- a/web/app/components/base/prompt-editor/plugins/update-block.tsx +++ b/web/app/components/base/prompt-editor/plugins/update-block.tsx @@ -3,7 +3,6 @@ import { $insertNodes } from 'lexical' import { useEventEmitterContextContext } from '@/context/event-emitter' import { textToEditorState } from '../utils' import { CustomTextNode } from './custom-text/node' -import { CLEAR_HIDE_MENU_TIMEOUT } from './workflow-variable-block' export const PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER = 'PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER' export const PROMPT_EDITOR_INSERT_QUICKLY = 'PROMPT_EDITOR_INSERT_QUICKLY' @@ -30,8 +29,6 @@ const UpdateBlock = ({ editor.update(() => { const textNode = new CustomTextNode('/') $insertNodes([textNode]) - - editor.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined) }) } }) diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/index.spec.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/index.spec.tsx index ca4973b830..1591dc44f9 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/index.spec.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/index.spec.tsx @@ -9,7 +9,6 @@ import { $insertNodes, COMMAND_PRIORITY_EDITOR } from 'lexical' import { Type } from '@/app/components/workflow/nodes/llm/types' import { BlockEnum } from '@/app/components/workflow/types' import { - CLEAR_HIDE_MENU_TIMEOUT, DELETE_WORKFLOW_VARIABLE_BLOCK_COMMAND, INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND, UPDATE_WORKFLOW_NODES_MAP, @@ -134,7 +133,6 @@ describe('WorkflowVariableBlock', () => { const insertHandler = mockRegisterCommand.mock.calls[0][1] as (variables: string[]) => boolean const result = insertHandler(['node-1', 'answer']) - expect(mockDispatchCommand).toHaveBeenCalledWith(CLEAR_HIDE_MENU_TIMEOUT, undefined) expect($createWorkflowVariableBlockNode).toHaveBeenCalledWith( ['node-1', 'answer'], workflowNodesMap, diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx index 76b2795803..c8cac64d19 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx @@ -18,7 +18,6 @@ import { export const INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND = createCommand('INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND') export const DELETE_WORKFLOW_VARIABLE_BLOCK_COMMAND = createCommand('DELETE_WORKFLOW_VARIABLE_BLOCK_COMMAND') -export const CLEAR_HIDE_MENU_TIMEOUT = createCommand('CLEAR_HIDE_MENU_TIMEOUT') export const UPDATE_WORKFLOW_NODES_MAP = createCommand('UPDATE_WORKFLOW_NODES_MAP') export type WorkflowVariableBlockProps = { @@ -49,7 +48,6 @@ const WorkflowVariableBlock = memo(({ editor.registerCommand( INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND, (variables: string[]) => { - editor.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined) const workflowVariableBlockNode = $createWorkflowVariableBlockNode(variables, workflowNodesMap, getVarType) $insertNodes([workflowVariableBlockNode]) diff --git a/web/app/components/plugins/plugin-auth/authorize/api-key-modal.tsx b/web/app/components/plugins/plugin-auth/authorize/api-key-modal.tsx index d8ab1cafdd..38ce5cca4e 100644 --- a/web/app/components/plugins/plugin-auth/authorize/api-key-modal.tsx +++ b/web/app/components/plugins/plugin-auth/authorize/api-key-modal.tsx @@ -142,7 +142,7 @@ const ApiKeyModal = ({ onExtraButtonClick={onRemove} disabled={disabled || isLoading || doingAction} clickOutsideNotClose={true} - wrapperClassName="!z-[101]" + wrapperClassName="!z-[1002]" > {pluginPayload.detail && ( diff --git a/web/app/components/plugins/plugin-auth/authorize/oauth-client-settings.tsx b/web/app/components/plugins/plugin-auth/authorize/oauth-client-settings.tsx index 28989da77c..a69cc1538b 100644 --- a/web/app/components/plugins/plugin-auth/authorize/oauth-client-settings.tsx +++ b/web/app/components/plugins/plugin-auth/authorize/oauth-client-settings.tsx @@ -157,7 +157,7 @@ const OAuthClientSettings = ({ ) } containerClassName="pt-0" - wrapperClassName="!z-[101]" + wrapperClassName="!z-[1002]" clickOutsideNotClose={true} > {pluginPayload.detail && ( diff --git a/web/eslint-suppressions.json b/web/eslint-suppressions.json index e28d915e66..3d311dbf54 100644 --- a/web/eslint-suppressions.json +++ b/web/eslint-suppressions.json @@ -2768,7 +2768,7 @@ }, "app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx": { "react-refresh/only-export-components": { - "count": 4 + "count": 3 } }, "app/components/base/prompt-editor/plugins/workflow-variable-block/node.tsx": {