feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -3,7 +3,7 @@ import uuid
from typing import Any, Optional
from pydantic import BaseModel, model_validator
from sqlalchemy import Column, Sequence, String, Table, create_engine, insert
from sqlalchemy import Column, String, Table, create_engine, insert
from sqlalchemy import text as sql_text
from sqlalchemy.dialects.postgresql import JSON, TEXT
from sqlalchemy.orm import Session
@@ -58,14 +58,14 @@ class RelytVector(BaseVector):
f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
)
self.client = create_engine(self._url)
self._fields = []
self._fields: list[str] = []
self._group_id = group_id
def get_type(self) -> str:
return VectorType.RELYT
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
index_params = {}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) -> None:
index_params: dict[str, Any] = {}
metadatas = [d.metadata for d in texts]
self.create_collection(len(embeddings[0]))
self.embedding_dimension = len(embeddings[0])
@@ -107,10 +107,10 @@ class RelytVector(BaseVector):
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
from pgvecto_rs.sqlalchemy import VECTOR
from pgvecto_rs.sqlalchemy import VECTOR # type: ignore
ids = [str(uuid.uuid1()) for _ in documents]
metadatas = [d.metadata for d in documents]
metadatas = [d.metadata for d in documents if d.metadata is not None]
for metadata in metadatas:
metadata["group_id"] = self._group_id
texts = [d.page_content for d in documents]
@@ -242,10 +242,6 @@ class RelytVector(BaseVector):
filter: Optional[dict] = None,
) -> list[tuple[Document, float]]:
# Add the filter if provided
try:
from sqlalchemy.engine import Row
except ImportError:
raise ImportError("Could not import Row from sqlalchemy.engine. Please 'pip install sqlalchemy>=1.4'.")
filter_condition = ""
if filter is not None:
@@ -275,7 +271,7 @@ class RelytVector(BaseVector):
# Execute the query and fetch the results
with self.client.connect() as conn:
results: Sequence[Row] = conn.execute(sql_text(sql_query), params).fetchall()
results = conn.execute(sql_text(sql_query), params).fetchall()
documents_with_scores = [
(
@@ -307,11 +303,11 @@ class RelytVectorFactory(AbstractVectorFactory):
return RelytVector(
collection_name=collection_name,
config=RelytConfig(
host=dify_config.RELYT_HOST,
host=dify_config.RELYT_HOST or "localhost",
port=dify_config.RELYT_PORT,
user=dify_config.RELYT_USER,
password=dify_config.RELYT_PASSWORD,
database=dify_config.RELYT_DATABASE,
user=dify_config.RELYT_USER or "",
password=dify_config.RELYT_PASSWORD or "",
database=dify_config.RELYT_DATABASE or "default",
),
group_id=dataset.id,
)