diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index ed264878d3..242da520c1 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -139,8 +139,10 @@ class Jieba(BaseKeyword): "__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table}, } dataset_keyword_table = self.dataset.dataset_keyword_table - keyword_data_source_type = dataset_keyword_table.data_source_type + keyword_data_source_type = dataset_keyword_table.data_source_type if dataset_keyword_table else "file" if keyword_data_source_type == "database": + if dataset_keyword_table is None: + return dataset_keyword_table.keyword_table = dumps_with_sets(keyword_table_dict) db.session.commit() else: diff --git a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py index 84f35c25f8..1ca6303af6 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py +++ b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py @@ -1,4 +1,5 @@ import re +from collections.abc import Callable from operator import itemgetter from typing import cast @@ -80,12 +81,14 @@ class JiebaKeywordTableHandler: def extract_tags(self, sentence: str, top_k: int | None = 20, **kwargs): # Basic frequency-based keyword extraction as a fallback when TF-IDF is unavailable. - top_k = kwargs.pop("topK", top_k) + top_k = cast(int | None, kwargs.pop("topK", top_k)) + if top_k is None: + top_k = 20 cut = getattr(jieba, "cut", None) if self._lcut: tokens = self._lcut(sentence) elif callable(cut): - tokens = list(cut(sentence)) + tokens = list(cast(Callable[[str], list[str]], cut)(sentence)) else: tokens = re.findall(r"\w+", sentence) @@ -108,7 +111,7 @@ class JiebaKeywordTableHandler: sentence=text, topK=max_keywords_per_chunk, ) - # jieba.analyse.extract_tags returns list[Any] when withFlag is False by default. + # jieba.analyse.extract_tags returns an untyped list when withFlag is False by default. keywords = cast(list[str], keywords) return set(self._expand_tokens_with_subtokens(set(keywords))) diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 7e71d67ec0..2997710daf 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -158,7 +158,7 @@ class RetrievalService: ) if futures: - for future in concurrent.futures.as_completed(futures, timeout=3600): + for _ in concurrent.futures.as_completed(futures, timeout=3600): if exceptions: for f in futures: f.cancel() diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index fbd2a6db93..b679edab36 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -94,6 +94,7 @@ class ExtractProcessor: cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str | None = None ) -> list[Document]: if extract_setting.datasource_type == DatasourceType.FILE: + upload_file = extract_setting.upload_file with tempfile.TemporaryDirectory() as temp_dir: upload_file = extract_setting.upload_file if not file_path: @@ -104,6 +105,7 @@ class ExtractProcessor: storage.download(upload_file.key, file_path) input_file = Path(file_path) file_extension = input_file.suffix.lower() + assert upload_file is not None, "upload_file is required" etl_type = dify_config.ETL_TYPE extractor: BaseExtractor | None = None if etl_type == "Unstructured": diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py index e617a9660e..426d1b67dc 100644 --- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py +++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py @@ -28,7 +28,7 @@ class FunctionCallMultiDatasetRouter: SystemPromptMessage(content="You are a helpful AI assistant."), UserPromptMessage(content=query), ] - result: LLMResult = model_instance.invoke_llm( + result: LLMResult = model_instance.invoke_llm( # pyright: ignore[reportCallIssue, reportArgumentType] prompt_messages=prompt_messages, tools=dataset_tools, stream=False, diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 66b375dad1..52c9a02f97 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -4,7 +4,7 @@ from __future__ import annotations import codecs import re -from collections.abc import Collection +from collections.abc import Set as AbstractSet from typing import Any, Literal from core.model_manager import ModelInstance @@ -21,8 +21,8 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): def from_encoder[T: EnhanceRecursiveCharacterTextSplitter]( cls: type[T], embedding_model_instance: ModelInstance | None, - allowed_special: Literal["all"] | set[str] = set(), - disallowed_special: Literal["all"] | Collection[str] = "all", + allowed_special: Literal["all"] | AbstractSet[str] = frozenset(), + disallowed_special: Literal["all"] | AbstractSet[str] = "all", **kwargs: Any, ) -> T: def _token_encoder(texts: list[str]) -> list[int]: @@ -40,6 +40,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): return [len(text) for text in texts] + _ = _token_encoder # kept for future token-length wiring return cls(length_function=_character_encoder, **kwargs) diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index 7f2117e2dd..a8d9013fbc 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -4,7 +4,8 @@ import copy import logging import re from abc import ABC, abstractmethod -from collections.abc import Callable, Collection, Iterable, Sequence, Set +from collections.abc import Callable, Iterable, Sequence +from collections.abc import Set as AbstractSet from dataclasses import dataclass from typing import Any, Literal @@ -187,8 +188,8 @@ class TokenTextSplitter(TextSplitter): self, encoding_name: str = "gpt2", model_name: str | None = None, - allowed_special: Literal["all"] | Set[str] = set(), - disallowed_special: Literal["all"] | Collection[str] = "all", + allowed_special: Literal["all"] | AbstractSet[str] = frozenset(), + disallowed_special: Literal["all"] | AbstractSet[str] = "all", **kwargs: Any, ): """Create a new TextSplitter.""" @@ -207,8 +208,8 @@ class TokenTextSplitter(TextSplitter): else: enc = tiktoken.get_encoding(encoding_name) self._tokenizer = enc - self._allowed_special = allowed_special - self._disallowed_special = disallowed_special + self._allowed_special: Literal["all"] | AbstractSet[str] = allowed_special + self._disallowed_special: Literal["all"] | AbstractSet[str] = disallowed_special def split_text(self, text: str) -> list[str]: def _encode(_text: str) -> list[int]: diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index ca84b2a3d8..2e5987dd28 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -1,10 +1,10 @@ import json import logging import time -from typing import Any, TypedDict +from typing import Any, TypedDict, cast from core.app.app_config.entities import ModelConfig -from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.retrieval.dataset_retrieval import DatasetRetrieval @@ -36,6 +36,10 @@ default_retrieval_model = { } +class HitTestingRetrievalModelDict(DefaultRetrievalModelDict, total=False): + metadata_filtering_conditions: dict[str, Any] + + class HitTestingService: @classmethod def retrieve( @@ -51,17 +55,18 @@ class HitTestingService: start = time.perf_counter() # get retrieval model , if the model is not setting , using default - if not retrieval_model: - retrieval_model = dataset.retrieval_model or default_retrieval_model - assert isinstance(retrieval_model, dict) + resolved_retrieval_model = cast( + HitTestingRetrievalModelDict, + retrieval_model or dataset.retrieval_model or default_retrieval_model, + ) document_ids_filter = None - metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {}) - if metadata_filtering_conditions and query: + metadata_filtering_conditions_raw = resolved_retrieval_model.get("metadata_filtering_conditions", {}) + if metadata_filtering_conditions_raw and query: dataset_retrieval = DatasetRetrieval() from core.rag.entities import MetadataFilteringCondition - metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions) + metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions_raw) metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition( dataset_ids=[dataset.id], @@ -78,19 +83,21 @@ class HitTestingService: if metadata_condition and not document_ids_filter: return cls.compact_retrieve_response(query, []) all_documents = RetrievalService.retrieve( - retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)), + retrieval_method=RetrievalMethod( + resolved_retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH) + ), dataset_id=dataset.id, query=query, attachment_ids=attachment_ids, - top_k=retrieval_model.get("top_k", 4), - score_threshold=retrieval_model.get("score_threshold", 0.0) - if retrieval_model["score_threshold_enabled"] + top_k=resolved_retrieval_model.get("top_k", 4), + score_threshold=resolved_retrieval_model.get("score_threshold", 0.0) + if resolved_retrieval_model["score_threshold_enabled"] else 0.0, - reranking_model=retrieval_model.get("reranking_model", None) - if retrieval_model["reranking_enable"] + reranking_model=resolved_retrieval_model.get("reranking_model", None) + if resolved_retrieval_model["reranking_enable"] else None, - reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", - weights=retrieval_model.get("weights", None), + reranking_mode=resolved_retrieval_model.get("reranking_mode") or "reranking_model", + weights=resolved_retrieval_model.get("weights", None), document_ids_filter=document_ids_filter, )