diff --git a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py index df02c584ed..90d6d98c63 100644 --- a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py +++ b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py @@ -5,6 +5,7 @@ from typing import Any from elasticsearch import Elasticsearch from pydantic import BaseModel, model_validator +from typing_extensions import TypedDict from configs import dify_config from core.rag.datasource.vdb.field import Field @@ -19,6 +20,16 @@ from models.dataset import Dataset logger = logging.getLogger(__name__) +class HuaweiElasticsearchParamsDict(TypedDict, total=False): + hosts: list[str] + verify_certs: bool + ssl_show_warn: bool + request_timeout: int + retry_on_timeout: bool + max_retries: int + basic_auth: tuple[str, str] + + def create_ssl_context() -> ssl.SSLContext: ssl_context = ssl.create_default_context() ssl_context.check_hostname = False @@ -38,15 +49,15 @@ class HuaweiCloudVectorConfig(BaseModel): raise ValueError("config HOSTS is required") return values - def to_elasticsearch_params(self) -> dict[str, Any]: - params = { - "hosts": self.hosts.split(","), - "verify_certs": False, - "ssl_show_warn": False, - "request_timeout": 30000, - "retry_on_timeout": True, - "max_retries": 10, - } + def to_elasticsearch_params(self) -> HuaweiElasticsearchParamsDict: + params = HuaweiElasticsearchParamsDict( + hosts=self.hosts.split(","), + verify_certs=False, + ssl_show_warn=False, + request_timeout=30000, + retry_on_timeout=True, + max_retries=10, + ) if self.username and self.password: params["basic_auth"] = (self.username, self.password) return params diff --git a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py index bfcb620618..fbe0bcad02 100644 --- a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py +++ b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py @@ -7,6 +7,7 @@ from opensearchpy import OpenSearch, helpers from opensearchpy.helpers import BulkIndexError from pydantic import BaseModel, model_validator from tenacity import retry, stop_after_attempt, wait_exponential +from typing_extensions import TypedDict from configs import dify_config from core.rag.datasource.vdb.field import Field @@ -26,6 +27,14 @@ ROUTING_FIELD = "routing_field" UGC_INDEX_PREFIX = "ugc_index" +class LindormOpenSearchParamsDict(TypedDict, total=False): + hosts: str | None + use_ssl: bool + pool_maxsize: int + timeout: int + http_auth: tuple[str, str] + + class LindormVectorStoreConfig(BaseModel): hosts: str | None username: str | None = None @@ -44,13 +53,13 @@ class LindormVectorStoreConfig(BaseModel): raise ValueError("config PASSWORD is required") return values - def to_opensearch_params(self) -> dict[str, Any]: - params: dict[str, Any] = { - "hosts": self.hosts, - "use_ssl": False, - "pool_maxsize": 128, - "timeout": 30, - } + def to_opensearch_params(self) -> LindormOpenSearchParamsDict: + params = LindormOpenSearchParamsDict( + hosts=self.hosts, + use_ssl=False, + pool_maxsize=128, + timeout=30, + ) if self.username and self.password: params["http_auth"] = (self.username, self.password) return params diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index 2f77776807..50d18cdc4c 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -6,6 +6,7 @@ from uuid import uuid4 from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers from opensearchpy.helpers import BulkIndexError from pydantic import BaseModel, model_validator +from typing_extensions import TypedDict from configs import dify_config from configs.middleware.vdb.opensearch_config import AuthMethod @@ -21,6 +22,20 @@ from models.dataset import Dataset logger = logging.getLogger(__name__) +class _OpenSearchHostDict(TypedDict): + host: str + port: int + + +class OpenSearchParamsDict(TypedDict, total=False): + hosts: list[_OpenSearchHostDict] + use_ssl: bool + verify_certs: bool + connection_class: type + pool_maxsize: int + http_auth: tuple[str | None, str | None] | Urllib3AWSV4SignerAuth + + class OpenSearchConfig(BaseModel): host: str port: int @@ -57,14 +72,14 @@ class OpenSearchConfig(BaseModel): service=self.aws_service, # type: ignore[arg-type] ) - def to_opensearch_params(self) -> dict[str, Any]: - params = { - "hosts": [{"host": self.host, "port": self.port}], - "use_ssl": self.secure, - "verify_certs": self.verify_certs, - "connection_class": Urllib3HttpConnection, - "pool_maxsize": 20, - } + def to_opensearch_params(self) -> OpenSearchParamsDict: + params = OpenSearchParamsDict( + hosts=[{"host": self.host, "port": self.port}], + use_ssl=self.secure, + verify_certs=self.verify_certs, + connection_class=Urllib3HttpConnection, + pool_maxsize=20, + ) if self.auth_method == "basic": logger.info("Using basic authentication for OpenSearch Vector DB")