From ed8d3f3e8d48e487bc17733d3e0a07bbb53c19e3 Mon Sep 17 00:00:00 2001 From: tmimmanuel <14046872+tmimmanuel@users.noreply.github.com> Date: Fri, 24 Apr 2026 00:30:28 +0200 Subject: [PATCH] refactor(api): fix pyright errors in jieba, milvus, couchbase, oracle, and router (#34938) Co-authored-by: Asuka Minato --- .../rag/datasource/keyword/jieba/jieba.py | 3 +- .../jieba/jieba_keyword_table_handler.py | 2 +- .../multi_dataset_function_call_router.py | 2 +- .../dify_vdb_couchbase/couchbase_vector.py | 4 +- .../src/dify_vdb_milvus/milvus_vector.py | 7 ++-- .../src/dify_vdb_oracle/oraclevector.py | 40 +++++++++++-------- 6 files changed, 34 insertions(+), 24 deletions(-) diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index 242da520c1..392af351b6 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -156,7 +156,8 @@ class Jieba(BaseKeyword): if dataset_keyword_table: keyword_table_dict = dataset_keyword_table.keyword_table_dict if keyword_table_dict: - return dict(keyword_table_dict["__data__"]["table"]) + data: Any = keyword_table_dict["__data__"] + return dict(data["table"]) else: keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE dataset_keyword_table = DatasetKeywordTable( diff --git a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py index 1ca6303af6..2af8238cc4 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py +++ b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py @@ -109,7 +109,7 @@ class JiebaKeywordTableHandler: """Extract keywords with JIEBA tfidf.""" keywords = self._tfidf.extract_tags( sentence=text, - topK=max_keywords_per_chunk, + topK=max_keywords_per_chunk or 10, ) # jieba.analyse.extract_tags returns an untyped list when withFlag is False by default. keywords = cast(list[str], keywords) diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py index 426d1b67dc..dd17545c86 100644 --- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py +++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py @@ -31,7 +31,7 @@ class FunctionCallMultiDatasetRouter: result: LLMResult = model_instance.invoke_llm( # pyright: ignore[reportCallIssue, reportArgumentType] prompt_messages=prompt_messages, tools=dataset_tools, - stream=False, + stream=False, # pyright: ignore[reportArgumentType] model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500}, ) usage = result.usage or LLMUsage.empty_usage() diff --git a/api/providers/vdb/vdb-couchbase/src/dify_vdb_couchbase/couchbase_vector.py b/api/providers/vdb/vdb-couchbase/src/dify_vdb_couchbase/couchbase_vector.py index 815ac30c0b..bab176e285 100644 --- a/api/providers/vdb/vdb-couchbase/src/dify_vdb_couchbase/couchbase_vector.py +++ b/api/providers/vdb/vdb-couchbase/src/dify_vdb_couchbase/couchbase_vector.py @@ -59,7 +59,7 @@ class CouchbaseVector(BaseVector): auth = PasswordAuthenticator(config.user, config.password) options = ClusterOptions(auth) - self._cluster = Cluster(config.connection_string, options) + self._cluster = Cluster(config.connection_string, options) # pyright: ignore[reportArgumentType] self._bucket = self._cluster.bucket(config.bucket_name) self._scope = self._bucket.scope(config.scope_name) self._bucket_name = config.bucket_name @@ -306,7 +306,7 @@ class CouchbaseVector(BaseVector): def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 4) try: - CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) + CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) # pyright: ignore[reportCallIssue] search_iter = self._scope.search( self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"]) ) diff --git a/api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/milvus_vector.py b/api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/milvus_vector.py index 46f3224a95..823b877707 100644 --- a/api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/milvus_vector.py +++ b/api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/milvus_vector.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, TypedDict +from typing import Any, TypedDict, cast from packaging import version from pydantic import BaseModel, model_validator @@ -92,7 +92,7 @@ class MilvusVector(BaseVector): def _load_collection_fields(self, fields: list[str] | None = None): if fields is None: # Load collection fields from remote server - collection_info = self._client.describe_collection(self._collection_name) + collection_info = cast(dict[str, Any], self._client.describe_collection(self._collection_name)) fields = [field["name"] for field in collection_info["fields"]] # Since primary field is auto-id, no need to track it self._fields = [f for f in fields if f != Field.PRIMARY_KEY] @@ -106,7 +106,8 @@ class MilvusVector(BaseVector): return False try: - milvus_version = self._client.get_server_version() + milvus_version_raw = self._client.get_server_version() + milvus_version = milvus_version_raw if isinstance(milvus_version_raw, str) else str(milvus_version_raw) # Check if it's Zilliz Cloud - it supports full-text search with Milvus 2.5 compatibility if "Zilliz Cloud" in milvus_version: return True diff --git a/api/providers/vdb/vdb-oracle/src/dify_vdb_oracle/oraclevector.py b/api/providers/vdb/vdb-oracle/src/dify_vdb_oracle/oraclevector.py index 70377c82c8..5d9ab38529 100644 --- a/api/providers/vdb/vdb-oracle/src/dify_vdb_oracle/oraclevector.py +++ b/api/providers/vdb/vdb-oracle/src/dify_vdb_oracle/oraclevector.py @@ -3,7 +3,7 @@ import json import logging import re import uuid -from typing import Any +from typing import Any, TypedDict import jieba.posseg as pseg # type: ignore import numpy @@ -25,6 +25,18 @@ logger = logging.getLogger(__name__) oracledb.defaults.fetch_lobs = False +class _OraclePoolParams(TypedDict, total=False): + user: str + password: str + dsn: str + min: int + max: int + increment: int + config_dir: str | None + wallet_location: str | None + wallet_password: str | None + + class OracleVectorConfig(BaseModel): user: str password: str @@ -127,22 +139,18 @@ class OracleVector(BaseVector): return connection def _create_connection_pool(self, config: OracleVectorConfig): - pool_params = { - "user": config.user, - "password": config.password, - "dsn": config.dsn, - "min": 1, - "max": 5, - "increment": 1, - } + pool_params = _OraclePoolParams( + user=config.user, + password=config.password, + dsn=config.dsn, + min=1, + max=5, + increment=1, + ) if config.is_autonomous: - pool_params.update( - { - "config_dir": config.config_dir, - "wallet_location": config.wallet_location, - "wallet_password": config.wallet_password, - } - ) + pool_params["config_dir"] = config.config_dir + pool_params["wallet_location"] = config.wallet_location + pool_params["wallet_password"] = config.wallet_password return oracledb.create_pool(**pool_params) def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):