mirror of
https://github.com/langgenius/dify.git
synced 2026-02-16 07:01:44 -05:00
fix: avoid unexpected error when create knowledge base with baidu vector database and wenxin embedding model (#10130)
This commit is contained in:
@@ -3,11 +3,13 @@ import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pymochow import MochowClient
|
||||
from pymochow.auth.bce_credentials import BceCredentials
|
||||
from pymochow.configuration import Configuration
|
||||
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, TableState
|
||||
from pymochow.exception import ServerError
|
||||
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState
|
||||
from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex
|
||||
from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row
|
||||
|
||||
@@ -116,6 +118,7 @@ class BaiduVector(BaseVector):
|
||||
self._db.table(self._collection_name).delete(filter=f"{key} = '{value}'")
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector]
|
||||
anns = AnnSearch(
|
||||
vector_field=self.field_vector,
|
||||
vector_floats=query_vector,
|
||||
@@ -149,7 +152,13 @@ class BaiduVector(BaseVector):
|
||||
return docs
|
||||
|
||||
def delete(self) -> None:
|
||||
self._db.drop_table(table_name=self._collection_name)
|
||||
try:
|
||||
self._db.drop_table(table_name=self._collection_name)
|
||||
except ServerError as e:
|
||||
if e.code == ServerErrCode.TABLE_NOT_EXIST:
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
|
||||
def _init_client(self, config) -> MochowClient:
|
||||
config = Configuration(credentials=BceCredentials(config.account, config.api_key), endpoint=config.endpoint)
|
||||
@@ -166,7 +175,14 @@ class BaiduVector(BaseVector):
|
||||
if exists:
|
||||
return self._client.database(self._client_config.database)
|
||||
else:
|
||||
return self._client.create_database(database_name=self._client_config.database)
|
||||
try:
|
||||
self._client.create_database(database_name=self._client_config.database)
|
||||
except ServerError as e:
|
||||
if e.code == ServerErrCode.DB_ALREADY_EXIST:
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
return
|
||||
|
||||
def _table_existed(self) -> bool:
|
||||
tables = self._db.list_table()
|
||||
@@ -175,7 +191,7 @@ class BaiduVector(BaseVector):
|
||||
def _create_table(self, dimension: int) -> None:
|
||||
# Try to grab distributed lock and create table
|
||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
with redis_client.lock(lock_name, timeout=60):
|
||||
table_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
if redis_client.get(table_exist_cache_key):
|
||||
return
|
||||
@@ -238,15 +254,14 @@ class BaiduVector(BaseVector):
|
||||
description="Table for Dify",
|
||||
)
|
||||
|
||||
# Wait for table created
|
||||
while True:
|
||||
time.sleep(1)
|
||||
table = self._db.describe_table(self._collection_name)
|
||||
if table.state == TableState.NORMAL:
|
||||
break
|
||||
redis_client.set(table_exist_cache_key, 1, ex=3600)
|
||||
|
||||
# Wait for table created
|
||||
while True:
|
||||
time.sleep(1)
|
||||
table = self._db.describe_table(self._collection_name)
|
||||
if table.state == TableState.NORMAL:
|
||||
break
|
||||
|
||||
|
||||
class BaiduVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaiduVector:
|
||||
|
||||
Reference in New Issue
Block a user