mirror of
https://github.com/langgenius/dify.git
synced 2026-02-26 17:02:03 -05:00
feat: mypy for all type check (#10921)
This commit is contained in:
@@ -3,7 +3,7 @@ import uuid
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import Column, Sequence, String, Table, create_engine, insert
|
||||
from sqlalchemy import Column, String, Table, create_engine, insert
|
||||
from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.dialects.postgresql import JSON, TEXT
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -58,14 +58,14 @@ class RelytVector(BaseVector):
|
||||
f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
|
||||
)
|
||||
self.client = create_engine(self._url)
|
||||
self._fields = []
|
||||
self._fields: list[str] = []
|
||||
self._group_id = group_id
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.RELYT
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
index_params = {}
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) -> None:
|
||||
index_params: dict[str, Any] = {}
|
||||
metadatas = [d.metadata for d in texts]
|
||||
self.create_collection(len(embeddings[0]))
|
||||
self.embedding_dimension = len(embeddings[0])
|
||||
@@ -107,10 +107,10 @@ class RelytVector(BaseVector):
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
from pgvecto_rs.sqlalchemy import VECTOR
|
||||
from pgvecto_rs.sqlalchemy import VECTOR # type: ignore
|
||||
|
||||
ids = [str(uuid.uuid1()) for _ in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
metadatas = [d.metadata for d in documents if d.metadata is not None]
|
||||
for metadata in metadatas:
|
||||
metadata["group_id"] = self._group_id
|
||||
texts = [d.page_content for d in documents]
|
||||
@@ -242,10 +242,6 @@ class RelytVector(BaseVector):
|
||||
filter: Optional[dict] = None,
|
||||
) -> list[tuple[Document, float]]:
|
||||
# Add the filter if provided
|
||||
try:
|
||||
from sqlalchemy.engine import Row
|
||||
except ImportError:
|
||||
raise ImportError("Could not import Row from sqlalchemy.engine. Please 'pip install sqlalchemy>=1.4'.")
|
||||
|
||||
filter_condition = ""
|
||||
if filter is not None:
|
||||
@@ -275,7 +271,7 @@ class RelytVector(BaseVector):
|
||||
|
||||
# Execute the query and fetch the results
|
||||
with self.client.connect() as conn:
|
||||
results: Sequence[Row] = conn.execute(sql_text(sql_query), params).fetchall()
|
||||
results = conn.execute(sql_text(sql_query), params).fetchall()
|
||||
|
||||
documents_with_scores = [
|
||||
(
|
||||
@@ -307,11 +303,11 @@ class RelytVectorFactory(AbstractVectorFactory):
|
||||
return RelytVector(
|
||||
collection_name=collection_name,
|
||||
config=RelytConfig(
|
||||
host=dify_config.RELYT_HOST,
|
||||
host=dify_config.RELYT_HOST or "localhost",
|
||||
port=dify_config.RELYT_PORT,
|
||||
user=dify_config.RELYT_USER,
|
||||
password=dify_config.RELYT_PASSWORD,
|
||||
database=dify_config.RELYT_DATABASE,
|
||||
user=dify_config.RELYT_USER or "",
|
||||
password=dify_config.RELYT_PASSWORD or "",
|
||||
database=dify_config.RELYT_DATABASE or "default",
|
||||
),
|
||||
group_id=dataset.id,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user