mirror of
https://github.com/langgenius/dify.git
synced 2026-05-01 01:00:51 -04:00
refactor(api): tighten core rag typing batch 1 (#35210)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -139,8 +139,10 @@ class Jieba(BaseKeyword):
|
||||
"__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table},
|
||||
}
|
||||
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||
keyword_data_source_type = dataset_keyword_table.data_source_type
|
||||
keyword_data_source_type = dataset_keyword_table.data_source_type if dataset_keyword_table else "file"
|
||||
if keyword_data_source_type == "database":
|
||||
if dataset_keyword_table is None:
|
||||
return
|
||||
dataset_keyword_table.keyword_table = dumps_with_sets(keyword_table_dict)
|
||||
db.session.commit()
|
||||
else:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from operator import itemgetter
|
||||
from typing import cast
|
||||
|
||||
@@ -80,12 +81,14 @@ class JiebaKeywordTableHandler:
|
||||
|
||||
def extract_tags(self, sentence: str, top_k: int | None = 20, **kwargs):
|
||||
# Basic frequency-based keyword extraction as a fallback when TF-IDF is unavailable.
|
||||
top_k = kwargs.pop("topK", top_k)
|
||||
top_k = cast(int | None, kwargs.pop("topK", top_k))
|
||||
if top_k is None:
|
||||
top_k = 20
|
||||
cut = getattr(jieba, "cut", None)
|
||||
if self._lcut:
|
||||
tokens = self._lcut(sentence)
|
||||
elif callable(cut):
|
||||
tokens = list(cut(sentence))
|
||||
tokens = list(cast(Callable[[str], list[str]], cut)(sentence))
|
||||
else:
|
||||
tokens = re.findall(r"\w+", sentence)
|
||||
|
||||
@@ -108,7 +111,7 @@ class JiebaKeywordTableHandler:
|
||||
sentence=text,
|
||||
topK=max_keywords_per_chunk,
|
||||
)
|
||||
# jieba.analyse.extract_tags returns list[Any] when withFlag is False by default.
|
||||
# jieba.analyse.extract_tags returns an untyped list when withFlag is False by default.
|
||||
keywords = cast(list[str], keywords)
|
||||
|
||||
return set(self._expand_tokens_with_subtokens(set(keywords)))
|
||||
|
||||
@@ -158,7 +158,7 @@ class RetrievalService:
|
||||
)
|
||||
|
||||
if futures:
|
||||
for future in concurrent.futures.as_completed(futures, timeout=3600):
|
||||
for _ in concurrent.futures.as_completed(futures, timeout=3600):
|
||||
if exceptions:
|
||||
for f in futures:
|
||||
f.cancel()
|
||||
|
||||
@@ -94,6 +94,7 @@ class ExtractProcessor:
|
||||
cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str | None = None
|
||||
) -> list[Document]:
|
||||
if extract_setting.datasource_type == DatasourceType.FILE:
|
||||
upload_file = extract_setting.upload_file
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
upload_file = extract_setting.upload_file
|
||||
if not file_path:
|
||||
@@ -104,6 +105,7 @@ class ExtractProcessor:
|
||||
storage.download(upload_file.key, file_path)
|
||||
input_file = Path(file_path)
|
||||
file_extension = input_file.suffix.lower()
|
||||
assert upload_file is not None, "upload_file is required"
|
||||
etl_type = dify_config.ETL_TYPE
|
||||
extractor: BaseExtractor | None = None
|
||||
if etl_type == "Unstructured":
|
||||
|
||||
@@ -28,7 +28,7 @@ class FunctionCallMultiDatasetRouter:
|
||||
SystemPromptMessage(content="You are a helpful AI assistant."),
|
||||
UserPromptMessage(content=query),
|
||||
]
|
||||
result: LLMResult = model_instance.invoke_llm(
|
||||
result: LLMResult = model_instance.invoke_llm( # pyright: ignore[reportCallIssue, reportArgumentType]
|
||||
prompt_messages=prompt_messages,
|
||||
tools=dataset_tools,
|
||||
stream=False,
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import codecs
|
||||
import re
|
||||
from collections.abc import Collection
|
||||
from collections.abc import Set as AbstractSet
|
||||
from typing import Any, Literal
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
@@ -21,8 +21,8 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||
def from_encoder[T: EnhanceRecursiveCharacterTextSplitter](
|
||||
cls: type[T],
|
||||
embedding_model_instance: ModelInstance | None,
|
||||
allowed_special: Literal["all"] | set[str] = set(),
|
||||
disallowed_special: Literal["all"] | Collection[str] = "all",
|
||||
allowed_special: Literal["all"] | AbstractSet[str] = frozenset(),
|
||||
disallowed_special: Literal["all"] | AbstractSet[str] = "all",
|
||||
**kwargs: Any,
|
||||
) -> T:
|
||||
def _token_encoder(texts: list[str]) -> list[int]:
|
||||
@@ -40,6 +40,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||
|
||||
return [len(text) for text in texts]
|
||||
|
||||
_ = _token_encoder # kept for future token-length wiring
|
||||
return cls(length_function=_character_encoder, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,8 @@ import copy
|
||||
import logging
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Collection, Iterable, Sequence, Set
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from collections.abc import Set as AbstractSet
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
@@ -187,8 +188,8 @@ class TokenTextSplitter(TextSplitter):
|
||||
self,
|
||||
encoding_name: str = "gpt2",
|
||||
model_name: str | None = None,
|
||||
allowed_special: Literal["all"] | Set[str] = set(),
|
||||
disallowed_special: Literal["all"] | Collection[str] = "all",
|
||||
allowed_special: Literal["all"] | AbstractSet[str] = frozenset(),
|
||||
disallowed_special: Literal["all"] | AbstractSet[str] = "all",
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Create a new TextSplitter."""
|
||||
@@ -207,8 +208,8 @@ class TokenTextSplitter(TextSplitter):
|
||||
else:
|
||||
enc = tiktoken.get_encoding(encoding_name)
|
||||
self._tokenizer = enc
|
||||
self._allowed_special = allowed_special
|
||||
self._disallowed_special = disallowed_special
|
||||
self._allowed_special: Literal["all"] | AbstractSet[str] = allowed_special
|
||||
self._disallowed_special: Literal["all"] | AbstractSet[str] = disallowed_special
|
||||
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
def _encode(_text: str) -> list[int]:
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, TypedDict, cast
|
||||
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
|
||||
from core.rag.index_processor.constant.query_type import QueryType
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
@@ -36,6 +36,10 @@ default_retrieval_model = {
|
||||
}
|
||||
|
||||
|
||||
class HitTestingRetrievalModelDict(DefaultRetrievalModelDict, total=False):
|
||||
metadata_filtering_conditions: dict[str, Any]
|
||||
|
||||
|
||||
class HitTestingService:
|
||||
@classmethod
|
||||
def retrieve(
|
||||
@@ -51,17 +55,18 @@ class HitTestingService:
|
||||
start = time.perf_counter()
|
||||
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
if not retrieval_model:
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
assert isinstance(retrieval_model, dict)
|
||||
resolved_retrieval_model = cast(
|
||||
HitTestingRetrievalModelDict,
|
||||
retrieval_model or dataset.retrieval_model or default_retrieval_model,
|
||||
)
|
||||
document_ids_filter = None
|
||||
metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {})
|
||||
if metadata_filtering_conditions and query:
|
||||
metadata_filtering_conditions_raw = resolved_retrieval_model.get("metadata_filtering_conditions", {})
|
||||
if metadata_filtering_conditions_raw and query:
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
from core.rag.entities import MetadataFilteringCondition
|
||||
|
||||
metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions)
|
||||
metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions_raw)
|
||||
|
||||
metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition(
|
||||
dataset_ids=[dataset.id],
|
||||
@@ -78,19 +83,21 @@ class HitTestingService:
|
||||
if metadata_condition and not document_ids_filter:
|
||||
return cls.compact_retrieve_response(query, [])
|
||||
all_documents = RetrievalService.retrieve(
|
||||
retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)),
|
||||
retrieval_method=RetrievalMethod(
|
||||
resolved_retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)
|
||||
),
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
attachment_ids=attachment_ids,
|
||||
top_k=retrieval_model.get("top_k", 4),
|
||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||
if retrieval_model["score_threshold_enabled"]
|
||||
top_k=resolved_retrieval_model.get("top_k", 4),
|
||||
score_threshold=resolved_retrieval_model.get("score_threshold", 0.0)
|
||||
if resolved_retrieval_model["score_threshold_enabled"]
|
||||
else 0.0,
|
||||
reranking_model=retrieval_model.get("reranking_model", None)
|
||||
if retrieval_model["reranking_enable"]
|
||||
reranking_model=resolved_retrieval_model.get("reranking_model", None)
|
||||
if resolved_retrieval_model["reranking_enable"]
|
||||
else None,
|
||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||
weights=retrieval_model.get("weights", None),
|
||||
reranking_mode=resolved_retrieval_model.get("reranking_mode") or "reranking_model",
|
||||
weights=resolved_retrieval_model.get("weights", None),
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user