diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py
index d34b4124ae..2c094aa3e6 100644
--- a/api/controllers/service_api/dataset/document.py
+++ b/api/controllers/service_api/dataset/document.py
@@ -6,7 +6,7 @@ from uuid import UUID
from flask import request, send_file
from flask_restx import marshal
from pydantic import BaseModel, Field, field_validator, model_validator
-from sqlalchemy import desc, select
+from sqlalchemy import desc, func, select
from werkzeug.exceptions import Forbidden, NotFound
import services
@@ -155,7 +155,9 @@ class DocumentAddByTextApi(DatasetApiResource):
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
- dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.scalar(
+ select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
+ )
if not dataset:
raise ValueError("Dataset does not exist.")
@@ -238,7 +240,9 @@ class DocumentUpdateByTextApi(DatasetApiResource):
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
"""Update document by text."""
payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {})
- dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).first()
+ dataset = db.session.scalar(
+ select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).limit(1)
+ )
args = payload.model_dump(exclude_none=True)
if not dataset:
raise ValueError("Dataset does not exist.")
@@ -315,7 +319,9 @@ class DocumentAddByFileApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id):
"""Create document by upload file."""
- dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.scalar(
+ select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
+ )
if not dataset:
raise ValueError("Dataset does not exist.")
@@ -425,7 +431,9 @@ class DocumentUpdateByFileApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, document_id):
"""Update document by upload file."""
- dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.scalar(
+ select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
+ )
if not dataset:
raise ValueError("Dataset does not exist.")
@@ -515,7 +523,9 @@ class DocumentListApi(DatasetApiResource):
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
query_params = DocumentListQuery.model_validate(request.args.to_dict())
- dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.scalar(
+ select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
+ )
if not dataset:
raise NotFound("Dataset not found.")
@@ -609,7 +619,9 @@ class DocumentIndexingStatusApi(DatasetApiResource):
batch = str(batch)
tenant_id = str(tenant_id)
# get dataset
- dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.scalar(
+ select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
+ )
if not dataset:
raise NotFound("Dataset not found.")
# get documents
@@ -619,20 +631,23 @@ class DocumentIndexingStatusApi(DatasetApiResource):
documents_status = []
for document in documents:
completed_segments = (
- db.session.query(DocumentSegment)
- .where(
- DocumentSegment.completed_at.isnot(None),
- DocumentSegment.document_id == str(document.id),
- DocumentSegment.status != SegmentStatus.RE_SEGMENT,
+ db.session.scalar(
+ select(func.count(DocumentSegment.id)).where(
+ DocumentSegment.completed_at.isnot(None),
+ DocumentSegment.document_id == str(document.id),
+ DocumentSegment.status != SegmentStatus.RE_SEGMENT,
+ )
)
- .count()
+ or 0
)
total_segments = (
- db.session.query(DocumentSegment)
- .where(
- DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
+ db.session.scalar(
+ select(func.count(DocumentSegment.id)).where(
+ DocumentSegment.document_id == str(document.id),
+ DocumentSegment.status != SegmentStatus.RE_SEGMENT,
+ )
)
- .count()
+ or 0
)
# Create a dictionary with document attributes and additional fields
document_dict = {
@@ -822,7 +837,9 @@ class DocumentApi(DatasetApiResource):
tenant_id = str(tenant_id)
# get dataset info
- dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.scalar(
+ select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
+ )
if not dataset:
raise ValueError("Dataset does not exist.")
diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py
index f8c6b251eb..28fa915117 100644
--- a/api/controllers/service_api/dataset/segment.py
+++ b/api/controllers/service_api/dataset/segment.py
@@ -3,6 +3,7 @@ from typing import Any
from flask import request
from flask_restx import marshal
from pydantic import BaseModel, Field
+from sqlalchemy import select
from werkzeug.exceptions import NotFound
from configs import dify_config
@@ -92,7 +93,9 @@ class SegmentApi(DatasetApiResource):
_, current_tenant_id = current_account_with_tenant()
"""Create single segment."""
# check dataset
- dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.scalar(
+ select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
+ )
if not dataset:
raise NotFound("Dataset not found.")
# check document
@@ -150,7 +153,9 @@ class SegmentApi(DatasetApiResource):
# check dataset
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
- dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.scalar(
+ select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
+ )
if not dataset:
raise NotFound("Dataset not found.")
# check document
@@ -220,7 +225,9 @@ class DatasetSegmentApi(DatasetApiResource):
def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
_, current_tenant_id = current_account_with_tenant()
# check dataset
- dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.scalar(
+ select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
+ )
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
@@ -254,7 +261,9 @@ class DatasetSegmentApi(DatasetApiResource):
def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
_, current_tenant_id = current_account_with_tenant()
# check dataset
- dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.scalar(
+ select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
+ )
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
@@ -301,7 +310,9 @@ class DatasetSegmentApi(DatasetApiResource):
def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
_, current_tenant_id = current_account_with_tenant()
# check dataset
- dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.scalar(
+ select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
+ )
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
@@ -344,7 +355,9 @@ class ChildChunkApi(DatasetApiResource):
_, current_tenant_id = current_account_with_tenant()
"""Create child chunk."""
# check dataset
- dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.scalar(
+ select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
+ )
if not dataset:
raise NotFound("Dataset not found.")
@@ -402,7 +415,9 @@ class ChildChunkApi(DatasetApiResource):
_, current_tenant_id = current_account_with_tenant()
"""Get child chunks."""
# check dataset
- dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.scalar(
+ select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
+ )
if not dataset:
raise NotFound("Dataset not found.")
@@ -468,7 +483,9 @@ class DatasetChildChunkApi(DatasetApiResource):
_, current_tenant_id = current_account_with_tenant()
"""Delete child chunk."""
# check dataset
- dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.scalar(
+ select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
+ )
if not dataset:
raise NotFound("Dataset not found.")
@@ -527,7 +544,9 @@ class DatasetChildChunkApi(DatasetApiResource):
_, current_tenant_id = current_account_with_tenant()
"""Update child chunk."""
# check dataset
- dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.scalar(
+ select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
+ )
if not dataset:
raise NotFound("Dataset not found.")
diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py
index 73a87761d5..7f5d6b0839 100644
--- a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py
+++ b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py
@@ -788,7 +788,7 @@ class TestSegmentApiGet:
"""Test successful segment list retrieval."""
# Arrange
mock_account_fn.return_value = (Mock(), mock_tenant.id)
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
mock_doc_svc.get_document.return_value = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
mock_seg_svc.get_segments.return_value = ([mock_segment], 1)
mock_marshal.return_value = [{"id": mock_segment.id}]
@@ -813,7 +813,7 @@ class TestSegmentApiGet:
"""Test 404 when dataset not found."""
# Arrange
mock_account_fn.return_value = (Mock(), mock_tenant.id)
- mock_db.session.query.return_value.where.return_value.first.return_value = None
+ mock_db.session.scalar.return_value = None
# Act & Assert
with app.test_request_context(
@@ -833,7 +833,7 @@ class TestSegmentApiGet:
"""Test 404 when document not found."""
# Arrange
mock_account_fn.return_value = (Mock(), mock_tenant.id)
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
mock_doc_svc.get_document.return_value = None
# Act & Assert
@@ -899,7 +899,7 @@ class TestSegmentApiPost:
mock_account_fn.return_value = (Mock(), mock_tenant.id)
mock_dataset.indexing_technique = "economy"
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
mock_doc = Mock()
mock_doc.indexing_status = "completed"
@@ -950,7 +950,7 @@ class TestSegmentApiPost:
mock_account_fn.return_value = (Mock(), mock_tenant.id)
mock_dataset.indexing_technique = "economy"
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
mock_doc = Mock()
mock_doc.indexing_status = "completed"
@@ -992,7 +992,7 @@ class TestSegmentApiPost:
self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
mock_account_fn.return_value = (Mock(), mock_tenant.id)
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
mock_doc = Mock()
mock_doc.indexing_status = "indexing" # Not completed
@@ -1043,7 +1043,7 @@ class TestDatasetSegmentApiDelete:
"""Test successful segment deletion."""
# Arrange
mock_account_fn.return_value = (Mock(), mock_tenant.id)
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
mock_dataset_svc.check_dataset_model_setting.return_value = None
mock_doc = Mock()
@@ -1087,7 +1087,7 @@ class TestDatasetSegmentApiDelete:
"""Test 404 when segment not found."""
# Arrange
mock_account_fn.return_value = (Mock(), mock_tenant.id)
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
mock_doc = Mock()
mock_doc.indexing_status = "completed"
@@ -1129,7 +1129,7 @@ class TestDatasetSegmentApiDelete:
"""Test 404 when dataset not found for delete."""
# Arrange
mock_account_fn.return_value = (Mock(), mock_tenant.id)
- mock_db.session.query.return_value.where.return_value.first.return_value = None
+ mock_db.session.scalar.return_value = None
# Act & Assert
with app.test_request_context(
@@ -1163,7 +1163,7 @@ class TestDatasetSegmentApiDelete:
"""Test 404 when document not found for delete."""
# Arrange
mock_account_fn.return_value = (Mock(), mock_tenant.id)
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
mock_dataset_svc.check_dataset_model_setting.return_value = None
mock_doc_svc.get_document.return_value = None
@@ -1233,7 +1233,7 @@ class TestDatasetSegmentApiUpdate:
self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
mock_account_fn.return_value = (Mock(), mock_tenant.id)
mock_dataset.indexing_technique = "economy"
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
mock_dataset_svc.check_dataset_model_setting.return_value = None
mock_doc_svc.get_document.return_value = Mock()
mock_seg_svc.get_segment_by_id.return_value = mock_segment
@@ -1280,7 +1280,7 @@ class TestDatasetSegmentApiUpdate:
"""Test 404 when dataset not found for update."""
self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
mock_account_fn.return_value = (Mock(), mock_tenant.id)
- mock_db.session.query.return_value.where.return_value.first.return_value = None
+ mock_db.session.scalar.return_value = None
with app.test_request_context(
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id",
@@ -1321,7 +1321,7 @@ class TestDatasetSegmentApiUpdate:
self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
mock_account_fn.return_value = (Mock(), mock_tenant.id)
mock_dataset.indexing_technique = "economy"
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
mock_dataset_svc.check_dataset_model_setting.return_value = None
mock_doc_svc.get_document.return_value = Mock()
mock_seg_svc.get_segment_by_id.return_value = None
@@ -1370,7 +1370,7 @@ class TestDatasetSegmentApiGetSingle:
):
"""Test successful single segment retrieval."""
mock_account_fn.return_value = (Mock(), mock_tenant.id)
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
mock_dataset_svc.check_dataset_model_setting.return_value = None
mock_doc = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
mock_doc_svc.get_document.return_value = mock_doc
@@ -1405,7 +1405,7 @@ class TestDatasetSegmentApiGetSingle:
):
"""Test 404 when dataset not found."""
mock_account_fn.return_value = (Mock(), mock_tenant.id)
- mock_db.session.query.return_value.where.return_value.first.return_value = None
+ mock_db.session.scalar.return_value = None
with app.test_request_context(
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id",
@@ -1436,7 +1436,7 @@ class TestDatasetSegmentApiGetSingle:
):
"""Test 404 when document not found."""
mock_account_fn.return_value = (Mock(), mock_tenant.id)
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
mock_dataset_svc.check_dataset_model_setting.return_value = None
mock_doc_svc.get_document.return_value = None
@@ -1471,7 +1471,7 @@ class TestDatasetSegmentApiGetSingle:
):
"""Test 404 when segment not found."""
mock_account_fn.return_value = (Mock(), mock_tenant.id)
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
mock_dataset_svc.check_dataset_model_setting.return_value = None
mock_doc_svc.get_document.return_value = Mock()
mock_seg_svc.get_segment_by_id.return_value = None
@@ -1515,7 +1515,7 @@ class TestChildChunkApiGet:
):
"""Test successful child chunk list retrieval."""
mock_account_fn.return_value = (Mock(), mock_tenant.id)
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
mock_doc_svc.get_document.return_value = Mock()
mock_seg_svc.get_segment_by_id.return_value = Mock()
@@ -1554,7 +1554,7 @@ class TestChildChunkApiGet:
):
"""Test 404 when dataset not found."""
mock_account_fn.return_value = (Mock(), mock_tenant.id)
- mock_db.session.query.return_value.where.return_value.first.return_value = None
+ mock_db.session.scalar.return_value = None
with app.test_request_context(
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id/child_chunks",
@@ -1583,7 +1583,7 @@ class TestChildChunkApiGet:
):
"""Test 404 when document not found."""
mock_account_fn.return_value = (Mock(), mock_tenant.id)
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
mock_doc_svc.get_document.return_value = None
with app.test_request_context(
@@ -1615,7 +1615,7 @@ class TestChildChunkApiGet:
):
"""Test 404 when segment not found."""
mock_account_fn.return_value = (Mock(), mock_tenant.id)
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
mock_doc_svc.get_document.return_value = Mock()
mock_seg_svc.get_segment_by_id.return_value = None
@@ -1676,7 +1676,7 @@ class TestChildChunkApiPost:
self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
mock_account_fn.return_value = (Mock(), mock_tenant.id)
mock_dataset.indexing_technique = "economy"
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
mock_doc_svc.get_document.return_value = Mock()
mock_seg_svc.get_segment_by_id.return_value = Mock()
mock_child = Mock()
@@ -1717,7 +1717,7 @@ class TestChildChunkApiPost:
"""Test 404 when dataset not found."""
self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
mock_account_fn.return_value = (Mock(), mock_tenant.id)
- mock_db.session.query.return_value.where.return_value.first.return_value = None
+ mock_db.session.scalar.return_value = None
with app.test_request_context(
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id/child_chunks",
@@ -1755,7 +1755,7 @@ class TestChildChunkApiPost:
"""Test 404 when segment not found."""
self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
mock_account_fn.return_value = (Mock(), mock_tenant.id)
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
mock_doc_svc.get_document.return_value = Mock()
mock_seg_svc.get_segment_by_id.return_value = None
@@ -1808,7 +1808,7 @@ class TestDatasetChildChunkApiDelete:
):
"""Test successful child chunk deletion."""
mock_account_fn.return_value = (Mock(), mock_tenant.id)
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
mock_doc = Mock()
mock_doc_svc.get_document.return_value = mock_doc
@@ -1858,7 +1858,7 @@ class TestDatasetChildChunkApiDelete:
):
"""Test 404 when child chunk not found."""
mock_account_fn.return_value = (Mock(), mock_tenant.id)
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
mock_doc_svc.get_document.return_value = Mock()
segment_id = str(uuid.uuid4())
@@ -1899,7 +1899,7 @@ class TestDatasetChildChunkApiDelete:
):
"""Test 404 when segment does not belong to the document."""
mock_account_fn.return_value = (Mock(), mock_tenant.id)
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
mock_doc_svc.get_document.return_value = Mock()
segment_id = str(uuid.uuid4())
@@ -1939,7 +1939,7 @@ class TestDatasetChildChunkApiDelete:
):
"""Test 404 when child chunk does not belong to the segment."""
mock_account_fn.return_value = (Mock(), mock_tenant.id)
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
mock_doc_svc.get_document.return_value = Mock()
segment_id = str(uuid.uuid4())
diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py
index 7f77e61ee4..12d5e7345d 100644
--- a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py
+++ b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py
@@ -717,7 +717,7 @@ class TestDocumentApiDelete:
dataset_id = str(uuid.uuid4())
mock_dataset = Mock()
mock_dataset.id = dataset_id
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
mock_doc_svc.get_document.return_value = mock_document
mock_doc_svc.check_archived.return_value = False
@@ -746,7 +746,7 @@ class TestDocumentApiDelete:
document_id = str(uuid.uuid4())
mock_dataset = Mock()
mock_dataset.id = dataset_id
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
mock_doc_svc.get_document.return_value = None
@@ -767,7 +767,7 @@ class TestDocumentApiDelete:
dataset_id = str(uuid.uuid4())
mock_dataset = Mock()
mock_dataset.id = dataset_id
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
mock_doc_svc.get_document.return_value = mock_document
mock_doc_svc.check_archived.return_value = True
@@ -788,7 +788,7 @@ class TestDocumentApiDelete:
# Arrange
dataset_id = str(uuid.uuid4())
document_id = str(uuid.uuid4())
- mock_db.session.query.return_value.where.return_value.first.return_value = None
+ mock_db.session.scalar.return_value = None
# Act & Assert
with app.test_request_context(
@@ -809,7 +809,7 @@ class TestDocumentListApi:
def test_list_documents_success(self, mock_db, mock_doc_svc, mock_marshal, app, mock_tenant, mock_dataset):
"""Test successful document list retrieval."""
# Arrange
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
mock_pagination = Mock()
mock_pagination.items = [Mock(), Mock()]
@@ -838,7 +838,7 @@ class TestDocumentListApi:
def test_list_documents_dataset_not_found(self, mock_db, app, mock_tenant, mock_dataset):
"""Test 404 when dataset not found."""
# Arrange
- mock_db.session.query.return_value.where.return_value.first.return_value = None
+ mock_db.session.scalar.return_value = None
# Act & Assert
with app.test_request_context(
@@ -860,8 +860,6 @@ class TestDocumentIndexingStatusApi:
"""Test successful indexing status retrieval."""
# Arrange
batch_id = "batch_123"
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
-
mock_doc = Mock()
mock_doc.id = str(uuid.uuid4())
mock_doc.is_paused = False
@@ -877,8 +875,8 @@ class TestDocumentIndexingStatusApi:
mock_doc_svc.get_batch_documents.return_value = [mock_doc]
- # Mock segment count queries
- mock_db.session.query.return_value.where.return_value.where.return_value.count.return_value = 5
+ # scalar() called 3 times: dataset lookup, completed_segments count, total_segments count
+ mock_db.session.scalar.side_effect = [mock_dataset, 5, 5]
mock_marshal.return_value = {"id": mock_doc.id, "indexing_status": "completed"}
# Act
@@ -898,7 +896,7 @@ class TestDocumentIndexingStatusApi:
"""Test 404 when dataset not found."""
# Arrange
batch_id = "batch_123"
- mock_db.session.query.return_value.where.return_value.first.return_value = None
+ mock_db.session.scalar.return_value = None
# Act & Assert
with app.test_request_context(
@@ -915,7 +913,7 @@ class TestDocumentIndexingStatusApi:
"""Test 404 when no documents found for batch."""
# Arrange
batch_id = "batch_empty"
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
mock_doc_svc.get_batch_documents.return_value = []
# Act & Assert
@@ -986,7 +984,7 @@ class TestDocumentAddByTextApi:
# Arrange — neutralise billing decorators
self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
mock_dataset.indexing_technique = "economy"
mock_current_user.id = str(uuid.uuid4())
@@ -1035,7 +1033,7 @@ class TestDocumentAddByTextApi:
# Arrange — neutralise billing decorators
self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
- mock_db.session.query.return_value.where.return_value.first.return_value = None
+ mock_db.session.scalar.return_value = None
# Act & Assert
with app.test_request_context(
@@ -1064,7 +1062,7 @@ class TestDocumentAddByTextApi:
self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
mock_dataset.indexing_technique = None
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
# Act & Assert
with app.test_request_context(
@@ -1150,7 +1148,7 @@ class TestDocumentUpdateByTextApiPost:
_setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
mock_dataset.indexing_technique = "economy"
mock_dataset.latest_process_rule = Mock()
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
mock_current_user.id = "user-1"
mock_upload = Mock()
@@ -1193,7 +1191,7 @@ class TestDocumentUpdateByTextApiPost:
):
"""Test ValueError when dataset not found."""
_setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
- mock_db.session.query.return_value.where.return_value.first.return_value = None
+ mock_db.session.scalar.return_value = None
doc_id = str(uuid.uuid4())
with app.test_request_context(
@@ -1232,7 +1230,7 @@ class TestDocumentAddByFileApiPost:
):
"""Test ValueError when dataset not found."""
_setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
- mock_db.session.query.return_value.where.return_value.first.return_value = None
+ mock_db.session.scalar.return_value = None
from io import BytesIO
@@ -1263,7 +1261,7 @@ class TestDocumentAddByFileApiPost:
"""Test ValueError when dataset is external."""
_setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
mock_dataset.provider = "external"
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
from io import BytesIO
@@ -1298,7 +1296,7 @@ class TestDocumentAddByFileApiPost:
mock_dataset.provider = "vendor"
mock_dataset.indexing_technique = "economy"
mock_dataset.chunk_structure = None
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
with app.test_request_context(
f"/datasets/{mock_dataset.id}/document/create_by_file",
@@ -1328,7 +1326,7 @@ class TestDocumentAddByFileApiPost:
mock_dataset.provider = "vendor"
mock_dataset.indexing_technique = None
mock_dataset.chunk_structure = None
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
from io import BytesIO
@@ -1366,7 +1364,7 @@ class TestDocumentUpdateByFileApiPost:
):
"""Test ValueError when dataset not found."""
_setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
- mock_db.session.query.return_value.where.return_value.first.return_value = None
+ mock_db.session.scalar.return_value = None
from io import BytesIO
@@ -1402,7 +1400,7 @@ class TestDocumentUpdateByFileApiPost:
"""Test ValueError when dataset is external."""
_setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
mock_dataset.provider = "external"
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
from io import BytesIO
@@ -1450,7 +1448,7 @@ class TestDocumentUpdateByFileApiPost:
mock_dataset.chunk_structure = None
mock_dataset.latest_process_rule = Mock()
mock_dataset.created_by_account = Mock()
- mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
+ mock_db.session.scalar.return_value = mock_dataset
mock_current_user.id = "user-1"
mock_upload = Mock()
diff --git a/api/tests/unit_tests/services/dataset_service_test_helpers.py b/api/tests/unit_tests/services/dataset_service_test_helpers.py
index 602488d6ea..542179e2a3 100644
--- a/api/tests/unit_tests/services/dataset_service_test_helpers.py
+++ b/api/tests/unit_tests/services/dataset_service_test_helpers.py
@@ -10,13 +10,13 @@ from types import SimpleNamespace
from unittest.mock import MagicMock, Mock, create_autospec, patch
import pytest
+from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType
from werkzeug.exceptions import Forbidden, NotFound
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
-from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType
from enums.cloud_plan import CloudPlan
from models import Account, TenantAccountRole
from models.dataset import (
diff --git a/api/uv.lock b/api/uv.lock
index 2faa0d38c1..fb08594fb3 100644
--- a/api/uv.lock
+++ b/api/uv.lock
@@ -5340,11 +5340,11 @@ wheels = [
[[package]]
name = "pypdf"
-version = "6.9.1"
+version = "6.9.2"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/f9/fb/dc2e8cb006e80b0020ed20d8649106fe4274e82d8e756ad3e24ade19c0df/pypdf-6.9.1.tar.gz", hash = "sha256:ae052407d33d34de0c86c5c729be6d51010bf36e03035a8f23ab449bca52377d", size = 5311551, upload-time = "2026-03-17T10:46:07.876Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/31/83/691bdb309306232362503083cb15777491045dd54f45393a317dc7d8082f/pypdf-6.9.2.tar.gz", hash = "sha256:7f850faf2b0d4ab936582c05da32c52214c2b089d61a316627b5bfb5b0dab46c", size = 5311837, upload-time = "2026-03-23T14:53:27.983Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/f9/f4/75543fa802b86e72f87e9395440fe1a89a6d149887e3e55745715c3352ac/pypdf-6.9.1-py3-none-any.whl", hash = "sha256:f35a6a022348fae47e092a908339a8f3dc993510c026bb39a96718fc7185e89f", size = 333661, upload-time = "2026-03-17T10:46:06.286Z" },
+ { url = "https://files.pythonhosted.org/packages/a5/7e/c85f41243086a8fe5d1baeba527cb26a1918158a565932b41e0f7c0b32e9/pypdf-6.9.2-py3-none-any.whl", hash = "sha256:662cf29bcb419a36a1365232449624ab40b7c2d0cfc28e54f42eeecd1fd7e844", size = 333744, upload-time = "2026-03-23T14:53:26.573Z" },
]
[[package]]
diff --git a/web/app/components/base/prompt-editor/plugins/__tests__/on-blur-or-focus-block.spec.tsx b/web/app/components/base/prompt-editor/plugins/__tests__/on-blur-or-focus-block.spec.tsx
index dd2f74f7e5..a16ae9d823 100644
--- a/web/app/components/base/prompt-editor/plugins/__tests__/on-blur-or-focus-block.spec.tsx
+++ b/web/app/components/base/prompt-editor/plugins/__tests__/on-blur-or-focus-block.spec.tsx
@@ -3,13 +3,10 @@ import { LexicalComposer } from '@lexical/react/LexicalComposer'
import { act, render, waitFor } from '@testing-library/react'
import {
BLUR_COMMAND,
- COMMAND_PRIORITY_EDITOR,
FOCUS_COMMAND,
- KEY_ESCAPE_COMMAND,
} from 'lexical'
import OnBlurBlock from '../on-blur-or-focus-block'
import { CaptureEditorPlugin } from '../test-utils'
-import { CLEAR_HIDE_MENU_TIMEOUT } from '../workflow-variable-block'
const renderOnBlurBlock = (props?: {
onBlur?: () => void
@@ -75,7 +72,7 @@ describe('OnBlurBlock', () => {
expect(onFocus).toHaveBeenCalledTimes(1)
})
- it('should call onBlur and dispatch escape after delay when blur target is not var-search-input', async () => {
+ it('should call onBlur when blur target is not var-search-input', async () => {
const onBlur = vi.fn()
const { getEditor } = renderOnBlurBlock({ onBlur })
@@ -85,14 +82,6 @@ describe('OnBlurBlock', () => {
const editor = getEditor()
expect(editor).not.toBeNull()
- vi.useFakeTimers()
-
- const onEscape = vi.fn(() => true)
- const unregister = editor!.registerCommand(
- KEY_ESCAPE_COMMAND,
- onEscape,
- COMMAND_PRIORITY_EDITOR,
- )
let handled = false
act(() => {
@@ -101,18 +90,9 @@ describe('OnBlurBlock', () => {
expect(handled).toBe(true)
expect(onBlur).toHaveBeenCalledTimes(1)
- expect(onEscape).not.toHaveBeenCalled()
-
- act(() => {
- vi.advanceTimersByTime(200)
- })
-
- expect(onEscape).toHaveBeenCalledTimes(1)
- unregister()
- vi.useRealTimers()
})
- it('should dispatch delayed escape when onBlur callback is not provided', async () => {
+ it('should handle blur when onBlur callback is not provided', async () => {
const { getEditor } = renderOnBlurBlock()
await waitFor(() => {
@@ -121,28 +101,16 @@ describe('OnBlurBlock', () => {
const editor = getEditor()
expect(editor).not.toBeNull()
- vi.useFakeTimers()
-
- const onEscape = vi.fn(() => true)
- const unregister = editor!.registerCommand(
- KEY_ESCAPE_COMMAND,
- onEscape,
- COMMAND_PRIORITY_EDITOR,
- )
+ let handled = false
act(() => {
- editor!.dispatchCommand(BLUR_COMMAND, createBlurEvent(document.createElement('div')))
- })
- act(() => {
- vi.advanceTimersByTime(200)
+ handled = editor!.dispatchCommand(BLUR_COMMAND, createBlurEvent(document.createElement('div')))
})
- expect(onEscape).toHaveBeenCalledTimes(1)
- unregister()
- vi.useRealTimers()
+ expect(handled).toBe(true)
})
- it('should skip onBlur and delayed escape when blur target is var-search-input', async () => {
+ it('should skip onBlur when blur target is var-search-input', async () => {
const onBlur = vi.fn()
const { getEditor } = renderOnBlurBlock({ onBlur })
@@ -152,31 +120,17 @@ describe('OnBlurBlock', () => {
const editor = getEditor()
expect(editor).not.toBeNull()
- vi.useFakeTimers()
const target = document.createElement('input')
target.classList.add('var-search-input')
- const onEscape = vi.fn(() => true)
- const unregister = editor!.registerCommand(
- KEY_ESCAPE_COMMAND,
- onEscape,
- COMMAND_PRIORITY_EDITOR,
- )
-
let handled = false
act(() => {
handled = editor!.dispatchCommand(BLUR_COMMAND, createBlurEvent(target))
})
- act(() => {
- vi.advanceTimersByTime(200)
- })
expect(handled).toBe(true)
expect(onBlur).not.toHaveBeenCalled()
- expect(onEscape).not.toHaveBeenCalled()
- unregister()
- vi.useRealTimers()
})
it('should handle focus command when onFocus callback is not provided', async () => {
@@ -198,59 +152,6 @@ describe('OnBlurBlock', () => {
})
})
- describe('Clear timeout command', () => {
- it('should clear scheduled escape timeout when clear command is dispatched', async () => {
- const { getEditor } = renderOnBlurBlock({ onBlur: vi.fn() })
-
- await waitFor(() => {
- expect(getEditor()).not.toBeNull()
- })
-
- const editor = getEditor()
- expect(editor).not.toBeNull()
- vi.useFakeTimers()
-
- const onEscape = vi.fn(() => true)
- const unregister = editor!.registerCommand(
- KEY_ESCAPE_COMMAND,
- onEscape,
- COMMAND_PRIORITY_EDITOR,
- )
-
- act(() => {
- editor!.dispatchCommand(BLUR_COMMAND, createBlurEvent(document.createElement('div')))
- })
- act(() => {
- editor!.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined)
- })
- act(() => {
- vi.advanceTimersByTime(200)
- })
-
- expect(onEscape).not.toHaveBeenCalled()
- unregister()
- vi.useRealTimers()
- })
-
- it('should handle clear command when no timeout is scheduled', async () => {
- const { getEditor } = renderOnBlurBlock()
-
- await waitFor(() => {
- expect(getEditor()).not.toBeNull()
- })
-
- const editor = getEditor()
- expect(editor).not.toBeNull()
-
- let handled = false
- act(() => {
- handled = editor!.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined)
- })
-
- expect(handled).toBe(true)
- })
- })
-
describe('Lifecycle cleanup', () => {
it('should unregister commands when component unmounts', async () => {
const { getEditor, unmount } = renderOnBlurBlock()
@@ -266,16 +167,13 @@ describe('OnBlurBlock', () => {
let blurHandled = true
let focusHandled = true
- let clearHandled = true
act(() => {
blurHandled = editor!.dispatchCommand(BLUR_COMMAND, createBlurEvent(document.createElement('div')))
focusHandled = editor!.dispatchCommand(FOCUS_COMMAND, createFocusEvent())
- clearHandled = editor!.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined)
})
expect(blurHandled).toBe(false)
expect(focusHandled).toBe(false)
- expect(clearHandled).toBe(false)
})
})
})
diff --git a/web/app/components/base/prompt-editor/plugins/__tests__/update-block.spec.tsx b/web/app/components/base/prompt-editor/plugins/__tests__/update-block.spec.tsx
index 8f6a72a7de..4283910c31 100644
--- a/web/app/components/base/prompt-editor/plugins/__tests__/update-block.spec.tsx
+++ b/web/app/components/base/prompt-editor/plugins/__tests__/update-block.spec.tsx
@@ -1,14 +1,13 @@
import type { LexicalEditor } from 'lexical'
import { LexicalComposer } from '@lexical/react/LexicalComposer'
import { act, render, waitFor } from '@testing-library/react'
-import { $getRoot, COMMAND_PRIORITY_EDITOR } from 'lexical'
+import { $getRoot } from 'lexical'
import { CustomTextNode } from '../custom-text/node'
import { CaptureEditorPlugin } from '../test-utils'
import UpdateBlock, {
PROMPT_EDITOR_INSERT_QUICKLY,
PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER,
} from '../update-block'
-import { CLEAR_HIDE_MENU_TIMEOUT } from '../workflow-variable-block'
const { mockUseEventEmitterContextContext } = vi.hoisted(() => ({
mockUseEventEmitterContextContext: vi.fn(),
@@ -157,7 +156,7 @@ describe('UpdateBlock', () => {
})
describe('Quick insert event', () => {
- it('should insert slash and dispatch clear command when quick insert event matches instance id', async () => {
+ it('should insert slash when quick insert event matches instance id', async () => {
const { emit, getEditor } = setup({ instanceId: 'instance-1' })
await waitFor(() => {
@@ -168,13 +167,6 @@ describe('UpdateBlock', () => {
selectRootEnd(editor!)
- const clearCommandHandler = vi.fn(() => true)
- const unregister = editor!.registerCommand(
- CLEAR_HIDE_MENU_TIMEOUT,
- clearCommandHandler,
- COMMAND_PRIORITY_EDITOR,
- )
-
emit({
type: PROMPT_EDITOR_INSERT_QUICKLY,
instanceId: 'instance-1',
@@ -183,9 +175,6 @@ describe('UpdateBlock', () => {
await waitFor(() => {
expect(readEditorText(editor!)).toBe('/')
})
- expect(clearCommandHandler).toHaveBeenCalledTimes(1)
-
- unregister()
})
it('should ignore quick insert event when instance id does not match', async () => {
diff --git a/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx b/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx
index 6cc6c3a67f..51b14b76c8 100644
--- a/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx
+++ b/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx
@@ -23,6 +23,8 @@ import {
$createTextNode,
$getRoot,
$setSelection,
+ BLUR_COMMAND,
+ FOCUS_COMMAND,
KEY_ESCAPE_COMMAND,
} from 'lexical'
import * as React from 'react'
@@ -631,4 +633,180 @@ describe('ComponentPicker (component-picker-block/index.tsx)', () => {
// With a single option group, the only divider should be the workflow-var/options separator.
expect(document.querySelectorAll('.bg-divider-subtle')).toHaveLength(1)
})
+
+ describe('blur/focus menu visibility', () => {
+ it('hides the menu after a 200ms delay when blur command is dispatched', async () => {
+ const captures: Captures = { editor: null, eventEmitter: null }
+
+ render((
+
+ ))
+
+ const editor = await waitForEditor(captures)
+ await setEditorText(editor, '{', true)
+ expect(await screen.findByText('common.promptEditor.context.item.title')).toBeInTheDocument()
+
+ vi.useFakeTimers()
+
+ act(() => {
+ editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: document.createElement('button') }))
+ })
+
+ expect(screen.queryByText('common.promptEditor.context.item.title')).toBeInTheDocument()
+
+ act(() => {
+ vi.advanceTimersByTime(200)
+ })
+
+ expect(screen.queryByText('common.promptEditor.context.item.title')).not.toBeInTheDocument()
+
+ vi.useRealTimers()
+ })
+
+ it('restores menu visibility when focus command is dispatched after blur hides it', async () => {
+ const captures: Captures = { editor: null, eventEmitter: null }
+
+ render((
+
+ ))
+
+ const editor = await waitForEditor(captures)
+ await setEditorText(editor, '{', true)
+ expect(await screen.findByText('common.promptEditor.context.item.title')).toBeInTheDocument()
+
+ vi.useFakeTimers()
+
+ act(() => {
+ editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: document.createElement('button') }))
+ })
+ act(() => {
+ vi.advanceTimersByTime(200)
+ })
+
+ expect(screen.queryByText('common.promptEditor.context.item.title')).not.toBeInTheDocument()
+
+ act(() => {
+ editor.dispatchCommand(FOCUS_COMMAND, new FocusEvent('focus'))
+ })
+
+ vi.useRealTimers()
+
+ await setEditorText(editor, '{', true)
+ await waitFor(() => {
+ expect(screen.queryByText('common.promptEditor.context.item.title')).toBeInTheDocument()
+ })
+ })
+
+ it('cancels the blur timer when focus arrives before the 200ms timeout', async () => {
+ const captures: Captures = { editor: null, eventEmitter: null }
+
+ render((
+
+ ))
+
+ const editor = await waitForEditor(captures)
+ await setEditorText(editor, '{', true)
+ expect(await screen.findByText('common.promptEditor.context.item.title')).toBeInTheDocument()
+
+ vi.useFakeTimers()
+
+ act(() => {
+ editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: document.createElement('button') }))
+ })
+
+ act(() => {
+ editor.dispatchCommand(FOCUS_COMMAND, new FocusEvent('focus'))
+ })
+
+ act(() => {
+ vi.advanceTimersByTime(200)
+ })
+
+ expect(screen.queryByText('common.promptEditor.context.item.title')).toBeInTheDocument()
+
+ vi.useRealTimers()
+ })
+
+ it('cancels a pending blur timer when a subsequent blur targets var-search-input', async () => {
+ const captures: Captures = { editor: null, eventEmitter: null }
+
+ render((
+
+ ))
+
+ const editor = await waitForEditor(captures)
+ await setEditorText(editor, '{', true)
+ expect(await screen.findByText('common.promptEditor.context.item.title')).toBeInTheDocument()
+
+ vi.useFakeTimers()
+
+ act(() => {
+ editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: document.createElement('button') }))
+ })
+
+ const varInput = document.createElement('input')
+ varInput.classList.add('var-search-input')
+
+ act(() => {
+ editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: varInput }))
+ })
+
+ act(() => {
+ vi.advanceTimersByTime(200)
+ })
+
+ expect(screen.queryByText('common.promptEditor.context.item.title')).toBeInTheDocument()
+
+ vi.useRealTimers()
+ })
+
+ it('does not hide the menu when blur target is var-search-input', async () => {
+ const captures: Captures = { editor: null, eventEmitter: null }
+
+ render((
+
+ ))
+
+ const editor = await waitForEditor(captures)
+ await setEditorText(editor, '{', true)
+ expect(await screen.findByText('common.promptEditor.context.item.title')).toBeInTheDocument()
+
+ vi.useFakeTimers()
+
+ const target = document.createElement('input')
+ target.classList.add('var-search-input')
+
+ act(() => {
+ editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: target }))
+ })
+
+ act(() => {
+ vi.advanceTimersByTime(200)
+ })
+
+ expect(screen.queryByText('common.promptEditor.context.item.title')).toBeInTheDocument()
+
+ vi.useRealTimers()
+ })
+ })
})
diff --git a/web/app/components/base/prompt-editor/plugins/component-picker-block/index.tsx b/web/app/components/base/prompt-editor/plugins/component-picker-block/index.tsx
index 8001a2755b..bebc1b59af 100644
--- a/web/app/components/base/prompt-editor/plugins/component-picker-block/index.tsx
+++ b/web/app/components/base/prompt-editor/plugins/component-picker-block/index.tsx
@@ -21,11 +21,19 @@ import {
} from '@floating-ui/react'
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
import { LexicalTypeaheadMenuPlugin } from '@lexical/react/LexicalTypeaheadMenuPlugin'
-import { KEY_ESCAPE_COMMAND } from 'lexical'
+import { mergeRegister } from '@lexical/utils'
+import {
+ BLUR_COMMAND,
+ COMMAND_PRIORITY_EDITOR,
+ FOCUS_COMMAND,
+ KEY_ESCAPE_COMMAND,
+} from 'lexical'
import {
Fragment,
memo,
useCallback,
+ useEffect,
+ useRef,
useState,
} from 'react'
import ReactDOM from 'react-dom'
@@ -87,6 +95,46 @@ const ComponentPicker = ({
})
const [queryString, setQueryString] = useState(null)
+ const [blurHidden, setBlurHidden] = useState(false)
+ const blurTimerRef = useRef | null>(null)
+
+ const clearBlurTimer = useCallback(() => {
+ if (blurTimerRef.current) {
+ clearTimeout(blurTimerRef.current)
+ blurTimerRef.current = null
+ }
+ }, [])
+
+ useEffect(() => {
+ const unregister = mergeRegister(
+ editor.registerCommand(
+ BLUR_COMMAND,
+ (event) => {
+ clearBlurTimer()
+ const target = event?.relatedTarget as HTMLElement
+ if (!target?.classList?.contains('var-search-input'))
+ blurTimerRef.current = setTimeout(() => setBlurHidden(true), 200)
+ return false
+ },
+ COMMAND_PRIORITY_EDITOR,
+ ),
+ editor.registerCommand(
+ FOCUS_COMMAND,
+ () => {
+ clearBlurTimer()
+ setBlurHidden(false)
+ return false
+ },
+ COMMAND_PRIORITY_EDITOR,
+ ),
+ )
+
+ return () => {
+ if (blurTimerRef.current)
+ clearTimeout(blurTimerRef.current)
+ unregister()
+ }
+ }, [editor, clearBlurTimer])
eventEmitter?.useSubscription((v: any) => {
if (v.type === INSERT_VARIABLE_VALUE_BLOCK_COMMAND)
@@ -159,6 +207,8 @@ const ComponentPicker = ({
anchorElementRef,
{ options, selectedIndex, selectOptionAndCleanUp, setHighlightedIndex },
) => {
+ if (blurHidden)
+ return null
if (!(anchorElementRef.current && (allFlattenOptions.length || workflowVariableBlock?.show)))
return null
@@ -240,7 +290,7 @@ const ComponentPicker = ({
}
>
)
- }, [allFlattenOptions.length, workflowVariableBlock?.show, floatingStyles, isPositioned, refs, workflowVariableOptions, isSupportFileVar, handleClose, currentBlock?.generatorType, handleSelectWorkflowVariable, queryString, workflowVariableBlock?.showManageInputField, workflowVariableBlock?.onManageInputField])
+ }, [blurHidden, allFlattenOptions.length, workflowVariableBlock?.show, floatingStyles, isPositioned, refs, workflowVariableOptions, isSupportFileVar, handleClose, currentBlock?.generatorType, handleSelectWorkflowVariable, queryString, workflowVariableBlock?.showManageInputField, workflowVariableBlock?.onManageInputField])
return (
void
@@ -20,35 +18,13 @@ const OnBlurBlock: FC = ({
}) => {
const [editor] = useLexicalComposerContext()
- const ref = useRef | null>(null)
-
useEffect(() => {
- const clearHideMenuTimeout = () => {
- if (ref.current) {
- clearTimeout(ref.current)
- ref.current = null
- }
- }
-
- const unregister = mergeRegister(
- editor.registerCommand(
- CLEAR_HIDE_MENU_TIMEOUT,
- () => {
- clearHideMenuTimeout()
- return true
- },
- COMMAND_PRIORITY_EDITOR,
- ),
+ return mergeRegister(
editor.registerCommand(
BLUR_COMMAND,
(event) => {
- // Check if the clicked target element is var-search-input
const target = event?.relatedTarget as HTMLElement
if (!target?.classList?.contains('var-search-input')) {
- clearHideMenuTimeout()
- ref.current = setTimeout(() => {
- editor.dispatchCommand(KEY_ESCAPE_COMMAND, new KeyboardEvent('keydown', { key: 'Escape' }))
- }, 200)
if (onBlur)
onBlur()
}
@@ -66,11 +42,6 @@ const OnBlurBlock: FC = ({
COMMAND_PRIORITY_EDITOR,
),
)
-
- return () => {
- clearHideMenuTimeout()
- unregister()
- }
}, [editor, onBlur, onFocus])
return null
diff --git a/web/app/components/base/prompt-editor/plugins/update-block.tsx b/web/app/components/base/prompt-editor/plugins/update-block.tsx
index bf89a259af..2d83573b1f 100644
--- a/web/app/components/base/prompt-editor/plugins/update-block.tsx
+++ b/web/app/components/base/prompt-editor/plugins/update-block.tsx
@@ -3,7 +3,6 @@ import { $insertNodes } from 'lexical'
import { useEventEmitterContextContext } from '@/context/event-emitter'
import { textToEditorState } from '../utils'
import { CustomTextNode } from './custom-text/node'
-import { CLEAR_HIDE_MENU_TIMEOUT } from './workflow-variable-block'
export const PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER = 'PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER'
export const PROMPT_EDITOR_INSERT_QUICKLY = 'PROMPT_EDITOR_INSERT_QUICKLY'
@@ -30,8 +29,6 @@ const UpdateBlock = ({
editor.update(() => {
const textNode = new CustomTextNode('/')
$insertNodes([textNode])
-
- editor.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined)
})
}
})
diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/index.spec.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/index.spec.tsx
index ca4973b830..1591dc44f9 100644
--- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/index.spec.tsx
+++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/index.spec.tsx
@@ -9,7 +9,6 @@ import { $insertNodes, COMMAND_PRIORITY_EDITOR } from 'lexical'
import { Type } from '@/app/components/workflow/nodes/llm/types'
import { BlockEnum } from '@/app/components/workflow/types'
import {
- CLEAR_HIDE_MENU_TIMEOUT,
DELETE_WORKFLOW_VARIABLE_BLOCK_COMMAND,
INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND,
UPDATE_WORKFLOW_NODES_MAP,
@@ -134,7 +133,6 @@ describe('WorkflowVariableBlock', () => {
const insertHandler = mockRegisterCommand.mock.calls[0][1] as (variables: string[]) => boolean
const result = insertHandler(['node-1', 'answer'])
- expect(mockDispatchCommand).toHaveBeenCalledWith(CLEAR_HIDE_MENU_TIMEOUT, undefined)
expect($createWorkflowVariableBlockNode).toHaveBeenCalledWith(
['node-1', 'answer'],
workflowNodesMap,
diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx
index 76b2795803..c8cac64d19 100644
--- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx
+++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx
@@ -18,7 +18,6 @@ import {
export const INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND = createCommand('INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND')
export const DELETE_WORKFLOW_VARIABLE_BLOCK_COMMAND = createCommand('DELETE_WORKFLOW_VARIABLE_BLOCK_COMMAND')
-export const CLEAR_HIDE_MENU_TIMEOUT = createCommand('CLEAR_HIDE_MENU_TIMEOUT')
export const UPDATE_WORKFLOW_NODES_MAP = createCommand('UPDATE_WORKFLOW_NODES_MAP')
export type WorkflowVariableBlockProps = {
@@ -49,7 +48,6 @@ const WorkflowVariableBlock = memo(({
editor.registerCommand(
INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND,
(variables: string[]) => {
- editor.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined)
const workflowVariableBlockNode = $createWorkflowVariableBlockNode(variables, workflowNodesMap, getVarType)
$insertNodes([workflowVariableBlockNode])
diff --git a/web/app/components/plugins/plugin-auth/authorize/api-key-modal.tsx b/web/app/components/plugins/plugin-auth/authorize/api-key-modal.tsx
index d8ab1cafdd..38ce5cca4e 100644
--- a/web/app/components/plugins/plugin-auth/authorize/api-key-modal.tsx
+++ b/web/app/components/plugins/plugin-auth/authorize/api-key-modal.tsx
@@ -142,7 +142,7 @@ const ApiKeyModal = ({
onExtraButtonClick={onRemove}
disabled={disabled || isLoading || doingAction}
clickOutsideNotClose={true}
- wrapperClassName="!z-[101]"
+ wrapperClassName="!z-[1002]"
>
{pluginPayload.detail && (
diff --git a/web/app/components/plugins/plugin-auth/authorize/oauth-client-settings.tsx b/web/app/components/plugins/plugin-auth/authorize/oauth-client-settings.tsx
index 28989da77c..a69cc1538b 100644
--- a/web/app/components/plugins/plugin-auth/authorize/oauth-client-settings.tsx
+++ b/web/app/components/plugins/plugin-auth/authorize/oauth-client-settings.tsx
@@ -157,7 +157,7 @@ const OAuthClientSettings = ({
)
}
containerClassName="pt-0"
- wrapperClassName="!z-[101]"
+ wrapperClassName="!z-[1002]"
clickOutsideNotClose={true}
>
{pluginPayload.detail && (
diff --git a/web/eslint-suppressions.json b/web/eslint-suppressions.json
index e28d915e66..3d311dbf54 100644
--- a/web/eslint-suppressions.json
+++ b/web/eslint-suppressions.json
@@ -2768,7 +2768,7 @@
},
"app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx": {
"react-refresh/only-export-components": {
- "count": 4
+ "count": 3
}
},
"app/components/base/prompt-editor/plugins/workflow-variable-block/node.tsx": {