remove bare list, dict, Sequence, None, Any (#25058)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
Asuka Minato
2025-09-06 04:32:23 +09:00
committed by GitHub
parent 2b0695bdde
commit a78339a040
306 changed files with 787 additions and 817 deletions

View File

@@ -76,7 +76,7 @@ class Jieba(BaseKeyword):
return False
return id in set.union(*keyword_table.values())
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
with redis_client.lock(lock_name, timeout=600):
keyword_table = self._get_dataset_keyword_table()
@@ -116,7 +116,7 @@ class Jieba(BaseKeyword):
return documents
def delete(self) -> None:
def delete(self):
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
with redis_client.lock(lock_name, timeout=600):
dataset_keyword_table = self.dataset.dataset_keyword_table
@@ -168,14 +168,14 @@ class Jieba(BaseKeyword):
return {}
def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict:
def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]):
for keyword in keywords:
if keyword not in keyword_table:
keyword_table[keyword] = set()
keyword_table[keyword].add(id)
return keyword_table
def _delete_ids_from_keyword_table(self, keyword_table: dict, ids: list[str]) -> dict:
def _delete_ids_from_keyword_table(self, keyword_table: dict, ids: list[str]):
# get set of ids that correspond to node
node_idxs_to_delete = set(ids)
@@ -251,7 +251,7 @@ class Jieba(BaseKeyword):
self._save_dataset_keyword_table(keyword_table)
def set_orjson_default(obj: Any) -> Any:
def set_orjson_default(obj: Any):
"""Default function for orjson serialization of set types"""
if isinstance(obj, set):
return list(obj)

View File

@@ -24,11 +24,11 @@ class BaseKeyword(ABC):
raise NotImplementedError
@abstractmethod
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
raise NotImplementedError
@abstractmethod
def delete(self) -> None:
def delete(self):
raise NotImplementedError
@abstractmethod

View File

@@ -36,10 +36,10 @@ class Keyword:
def text_exists(self, id: str) -> bool:
return self._keyword_processor.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
self._keyword_processor.delete_by_ids(ids)
def delete(self) -> None:
def delete(self):
self._keyword_processor.delete()
def search(self, query: str, **kwargs: Any) -> list[Document]:

View File

@@ -46,10 +46,10 @@ class AnalyticdbVector(BaseVector):
def text_exists(self, id: str) -> bool:
return self.analyticdb_vector.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
self.analyticdb_vector.delete_by_ids(ids)
def delete_by_metadata_field(self, key: str, value: str) -> None:
def delete_by_metadata_field(self, key: str, value: str):
self.analyticdb_vector.delete_by_metadata_field(key, value)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
@@ -58,7 +58,7 @@ class AnalyticdbVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self.analyticdb_vector.search_by_full_text(query, **kwargs)
def delete(self) -> None:
def delete(self):
self.analyticdb_vector.delete()

View File

@@ -26,7 +26,7 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
def validate_config(cls, values: dict):
if not values["access_key_id"]:
raise ValueError("config ANALYTICDB_KEY_ID is required")
if not values["access_key_secret"]:
@@ -65,7 +65,7 @@ class AnalyticdbVectorOpenAPI:
self._client = Client(self._client_config)
self._initialize()
def _initialize(self) -> None:
def _initialize(self):
cache_key = f"vector_initialize_{self.config.instance_id}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
@@ -76,7 +76,7 @@ class AnalyticdbVectorOpenAPI:
self._create_namespace_if_not_exists()
redis_client.set(database_exist_cache_key, 1, ex=3600)
def _initialize_vector_database(self) -> None:
def _initialize_vector_database(self):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models # type: ignore
request = gpdb_20160503_models.InitVectorDatabaseRequest(
@@ -87,7 +87,7 @@ class AnalyticdbVectorOpenAPI:
)
self._client.init_vector_database(request)
def _create_namespace_if_not_exists(self) -> None:
def _create_namespace_if_not_exists(self):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException # type: ignore
@@ -200,7 +200,7 @@ class AnalyticdbVectorOpenAPI:
response = self._client.query_collection_data(request)
return len(response.body.matches.match) > 0
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
ids_str = ",".join(f"'{id}'" for id in ids)
@@ -216,7 +216,7 @@ class AnalyticdbVectorOpenAPI:
)
self._client.delete_collection_data(request)
def delete_by_metadata_field(self, key: str, value: str) -> None:
def delete_by_metadata_field(self, key: str, value: str):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionDataRequest(
@@ -305,7 +305,7 @@ class AnalyticdbVectorOpenAPI:
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
return documents
def delete(self) -> None:
def delete(self):
try:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models

View File

@@ -23,7 +23,7 @@ class AnalyticdbVectorBySqlConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
def validate_config(cls, values: dict):
if not values["host"]:
raise ValueError("config ANALYTICDB_HOST is required")
if not values["port"]:
@@ -52,7 +52,7 @@ class AnalyticdbVectorBySql:
if not self.pool:
self.pool = self._create_connection_pool()
def _initialize(self) -> None:
def _initialize(self):
cache_key = f"vector_initialize_{self.config.host}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
@@ -85,7 +85,7 @@ class AnalyticdbVectorBySql:
conn.commit()
self.pool.putconn(conn)
def _initialize_vector_database(self) -> None:
def _initialize_vector_database(self):
conn = psycopg2.connect(
host=self.config.host,
port=self.config.port,
@@ -188,7 +188,7 @@ class AnalyticdbVectorBySql:
cur.execute(f"SELECT id FROM {self.table_name} WHERE ref_doc_id = %s", (id,))
return cur.fetchone() is not None
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
if not ids:
return
with self._get_cursor() as cur:
@@ -198,7 +198,7 @@ class AnalyticdbVectorBySql:
if "does not exist" not in str(e):
raise e
def delete_by_metadata_field(self, key: str, value: str) -> None:
def delete_by_metadata_field(self, key: str, value: str):
with self._get_cursor() as cur:
try:
cur.execute(f"DELETE FROM {self.table_name} WHERE metadata_->>%s = %s", (key, value))
@@ -270,6 +270,6 @@ class AnalyticdbVectorBySql:
documents.append(doc)
return documents
def delete(self) -> None:
def delete(self):
with self._get_cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")

View File

@@ -36,7 +36,7 @@ class BaiduConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
def validate_config(cls, values: dict):
if not values["endpoint"]:
raise ValueError("config BAIDU_VECTOR_DB_ENDPOINT is required")
if not values["account"]:
@@ -66,7 +66,7 @@ class BaiduVector(BaseVector):
def get_type(self) -> str:
return VectorType.BAIDU
def to_index_struct(self) -> dict:
def to_index_struct(self):
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
@@ -111,13 +111,13 @@ class BaiduVector(BaseVector):
return True
return False
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
if not ids:
return
quoted_ids = [f"'{id}'" for id in ids]
self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})")
def delete_by_metadata_field(self, key: str, value: str) -> None:
def delete_by_metadata_field(self, key: str, value: str):
self._db.table(self._collection_name).delete(filter=f"{key} = '{value}'")
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
@@ -164,7 +164,7 @@ class BaiduVector(BaseVector):
return docs
def delete(self) -> None:
def delete(self):
try:
self._db.drop_table(table_name=self._collection_name)
except ServerError as e:
@@ -201,7 +201,7 @@ class BaiduVector(BaseVector):
tables = self._db.list_table()
return any(table.table_name == self._collection_name for table in tables)
def _create_table(self, dimension: int) -> None:
def _create_table(self, dimension: int):
# Try to grab distributed lock and create table
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=60):

View File

@@ -82,7 +82,7 @@ class ChromaVector(BaseVector):
def delete(self):
self._client.delete_collection(self._collection_name)
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
if not ids:
return
collection = self._client.get_or_create_collection(self._collection_name)

View File

@@ -49,7 +49,7 @@ class ClickzettaConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
def validate_config(cls, values: dict):
"""
Validate the configuration values.
"""
@@ -134,7 +134,7 @@ class ClickzettaConnectionPool:
raise RuntimeError(f"Failed to create ClickZetta connection after {max_retries} attempts")
def _configure_connection(self, connection: "Connection") -> None:
def _configure_connection(self, connection: "Connection"):
"""Configure connection session settings."""
try:
with connection.cursor() as cursor:
@@ -221,7 +221,7 @@ class ClickzettaConnectionPool:
# No valid connection found, create new one
return self._create_connection(config)
def return_connection(self, config: ClickzettaConfig, connection: "Connection") -> None:
def return_connection(self, config: ClickzettaConfig, connection: "Connection"):
"""Return a connection to the pool."""
config_key = self._get_config_key(config)
@@ -243,7 +243,7 @@ class ClickzettaConnectionPool:
with contextlib.suppress(Exception):
connection.close()
def _cleanup_expired_connections(self) -> None:
def _cleanup_expired_connections(self):
"""Clean up expired connections from all pools."""
current_time = time.time()
@@ -265,7 +265,7 @@ class ClickzettaConnectionPool:
self._pools[config_key] = valid_connections
def _start_cleanup_thread(self) -> None:
def _start_cleanup_thread(self):
"""Start background thread for connection cleanup."""
def cleanup_worker():
@@ -280,7 +280,7 @@ class ClickzettaConnectionPool:
self._cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True)
self._cleanup_thread.start()
def shutdown(self) -> None:
def shutdown(self):
"""Shutdown connection pool and close all connections."""
self._shutdown = True
@@ -319,7 +319,7 @@ class ClickzettaVector(BaseVector):
"""Get a connection from the pool."""
return self._connection_pool.get_connection(self._config)
def _return_connection(self, connection: "Connection") -> None:
def _return_connection(self, connection: "Connection"):
"""Return a connection to the pool."""
self._connection_pool.return_connection(self._config, connection)
@@ -342,7 +342,7 @@ class ClickzettaVector(BaseVector):
"""Get a connection context manager."""
return self.ConnectionContext(self)
def _parse_metadata(self, raw_metadata: str, row_id: str) -> dict:
def _parse_metadata(self, raw_metadata: str, row_id: str):
"""
Parse metadata from JSON string with proper error handling and fallback.
@@ -723,7 +723,7 @@ class ClickzettaVector(BaseVector):
result = cursor.fetchone()
return result[0] > 0 if result else False
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
"""Delete documents by IDs."""
if not ids:
return
@@ -736,7 +736,7 @@ class ClickzettaVector(BaseVector):
# Execute delete through write queue
self._execute_write(self._delete_by_ids_impl, ids)
def _delete_by_ids_impl(self, ids: list[str]) -> None:
def _delete_by_ids_impl(self, ids: list[str]):
"""Implementation of delete by IDs (executed in write worker thread)."""
safe_ids = [self._safe_doc_id(id) for id in ids]
@@ -748,7 +748,7 @@ class ClickzettaVector(BaseVector):
with connection.cursor() as cursor:
cursor.execute(sql, binding_params=safe_ids)
def delete_by_metadata_field(self, key: str, value: str) -> None:
def delete_by_metadata_field(self, key: str, value: str):
"""Delete documents by metadata field."""
# Check if table exists before attempting delete
if not self._table_exists():
@@ -758,7 +758,7 @@ class ClickzettaVector(BaseVector):
# Execute delete through write queue
self._execute_write(self._delete_by_metadata_field_impl, key, value)
def _delete_by_metadata_field_impl(self, key: str, value: str) -> None:
def _delete_by_metadata_field_impl(self, key: str, value: str):
"""Implementation of delete by metadata field (executed in write worker thread)."""
with self.get_connection_context() as connection:
with connection.cursor() as cursor:
@@ -1027,7 +1027,7 @@ class ClickzettaVector(BaseVector):
return documents
def delete(self) -> None:
def delete(self):
"""Delete the entire collection."""
with self.get_connection_context() as connection:
with connection.cursor() as cursor:

View File

@@ -36,7 +36,7 @@ class CouchbaseConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
def validate_config(cls, values: dict):
if not values.get("connection_string"):
raise ValueError("config COUCHBASE_CONNECTION_STRING is required")
if not values.get("user"):
@@ -234,7 +234,7 @@ class CouchbaseVector(BaseVector):
return bool(row["count"] > 0)
return False # Return False if no rows are returned
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
query = f"""
DELETE FROM `{self._bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
WHERE META().id IN $doc_ids;
@@ -261,7 +261,7 @@ class CouchbaseVector(BaseVector):
# result = self._cluster.query(query, named_parameters={'value':value})
# return [row['id'] for row in result.rows()]
def delete_by_metadata_field(self, key: str, value: str) -> None:
def delete_by_metadata_field(self, key: str, value: str):
query = f"""
DELETE FROM `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
WHERE metadata.{key} = $value;

View File

@@ -43,7 +43,7 @@ class ElasticSearchConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
def validate_config(cls, values: dict):
use_cloud = values.get("use_cloud", False)
cloud_url = values.get("cloud_url")
@@ -174,20 +174,20 @@ class ElasticSearchVector(BaseVector):
def text_exists(self, id: str) -> bool:
return bool(self._client.exists(index=self._collection_name, id=id))
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
if not ids:
return
for id in ids:
self._client.delete(index=self._collection_name, id=id)
def delete_by_metadata_field(self, key: str, value: str) -> None:
def delete_by_metadata_field(self, key: str, value: str):
query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}}
results = self._client.search(index=self._collection_name, body=query_str)
ids = [hit["_id"] for hit in results["hits"]["hits"]]
if ids:
self.delete_by_ids(ids)
def delete(self) -> None:
def delete(self):
self._client.indices.delete(index=self._collection_name)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:

View File

@@ -33,7 +33,7 @@ class HuaweiCloudVectorConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
def validate_config(cls, values: dict):
if not values["hosts"]:
raise ValueError("config HOSTS is required")
return values
@@ -78,20 +78,20 @@ class HuaweiCloudVector(BaseVector):
def text_exists(self, id: str) -> bool:
return bool(self._client.exists(index=self._collection_name, id=id))
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
if not ids:
return
for id in ids:
self._client.delete(index=self._collection_name, id=id)
def delete_by_metadata_field(self, key: str, value: str) -> None:
def delete_by_metadata_field(self, key: str, value: str):
query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}}
results = self._client.search(index=self._collection_name, body=query_str)
ids = [hit["_id"] for hit in results["hits"]["hits"]]
if ids:
self.delete_by_ids(ids)
def delete(self) -> None:
def delete(self):
self._client.indices.delete(index=self._collection_name)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:

View File

@@ -36,7 +36,7 @@ class LindormVectorStoreConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
def validate_config(cls, values: dict):
if not values["hosts"]:
raise ValueError("config URL is required")
if not values["username"]:
@@ -167,7 +167,7 @@ class LindormVectorStore(BaseVector):
if ids:
self.delete_by_ids(ids)
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
"""Delete documents by their IDs in batch.
Args:
@@ -213,7 +213,7 @@ class LindormVectorStore(BaseVector):
else:
logger.exception("Error deleting document: %s", error)
def delete(self) -> None:
def delete(self):
if self._using_ugc:
routing_filter_query = {
"query": {"bool": {"must": [{"term": {f"{self._routing_field}.keyword": self._routing}}]}}
@@ -372,7 +372,7 @@ class LindormVectorStore(BaseVector):
# logger.info(f"create index success: {self._collection_name}")
def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dict:
def default_text_mapping(dimension: int, method_name: str, **kwargs: Any):
excludes_from_source = kwargs.get("excludes_from_source", False)
analyzer = kwargs.get("analyzer", "ik_max_word")
text_field = kwargs.get("text_field", Field.CONTENT_KEY.value)
@@ -456,7 +456,7 @@ def default_text_search_query(
routing: Optional[str] = None,
routing_field: Optional[str] = None,
**kwargs,
) -> dict:
):
query_clause: dict[str, Any] = {}
if routing is not None:
query_clause = {
@@ -513,7 +513,7 @@ def default_vector_search_query(
filters: Optional[list[dict]] = None,
filter_type: Optional[str] = None,
**kwargs,
) -> dict:
):
if filters is not None:
filter_type = "pre_filter" if filter_type is None else filter_type
if not isinstance(filters, list):

View File

@@ -29,7 +29,7 @@ class MatrixoneConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
def validate_config(cls, values: dict):
if not values["host"]:
raise ValueError("config host is required")
if not values["port"]:
@@ -128,7 +128,7 @@ class MatrixoneVector(BaseVector):
return len(result) > 0
@ensure_client
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
assert self.client is not None
if not ids:
return
@@ -141,7 +141,7 @@ class MatrixoneVector(BaseVector):
return [result.id for result in results]
@ensure_client
def delete_by_metadata_field(self, key: str, value: str) -> None:
def delete_by_metadata_field(self, key: str, value: str):
assert self.client is not None
self.client.delete(filter={key: value})
@@ -207,7 +207,7 @@ class MatrixoneVector(BaseVector):
return docs
@ensure_client
def delete(self) -> None:
def delete(self):
assert self.client is not None
self.client.delete()

View File

@@ -36,7 +36,7 @@ class MilvusConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
def validate_config(cls, values: dict):
"""
Validate the configuration values.
Raises ValueError if required fields are missing.
@@ -79,7 +79,7 @@ class MilvusVector(BaseVector):
self._load_collection_fields()
self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported
def _load_collection_fields(self, fields: Optional[list[str]] = None) -> None:
def _load_collection_fields(self, fields: Optional[list[str]] = None):
if fields is None:
# Load collection fields from remote server
collection_info = self._client.describe_collection(self._collection_name)
@@ -171,7 +171,7 @@ class MilvusVector(BaseVector):
if ids:
self._client.delete(collection_name=self._collection_name, pks=ids)
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
"""
Delete documents by their IDs.
"""
@@ -183,7 +183,7 @@ class MilvusVector(BaseVector):
ids = [item["id"] for item in result]
self._client.delete(collection_name=self._collection_name, pks=ids)
def delete(self) -> None:
def delete(self):
"""
Delete the entire collection.
"""

View File

@@ -101,7 +101,7 @@ class MyScaleVector(BaseVector):
results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'")
return results.row_count > 0
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
if not ids:
return
self._client.command(
@@ -114,7 +114,7 @@ class MyScaleVector(BaseVector):
).result_rows
return [row[0] for row in rows]
def delete_by_metadata_field(self, key: str, value: str) -> None:
def delete_by_metadata_field(self, key: str, value: str):
self._client.command(
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE metadata.{key}='{value}'"
)
@@ -156,7 +156,7 @@ class MyScaleVector(BaseVector):
logger.exception("Vector search operation failed")
return []
def delete(self) -> None:
def delete(self):
self._client.command(f"DROP TABLE IF EXISTS {self._config.database}.{self._collection_name}")

View File

@@ -35,7 +35,7 @@ class OceanBaseVectorConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
def validate_config(cls, values: dict):
if not values["host"]:
raise ValueError("config OCEANBASE_VECTOR_HOST is required")
if not values["port"]:
@@ -68,7 +68,7 @@ class OceanBaseVector(BaseVector):
self._create_collection()
self.add_texts(texts, embeddings)
def _create_collection(self) -> None:
def _create_collection(self):
lock_name = "vector_indexing_lock_" + self._collection_name
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = "vector_indexing_" + self._collection_name
@@ -174,7 +174,7 @@ class OceanBaseVector(BaseVector):
cur = self._client.get(table_name=self._collection_name, ids=id)
return bool(cur.rowcount != 0)
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
if not ids:
return
self._client.delete(table_name=self._collection_name, ids=ids)
@@ -190,7 +190,7 @@ class OceanBaseVector(BaseVector):
)
return [row[0] for row in cur]
def delete_by_metadata_field(self, key: str, value: str) -> None:
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
self.delete_by_ids(ids)
@@ -278,7 +278,7 @@ class OceanBaseVector(BaseVector):
)
return docs
def delete(self) -> None:
def delete(self):
self._client.drop_table_if_exist(self._collection_name)

View File

@@ -29,7 +29,7 @@ class OpenGaussConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
def validate_config(cls, values: dict):
if not values["host"]:
raise ValueError("config OPENGAUSS_HOST is required")
if not values["port"]:
@@ -159,7 +159,7 @@ class OpenGauss(BaseVector):
docs.append(Document(page_content=record[1], metadata=record[0]))
return docs
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
# Avoiding crashes caused by performing delete operations on empty lists in certain scenarios
# Scenario 1: extract a document fails, resulting in a table not being created.
# Then clicking the retry button triggers a delete operation on an empty list.
@@ -168,7 +168,7 @@ class OpenGauss(BaseVector):
with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
def delete_by_metadata_field(self, key: str, value: str) -> None:
def delete_by_metadata_field(self, key: str, value: str):
with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
@@ -222,7 +222,7 @@ class OpenGauss(BaseVector):
return docs
def delete(self) -> None:
def delete(self):
with self._get_cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")

View File

@@ -33,7 +33,7 @@ class OpenSearchConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
def validate_config(cls, values: dict):
if not values.get("host"):
raise ValueError("config OPENSEARCH_HOST is required")
if not values.get("port"):
@@ -128,7 +128,7 @@ class OpenSearchVector(BaseVector):
if ids:
self.delete_by_ids(ids)
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
index_name = self._collection_name.lower()
if not self._client.indices.exists(index=index_name):
logger.warning("Index %s does not exist", index_name)
@@ -159,7 +159,7 @@ class OpenSearchVector(BaseVector):
else:
logger.exception("Error deleting document: %s", error)
def delete(self) -> None:
def delete(self):
self._client.indices.delete(index=self._collection_name.lower())
def text_exists(self, id: str) -> bool:

View File

@@ -33,7 +33,7 @@ class OracleVectorConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
def validate_config(cls, values: dict):
if not values["user"]:
raise ValueError("config ORACLE_USER is required")
if not values["password"]:
@@ -206,7 +206,7 @@ class OracleVector(BaseVector):
conn.close()
return docs
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
if not ids:
return
with self._get_connection() as conn:
@@ -216,7 +216,7 @@ class OracleVector(BaseVector):
conn.commit()
conn.close()
def delete_by_metadata_field(self, key: str, value: str) -> None:
def delete_by_metadata_field(self, key: str, value: str):
with self._get_connection() as conn:
with conn.cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE JSON_VALUE(meta, '$." + key + "') = :1", (value,))
@@ -336,7 +336,7 @@ class OracleVector(BaseVector):
else:
return [Document(page_content="", metadata={})]
def delete(self) -> None:
def delete(self):
with self._get_connection() as conn:
with conn.cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name} cascade constraints")

View File

@@ -33,7 +33,7 @@ class PgvectoRSConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
def validate_config(cls, values: dict):
if not values["host"]:
raise ValueError("config PGVECTO_RS_HOST is required")
if not values["port"]:
@@ -150,7 +150,7 @@ class PGVectoRS(BaseVector):
session.execute(select_statement, {"ids": ids})
session.commit()
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
with Session(self._client) as session:
select_statement = sql_text(
f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = ANY (:doc_ids); "
@@ -164,7 +164,7 @@ class PGVectoRS(BaseVector):
session.execute(select_statement, {"ids": ids})
session.commit()
def delete(self) -> None:
def delete(self):
with Session(self._client) as session:
session.execute(sql_text(f"DROP TABLE IF EXISTS {self._collection_name}"))
session.commit()

View File

@@ -34,7 +34,7 @@ class PGVectorConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
def validate_config(cls, values: dict):
if not values["host"]:
raise ValueError("config PGVECTOR_HOST is required")
if not values["port"]:
@@ -146,7 +146,7 @@ class PGVector(BaseVector):
docs.append(Document(page_content=record[1], metadata=record[0]))
return docs
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
# Avoiding crashes caused by performing delete operations on empty lists in certain scenarios
# Scenario 1: extract a document fails, resulting in a table not being created.
# Then clicking the retry button triggers a delete operation on an empty list.
@@ -162,7 +162,7 @@ class PGVector(BaseVector):
except Exception as e:
raise e
def delete_by_metadata_field(self, key: str, value: str) -> None:
def delete_by_metadata_field(self, key: str, value: str):
with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
@@ -242,7 +242,7 @@ class PGVector(BaseVector):
return docs
def delete(self) -> None:
def delete(self):
with self._get_cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")

View File

@@ -28,7 +28,7 @@ class VastbaseVectorConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
def validate_config(cls, values: dict):
if not values["host"]:
raise ValueError("config VASTBASE_HOST is required")
if not values["port"]:
@@ -133,7 +133,7 @@ class VastbaseVector(BaseVector):
docs.append(Document(page_content=record[1], metadata=record[0]))
return docs
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
# Avoiding crashes caused by performing delete operations on empty lists in certain scenarios
# Scenario 1: extract a document fails, resulting in a table not being created.
# Then clicking the retry button triggers a delete operation on an empty list.
@@ -142,7 +142,7 @@ class VastbaseVector(BaseVector):
with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
def delete_by_metadata_field(self, key: str, value: str) -> None:
def delete_by_metadata_field(self, key: str, value: str):
with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
@@ -199,7 +199,7 @@ class VastbaseVector(BaseVector):
return docs
def delete(self) -> None:
def delete(self):
with self._get_cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")

View File

@@ -81,7 +81,7 @@ class QdrantVector(BaseVector):
def get_type(self) -> str:
return VectorType.QDRANT
def to_index_struct(self) -> dict:
def to_index_struct(self):
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
@@ -292,7 +292,7 @@ class QdrantVector(BaseVector):
else:
raise e
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse

View File

@@ -35,7 +35,7 @@ class RelytConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
def validate_config(cls, values: dict):
if not values["host"]:
raise ValueError("config RELYT_HOST is required")
if not values["port"]:
@@ -64,7 +64,7 @@ class RelytVector(BaseVector):
def get_type(self) -> str:
return VectorType.RELYT
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) -> None:
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self.create_collection(len(embeddings[0]))
self.embedding_dimension = len(embeddings[0])
self.add_texts(texts, embeddings)
@@ -196,7 +196,7 @@ class RelytVector(BaseVector):
if ids:
self.delete_by_uuids(ids)
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
with Session(self.client) as session:
ids_str = ",".join(f"'{doc_id}'" for doc_id in ids)
select_statement = sql_text(
@@ -207,7 +207,7 @@ class RelytVector(BaseVector):
ids = [item[0] for item in result]
self.delete_by_uuids(ids)
def delete(self) -> None:
def delete(self):
with Session(self.client) as session:
session.execute(sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}";"""))
session.commit()

View File

@@ -30,7 +30,7 @@ class TableStoreConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
def validate_config(cls, values: dict):
if not values["access_key_id"]:
raise ValueError("config ACCESS_KEY_ID is required")
if not values["access_key_secret"]:
@@ -112,7 +112,7 @@ class TableStoreVector(BaseVector):
return return_row is not None
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
if not ids:
return
for id in ids:
@@ -121,7 +121,7 @@ class TableStoreVector(BaseVector):
def get_ids_by_metadata_field(self, key: str, value: str):
return self._search_by_metadata(key, value)
def delete_by_metadata_field(self, key: str, value: str) -> None:
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
self.delete_by_ids(ids)
@@ -143,7 +143,7 @@ class TableStoreVector(BaseVector):
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._search_by_full_text(query, filtered_list, top_k, score_threshold)
def delete(self) -> None:
def delete(self):
self._delete_table_if_exist()
def _create_collection(self, dimension: int):
@@ -158,7 +158,7 @@ class TableStoreVector(BaseVector):
self._create_search_index_if_not_exist(dimension)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _create_table_if_not_exist(self) -> None:
def _create_table_if_not_exist(self):
table_list = self._tablestore_client.list_table()
if self._table_name in table_list:
logger.info("Tablestore system table[%s] already exists", self._table_name)
@@ -171,7 +171,7 @@ class TableStoreVector(BaseVector):
self._tablestore_client.create_table(table_meta, table_options, reserved_throughput)
logger.info("Tablestore create table[%s] successfully.", self._table_name)
def _create_search_index_if_not_exist(self, dimension: int) -> None:
def _create_search_index_if_not_exist(self, dimension: int):
search_index_list = self._tablestore_client.list_search_index(table_name=self._table_name)
assert isinstance(search_index_list, Iterable)
if self._index_name in [t[1] for t in search_index_list]:
@@ -225,11 +225,11 @@ class TableStoreVector(BaseVector):
self._tablestore_client.delete_table(self._table_name)
logger.info("Tablestore delete system table[%s] successfully.", self._index_name)
def _delete_search_index(self) -> None:
def _delete_search_index(self):
self._tablestore_client.delete_search_index(self._table_name, self._index_name)
logger.info("Tablestore delete index[%s] successfully.", self._index_name)
def _write_row(self, primary_key: str, attributes: dict[str, Any]) -> None:
def _write_row(self, primary_key: str, attributes: dict[str, Any]):
pk = [("id", primary_key)]
tags = []
@@ -248,7 +248,7 @@ class TableStoreVector(BaseVector):
row = tablestore.Row(pk, attribute_columns)
self._tablestore_client.put_row(self._table_name, row)
def _delete_row(self, id: str) -> None:
def _delete_row(self, id: str):
primary_key = [("id", id)]
row = tablestore.Row(primary_key)
self._tablestore_client.delete_row(self._table_name, row, None)

View File

@@ -82,7 +82,7 @@ class TencentVector(BaseVector):
def get_type(self) -> str:
return VectorType.TENCENT
def to_index_struct(self) -> dict:
def to_index_struct(self):
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def _has_collection(self) -> bool:
@@ -92,7 +92,7 @@ class TencentVector(BaseVector):
)
)
def _create_collection(self, dimension: int) -> None:
def _create_collection(self, dimension: int):
self._dimension = dimension
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
@@ -205,7 +205,7 @@ class TencentVector(BaseVector):
return True
return False
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
if not ids:
return
@@ -222,7 +222,7 @@ class TencentVector(BaseVector):
database_name=self._client_config.database, collection_name=self.collection_name, document_ids=batch_ids
)
def delete_by_metadata_field(self, key: str, value: str) -> None:
def delete_by_metadata_field(self, key: str, value: str):
self._client.delete(
database_name=self._client_config.database,
collection_name=self.collection_name,
@@ -299,7 +299,7 @@ class TencentVector(BaseVector):
docs.append(doc)
return docs
def delete(self) -> None:
def delete(self):
if self._has_collection():
self._client.drop_collection(
database_name=self._client_config.database, collection_name=self.collection_name

View File

@@ -90,7 +90,7 @@ class TidbOnQdrantVector(BaseVector):
def get_type(self) -> str:
return VectorType.TIDB_ON_QDRANT
def to_index_struct(self) -> dict:
def to_index_struct(self):
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
@@ -284,7 +284,7 @@ class TidbOnQdrantVector(BaseVector):
else:
raise e
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse

View File

@@ -31,7 +31,7 @@ class TiDBVectorConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
def validate_config(cls, values: dict):
if not values["host"]:
raise ValueError("config TIDB_VECTOR_HOST is required")
if not values["port"]:
@@ -144,7 +144,7 @@ class TiDBVector(BaseVector):
result = self.get_ids_by_metadata_field("doc_id", id)
return bool(result)
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
with Session(self._engine) as session:
ids_str = ",".join(f"'{doc_id}'" for doc_id in ids)
select_statement = sql_text(
@@ -179,7 +179,7 @@ class TiDBVector(BaseVector):
else:
return None
def delete_by_metadata_field(self, key: str, value: str) -> None:
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self._delete_by_ids(ids)
@@ -237,7 +237,7 @@ class TiDBVector(BaseVector):
# tidb doesn't support bm25 search
return []
def delete(self) -> None:
def delete(self):
with Session(self._engine) as session:
session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))
session.commit()

View File

@@ -20,7 +20,7 @@ class UpstashVectorConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
def validate_config(cls, values: dict):
if not values["url"]:
raise ValueError("Upstash URL is required")
if not values["token"]:
@@ -60,7 +60,7 @@ class UpstashVector(BaseVector):
response = self.get_ids_by_metadata_field("doc_id", id)
return len(response) > 0
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
item_ids = []
for doc_id in ids:
ids = self.get_ids_by_metadata_field("doc_id", doc_id)
@@ -68,7 +68,7 @@ class UpstashVector(BaseVector):
item_ids += ids
self._delete_by_ids(ids=item_ids)
def _delete_by_ids(self, ids: list[str]) -> None:
def _delete_by_ids(self, ids: list[str]):
if ids:
self.index.delete(ids=ids)
@@ -81,7 +81,7 @@ class UpstashVector(BaseVector):
)
return [result.id for result in query_result]
def delete_by_metadata_field(self, key: str, value: str) -> None:
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self._delete_by_ids(ids)
@@ -117,7 +117,7 @@ class UpstashVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return []
def delete(self) -> None:
def delete(self):
self.index.reset()
def get_type(self) -> str:

View File

@@ -27,14 +27,14 @@ class BaseVector(ABC):
raise NotImplementedError
@abstractmethod
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
raise NotImplementedError
def get_ids_by_metadata_field(self, key: str, value: str):
raise NotImplementedError
@abstractmethod
def delete_by_metadata_field(self, key: str, value: str) -> None:
def delete_by_metadata_field(self, key: str, value: str):
raise NotImplementedError
@abstractmethod
@@ -46,7 +46,7 @@ class BaseVector(ABC):
raise NotImplementedError
@abstractmethod
def delete(self) -> None:
def delete(self):
raise NotImplementedError
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:

View File

@@ -26,7 +26,7 @@ class AbstractVectorFactory(ABC):
raise NotImplementedError
@staticmethod
def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> dict:
def gen_index_struct_dict(vector_type: VectorType, collection_name: str):
index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}}
return index_struct_dict
@@ -207,10 +207,10 @@ class Vector:
def text_exists(self, id: str) -> bool:
return self._vector_processor.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
self._vector_processor.delete_by_ids(ids)
def delete_by_metadata_field(self, key: str, value: str) -> None:
def delete_by_metadata_field(self, key: str, value: str):
self._vector_processor.delete_by_metadata_field(key, value)
def search_by_vector(self, query: str, **kwargs: Any) -> list[Document]:
@@ -220,7 +220,7 @@ class Vector:
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self._vector_processor.search_by_full_text(query, **kwargs)
def delete(self) -> None:
def delete(self):
self._vector_processor.delete()
# delete collection redis cache
if self._vector_processor.collection_name:

View File

@@ -144,7 +144,7 @@ class VikingDBVector(BaseVector):
return True
return False
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
self._client.get_collection(self._collection_name).delete_data(ids)
def get_ids_by_metadata_field(self, key: str, value: str):
@@ -168,7 +168,7 @@ class VikingDBVector(BaseVector):
ids.append(result.id)
return ids
def delete_by_metadata_field(self, key: str, value: str) -> None:
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
self.delete_by_ids(ids)
@@ -202,7 +202,7 @@ class VikingDBVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return []
def delete(self) -> None:
def delete(self):
if self._has_index():
self._client.drop_index(self._collection_name, self._index_name)
if self._has_collection():

View File

@@ -24,7 +24,7 @@ class WeaviateConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
def validate_config(cls, values: dict):
if not values["endpoint"]:
raise ValueError("config WEAVIATE_ENDPOINT is required")
return values
@@ -75,7 +75,7 @@ class WeaviateVector(BaseVector):
dataset_id = dataset.id
return Dataset.gen_collection_name_by_id(dataset_id)
def to_index_struct(self) -> dict:
def to_index_struct(self):
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
@@ -164,7 +164,7 @@ class WeaviateVector(BaseVector):
return True
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
# check whether the index already exists
schema = self._default_schema(self._collection_name)
if self._client.schema.contains(schema):
@@ -256,7 +256,7 @@ class WeaviateVector(BaseVector):
docs.append(Document(page_content=text, vector=additional["vector"], metadata=res))
return docs
def _default_schema(self, index_name: str) -> dict:
def _default_schema(self, index_name: str):
return {
"class": index_name,
"properties": [
@@ -267,7 +267,7 @@ class WeaviateVector(BaseVector):
],
}
def _json_serializable(self, value: Any) -> Any:
def _json_serializable(self, value: Any):
if isinstance(value, datetime.datetime):
return value.isoformat()
return value

View File

@@ -32,11 +32,11 @@ class DatasetDocumentStore:
}
@property
def dataset_id(self) -> Any:
def dataset_id(self):
return self._dataset.id
@property
def user_id(self) -> Any:
def user_id(self):
return self._user_id
@property
@@ -59,7 +59,7 @@ class DatasetDocumentStore:
return output
def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False) -> None:
def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False):
max_position = (
db.session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == self._document_id)
@@ -195,7 +195,7 @@ class DatasetDocumentStore:
},
)
def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
def delete_document(self, doc_id: str, raise_error: bool = True):
document_segment = self.get_document_segment(doc_id)
if document_segment is None:
@@ -207,7 +207,7 @@ class DatasetDocumentStore:
db.session.delete(document_segment)
db.session.commit()
def set_document_hash(self, doc_id: str, doc_hash: str) -> None:
def set_document_hash(self, doc_id: str, doc_hash: str):
"""Set the hash for a given doc_id."""
document_segment = self.get_document_segment(doc_id)

View File

@@ -20,7 +20,7 @@ logger = logging.getLogger(__name__)
class CacheEmbedding(Embeddings):
def __init__(self, model_instance: ModelInstance, user: Optional[str] = None) -> None:
def __init__(self, model_instance: ModelInstance, user: Optional[str] = None):
self._model_instance = model_instance
self._user = user

View File

@@ -18,7 +18,7 @@ class NotionInfo(BaseModel):
tenant_id: str
model_config = ConfigDict(arbitrary_types_allowed=True)
def __init__(self, **data) -> None:
def __init__(self, **data):
super().__init__(**data)
@@ -49,5 +49,5 @@ class ExtractSetting(BaseModel):
document_model: Optional[str] = None
model_config = ConfigDict(arbitrary_types_allowed=True)
def __init__(self, **data) -> None:
def __init__(self, **data):
super().__init__(**data)

View File

@@ -122,7 +122,7 @@ class FirecrawlApp:
return response
return response
def _handle_error(self, response, action) -> None:
def _handle_error(self, response, action):
error_message = response.json().get("error", "Unknown error occurred")
raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") # type: ignore[return]

View File

@@ -29,7 +29,7 @@ def detect_file_encodings(file_path: str, timeout: int = 5, sample_size: int = 1
"""
import chardet
def read_and_detect(file_path: str) -> list[dict]:
def read_and_detect(file_path: str):
with open(file_path, "rb") as f:
# Read only a sample of the file for encoding detection
# This prevents timeout on large files while still providing accurate encoding detection

View File

@@ -9,7 +9,7 @@ class WaterCrawlProvider:
def __init__(self, api_key, base_url: str | None = None):
self.client = WaterCrawlAPIClient(api_key, base_url)
def crawl_url(self, url, options: Optional[dict | Any] = None) -> dict:
def crawl_url(self, url, options: Optional[dict | Any] = None):
options = options or {}
spider_options = {
"max_depth": 1,
@@ -41,7 +41,7 @@ class WaterCrawlProvider:
return {"status": "active", "job_id": result.get("uuid")}
def get_crawl_status(self, crawl_request_id) -> dict:
def get_crawl_status(self, crawl_request_id):
response = self.client.get_crawl_request(crawl_request_id)
data = []
if response["status"] in ["new", "running"]:
@@ -82,11 +82,11 @@ class WaterCrawlProvider:
return None
def scrape_url(self, url: str) -> dict:
def scrape_url(self, url: str):
response = self.client.scrape_url(url=url, sync=True, prefetched=True)
return self._structure_data(response)
def _structure_data(self, result_object: dict) -> dict:
def _structure_data(self, result_object: dict):
if isinstance(result_object.get("result", {}), str):
raise ValueError("Invalid result object. Expected a dictionary.")

View File

@@ -56,7 +56,7 @@ class WordExtractor(BaseExtractor):
elif not os.path.isfile(self.file_path):
raise ValueError(f"File path {self.file_path} is not a valid file or url")
def __del__(self) -> None:
def __del__(self):
if hasattr(self, "temp_file"):
self.temp_file.close()

View File

@@ -6,7 +6,7 @@ from core.rag.rerank.rerank_base import BaseRerankRunner
class RerankModelRunner(BaseRerankRunner):
def __init__(self, rerank_model_instance: ModelInstance) -> None:
def __init__(self, rerank_model_instance: ModelInstance):
self.rerank_model_instance = rerank_model_instance
def run(

View File

@@ -14,7 +14,7 @@ from core.rag.rerank.rerank_base import BaseRerankRunner
class WeightRerankRunner(BaseRerankRunner):
def __init__(self, tenant_id: str, weights: Weights) -> None:
def __init__(self, tenant_id: str, weights: Weights):
self.tenant_id = tenant_id
self.weights = weights

View File

@@ -507,7 +507,7 @@ class DatasetRetrieval:
def _on_retrieval_end(
self, documents: list[Document], message_id: Optional[str] = None, timer: Optional[dict] = None
) -> None:
):
"""Handle retrieval end."""
dify_documents = [document for document in documents if document.provider == "dify"]
for document in dify_documents:
@@ -560,7 +560,7 @@ class DatasetRetrieval:
)
)
def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str) -> None:
def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str):
"""
Handle query.
"""

View File

@@ -47,7 +47,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
length_function: Callable[[list[str]], list[int]] = lambda x: [len(x) for x in x],
keep_separator: bool = False,
add_start_index: bool = False,
) -> None:
):
"""Create a new TextSplitter.
Args:
@@ -201,7 +201,7 @@ class TokenTextSplitter(TextSplitter):
allowed_special: Union[Literal["all"], Set[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
) -> None:
):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
try:
@@ -251,7 +251,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
separators: Optional[list[str]] = None,
keep_separator: bool = True,
**kwargs: Any,
) -> None:
):
"""Create a new TextSplitter."""
super().__init__(keep_separator=keep_separator, **kwargs)
self._separators = separators or ["\n\n", "\n", " ", ""]