mirror of
https://github.com/langgenius/dify.git
synced 2026-04-11 06:00:28 -04:00
refactor(api): type OpenSearch/Lindorm/Huawei VDB config params dicts with TypedDicts (#34870)
This commit is contained in:
@@ -5,6 +5,7 @@ from typing import Any
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
from pydantic import BaseModel, model_validator
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
@@ -19,6 +20,16 @@ from models.dataset import Dataset
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HuaweiElasticsearchParamsDict(TypedDict, total=False):
|
||||
hosts: list[str]
|
||||
verify_certs: bool
|
||||
ssl_show_warn: bool
|
||||
request_timeout: int
|
||||
retry_on_timeout: bool
|
||||
max_retries: int
|
||||
basic_auth: tuple[str, str]
|
||||
|
||||
|
||||
def create_ssl_context() -> ssl.SSLContext:
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.check_hostname = False
|
||||
@@ -38,15 +49,15 @@ class HuaweiCloudVectorConfig(BaseModel):
|
||||
raise ValueError("config HOSTS is required")
|
||||
return values
|
||||
|
||||
def to_elasticsearch_params(self) -> dict[str, Any]:
|
||||
params = {
|
||||
"hosts": self.hosts.split(","),
|
||||
"verify_certs": False,
|
||||
"ssl_show_warn": False,
|
||||
"request_timeout": 30000,
|
||||
"retry_on_timeout": True,
|
||||
"max_retries": 10,
|
||||
}
|
||||
def to_elasticsearch_params(self) -> HuaweiElasticsearchParamsDict:
|
||||
params = HuaweiElasticsearchParamsDict(
|
||||
hosts=self.hosts.split(","),
|
||||
verify_certs=False,
|
||||
ssl_show_warn=False,
|
||||
request_timeout=30000,
|
||||
retry_on_timeout=True,
|
||||
max_retries=10,
|
||||
)
|
||||
if self.username and self.password:
|
||||
params["basic_auth"] = (self.username, self.password)
|
||||
return params
|
||||
|
||||
@@ -7,6 +7,7 @@ from opensearchpy import OpenSearch, helpers
|
||||
from opensearchpy.helpers import BulkIndexError
|
||||
from pydantic import BaseModel, model_validator
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
@@ -26,6 +27,14 @@ ROUTING_FIELD = "routing_field"
|
||||
UGC_INDEX_PREFIX = "ugc_index"
|
||||
|
||||
|
||||
class LindormOpenSearchParamsDict(TypedDict, total=False):
|
||||
hosts: str | None
|
||||
use_ssl: bool
|
||||
pool_maxsize: int
|
||||
timeout: int
|
||||
http_auth: tuple[str, str]
|
||||
|
||||
|
||||
class LindormVectorStoreConfig(BaseModel):
|
||||
hosts: str | None
|
||||
username: str | None = None
|
||||
@@ -44,13 +53,13 @@ class LindormVectorStoreConfig(BaseModel):
|
||||
raise ValueError("config PASSWORD is required")
|
||||
return values
|
||||
|
||||
def to_opensearch_params(self) -> dict[str, Any]:
|
||||
params: dict[str, Any] = {
|
||||
"hosts": self.hosts,
|
||||
"use_ssl": False,
|
||||
"pool_maxsize": 128,
|
||||
"timeout": 30,
|
||||
}
|
||||
def to_opensearch_params(self) -> LindormOpenSearchParamsDict:
|
||||
params = LindormOpenSearchParamsDict(
|
||||
hosts=self.hosts,
|
||||
use_ssl=False,
|
||||
pool_maxsize=128,
|
||||
timeout=30,
|
||||
)
|
||||
if self.username and self.password:
|
||||
params["http_auth"] = (self.username, self.password)
|
||||
return params
|
||||
|
||||
@@ -6,6 +6,7 @@ from uuid import uuid4
|
||||
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
|
||||
from opensearchpy.helpers import BulkIndexError
|
||||
from pydantic import BaseModel, model_validator
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from configs.middleware.vdb.opensearch_config import AuthMethod
|
||||
@@ -21,6 +22,20 @@ from models.dataset import Dataset
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _OpenSearchHostDict(TypedDict):
|
||||
host: str
|
||||
port: int
|
||||
|
||||
|
||||
class OpenSearchParamsDict(TypedDict, total=False):
|
||||
hosts: list[_OpenSearchHostDict]
|
||||
use_ssl: bool
|
||||
verify_certs: bool
|
||||
connection_class: type
|
||||
pool_maxsize: int
|
||||
http_auth: tuple[str | None, str | None] | Urllib3AWSV4SignerAuth
|
||||
|
||||
|
||||
class OpenSearchConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
@@ -57,14 +72,14 @@ class OpenSearchConfig(BaseModel):
|
||||
service=self.aws_service, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
def to_opensearch_params(self) -> dict[str, Any]:
|
||||
params = {
|
||||
"hosts": [{"host": self.host, "port": self.port}],
|
||||
"use_ssl": self.secure,
|
||||
"verify_certs": self.verify_certs,
|
||||
"connection_class": Urllib3HttpConnection,
|
||||
"pool_maxsize": 20,
|
||||
}
|
||||
def to_opensearch_params(self) -> OpenSearchParamsDict:
|
||||
params = OpenSearchParamsDict(
|
||||
hosts=[{"host": self.host, "port": self.port}],
|
||||
use_ssl=self.secure,
|
||||
verify_certs=self.verify_certs,
|
||||
connection_class=Urllib3HttpConnection,
|
||||
pool_maxsize=20,
|
||||
)
|
||||
|
||||
if self.auth_method == "basic":
|
||||
logger.info("Using basic authentication for OpenSearch Vector DB")
|
||||
|
||||
Reference in New Issue
Block a user