feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -17,12 +17,19 @@ from models.dataset import Dataset
class AnalyticdbVector(BaseVector):
def __init__(
self, collection_name: str, api_config: AnalyticdbVectorOpenAPIConfig, sql_config: AnalyticdbVectorBySqlConfig
self,
collection_name: str,
api_config: AnalyticdbVectorOpenAPIConfig | None,
sql_config: AnalyticdbVectorBySqlConfig | None,
):
super().__init__(collection_name)
if api_config is not None:
self.analyticdb_vector = AnalyticdbVectorOpenAPI(collection_name, api_config)
self.analyticdb_vector: AnalyticdbVectorOpenAPI | AnalyticdbVectorBySql = AnalyticdbVectorOpenAPI(
collection_name, api_config
)
else:
if sql_config is None:
raise ValueError("Either api_config or sql_config must be provided")
self.analyticdb_vector = AnalyticdbVectorBySql(collection_name, sql_config)
def get_type(self) -> str:
@@ -33,8 +40,8 @@ class AnalyticdbVector(BaseVector):
self.analyticdb_vector._create_collection_if_not_exists(dimension)
self.analyticdb_vector.add_texts(texts, embeddings)
def add_texts(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self.analyticdb_vector.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
self.analyticdb_vector.add_texts(documents, embeddings)
def text_exists(self, id: str) -> bool:
return self.analyticdb_vector.text_exists(id)
@@ -68,13 +75,13 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
if dify_config.ANALYTICDB_HOST is None:
# implemented through OpenAPI
apiConfig = AnalyticdbVectorOpenAPIConfig(
access_key_id=dify_config.ANALYTICDB_KEY_ID,
access_key_secret=dify_config.ANALYTICDB_KEY_SECRET,
region_id=dify_config.ANALYTICDB_REGION_ID,
instance_id=dify_config.ANALYTICDB_INSTANCE_ID,
account=dify_config.ANALYTICDB_ACCOUNT,
account_password=dify_config.ANALYTICDB_PASSWORD,
namespace=dify_config.ANALYTICDB_NAMESPACE,
access_key_id=dify_config.ANALYTICDB_KEY_ID or "",
access_key_secret=dify_config.ANALYTICDB_KEY_SECRET or "",
region_id=dify_config.ANALYTICDB_REGION_ID or "",
instance_id=dify_config.ANALYTICDB_INSTANCE_ID or "",
account=dify_config.ANALYTICDB_ACCOUNT or "",
account_password=dify_config.ANALYTICDB_PASSWORD or "",
namespace=dify_config.ANALYTICDB_NAMESPACE or "",
namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD,
)
sqlConfig = None
@@ -83,11 +90,11 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
sqlConfig = AnalyticdbVectorBySqlConfig(
host=dify_config.ANALYTICDB_HOST,
port=dify_config.ANALYTICDB_PORT,
account=dify_config.ANALYTICDB_ACCOUNT,
account_password=dify_config.ANALYTICDB_PASSWORD,
account=dify_config.ANALYTICDB_ACCOUNT or "",
account_password=dify_config.ANALYTICDB_PASSWORD or "",
min_connection=dify_config.ANALYTICDB_MIN_CONNECTION,
max_connection=dify_config.ANALYTICDB_MAX_CONNECTION,
namespace=dify_config.ANALYTICDB_NAMESPACE,
namespace=dify_config.ANALYTICDB_NAMESPACE or "",
)
apiConfig = None
return AnalyticdbVector(

View File

@@ -1,5 +1,5 @@
import json
from typing import Any
from typing import Any, Optional
from pydantic import BaseModel, model_validator
@@ -20,7 +20,7 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel):
account: str
account_password: str
namespace: str = "dify"
namespace_password: str = (None,)
namespace_password: Optional[str] = None
metrics: str = "cosine"
read_timeout: int = 60000
@@ -55,8 +55,8 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel):
class AnalyticdbVectorOpenAPI:
def __init__(self, collection_name: str, config: AnalyticdbVectorOpenAPIConfig):
try:
from alibabacloud_gpdb20160503.client import Client
from alibabacloud_tea_openapi import models as open_api_models
from alibabacloud_gpdb20160503.client import Client # type: ignore
from alibabacloud_tea_openapi import models as open_api_models # type: ignore
except:
raise ImportError(_import_err_msg)
self._collection_name = collection_name.lower()
@@ -77,7 +77,7 @@ class AnalyticdbVectorOpenAPI:
redis_client.set(database_exist_cache_key, 1, ex=3600)
def _initialize_vector_database(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models # type: ignore
request = gpdb_20160503_models.InitVectorDatabaseRequest(
dbinstance_id=self.config.instance_id,
@@ -89,7 +89,7 @@ class AnalyticdbVectorOpenAPI:
def _create_namespace_if_not_exists(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
from Tea.exceptions import TeaException # type: ignore
try:
request = gpdb_20160503_models.DescribeNamespaceRequest(
@@ -159,17 +159,18 @@ class AnalyticdbVectorOpenAPI:
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
for doc, embedding in zip(documents, embeddings, strict=True):
metadata = {
"ref_doc_id": doc.metadata["doc_id"],
"page_content": doc.page_content,
"metadata_": json.dumps(doc.metadata),
}
rows.append(
gpdb_20160503_models.UpsertCollectionDataRequestRows(
vector=embedding,
metadata=metadata,
if doc.metadata is not None:
metadata = {
"ref_doc_id": doc.metadata["doc_id"],
"page_content": doc.page_content,
"metadata_": json.dumps(doc.metadata),
}
rows.append(
gpdb_20160503_models.UpsertCollectionDataRequestRows(
vector=embedding,
metadata=metadata,
)
)
)
request = gpdb_20160503_models.UpsertCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
@@ -258,7 +259,7 @@ class AnalyticdbVectorOpenAPI:
metadata=metadata,
)
documents.append(doc)
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
return documents
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@@ -290,7 +291,7 @@ class AnalyticdbVectorOpenAPI:
metadata=metadata,
)
documents.append(doc)
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
return documents
def delete(self) -> None:

View File

@@ -3,8 +3,8 @@ import uuid
from contextlib import contextmanager
from typing import Any
import psycopg2.extras
import psycopg2.pool
import psycopg2.extras # type: ignore
import psycopg2.pool # type: ignore
from pydantic import BaseModel, model_validator
from core.rag.models.document import Document
@@ -75,6 +75,7 @@ class AnalyticdbVectorBySql:
@contextmanager
def _get_cursor(self):
assert self.pool is not None, "Connection pool is not initialized"
conn = self.pool.getconn()
cur = conn.cursor()
try:
@@ -156,16 +157,17 @@ class AnalyticdbVectorBySql:
VALUES (%s, %s, %s, %s, %s, to_tsvector('zh_cn', %s));
"""
for i, doc in enumerate(documents):
values.append(
(
id_prefix + str(i),
doc.metadata.get("doc_id", str(uuid.uuid4())),
embeddings[i],
doc.page_content,
json.dumps(doc.metadata),
doc.page_content,
if doc.metadata is not None:
values.append(
(
id_prefix + str(i),
doc.metadata.get("doc_id", str(uuid.uuid4())),
embeddings[i],
doc.page_content,
json.dumps(doc.metadata),
doc.page_content,
)
)
)
with self._get_cursor() as cur:
psycopg2.extras.execute_batch(cur, sql, values)