From 1a011dc14af6fcfecfbc2c9524cce48e24309902 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Mon, 11 May 2026 12:37:26 +0900 Subject: [PATCH] refactor: port DatasetProcessRule (#31004) Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/core/indexing_runner.py | 9 +++-- .../processor/parent_child_index_processor.py | 3 +- api/models/dataset.py | 22 ++++++----- api/services/dataset_service.py | 12 +++--- .../test_dataset_service_retrieval.py | 15 ++++---- .../test_disable_segments_from_index_task.py | 37 ++++++++++--------- .../unit_tests/models/test_dataset_models.py | 4 +- 7 files changed, 54 insertions(+), 48 deletions(-) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index b6e33396d1..537b14388e 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -324,9 +324,10 @@ class IndexingRunner: # one extract_setting is one source document for extract_setting in extract_settings: # extract - processing_rule = DatasetProcessRule( - mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"]) - ) + processing_rule = { + "mode": tmp_processing_rule["mode"], + "rules": tmp_processing_rule.get("rules"), + } # Extract document content text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) # Cleaning and segmentation @@ -334,7 +335,7 @@ class IndexingRunner: text_docs, current_user=None, embedding_model_instance=embedding_model_instance, - process_rule=processing_rule.to_dict(), + process_rule=processing_rule, tenant_id=tenant_id, doc_language=doc_language, preview=True, diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index ba277d5018..a26a900512 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -29,6 +29,7 @@ from libs import helper from models import Account from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment from models.dataset import Document as DatasetDocument +from models.enums import ProcessRuleMode from services.account_service import AccountService from services.summary_index_service import SummaryIndexService @@ -325,7 +326,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): # update document parent mode dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, - mode="hierarchical", + mode=ProcessRuleMode.HIERARCHICAL, rules=json.dumps( { "parent_mode": parent_childs.parent_mode, diff --git a/api/models/dataset.py b/api/models/dataset.py index ed7727e0f1..f823e0aa10 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -11,7 +11,7 @@ import time from collections.abc import Sequence from datetime import datetime from json import JSONDecodeError -from typing import Any, TypedDict, cast +from typing import Any, ClassVar, TypedDict, cast from uuid import uuid4 import sqlalchemy as sa @@ -441,23 +441,27 @@ class Dataset(Base): return f"{dify_config.VECTOR_INDEX_NAME_PREFIX}_{normalized_dataset_id}_Node" -class DatasetProcessRule(Base): # bug +class DatasetProcessRule(TypeBase): __tablename__ = "dataset_process_rules" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"), sa.Index("dataset_process_rule_dataset_id_idx", "dataset_id"), ) - id = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4())) - dataset_id = mapped_column(StringUUID, nullable=False) - mode = mapped_column(EnumText(ProcessRuleMode, length=255), nullable=False, server_default=sa.text("'automatic'")) - rules = mapped_column(LongText, nullable=True) - created_by = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, nullable=False, default_factory=lambda: str(uuid4()), init=False) + dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + mode: Mapped[ProcessRuleMode] = mapped_column( + EnumText(ProcessRuleMode, length=255), nullable=False, server_default=sa.text("'automatic'") + ) + rules: Mapped[str | None] = mapped_column(LongText, nullable=True) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) MODES = ["automatic", "custom", "hierarchical"] PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] - AUTOMATIC_RULES: AutomaticRulesConfig = { + AUTOMATIC_RULES: ClassVar[AutomaticRulesConfig] = { "pre_processing_rules": [ {"id": "remove_extra_spaces", "enabled": True}, {"id": "remove_urls_emails", "enabled": False}, diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index eef38f1ce2..383474f4f6 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -108,7 +108,7 @@ logger = logging.getLogger(__name__) class ProcessRulesDict(TypedDict): - mode: str + mode: ProcessRuleMode rules: dict[str, Any] @@ -204,7 +204,7 @@ class DatasetService: mode = dataset_process_rule.mode rules = dataset_process_rule.rules_dict or {} else: - mode = str(DocumentService.DEFAULT_RULES["mode"]) + mode = ProcessRuleMode(DocumentService.DEFAULT_RULES["mode"]) rules = dict(DocumentService.DEFAULT_RULES.get("rules") or {}) return {"mode": mode, "rules": rules} @@ -1984,7 +1984,7 @@ class DocumentService: if process_rule.rules: dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, - mode=process_rule.mode, + mode=ProcessRuleMode(process_rule.mode), rules=process_rule.rules.model_dump_json() if process_rule.rules else None, created_by=account.id, ) @@ -1995,7 +1995,7 @@ class DocumentService: elif process_rule.mode == ProcessRuleMode.AUTOMATIC: dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, - mode=process_rule.mode, + mode=ProcessRuleMode.AUTOMATIC, rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), created_by=account.id, ) @@ -2572,14 +2572,14 @@ class DocumentService: if process_rule.mode in {ProcessRuleMode.CUSTOM, ProcessRuleMode.HIERARCHICAL}: dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, - mode=process_rule.mode, + mode=ProcessRuleMode(process_rule.mode), rules=process_rule.rules.model_dump_json() if process_rule.rules else None, created_by=account.id, ) elif process_rule.mode == ProcessRuleMode.AUTOMATIC: dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, - mode=process_rule.mode, + mode=ProcessRuleMode.AUTOMATIC, rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), created_by=account.id, ) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py index 2f90d16176..0c610311bb 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py @@ -16,6 +16,7 @@ from uuid import uuid4 from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexTechniqueType +from models import AccountStatus, CreatorUserRole, TenantStatus from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( AppDatasetJoin, @@ -25,7 +26,7 @@ from models.dataset import ( DatasetProcessRule, DatasetQuery, ) -from models.enums import DatasetQuerySource, DataSourceType, ProcessRuleMode +from models.enums import DatasetQuerySource, DataSourceType, ProcessRuleMode, TagType from models.model import Tag, TagBinding from services.dataset_service import DatasetService, DocumentService @@ -42,11 +43,11 @@ class DatasetRetrievalTestDataFactory: email=f"{uuid4()}@example.com", name=f"user-{uuid4()}", interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) tenant = Tenant( name=f"tenant-{uuid4()}", - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add_all([account, tenant]) db_session_with_containers.flush() @@ -72,7 +73,7 @@ class DatasetRetrievalTestDataFactory: email=f"{uuid4()}@example.com", name=f"user-{uuid4()}", interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.flush() @@ -130,7 +131,7 @@ class DatasetRetrievalTestDataFactory: @staticmethod def create_process_rule( - db_session_with_containers: Session, dataset_id: str, created_by: str, mode: str, rules: dict + db_session_with_containers: Session, dataset_id: str, created_by: str, mode: ProcessRuleMode, rules: dict ) -> DatasetProcessRule: """Create a dataset process rule.""" process_rule = DatasetProcessRule( @@ -153,7 +154,7 @@ class DatasetRetrievalTestDataFactory: content=content, source=DatasetQuerySource.APP, source_app_id=None, - created_by_role="account", + created_by_role=CreatorUserRole.ACCOUNT, created_by=created_by, ) db_session_with_containers.add(dataset_query) @@ -176,7 +177,7 @@ class DatasetRetrievalTestDataFactory: """Create a knowledge tag and bind it to the target dataset.""" tag = Tag( tenant_id=tenant_id, - type="knowledge", + type=TagType.KNOWLEDGE, name=f"tag-{uuid4()}", created_by=created_by, ) diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index 6bfb1e1f1e..6a95bfc425 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -7,6 +7,7 @@ The task is responsible for removing document segments from the search index whe """ from unittest.mock import MagicMock, patch +from uuid import uuid4 from faker import Faker from sqlalchemy import select @@ -82,7 +83,7 @@ class TestDisableSegmentsFromIndexTask: return account - def _create_test_dataset(self, db_session_with_containers: Session, account, fake: Faker | None = None): + def _create_test_dataset(self, db_session_with_containers: Session, account: Account, fake: Faker | None = None): """ Helper method to create a test dataset with realistic data. @@ -117,7 +118,7 @@ class TestDisableSegmentsFromIndexTask: return dataset def _create_test_document( - self, db_session_with_containers: Session, dataset, account: Account, fake: Faker | None = None + self, db_session_with_containers: Session, dataset: Dataset, account: Account, fake: Faker | None = None ): """ Helper method to create a test document with realistic data. @@ -164,7 +165,7 @@ class TestDisableSegmentsFromIndexTask: return document def _create_test_segments( - self, db_session_with_containers: Session, document, dataset, account, count=3, fake=None + self, db_session_with_containers: Session, document, dataset: Dataset, account: Account, count=3, fake=None ): """ Helper method to create test document segments with realistic data. @@ -217,7 +218,9 @@ class TestDisableSegmentsFromIndexTask: return segments - def _create_dataset_process_rule(self, db_session_with_containers: Session, dataset, fake: Faker | None = None): + def _create_dataset_process_rule( + self, db_session_with_containers: Session, dataset: Dataset, fake: Faker | None = None + ): """ Helper method to create a dataset process rule. @@ -230,21 +233,19 @@ class TestDisableSegmentsFromIndexTask: DatasetProcessRule: Created process rule instance """ fake = fake or Faker() - process_rule = DatasetProcessRule() - process_rule.id = fake.uuid4() - process_rule.tenant_id = dataset.tenant_id - process_rule.dataset_id = dataset.id - process_rule.mode = ProcessRuleMode.AUTOMATIC - process_rule.rules = ( - "{" - '"mode": "automatic", ' - '"rules": {' - '"pre_processing_rules": [], "segmentation": ' - '{"separator": "\\n\\n", "max_tokens": 1000, "chunk_overlap": 50}}' - "}" + process_rule = DatasetProcessRule( + dataset_id=dataset.id, + mode=ProcessRuleMode.AUTOMATIC, + rules=( + "{" + '"mode": "automatic", ' + '"rules": {' + '"pre_processing_rules": [], "segmentation": ' + '{"separator": "\\n\\n", "max_tokens": 1000, "chunk_overlap": 50}}' + "}" + ), + created_by=str(uuid4()), ) - process_rule.created_by = dataset.created_by - process_rule.updated_by = dataset.updated_by db_session_with_containers.add(process_rule) db_session_with_containers.commit() diff --git a/api/tests/unit_tests/models/test_dataset_models.py b/api/tests/unit_tests/models/test_dataset_models.py index 3f14ebe8bf..f4ccfb4191 100644 --- a/api/tests/unit_tests/models/test_dataset_models.py +++ b/api/tests/unit_tests/models/test_dataset_models.py @@ -847,9 +847,7 @@ class TestDatasetProcessRule: # Act process_rule = DatasetProcessRule( - dataset_id=dataset_id, - mode=ProcessRuleMode.AUTOMATIC, - created_by=created_by, + dataset_id=dataset_id, mode=ProcessRuleMode.AUTOMATIC, created_by=created_by, rules=None ) # Assert