refactor(api): type OpenSearch/Lindorm/Huawei VDB config params dicts with TypedDicts (#34870)

This commit is contained in:
dataCenter430
2026-04-09 17:34:34 -07:00
committed by GitHub
parent a31c1d2c69
commit c5c5c71d15
3 changed files with 59 additions and 24 deletions

View File

@@ -5,6 +5,7 @@ from typing import Any
from elasticsearch import Elasticsearch
from pydantic import BaseModel, model_validator
from typing_extensions import TypedDict
from configs import dify_config
from core.rag.datasource.vdb.field import Field
@@ -19,6 +20,16 @@ from models.dataset import Dataset
logger = logging.getLogger(__name__)
class HuaweiElasticsearchParamsDict(TypedDict, total=False):
hosts: list[str]
verify_certs: bool
ssl_show_warn: bool
request_timeout: int
retry_on_timeout: bool
max_retries: int
basic_auth: tuple[str, str]
def create_ssl_context() -> ssl.SSLContext:
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
@@ -38,15 +49,15 @@ class HuaweiCloudVectorConfig(BaseModel):
raise ValueError("config HOSTS is required")
return values
def to_elasticsearch_params(self) -> dict[str, Any]:
params = {
"hosts": self.hosts.split(","),
"verify_certs": False,
"ssl_show_warn": False,
"request_timeout": 30000,
"retry_on_timeout": True,
"max_retries": 10,
}
def to_elasticsearch_params(self) -> HuaweiElasticsearchParamsDict:
params = HuaweiElasticsearchParamsDict(
hosts=self.hosts.split(","),
verify_certs=False,
ssl_show_warn=False,
request_timeout=30000,
retry_on_timeout=True,
max_retries=10,
)
if self.username and self.password:
params["basic_auth"] = (self.username, self.password)
return params

View File

@@ -7,6 +7,7 @@ from opensearchpy import OpenSearch, helpers
from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator
from tenacity import retry, stop_after_attempt, wait_exponential
from typing_extensions import TypedDict
from configs import dify_config
from core.rag.datasource.vdb.field import Field
@@ -26,6 +27,14 @@ ROUTING_FIELD = "routing_field"
UGC_INDEX_PREFIX = "ugc_index"
class LindormOpenSearchParamsDict(TypedDict, total=False):
hosts: str | None
use_ssl: bool
pool_maxsize: int
timeout: int
http_auth: tuple[str, str]
class LindormVectorStoreConfig(BaseModel):
hosts: str | None
username: str | None = None
@@ -44,13 +53,13 @@ class LindormVectorStoreConfig(BaseModel):
raise ValueError("config PASSWORD is required")
return values
def to_opensearch_params(self) -> dict[str, Any]:
params: dict[str, Any] = {
"hosts": self.hosts,
"use_ssl": False,
"pool_maxsize": 128,
"timeout": 30,
}
def to_opensearch_params(self) -> LindormOpenSearchParamsDict:
params = LindormOpenSearchParamsDict(
hosts=self.hosts,
use_ssl=False,
pool_maxsize=128,
timeout=30,
)
if self.username and self.password:
params["http_auth"] = (self.username, self.password)
return params

View File

@@ -6,6 +6,7 @@ from uuid import uuid4
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator
from typing_extensions import TypedDict
from configs import dify_config
from configs.middleware.vdb.opensearch_config import AuthMethod
@@ -21,6 +22,20 @@ from models.dataset import Dataset
logger = logging.getLogger(__name__)
class _OpenSearchHostDict(TypedDict):
host: str
port: int
class OpenSearchParamsDict(TypedDict, total=False):
hosts: list[_OpenSearchHostDict]
use_ssl: bool
verify_certs: bool
connection_class: type
pool_maxsize: int
http_auth: tuple[str | None, str | None] | Urllib3AWSV4SignerAuth
class OpenSearchConfig(BaseModel):
host: str
port: int
@@ -57,14 +72,14 @@ class OpenSearchConfig(BaseModel):
service=self.aws_service, # type: ignore[arg-type]
)
def to_opensearch_params(self) -> dict[str, Any]:
params = {
"hosts": [{"host": self.host, "port": self.port}],
"use_ssl": self.secure,
"verify_certs": self.verify_certs,
"connection_class": Urllib3HttpConnection,
"pool_maxsize": 20,
}
def to_opensearch_params(self) -> OpenSearchParamsDict:
params = OpenSearchParamsDict(
hosts=[{"host": self.host, "port": self.port}],
use_ssl=self.secure,
verify_certs=self.verify_certs,
connection_class=Urllib3HttpConnection,
pool_maxsize=20,
)
if self.auth_method == "basic":
logger.info("Using basic authentication for OpenSearch Vector DB")