mirror of
https://github.com/langgenius/dify.git
synced 2026-01-07 10:00:46 -05:00
feat: mypy for all type check (#10921)
This commit is contained in:
@@ -17,12 +17,19 @@ from models.dataset import Dataset
|
||||
|
||||
class AnalyticdbVector(BaseVector):
|
||||
def __init__(
|
||||
self, collection_name: str, api_config: AnalyticdbVectorOpenAPIConfig, sql_config: AnalyticdbVectorBySqlConfig
|
||||
self,
|
||||
collection_name: str,
|
||||
api_config: AnalyticdbVectorOpenAPIConfig | None,
|
||||
sql_config: AnalyticdbVectorBySqlConfig | None,
|
||||
):
|
||||
super().__init__(collection_name)
|
||||
if api_config is not None:
|
||||
self.analyticdb_vector = AnalyticdbVectorOpenAPI(collection_name, api_config)
|
||||
self.analyticdb_vector: AnalyticdbVectorOpenAPI | AnalyticdbVectorBySql = AnalyticdbVectorOpenAPI(
|
||||
collection_name, api_config
|
||||
)
|
||||
else:
|
||||
if sql_config is None:
|
||||
raise ValueError("Either api_config or sql_config must be provided")
|
||||
self.analyticdb_vector = AnalyticdbVectorBySql(collection_name, sql_config)
|
||||
|
||||
def get_type(self) -> str:
|
||||
@@ -33,8 +40,8 @@ class AnalyticdbVector(BaseVector):
|
||||
self.analyticdb_vector._create_collection_if_not_exists(dimension)
|
||||
self.analyticdb_vector.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
self.analyticdb_vector.add_texts(texts, embeddings)
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
self.analyticdb_vector.add_texts(documents, embeddings)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
return self.analyticdb_vector.text_exists(id)
|
||||
@@ -68,13 +75,13 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
|
||||
if dify_config.ANALYTICDB_HOST is None:
|
||||
# implemented through OpenAPI
|
||||
apiConfig = AnalyticdbVectorOpenAPIConfig(
|
||||
access_key_id=dify_config.ANALYTICDB_KEY_ID,
|
||||
access_key_secret=dify_config.ANALYTICDB_KEY_SECRET,
|
||||
region_id=dify_config.ANALYTICDB_REGION_ID,
|
||||
instance_id=dify_config.ANALYTICDB_INSTANCE_ID,
|
||||
account=dify_config.ANALYTICDB_ACCOUNT,
|
||||
account_password=dify_config.ANALYTICDB_PASSWORD,
|
||||
namespace=dify_config.ANALYTICDB_NAMESPACE,
|
||||
access_key_id=dify_config.ANALYTICDB_KEY_ID or "",
|
||||
access_key_secret=dify_config.ANALYTICDB_KEY_SECRET or "",
|
||||
region_id=dify_config.ANALYTICDB_REGION_ID or "",
|
||||
instance_id=dify_config.ANALYTICDB_INSTANCE_ID or "",
|
||||
account=dify_config.ANALYTICDB_ACCOUNT or "",
|
||||
account_password=dify_config.ANALYTICDB_PASSWORD or "",
|
||||
namespace=dify_config.ANALYTICDB_NAMESPACE or "",
|
||||
namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD,
|
||||
)
|
||||
sqlConfig = None
|
||||
@@ -83,11 +90,11 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
|
||||
sqlConfig = AnalyticdbVectorBySqlConfig(
|
||||
host=dify_config.ANALYTICDB_HOST,
|
||||
port=dify_config.ANALYTICDB_PORT,
|
||||
account=dify_config.ANALYTICDB_ACCOUNT,
|
||||
account_password=dify_config.ANALYTICDB_PASSWORD,
|
||||
account=dify_config.ANALYTICDB_ACCOUNT or "",
|
||||
account_password=dify_config.ANALYTICDB_PASSWORD or "",
|
||||
min_connection=dify_config.ANALYTICDB_MIN_CONNECTION,
|
||||
max_connection=dify_config.ANALYTICDB_MAX_CONNECTION,
|
||||
namespace=dify_config.ANALYTICDB_NAMESPACE,
|
||||
namespace=dify_config.ANALYTICDB_NAMESPACE or "",
|
||||
)
|
||||
apiConfig = None
|
||||
return AnalyticdbVector(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
@@ -20,7 +20,7 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel):
|
||||
account: str
|
||||
account_password: str
|
||||
namespace: str = "dify"
|
||||
namespace_password: str = (None,)
|
||||
namespace_password: Optional[str] = None
|
||||
metrics: str = "cosine"
|
||||
read_timeout: int = 60000
|
||||
|
||||
@@ -55,8 +55,8 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel):
|
||||
class AnalyticdbVectorOpenAPI:
|
||||
def __init__(self, collection_name: str, config: AnalyticdbVectorOpenAPIConfig):
|
||||
try:
|
||||
from alibabacloud_gpdb20160503.client import Client
|
||||
from alibabacloud_tea_openapi import models as open_api_models
|
||||
from alibabacloud_gpdb20160503.client import Client # type: ignore
|
||||
from alibabacloud_tea_openapi import models as open_api_models # type: ignore
|
||||
except:
|
||||
raise ImportError(_import_err_msg)
|
||||
self._collection_name = collection_name.lower()
|
||||
@@ -77,7 +77,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
redis_client.set(database_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _initialize_vector_database(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models # type: ignore
|
||||
|
||||
request = gpdb_20160503_models.InitVectorDatabaseRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
@@ -89,7 +89,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
|
||||
def _create_namespace_if_not_exists(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
from Tea.exceptions import TeaException # type: ignore
|
||||
|
||||
try:
|
||||
request = gpdb_20160503_models.DescribeNamespaceRequest(
|
||||
@@ -159,17 +159,18 @@ class AnalyticdbVectorOpenAPI:
|
||||
|
||||
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
|
||||
for doc, embedding in zip(documents, embeddings, strict=True):
|
||||
metadata = {
|
||||
"ref_doc_id": doc.metadata["doc_id"],
|
||||
"page_content": doc.page_content,
|
||||
"metadata_": json.dumps(doc.metadata),
|
||||
}
|
||||
rows.append(
|
||||
gpdb_20160503_models.UpsertCollectionDataRequestRows(
|
||||
vector=embedding,
|
||||
metadata=metadata,
|
||||
if doc.metadata is not None:
|
||||
metadata = {
|
||||
"ref_doc_id": doc.metadata["doc_id"],
|
||||
"page_content": doc.page_content,
|
||||
"metadata_": json.dumps(doc.metadata),
|
||||
}
|
||||
rows.append(
|
||||
gpdb_20160503_models.UpsertCollectionDataRequestRows(
|
||||
vector=embedding,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
)
|
||||
request = gpdb_20160503_models.UpsertCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
@@ -258,7 +259,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
|
||||
return documents
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@@ -290,7 +291,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
|
||||
return documents
|
||||
|
||||
def delete(self) -> None:
|
||||
|
||||
@@ -3,8 +3,8 @@ import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import psycopg2.extras
|
||||
import psycopg2.pool
|
||||
import psycopg2.extras # type: ignore
|
||||
import psycopg2.pool # type: ignore
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from core.rag.models.document import Document
|
||||
@@ -75,6 +75,7 @@ class AnalyticdbVectorBySql:
|
||||
|
||||
@contextmanager
|
||||
def _get_cursor(self):
|
||||
assert self.pool is not None, "Connection pool is not initialized"
|
||||
conn = self.pool.getconn()
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
@@ -156,16 +157,17 @@ class AnalyticdbVectorBySql:
|
||||
VALUES (%s, %s, %s, %s, %s, to_tsvector('zh_cn', %s));
|
||||
"""
|
||||
for i, doc in enumerate(documents):
|
||||
values.append(
|
||||
(
|
||||
id_prefix + str(i),
|
||||
doc.metadata.get("doc_id", str(uuid.uuid4())),
|
||||
embeddings[i],
|
||||
doc.page_content,
|
||||
json.dumps(doc.metadata),
|
||||
doc.page_content,
|
||||
if doc.metadata is not None:
|
||||
values.append(
|
||||
(
|
||||
id_prefix + str(i),
|
||||
doc.metadata.get("doc_id", str(uuid.uuid4())),
|
||||
embeddings[i],
|
||||
doc.page_content,
|
||||
json.dumps(doc.metadata),
|
||||
doc.page_content,
|
||||
)
|
||||
)
|
||||
)
|
||||
with self._get_cursor() as cur:
|
||||
psycopg2.extras.execute_batch(cur, sql, values)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user