mirror of
https://github.com/langgenius/dify.git
synced 2025-12-19 17:27:16 -05:00
Feature integrate alibabacloud mysql vector (#25994)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
@@ -343,6 +343,15 @@ OCEANBASE_VECTOR_DATABASE=test
|
|||||||
OCEANBASE_MEMORY_LIMIT=6G
|
OCEANBASE_MEMORY_LIMIT=6G
|
||||||
OCEANBASE_ENABLE_HYBRID_SEARCH=false
|
OCEANBASE_ENABLE_HYBRID_SEARCH=false
|
||||||
|
|
||||||
|
# AlibabaCloud MySQL Vector configuration
|
||||||
|
ALIBABACLOUD_MYSQL_HOST=127.0.0.1
|
||||||
|
ALIBABACLOUD_MYSQL_PORT=3306
|
||||||
|
ALIBABACLOUD_MYSQL_USER=root
|
||||||
|
ALIBABACLOUD_MYSQL_PASSWORD=root
|
||||||
|
ALIBABACLOUD_MYSQL_DATABASE=dify
|
||||||
|
ALIBABACLOUD_MYSQL_MAX_CONNECTION=5
|
||||||
|
ALIBABACLOUD_MYSQL_HNSW_M=6
|
||||||
|
|
||||||
# openGauss configuration
|
# openGauss configuration
|
||||||
OPENGAUSS_HOST=127.0.0.1
|
OPENGAUSS_HOST=127.0.0.1
|
||||||
OPENGAUSS_PORT=6600
|
OPENGAUSS_PORT=6600
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from .storage.opendal_storage_config import OpenDALStorageConfig
|
|||||||
from .storage.supabase_storage_config import SupabaseStorageConfig
|
from .storage.supabase_storage_config import SupabaseStorageConfig
|
||||||
from .storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
from .storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
||||||
from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
|
from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
|
||||||
|
from .vdb.alibabacloud_mysql_config import AlibabaCloudMySQLConfig
|
||||||
from .vdb.analyticdb_config import AnalyticdbConfig
|
from .vdb.analyticdb_config import AnalyticdbConfig
|
||||||
from .vdb.baidu_vector_config import BaiduVectorDBConfig
|
from .vdb.baidu_vector_config import BaiduVectorDBConfig
|
||||||
from .vdb.chroma_config import ChromaConfig
|
from .vdb.chroma_config import ChromaConfig
|
||||||
@@ -330,6 +331,7 @@ class MiddlewareConfig(
|
|||||||
ClickzettaConfig,
|
ClickzettaConfig,
|
||||||
HuaweiCloudConfig,
|
HuaweiCloudConfig,
|
||||||
MilvusConfig,
|
MilvusConfig,
|
||||||
|
AlibabaCloudMySQLConfig,
|
||||||
MyScaleConfig,
|
MyScaleConfig,
|
||||||
OpenSearchConfig,
|
OpenSearchConfig,
|
||||||
OracleConfig,
|
OracleConfig,
|
||||||
|
|||||||
54
api/configs/middleware/vdb/alibabacloud_mysql_config.py
Normal file
54
api/configs/middleware/vdb/alibabacloud_mysql_config.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
from pydantic import Field, PositiveInt
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class AlibabaCloudMySQLConfig(BaseSettings):
|
||||||
|
"""
|
||||||
|
Configuration settings for AlibabaCloud MySQL vector database
|
||||||
|
"""
|
||||||
|
|
||||||
|
ALIBABACLOUD_MYSQL_HOST: str = Field(
|
||||||
|
description="Hostname or IP address of the AlibabaCloud MySQL server (e.g., 'localhost' or 'mysql.aliyun.com')",
|
||||||
|
default="localhost",
|
||||||
|
)
|
||||||
|
|
||||||
|
ALIBABACLOUD_MYSQL_PORT: PositiveInt = Field(
|
||||||
|
description="Port number on which the AlibabaCloud MySQL server is listening (default is 3306)",
|
||||||
|
default=3306,
|
||||||
|
)
|
||||||
|
|
||||||
|
ALIBABACLOUD_MYSQL_USER: str = Field(
|
||||||
|
description="Username for authenticating with AlibabaCloud MySQL (default is 'root')",
|
||||||
|
default="root",
|
||||||
|
)
|
||||||
|
|
||||||
|
ALIBABACLOUD_MYSQL_PASSWORD: str = Field(
|
||||||
|
description="Password for authenticating with AlibabaCloud MySQL (default is an empty string)",
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
|
||||||
|
ALIBABACLOUD_MYSQL_DATABASE: str = Field(
|
||||||
|
description="Name of the AlibabaCloud MySQL database to connect to (default is 'dify')",
|
||||||
|
default="dify",
|
||||||
|
)
|
||||||
|
|
||||||
|
ALIBABACLOUD_MYSQL_MAX_CONNECTION: PositiveInt = Field(
|
||||||
|
description="Maximum number of connections in the connection pool",
|
||||||
|
default=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
ALIBABACLOUD_MYSQL_CHARSET: str = Field(
|
||||||
|
description="Character set for AlibabaCloud MySQL connection (default is 'utf8mb4')",
|
||||||
|
default="utf8mb4",
|
||||||
|
)
|
||||||
|
|
||||||
|
ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION: str = Field(
|
||||||
|
description="Distance function used for vector similarity search in AlibabaCloud MySQL "
|
||||||
|
"(e.g., 'cosine', 'euclidean')",
|
||||||
|
default="cosine",
|
||||||
|
)
|
||||||
|
|
||||||
|
ALIBABACLOUD_MYSQL_HNSW_M: PositiveInt = Field(
|
||||||
|
description="Maximum number of connections per layer for HNSW vector index (default is 6, range: 3-200)",
|
||||||
|
default=6,
|
||||||
|
)
|
||||||
@@ -810,6 +810,7 @@ class DatasetRetrievalSettingApi(Resource):
|
|||||||
| VectorType.MATRIXONE
|
| VectorType.MATRIXONE
|
||||||
| VectorType.CLICKZETTA
|
| VectorType.CLICKZETTA
|
||||||
| VectorType.BAIDU
|
| VectorType.BAIDU
|
||||||
|
| VectorType.ALIBABACLOUD_MYSQL
|
||||||
):
|
):
|
||||||
return {
|
return {
|
||||||
"retrieval_method": [
|
"retrieval_method": [
|
||||||
@@ -864,6 +865,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
|||||||
| VectorType.MATRIXONE
|
| VectorType.MATRIXONE
|
||||||
| VectorType.CLICKZETTA
|
| VectorType.CLICKZETTA
|
||||||
| VectorType.BAIDU
|
| VectorType.BAIDU
|
||||||
|
| VectorType.ALIBABACLOUD_MYSQL
|
||||||
):
|
):
|
||||||
return {
|
return {
|
||||||
"retrieval_method": [
|
"retrieval_method": [
|
||||||
|
|||||||
@@ -0,0 +1,388 @@
|
|||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Any, Literal, cast
|
||||||
|
|
||||||
|
import mysql.connector
|
||||||
|
from mysql.connector import Error as MySQLError
|
||||||
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||||
|
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||||
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
|
from core.rag.embedding.embedding_base import Embeddings
|
||||||
|
from core.rag.models.document import Document
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
from models.dataset import Dataset
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AlibabaCloudMySQLVectorConfig(BaseModel):
|
||||||
|
host: str
|
||||||
|
port: int
|
||||||
|
user: str
|
||||||
|
password: str
|
||||||
|
database: str
|
||||||
|
max_connection: int
|
||||||
|
charset: str = "utf8mb4"
|
||||||
|
distance_function: Literal["cosine", "euclidean"] = "cosine"
|
||||||
|
hnsw_m: int = 6
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_config(cls, values: dict):
|
||||||
|
if not values.get("host"):
|
||||||
|
raise ValueError("config ALIBABACLOUD_MYSQL_HOST is required")
|
||||||
|
if not values.get("port"):
|
||||||
|
raise ValueError("config ALIBABACLOUD_MYSQL_PORT is required")
|
||||||
|
if not values.get("user"):
|
||||||
|
raise ValueError("config ALIBABACLOUD_MYSQL_USER is required")
|
||||||
|
if values.get("password") is None:
|
||||||
|
raise ValueError("config ALIBABACLOUD_MYSQL_PASSWORD is required")
|
||||||
|
if not values.get("database"):
|
||||||
|
raise ValueError("config ALIBABACLOUD_MYSQL_DATABASE is required")
|
||||||
|
if not values.get("max_connection"):
|
||||||
|
raise ValueError("config ALIBABACLOUD_MYSQL_MAX_CONNECTION is required")
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
SQL_CREATE_TABLE = """
|
||||||
|
CREATE TABLE IF NOT EXISTS {table_name} (
|
||||||
|
id VARCHAR(36) PRIMARY KEY,
|
||||||
|
text LONGTEXT NOT NULL,
|
||||||
|
meta JSON NOT NULL,
|
||||||
|
embedding VECTOR({dimension}) NOT NULL,
|
||||||
|
VECTOR INDEX (embedding) M={hnsw_m} DISTANCE={distance_function}
|
||||||
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||||
|
"""
|
||||||
|
|
||||||
|
SQL_CREATE_META_INDEX = """
|
||||||
|
CREATE INDEX idx_{index_hash}_meta ON {table_name}
|
||||||
|
((CAST(JSON_UNQUOTE(JSON_EXTRACT(meta, '$.document_id')) AS CHAR(36))));
|
||||||
|
"""
|
||||||
|
|
||||||
|
SQL_CREATE_FULLTEXT_INDEX = """
|
||||||
|
CREATE FULLTEXT INDEX idx_{index_hash}_text ON {table_name} (text) WITH PARSER ngram;
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class AlibabaCloudMySQLVector(BaseVector):
|
||||||
|
def __init__(self, collection_name: str, config: AlibabaCloudMySQLVectorConfig):
|
||||||
|
super().__init__(collection_name)
|
||||||
|
self.pool = self._create_connection_pool(config)
|
||||||
|
self.table_name = collection_name.lower()
|
||||||
|
self.index_hash = hashlib.md5(self.table_name.encode()).hexdigest()[:8]
|
||||||
|
self.distance_function = config.distance_function.lower()
|
||||||
|
self.hnsw_m = config.hnsw_m
|
||||||
|
self._check_vector_support()
|
||||||
|
|
||||||
|
def get_type(self) -> str:
|
||||||
|
return VectorType.ALIBABACLOUD_MYSQL
|
||||||
|
|
||||||
|
def _create_connection_pool(self, config: AlibabaCloudMySQLVectorConfig):
|
||||||
|
# Create connection pool using mysql-connector-python pooling
|
||||||
|
pool_config: dict[str, Any] = {
|
||||||
|
"host": config.host,
|
||||||
|
"port": config.port,
|
||||||
|
"user": config.user,
|
||||||
|
"password": config.password,
|
||||||
|
"database": config.database,
|
||||||
|
"charset": config.charset,
|
||||||
|
"autocommit": True,
|
||||||
|
"pool_name": f"pool_{self.collection_name}",
|
||||||
|
"pool_size": config.max_connection,
|
||||||
|
"pool_reset_session": True,
|
||||||
|
}
|
||||||
|
return mysql.connector.pooling.MySQLConnectionPool(**pool_config)
|
||||||
|
|
||||||
|
def _check_vector_support(self):
|
||||||
|
"""Check if the MySQL server supports vector operations."""
|
||||||
|
try:
|
||||||
|
with self._get_cursor() as cur:
|
||||||
|
# Check MySQL version and vector support
|
||||||
|
cur.execute("SELECT VERSION()")
|
||||||
|
version = cur.fetchone()["VERSION()"]
|
||||||
|
logger.debug("Connected to MySQL version: %s", version)
|
||||||
|
# Try to execute a simple vector function to verify support
|
||||||
|
cur.execute("SELECT VEC_FromText('[1,2,3]') IS NOT NULL as vector_support")
|
||||||
|
result = cur.fetchone()
|
||||||
|
if not result or not result.get("vector_support"):
|
||||||
|
raise ValueError(
|
||||||
|
"RDS MySQL Vector functions are not available."
|
||||||
|
" Please ensure you're using RDS MySQL 8.0.36+ with Vector support."
|
||||||
|
)
|
||||||
|
|
||||||
|
except MySQLError as e:
|
||||||
|
if "FUNCTION" in str(e) and "VEC_FromText" in str(e):
|
||||||
|
raise ValueError(
|
||||||
|
"RDS MySQL Vector functions are not available."
|
||||||
|
" Please ensure you're using RDS MySQL 8.0.36+ with Vector support."
|
||||||
|
) from e
|
||||||
|
raise e
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _get_cursor(self):
|
||||||
|
conn = self.pool.get_connection()
|
||||||
|
cur = conn.cursor(dictionary=True)
|
||||||
|
try:
|
||||||
|
yield cur
|
||||||
|
finally:
|
||||||
|
cur.close()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
|
dimension = len(embeddings[0])
|
||||||
|
self._create_collection(dimension)
|
||||||
|
return self.add_texts(texts, embeddings)
|
||||||
|
|
||||||
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
|
values = []
|
||||||
|
pks = []
|
||||||
|
for i, doc in enumerate(documents):
|
||||||
|
if doc.metadata is not None:
|
||||||
|
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
|
||||||
|
pks.append(doc_id)
|
||||||
|
# Convert embedding list to Aliyun MySQL vector format
|
||||||
|
vector_str = "[" + ",".join(map(str, embeddings[i])) + "]"
|
||||||
|
values.append(
|
||||||
|
(
|
||||||
|
doc_id,
|
||||||
|
doc.page_content,
|
||||||
|
json.dumps(doc.metadata),
|
||||||
|
vector_str,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._get_cursor() as cur:
|
||||||
|
insert_sql = (
|
||||||
|
f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (%s, %s, %s, VEC_FromText(%s))"
|
||||||
|
)
|
||||||
|
cur.executemany(insert_sql, values)
|
||||||
|
return pks
|
||||||
|
|
||||||
|
def text_exists(self, id: str) -> bool:
|
||||||
|
with self._get_cursor() as cur:
|
||||||
|
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,))
|
||||||
|
return cur.fetchone() is not None
|
||||||
|
|
||||||
|
def get_by_ids(self, ids: list[str]) -> list[Document]:
|
||||||
|
if not ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
with self._get_cursor() as cur:
|
||||||
|
placeholders = ",".join(["%s"] * len(ids))
|
||||||
|
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN ({placeholders})", ids)
|
||||||
|
docs = []
|
||||||
|
for record in cur:
|
||||||
|
metadata = record["meta"]
|
||||||
|
if isinstance(metadata, str):
|
||||||
|
metadata = json.loads(metadata)
|
||||||
|
docs.append(Document(page_content=record["text"], metadata=metadata))
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def delete_by_ids(self, ids: list[str]):
|
||||||
|
# Avoiding crashes caused by performing delete operations on empty lists
|
||||||
|
if not ids:
|
||||||
|
return
|
||||||
|
|
||||||
|
with self._get_cursor() as cur:
|
||||||
|
try:
|
||||||
|
placeholders = ",".join(["%s"] * len(ids))
|
||||||
|
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN ({placeholders})", ids)
|
||||||
|
except MySQLError as e:
|
||||||
|
if e.errno == 1146: # Table doesn't exist
|
||||||
|
logger.warning("Table %s not found, skipping delete operation.", self.table_name)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def delete_by_metadata_field(self, key: str, value: str):
|
||||||
|
with self._get_cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
f"DELETE FROM {self.table_name} WHERE JSON_UNQUOTE(JSON_EXTRACT(meta, %s)) = %s", (f"$.{key}", value)
|
||||||
|
)
|
||||||
|
|
||||||
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||||
|
"""
|
||||||
|
Search the nearest neighbors to a vector using RDS MySQL vector distance functions.
|
||||||
|
|
||||||
|
:param query_vector: The input vector to search for similar items.
|
||||||
|
:return: List of Documents that are nearest to the query vector.
|
||||||
|
"""
|
||||||
|
top_k = kwargs.get("top_k", 4)
|
||||||
|
if not isinstance(top_k, int) or top_k <= 0:
|
||||||
|
raise ValueError("top_k must be a positive integer")
|
||||||
|
|
||||||
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
|
where_clause = ""
|
||||||
|
params = []
|
||||||
|
|
||||||
|
if document_ids_filter:
|
||||||
|
placeholders = ",".join(["%s"] * len(document_ids_filter))
|
||||||
|
where_clause = f" WHERE JSON_UNQUOTE(JSON_EXTRACT(meta, '$.document_id')) IN ({placeholders}) "
|
||||||
|
params.extend(document_ids_filter)
|
||||||
|
|
||||||
|
# Convert query vector to RDS MySQL vector format
|
||||||
|
query_vector_str = "[" + ",".join(map(str, query_vector)) + "]"
|
||||||
|
|
||||||
|
# Use RSD MySQL's native vector distance functions
|
||||||
|
with self._get_cursor() as cur:
|
||||||
|
# Choose distance function based on configuration
|
||||||
|
distance_func = "VEC_DISTANCE_COSINE" if self.distance_function == "cosine" else "VEC_DISTANCE_EUCLIDEAN"
|
||||||
|
|
||||||
|
# Note: RDS MySQL optimizer will use vector index when ORDER BY + LIMIT are present
|
||||||
|
# Use column alias in ORDER BY to avoid calculating distance twice
|
||||||
|
sql = f"""
|
||||||
|
SELECT meta, text,
|
||||||
|
{distance_func}(embedding, VEC_FromText(%s)) AS distance
|
||||||
|
FROM {self.table_name}
|
||||||
|
{where_clause}
|
||||||
|
ORDER BY distance
|
||||||
|
LIMIT %s
|
||||||
|
"""
|
||||||
|
query_params = [query_vector_str] + params + [top_k]
|
||||||
|
|
||||||
|
cur.execute(sql, query_params)
|
||||||
|
|
||||||
|
docs = []
|
||||||
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
|
|
||||||
|
for record in cur:
|
||||||
|
try:
|
||||||
|
distance = float(record["distance"])
|
||||||
|
# Convert distance to similarity score
|
||||||
|
if self.distance_function == "cosine":
|
||||||
|
# For cosine distance: similarity = 1 - distance
|
||||||
|
similarity = 1.0 - distance
|
||||||
|
else:
|
||||||
|
# For euclidean distance: use inverse relationship
|
||||||
|
# similarity = 1 / (1 + distance)
|
||||||
|
similarity = 1.0 / (1.0 + distance)
|
||||||
|
|
||||||
|
metadata = record["meta"]
|
||||||
|
if isinstance(metadata, str):
|
||||||
|
metadata = json.loads(metadata)
|
||||||
|
metadata["score"] = similarity
|
||||||
|
metadata["distance"] = distance
|
||||||
|
|
||||||
|
if similarity >= score_threshold:
|
||||||
|
docs.append(Document(page_content=record["text"], metadata=metadata))
|
||||||
|
except (ValueError, json.JSONDecodeError) as e:
|
||||||
|
logger.warning("Error processing search result: %s", e)
|
||||||
|
continue
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
|
top_k = kwargs.get("top_k", 5)
|
||||||
|
if not isinstance(top_k, int) or top_k <= 0:
|
||||||
|
raise ValueError("top_k must be a positive integer")
|
||||||
|
|
||||||
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
|
where_clause = ""
|
||||||
|
params = []
|
||||||
|
|
||||||
|
if document_ids_filter:
|
||||||
|
placeholders = ",".join(["%s"] * len(document_ids_filter))
|
||||||
|
where_clause = f" AND JSON_UNQUOTE(JSON_EXTRACT(meta, '$.document_id')) IN ({placeholders}) "
|
||||||
|
params.extend(document_ids_filter)
|
||||||
|
|
||||||
|
with self._get_cursor() as cur:
|
||||||
|
# Build query parameters: query (twice for MATCH clauses), document_ids_filter (if any), top_k
|
||||||
|
query_params = [query, query] + params + [top_k]
|
||||||
|
cur.execute(
|
||||||
|
f"""SELECT meta, text,
|
||||||
|
MATCH(text) AGAINST(%s IN NATURAL LANGUAGE MODE) AS score
|
||||||
|
FROM {self.table_name}
|
||||||
|
WHERE MATCH(text) AGAINST(%s IN NATURAL LANGUAGE MODE)
|
||||||
|
{where_clause}
|
||||||
|
ORDER BY score DESC
|
||||||
|
LIMIT %s""",
|
||||||
|
query_params,
|
||||||
|
)
|
||||||
|
docs = []
|
||||||
|
for record in cur:
|
||||||
|
metadata = record["meta"]
|
||||||
|
if isinstance(metadata, str):
|
||||||
|
metadata = json.loads(metadata)
|
||||||
|
metadata["score"] = float(record["score"])
|
||||||
|
docs.append(Document(page_content=record["text"], metadata=metadata))
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def delete(self):
|
||||||
|
with self._get_cursor() as cur:
|
||||||
|
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||||
|
|
||||||
|
def _create_collection(self, dimension: int):
|
||||||
|
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||||
|
lock_name = f"{collection_exist_cache_key}_lock"
|
||||||
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
|
if redis_client.get(collection_exist_cache_key):
|
||||||
|
return
|
||||||
|
|
||||||
|
with self._get_cursor() as cur:
|
||||||
|
# Create table with vector column and vector index
|
||||||
|
cur.execute(
|
||||||
|
SQL_CREATE_TABLE.format(
|
||||||
|
table_name=self.table_name,
|
||||||
|
dimension=dimension,
|
||||||
|
distance_function=self.distance_function,
|
||||||
|
hnsw_m=self.hnsw_m,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Create metadata index (check if exists first)
|
||||||
|
try:
|
||||||
|
cur.execute(SQL_CREATE_META_INDEX.format(table_name=self.table_name, index_hash=self.index_hash))
|
||||||
|
except MySQLError as e:
|
||||||
|
if e.errno != 1061: # Duplicate key name
|
||||||
|
logger.warning("Could not create meta index: %s", e)
|
||||||
|
|
||||||
|
# Create full-text index for text search
|
||||||
|
try:
|
||||||
|
cur.execute(
|
||||||
|
SQL_CREATE_FULLTEXT_INDEX.format(table_name=self.table_name, index_hash=self.index_hash)
|
||||||
|
)
|
||||||
|
except MySQLError as e:
|
||||||
|
if e.errno != 1061: # Duplicate key name
|
||||||
|
logger.warning("Could not create fulltext index: %s", e)
|
||||||
|
|
||||||
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||||
|
|
||||||
|
|
||||||
|
class AlibabaCloudMySQLVectorFactory(AbstractVectorFactory):
|
||||||
|
def _validate_distance_function(self, distance_function: str) -> Literal["cosine", "euclidean"]:
|
||||||
|
"""Validate and return the distance function as a proper Literal type."""
|
||||||
|
if distance_function not in ["cosine", "euclidean"]:
|
||||||
|
raise ValueError(f"Invalid distance function: {distance_function}. Must be 'cosine' or 'euclidean'")
|
||||||
|
return cast(Literal["cosine", "euclidean"], distance_function)
|
||||||
|
|
||||||
|
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> AlibabaCloudMySQLVector:
|
||||||
|
if dataset.index_struct_dict:
|
||||||
|
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||||
|
collection_name = class_prefix
|
||||||
|
else:
|
||||||
|
dataset_id = dataset.id
|
||||||
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
|
dataset.index_struct = json.dumps(
|
||||||
|
self.gen_index_struct_dict(VectorType.ALIBABACLOUD_MYSQL, collection_name)
|
||||||
|
)
|
||||||
|
return AlibabaCloudMySQLVector(
|
||||||
|
collection_name=collection_name,
|
||||||
|
config=AlibabaCloudMySQLVectorConfig(
|
||||||
|
host=dify_config.ALIBABACLOUD_MYSQL_HOST or "localhost",
|
||||||
|
port=dify_config.ALIBABACLOUD_MYSQL_PORT,
|
||||||
|
user=dify_config.ALIBABACLOUD_MYSQL_USER or "root",
|
||||||
|
password=dify_config.ALIBABACLOUD_MYSQL_PASSWORD or "",
|
||||||
|
database=dify_config.ALIBABACLOUD_MYSQL_DATABASE or "dify",
|
||||||
|
max_connection=dify_config.ALIBABACLOUD_MYSQL_MAX_CONNECTION,
|
||||||
|
charset=dify_config.ALIBABACLOUD_MYSQL_CHARSET or "utf8mb4",
|
||||||
|
distance_function=self._validate_distance_function(
|
||||||
|
dify_config.ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION or "cosine"
|
||||||
|
),
|
||||||
|
hnsw_m=dify_config.ALIBABACLOUD_MYSQL_HNSW_M or 6,
|
||||||
|
),
|
||||||
|
)
|
||||||
@@ -71,6 +71,12 @@ class Vector:
|
|||||||
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory
|
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory
|
||||||
|
|
||||||
return MilvusVectorFactory
|
return MilvusVectorFactory
|
||||||
|
case VectorType.ALIBABACLOUD_MYSQL:
|
||||||
|
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import (
|
||||||
|
AlibabaCloudMySQLVectorFactory,
|
||||||
|
)
|
||||||
|
|
||||||
|
return AlibabaCloudMySQLVectorFactory
|
||||||
case VectorType.MYSCALE:
|
case VectorType.MYSCALE:
|
||||||
from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory
|
from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from enum import StrEnum
|
|||||||
|
|
||||||
|
|
||||||
class VectorType(StrEnum):
|
class VectorType(StrEnum):
|
||||||
|
ALIBABACLOUD_MYSQL = "alibabacloud_mysql"
|
||||||
ANALYTICDB = "analyticdb"
|
ANALYTICDB = "analyticdb"
|
||||||
CHROMA = "chroma"
|
CHROMA = "chroma"
|
||||||
MILVUS = "milvus"
|
MILVUS = "milvus"
|
||||||
|
|||||||
@@ -217,4 +217,5 @@ vdb = [
|
|||||||
"weaviate-client~=3.24.0",
|
"weaviate-client~=3.24.0",
|
||||||
"xinference-client~=1.2.2",
|
"xinference-client~=1.2.2",
|
||||||
"mo-vector~=0.1.13",
|
"mo-vector~=0.1.13",
|
||||||
|
"mysql-connector-python>=9.3.0",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -0,0 +1,722 @@
|
|||||||
|
import json
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import (
|
||||||
|
AlibabaCloudMySQLVector,
|
||||||
|
AlibabaCloudMySQLVectorConfig,
|
||||||
|
)
|
||||||
|
from core.rag.models.document import Document
|
||||||
|
|
||||||
|
try:
|
||||||
|
from mysql.connector import Error as MySQLError
|
||||||
|
except ImportError:
|
||||||
|
# Fallback for testing environments where mysql-connector-python might not be installed
|
||||||
|
class MySQLError(Exception):
|
||||||
|
def __init__(self, errno, msg):
|
||||||
|
self.errno = errno
|
||||||
|
self.msg = msg
|
||||||
|
super().__init__(msg)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.config = AlibabaCloudMySQLVectorConfig(
|
||||||
|
host="localhost",
|
||||||
|
port=3306,
|
||||||
|
user="test_user",
|
||||||
|
password="test_password",
|
||||||
|
database="test_db",
|
||||||
|
max_connection=5,
|
||||||
|
charset="utf8mb4",
|
||||||
|
)
|
||||||
|
self.collection_name = "test_collection"
|
||||||
|
|
||||||
|
# Sample documents for testing
|
||||||
|
self.sample_documents = [
|
||||||
|
Document(
|
||||||
|
page_content="This is a test document about AI.",
|
||||||
|
metadata={"doc_id": "doc1", "document_id": "dataset1", "source": "test"},
|
||||||
|
),
|
||||||
|
Document(
|
||||||
|
page_content="Another document about machine learning.",
|
||||||
|
metadata={"doc_id": "doc2", "document_id": "dataset1", "source": "test"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Sample embeddings
|
||||||
|
self.sample_embeddings = [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||||
|
)
|
||||||
|
def test_init(self, mock_pool_class):
|
||||||
|
"""Test AlibabaCloudMySQLVector initialization."""
|
||||||
|
# Mock the connection pool
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
# Mock connection and cursor for vector support check
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_cursor = MagicMock()
|
||||||
|
mock_pool.get_connection.return_value = mock_conn
|
||||||
|
mock_conn.cursor.return_value = mock_cursor
|
||||||
|
mock_cursor.fetchone.side_effect = [
|
||||||
|
{"VERSION()": "8.0.36"}, # Version check
|
||||||
|
{"vector_support": True}, # Vector support check
|
||||||
|
]
|
||||||
|
|
||||||
|
alibabacloud_mysql_vector = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||||
|
|
||||||
|
assert alibabacloud_mysql_vector.collection_name == self.collection_name
|
||||||
|
assert alibabacloud_mysql_vector.table_name == self.collection_name.lower()
|
||||||
|
assert alibabacloud_mysql_vector.get_type() == "alibabacloud_mysql"
|
||||||
|
assert alibabacloud_mysql_vector.distance_function == "cosine"
|
||||||
|
assert alibabacloud_mysql_vector.pool is not None
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||||
|
)
|
||||||
|
@patch("core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
|
||||||
|
def test_create_collection(self, mock_redis, mock_pool_class):
|
||||||
|
"""Test collection creation."""
|
||||||
|
# Mock Redis operations
|
||||||
|
mock_redis.lock.return_value.__enter__ = MagicMock()
|
||||||
|
mock_redis.lock.return_value.__exit__ = MagicMock()
|
||||||
|
mock_redis.get.return_value = None
|
||||||
|
mock_redis.set.return_value = None
|
||||||
|
|
||||||
|
# Mock the connection pool
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
# Mock connection and cursor
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_cursor = MagicMock()
|
||||||
|
mock_pool.get_connection.return_value = mock_conn
|
||||||
|
mock_conn.cursor.return_value = mock_cursor
|
||||||
|
mock_cursor.fetchone.side_effect = [
|
||||||
|
{"VERSION()": "8.0.36"}, # Version check
|
||||||
|
{"vector_support": True}, # Vector support check
|
||||||
|
]
|
||||||
|
|
||||||
|
alibabacloud_mysql_vector = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||||
|
alibabacloud_mysql_vector._create_collection(768)
|
||||||
|
|
||||||
|
# Verify SQL execution calls - should include table creation and index creation
|
||||||
|
assert mock_cursor.execute.called
|
||||||
|
assert mock_cursor.execute.call_count >= 3 # CREATE TABLE + 2 indexes
|
||||||
|
mock_redis.set.assert_called_once()
|
||||||
|
|
||||||
|
def test_config_validation(self):
|
||||||
|
"""Test configuration validation."""
|
||||||
|
# Test missing required fields
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
AlibabaCloudMySQLVectorConfig(
|
||||||
|
host="", # Empty host should raise error
|
||||||
|
port=3306,
|
||||||
|
user="test",
|
||||||
|
password="test",
|
||||||
|
database="test",
|
||||||
|
max_connection=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||||
|
)
|
||||||
|
def test_vector_support_check_success(self, mock_pool_class):
|
||||||
|
"""Test successful vector support check."""
|
||||||
|
# Mock the connection pool
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_cursor = MagicMock()
|
||||||
|
mock_pool.get_connection.return_value = mock_conn
|
||||||
|
mock_conn.cursor.return_value = mock_cursor
|
||||||
|
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||||
|
|
||||||
|
# Should not raise an exception
|
||||||
|
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||||
|
assert vector_store is not None
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||||
|
)
|
||||||
|
def test_vector_support_check_failure(self, mock_pool_class):
|
||||||
|
"""Test vector support check failure."""
|
||||||
|
# Mock the connection pool
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_cursor = MagicMock()
|
||||||
|
mock_pool.get_connection.return_value = mock_conn
|
||||||
|
mock_conn.cursor.return_value = mock_cursor
|
||||||
|
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.35"}, {"vector_support": False}]
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as context:
|
||||||
|
AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||||
|
|
||||||
|
assert "RDS MySQL Vector functions are not available" in str(context.value)
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||||
|
)
|
||||||
|
def test_vector_support_check_function_error(self, mock_pool_class):
|
||||||
|
"""Test vector support check with function not found error."""
|
||||||
|
# Mock the connection pool
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_cursor = MagicMock()
|
||||||
|
mock_pool.get_connection.return_value = mock_conn
|
||||||
|
mock_conn.cursor.return_value = mock_cursor
|
||||||
|
mock_cursor.fetchone.return_value = {"VERSION()": "8.0.36"}
|
||||||
|
mock_cursor.execute.side_effect = [None, MySQLError(errno=1305, msg="FUNCTION VEC_FromText does not exist")]
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as context:
|
||||||
|
AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||||
|
|
||||||
|
assert "RDS MySQL Vector functions are not available" in str(context.value)
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||||
|
)
|
||||||
|
@patch("core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
|
||||||
|
def test_create_documents(self, mock_redis, mock_pool_class):
|
||||||
|
"""Test creating documents with embeddings."""
|
||||||
|
# Setup mocks
|
||||||
|
self._setup_mocks(mock_redis, mock_pool_class)
|
||||||
|
|
||||||
|
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||||
|
result = vector_store.create(self.sample_documents, self.sample_embeddings)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert "doc1" in result
|
||||||
|
assert "doc2" in result
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||||
|
)
|
||||||
|
def test_add_texts(self, mock_pool_class):
|
||||||
|
"""Test adding texts to the vector store."""
|
||||||
|
# Mock the connection pool
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_cursor = MagicMock()
|
||||||
|
mock_pool.get_connection.return_value = mock_conn
|
||||||
|
mock_conn.cursor.return_value = mock_cursor
|
||||||
|
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||||
|
|
||||||
|
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||||
|
result = vector_store.add_texts(self.sample_documents, self.sample_embeddings)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
mock_cursor.executemany.assert_called_once()
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||||
|
)
|
||||||
|
def test_text_exists(self, mock_pool_class):
|
||||||
|
"""Test checking if text exists."""
|
||||||
|
# Mock the connection pool
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_cursor = MagicMock()
|
||||||
|
mock_pool.get_connection.return_value = mock_conn
|
||||||
|
mock_conn.cursor.return_value = mock_cursor
|
||||||
|
mock_cursor.fetchone.side_effect = [
|
||||||
|
{"VERSION()": "8.0.36"},
|
||||||
|
{"vector_support": True},
|
||||||
|
{"id": "doc1"}, # Text exists
|
||||||
|
]
|
||||||
|
|
||||||
|
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||||
|
exists = vector_store.text_exists("doc1")
|
||||||
|
|
||||||
|
assert exists
|
||||||
|
# Check that the correct SQL was executed (last call after init)
|
||||||
|
execute_calls = mock_cursor.execute.call_args_list
|
||||||
|
last_call = execute_calls[-1]
|
||||||
|
assert "SELECT id FROM" in last_call[0][0]
|
||||||
|
assert last_call[0][1] == ("doc1",)
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||||
|
)
|
||||||
|
def test_text_not_exists(self, mock_pool_class):
|
||||||
|
"""Test checking if text does not exist."""
|
||||||
|
# Mock the connection pool
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_cursor = MagicMock()
|
||||||
|
mock_pool.get_connection.return_value = mock_conn
|
||||||
|
mock_conn.cursor.return_value = mock_cursor
|
||||||
|
mock_cursor.fetchone.side_effect = [
|
||||||
|
{"VERSION()": "8.0.36"},
|
||||||
|
{"vector_support": True},
|
||||||
|
None, # Text does not exist
|
||||||
|
]
|
||||||
|
|
||||||
|
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||||
|
exists = vector_store.text_exists("nonexistent")
|
||||||
|
|
||||||
|
assert not exists
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||||
|
)
|
||||||
|
def test_get_by_ids(self, mock_pool_class):
|
||||||
|
"""Test getting documents by IDs."""
|
||||||
|
# Mock the connection pool
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_cursor = MagicMock()
|
||||||
|
mock_pool.get_connection.return_value = mock_conn
|
||||||
|
mock_conn.cursor.return_value = mock_cursor
|
||||||
|
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||||
|
mock_cursor.__iter__ = lambda self: iter(
|
||||||
|
[
|
||||||
|
{"meta": json.dumps({"doc_id": "doc1", "source": "test"}), "text": "Test document 1"},
|
||||||
|
{"meta": json.dumps({"doc_id": "doc2", "source": "test"}), "text": "Test document 2"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||||
|
docs = vector_store.get_by_ids(["doc1", "doc2"])
|
||||||
|
|
||||||
|
assert len(docs) == 2
|
||||||
|
assert docs[0].page_content == "Test document 1"
|
||||||
|
assert docs[1].page_content == "Test document 2"
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||||
|
)
|
||||||
|
def test_get_by_ids_empty_list(self, mock_pool_class):
|
||||||
|
"""Test getting documents with empty ID list."""
|
||||||
|
# Mock the connection pool
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_cursor = MagicMock()
|
||||||
|
mock_pool.get_connection.return_value = mock_conn
|
||||||
|
mock_conn.cursor.return_value = mock_cursor
|
||||||
|
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||||
|
|
||||||
|
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||||
|
docs = vector_store.get_by_ids([])
|
||||||
|
|
||||||
|
assert len(docs) == 0
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||||
|
)
|
||||||
|
def test_delete_by_ids(self, mock_pool_class):
|
||||||
|
"""Test deleting documents by IDs."""
|
||||||
|
# Mock the connection pool
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_cursor = MagicMock()
|
||||||
|
mock_pool.get_connection.return_value = mock_conn
|
||||||
|
mock_conn.cursor.return_value = mock_cursor
|
||||||
|
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||||
|
|
||||||
|
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||||
|
vector_store.delete_by_ids(["doc1", "doc2"])
|
||||||
|
|
||||||
|
# Check that delete SQL was executed
|
||||||
|
execute_calls = mock_cursor.execute.call_args_list
|
||||||
|
delete_calls = [call for call in execute_calls if "DELETE" in str(call)]
|
||||||
|
assert len(delete_calls) == 1
|
||||||
|
delete_call = delete_calls[0]
|
||||||
|
assert "DELETE FROM" in delete_call[0][0]
|
||||||
|
assert delete_call[0][1] == ["doc1", "doc2"]
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||||
|
)
|
||||||
|
def test_delete_by_ids_empty_list(self, mock_pool_class):
|
||||||
|
"""Test deleting with empty ID list."""
|
||||||
|
# Mock the connection pool
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_cursor = MagicMock()
|
||||||
|
mock_pool.get_connection.return_value = mock_conn
|
||||||
|
mock_conn.cursor.return_value = mock_cursor
|
||||||
|
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||||
|
|
||||||
|
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||||
|
vector_store.delete_by_ids([]) # Should not raise an exception
|
||||||
|
|
||||||
|
# Verify no delete SQL was executed
|
||||||
|
execute_calls = mock_cursor.execute.call_args_list
|
||||||
|
delete_calls = [call for call in execute_calls if "DELETE" in str(call)]
|
||||||
|
assert len(delete_calls) == 0
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||||
|
)
|
||||||
|
def test_delete_by_ids_table_not_exists(self, mock_pool_class):
|
||||||
|
"""Test deleting when table doesn't exist."""
|
||||||
|
# Mock the connection pool
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_cursor = MagicMock()
|
||||||
|
mock_pool.get_connection.return_value = mock_conn
|
||||||
|
mock_conn.cursor.return_value = mock_cursor
|
||||||
|
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||||
|
|
||||||
|
# Simulate table doesn't exist error on delete
|
||||||
|
|
||||||
|
def execute_side_effect(*args, **kwargs):
|
||||||
|
if "DELETE" in args[0]:
|
||||||
|
raise MySQLError(errno=1146, msg="Table doesn't exist")
|
||||||
|
|
||||||
|
mock_cursor.execute.side_effect = execute_side_effect
|
||||||
|
|
||||||
|
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||||
|
# Should not raise an exception
|
||||||
|
vector_store.delete_by_ids(["doc1"])
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||||
|
)
|
||||||
|
def test_delete_by_metadata_field(self, mock_pool_class):
|
||||||
|
"""Test deleting documents by metadata field."""
|
||||||
|
# Mock the connection pool
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_cursor = MagicMock()
|
||||||
|
mock_pool.get_connection.return_value = mock_conn
|
||||||
|
mock_conn.cursor.return_value = mock_cursor
|
||||||
|
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||||
|
|
||||||
|
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||||
|
vector_store.delete_by_metadata_field("document_id", "dataset1")
|
||||||
|
|
||||||
|
# Check that the correct SQL was executed
|
||||||
|
execute_calls = mock_cursor.execute.call_args_list
|
||||||
|
delete_calls = [call for call in execute_calls if "DELETE" in str(call)]
|
||||||
|
assert len(delete_calls) == 1
|
||||||
|
delete_call = delete_calls[0]
|
||||||
|
assert "JSON_UNQUOTE(JSON_EXTRACT(meta" in delete_call[0][0]
|
||||||
|
assert delete_call[0][1] == ("$.document_id", "dataset1")
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||||
|
)
|
||||||
|
def test_search_by_vector_cosine(self, mock_pool_class):
|
||||||
|
"""Test vector search with cosine distance."""
|
||||||
|
# Mock the connection pool
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_cursor = MagicMock()
|
||||||
|
mock_pool.get_connection.return_value = mock_conn
|
||||||
|
mock_conn.cursor.return_value = mock_cursor
|
||||||
|
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||||
|
mock_cursor.__iter__ = lambda self: iter(
|
||||||
|
[{"meta": json.dumps({"doc_id": "doc1", "source": "test"}), "text": "Test document 1", "distance": 0.1}]
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||||
|
query_vector = [0.1, 0.2, 0.3, 0.4]
|
||||||
|
docs = vector_store.search_by_vector(query_vector, top_k=5)
|
||||||
|
|
||||||
|
assert len(docs) == 1
|
||||||
|
assert docs[0].page_content == "Test document 1"
|
||||||
|
assert abs(docs[0].metadata["score"] - 0.9) < 0.1 # 1 - 0.1 = 0.9
|
||||||
|
assert docs[0].metadata["distance"] == 0.1
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||||
|
)
|
||||||
|
def test_search_by_vector_euclidean(self, mock_pool_class):
|
||||||
|
"""Test vector search with euclidean distance."""
|
||||||
|
config = AlibabaCloudMySQLVectorConfig(
|
||||||
|
host="localhost",
|
||||||
|
port=3306,
|
||||||
|
user="test_user",
|
||||||
|
password="test_password",
|
||||||
|
database="test_db",
|
||||||
|
max_connection=5,
|
||||||
|
distance_function="euclidean",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the connection pool
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_cursor = MagicMock()
|
||||||
|
mock_pool.get_connection.return_value = mock_conn
|
||||||
|
mock_conn.cursor.return_value = mock_cursor
|
||||||
|
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||||
|
mock_cursor.__iter__ = lambda self: iter(
|
||||||
|
[{"meta": json.dumps({"doc_id": "doc1", "source": "test"}), "text": "Test document 1", "distance": 2.0}]
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_store = AlibabaCloudMySQLVector(self.collection_name, config)
|
||||||
|
query_vector = [0.1, 0.2, 0.3, 0.4]
|
||||||
|
docs = vector_store.search_by_vector(query_vector, top_k=5)
|
||||||
|
|
||||||
|
assert len(docs) == 1
|
||||||
|
assert abs(docs[0].metadata["score"] - 1.0 / 3.0) < 0.01 # 1/(1+2) = 1/3
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||||
|
)
|
||||||
|
def test_search_by_vector_with_filter(self, mock_pool_class):
|
||||||
|
"""Test vector search with document ID filter."""
|
||||||
|
# Mock the connection pool
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_cursor = MagicMock()
|
||||||
|
mock_pool.get_connection.return_value = mock_conn
|
||||||
|
mock_conn.cursor.return_value = mock_cursor
|
||||||
|
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||||
|
mock_cursor.__iter__ = lambda self: iter([])
|
||||||
|
|
||||||
|
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||||
|
query_vector = [0.1, 0.2, 0.3, 0.4]
|
||||||
|
docs = vector_store.search_by_vector(query_vector, top_k=5, document_ids_filter=["dataset1"])
|
||||||
|
|
||||||
|
# Verify the SQL contains the WHERE clause for filtering
|
||||||
|
execute_calls = mock_cursor.execute.call_args_list
|
||||||
|
search_calls = [call for call in execute_calls if "VEC_DISTANCE" in str(call)]
|
||||||
|
assert len(search_calls) > 0
|
||||||
|
search_call = search_calls[0]
|
||||||
|
assert "WHERE JSON_UNQUOTE" in search_call[0][0]
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||||
|
)
|
||||||
|
def test_search_by_vector_with_score_threshold(self, mock_pool_class):
|
||||||
|
"""Test vector search with score threshold."""
|
||||||
|
# Mock the connection pool
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_cursor = MagicMock()
|
||||||
|
mock_pool.get_connection.return_value = mock_conn
|
||||||
|
mock_conn.cursor.return_value = mock_cursor
|
||||||
|
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||||
|
mock_cursor.__iter__ = lambda self: iter(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"meta": json.dumps({"doc_id": "doc1", "source": "test"}),
|
||||||
|
"text": "High similarity document",
|
||||||
|
"distance": 0.1, # High similarity (score = 0.9)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"meta": json.dumps({"doc_id": "doc2", "source": "test"}),
|
||||||
|
"text": "Low similarity document",
|
||||||
|
"distance": 0.8, # Low similarity (score = 0.2)
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||||
|
query_vector = [0.1, 0.2, 0.3, 0.4]
|
||||||
|
docs = vector_store.search_by_vector(query_vector, top_k=5, score_threshold=0.5)
|
||||||
|
|
||||||
|
# Only the high similarity document should be returned
|
||||||
|
assert len(docs) == 1
|
||||||
|
assert docs[0].page_content == "High similarity document"
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||||
|
)
|
||||||
|
def test_search_by_vector_invalid_top_k(self, mock_pool_class):
|
||||||
|
"""Test vector search with invalid top_k."""
|
||||||
|
# Mock the connection pool
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_cursor = MagicMock()
|
||||||
|
mock_pool.get_connection.return_value = mock_conn
|
||||||
|
mock_conn.cursor.return_value = mock_cursor
|
||||||
|
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||||
|
|
||||||
|
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||||
|
query_vector = [0.1, 0.2, 0.3, 0.4]
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
vector_store.search_by_vector(query_vector, top_k=0)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
vector_store.search_by_vector(query_vector, top_k="invalid")
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||||
|
)
|
||||||
|
def test_search_by_full_text(self, mock_pool_class):
|
||||||
|
"""Test full-text search."""
|
||||||
|
# Mock the connection pool
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_cursor = MagicMock()
|
||||||
|
mock_pool.get_connection.return_value = mock_conn
|
||||||
|
mock_conn.cursor.return_value = mock_cursor
|
||||||
|
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||||
|
mock_cursor.__iter__ = lambda self: iter(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"meta": {"doc_id": "doc1", "source": "test"},
|
||||||
|
"text": "This document contains machine learning content",
|
||||||
|
"score": 1.5,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||||
|
docs = vector_store.search_by_full_text("machine learning", top_k=5)
|
||||||
|
|
||||||
|
assert len(docs) == 1
|
||||||
|
assert docs[0].page_content == "This document contains machine learning content"
|
||||||
|
assert docs[0].metadata["score"] == 1.5
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||||
|
)
|
||||||
|
def test_search_by_full_text_with_filter(self, mock_pool_class):
|
||||||
|
"""Test full-text search with document ID filter."""
|
||||||
|
# Mock the connection pool
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_cursor = MagicMock()
|
||||||
|
mock_pool.get_connection.return_value = mock_conn
|
||||||
|
mock_conn.cursor.return_value = mock_cursor
|
||||||
|
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||||
|
mock_cursor.__iter__ = lambda self: iter([])
|
||||||
|
|
||||||
|
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||||
|
docs = vector_store.search_by_full_text("machine learning", top_k=5, document_ids_filter=["dataset1"])
|
||||||
|
|
||||||
|
# Verify the SQL contains the AND clause for filtering
|
||||||
|
execute_calls = mock_cursor.execute.call_args_list
|
||||||
|
search_calls = [call for call in execute_calls if "MATCH" in str(call)]
|
||||||
|
assert len(search_calls) > 0
|
||||||
|
search_call = search_calls[0]
|
||||||
|
assert "AND JSON_UNQUOTE" in search_call[0][0]
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||||
|
)
|
||||||
|
def test_search_by_full_text_invalid_top_k(self, mock_pool_class):
|
||||||
|
"""Test full-text search with invalid top_k."""
|
||||||
|
# Mock the connection pool
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_cursor = MagicMock()
|
||||||
|
mock_pool.get_connection.return_value = mock_conn
|
||||||
|
mock_conn.cursor.return_value = mock_cursor
|
||||||
|
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||||
|
|
||||||
|
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
vector_store.search_by_full_text("test", top_k=0)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
vector_store.search_by_full_text("test", top_k="invalid")
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||||
|
)
|
||||||
|
def test_delete_collection(self, mock_pool_class):
|
||||||
|
"""Test deleting the entire collection."""
|
||||||
|
# Mock the connection pool
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_cursor = MagicMock()
|
||||||
|
mock_pool.get_connection.return_value = mock_conn
|
||||||
|
mock_conn.cursor.return_value = mock_cursor
|
||||||
|
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||||
|
|
||||||
|
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||||
|
vector_store.delete()
|
||||||
|
|
||||||
|
# Check that DROP TABLE SQL was executed
|
||||||
|
execute_calls = mock_cursor.execute.call_args_list
|
||||||
|
drop_calls = [call for call in execute_calls if "DROP TABLE" in str(call)]
|
||||||
|
assert len(drop_calls) == 1
|
||||||
|
drop_call = drop_calls[0]
|
||||||
|
assert f"DROP TABLE IF EXISTS {self.collection_name.lower()}" in drop_call[0][0]
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||||
|
)
|
||||||
|
def test_unsupported_distance_function(self, mock_pool_class):
|
||||||
|
"""Test that Pydantic validation rejects unsupported distance functions."""
|
||||||
|
# Test that creating config with unsupported distance function raises ValidationError
|
||||||
|
with pytest.raises(ValueError) as context:
|
||||||
|
AlibabaCloudMySQLVectorConfig(
|
||||||
|
host="localhost",
|
||||||
|
port=3306,
|
||||||
|
user="test_user",
|
||||||
|
password="test_password",
|
||||||
|
database="test_db",
|
||||||
|
max_connection=5,
|
||||||
|
distance_function="manhattan", # Unsupported - not in Literal["cosine", "euclidean"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# The error should be related to validation
|
||||||
|
assert "Input should be 'cosine' or 'euclidean'" in str(context.value) or "manhattan" in str(context.value)
|
||||||
|
|
||||||
|
def _setup_mocks(self, mock_redis, mock_pool_class):
|
||||||
|
"""Helper method to setup common mocks."""
|
||||||
|
# Mock Redis operations
|
||||||
|
mock_redis.lock.return_value.__enter__ = MagicMock()
|
||||||
|
mock_redis.lock.return_value.__exit__ = MagicMock()
|
||||||
|
mock_redis.get.return_value = None
|
||||||
|
mock_redis.set.return_value = None
|
||||||
|
|
||||||
|
# Mock the connection pool
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
# Mock connection and cursor
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_cursor = MagicMock()
|
||||||
|
mock_pool.get_connection.return_value = mock_conn
|
||||||
|
mock_conn.cursor.return_value = mock_cursor
|
||||||
|
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
21
api/uv.lock
generated
21
api/uv.lock
generated
@@ -1449,6 +1449,7 @@ vdb = [
|
|||||||
{ name = "couchbase" },
|
{ name = "couchbase" },
|
||||||
{ name = "elasticsearch" },
|
{ name = "elasticsearch" },
|
||||||
{ name = "mo-vector" },
|
{ name = "mo-vector" },
|
||||||
|
{ name = "mysql-connector-python" },
|
||||||
{ name = "opensearch-py" },
|
{ name = "opensearch-py" },
|
||||||
{ name = "oracledb" },
|
{ name = "oracledb" },
|
||||||
{ name = "pgvecto-rs", extra = ["sqlalchemy"] },
|
{ name = "pgvecto-rs", extra = ["sqlalchemy"] },
|
||||||
@@ -1637,6 +1638,7 @@ vdb = [
|
|||||||
{ name = "couchbase", specifier = "~=4.3.0" },
|
{ name = "couchbase", specifier = "~=4.3.0" },
|
||||||
{ name = "elasticsearch", specifier = "==8.14.0" },
|
{ name = "elasticsearch", specifier = "==8.14.0" },
|
||||||
{ name = "mo-vector", specifier = "~=0.1.13" },
|
{ name = "mo-vector", specifier = "~=0.1.13" },
|
||||||
|
{ name = "mysql-connector-python", specifier = ">=9.3.0" },
|
||||||
{ name = "opensearch-py", specifier = "==2.4.0" },
|
{ name = "opensearch-py", specifier = "==2.4.0" },
|
||||||
{ name = "oracledb", specifier = "==3.3.0" },
|
{ name = "oracledb", specifier = "==3.3.0" },
|
||||||
{ name = "pgvecto-rs", extras = ["sqlalchemy"], specifier = "~=0.2.1" },
|
{ name = "pgvecto-rs", extras = ["sqlalchemy"], specifier = "~=0.2.1" },
|
||||||
@@ -3437,6 +3439,25 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963, upload-time = "2025-04-22T14:54:22.983Z" },
|
{ url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963, upload-time = "2025-04-22T14:54:22.983Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "mysql-connector-python"
|
||||||
|
version = "9.4.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/02/77/2b45e6460d05b1f1b7a4c8eb79a50440b4417971973bb78c9ef6cad630a6/mysql_connector_python-9.4.0.tar.gz", hash = "sha256:d111360332ae78933daf3d48ff497b70739aa292ab0017791a33e826234e743b", size = 12185532, upload-time = "2025-07-22T08:02:05.788Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/fe/0c/4365a802129be9fa63885533c38be019f1c6b6f5bcf8844ac53902314028/mysql_connector_python-9.4.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:7df1a8ddd182dd8adc914f6dc902a986787bf9599705c29aca7b2ce84e79d361", size = 17501627, upload-time = "2025-07-22T07:57:45.416Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c0/bf/ca596c00d7a6eaaf8ef2f66c9b23cd312527f483073c43ffac7843049cb4/mysql_connector_python-9.4.0-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:3892f20472e13e63b1fb4983f454771dd29f211b09724e69a9750e299542f2f8", size = 18369494, upload-time = "2025-07-22T07:57:49.714Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/25/14/6510a11ed9f80d77f743dc207773092c4ab78d5efa454b39b48480315d85/mysql_connector_python-9.4.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:d3e87142103d71c4df647ece30f98e85e826652272ed1c74822b56f6acdc38e7", size = 33516187, upload-time = "2025-07-22T07:57:55.294Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/16/a8/4f99d80f1cf77733ce9a44b6adb7f0dd7079e7afa51ca4826515ef0c3e16/mysql_connector_python-9.4.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:b27fcd403436fe83bafb2fe7fcb785891e821e639275c4ad3b3bd1e25f533206", size = 33917818, upload-time = "2025-07-22T07:58:00.523Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/15/9c/127f974ca9d5ee25373cb5433da06bb1f36e05f2a6b7436da1fe9c6346b0/mysql_connector_python-9.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:fd6ff5afb9c324b0bbeae958c93156cce4168c743bf130faf224d52818d1f0ee", size = 16392378, upload-time = "2025-07-22T07:58:04.669Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/03/7c/a543fb17c2dfa6be8548dfdc5879a0c7924cd5d1c79056c48472bb8fe858/mysql_connector_python-9.4.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:4efa3898a24aba6a4bfdbf7c1f5023c78acca3150d72cc91199cca2ccd22f76f", size = 17503693, upload-time = "2025-07-22T07:58:08.96Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/cb/6e/c22fbee05f5cfd6ba76155b6d45f6261d8d4c1e36e23de04e7f25fbd01a4/mysql_connector_python-9.4.0-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:665c13e7402235162e5b7a2bfdee5895192121b64ea455c90a81edac6a48ede5", size = 18371987, upload-time = "2025-07-22T07:58:13.273Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b4/fd/f426f5f35a3d3180c7f84d1f96b4631be2574df94ca1156adab8618b236c/mysql_connector_python-9.4.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:815aa6cad0f351c1223ef345781a538f2e5e44ef405fdb3851eb322bd9c4ca2b", size = 33516214, upload-time = "2025-07-22T07:58:18.967Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/45/5a/1b053ae80b43cd3ccebc4bb99a98826969b3b0f8adebdcc2530750ad76ed/mysql_connector_python-9.4.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b3436a2c8c0ec7052932213e8d01882e6eb069dbab33402e685409084b133a1c", size = 33918565, upload-time = "2025-07-22T07:58:25.28Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/cb/69/36b989de675d98ba8ff7d45c96c30c699865c657046f2e32db14e78f13d9/mysql_connector_python-9.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:57b0c224676946b70548c56798d5023f65afa1ba5b8ac9f04a143d27976c7029", size = 16392563, upload-time = "2025-07-22T07:58:29.623Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/36/34/b6165e15fd45a8deb00932d8e7d823de7650270873b4044c4db6688e1d8f/mysql_connector_python-9.4.0-py2.py3-none-any.whl", hash = "sha256:56e679169c704dab279b176fab2a9ee32d2c632a866c0f7cd48a8a1e2cf802c4", size = 406574, upload-time = "2025-07-22T07:59:08.394Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "nest-asyncio"
|
name = "nest-asyncio"
|
||||||
version = "1.6.0"
|
version = "1.6.0"
|
||||||
|
|||||||
@@ -449,7 +449,7 @@ SUPABASE_URL=your-server-url
|
|||||||
# ------------------------------
|
# ------------------------------
|
||||||
|
|
||||||
# The type of vector store to use.
|
# The type of vector store to use.
|
||||||
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `clickzetta`.
|
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `clickzetta`, `alibabacloud_mysql`.
|
||||||
VECTOR_STORE=weaviate
|
VECTOR_STORE=weaviate
|
||||||
# Prefix used to create collection name in vector database
|
# Prefix used to create collection name in vector database
|
||||||
VECTOR_INDEX_NAME_PREFIX=Vector_index
|
VECTOR_INDEX_NAME_PREFIX=Vector_index
|
||||||
@@ -580,6 +580,15 @@ ORACLE_WALLET_LOCATION=/app/api/storage/wallet
|
|||||||
ORACLE_WALLET_PASSWORD=dify
|
ORACLE_WALLET_PASSWORD=dify
|
||||||
ORACLE_IS_AUTONOMOUS=false
|
ORACLE_IS_AUTONOMOUS=false
|
||||||
|
|
||||||
|
# AlibabaCloud MySQL configuration, only available when VECTOR_STORE is `alibabcloud_mysql`
|
||||||
|
ALIBABACLOUD_MYSQL_HOST=127.0.0.1
|
||||||
|
ALIBABACLOUD_MYSQL_PORT=3306
|
||||||
|
ALIBABACLOUD_MYSQL_USER=root
|
||||||
|
ALIBABACLOUD_MYSQL_PASSWORD=difyai123456
|
||||||
|
ALIBABACLOUD_MYSQL_DATABASE=dify
|
||||||
|
ALIBABACLOUD_MYSQL_MAX_CONNECTION=5
|
||||||
|
ALIBABACLOUD_MYSQL_HNSW_M=6
|
||||||
|
|
||||||
# relyt configurations, only available when VECTOR_STORE is `relyt`
|
# relyt configurations, only available when VECTOR_STORE is `relyt`
|
||||||
RELYT_HOST=db
|
RELYT_HOST=db
|
||||||
RELYT_PORT=5432
|
RELYT_PORT=5432
|
||||||
|
|||||||
@@ -244,6 +244,13 @@ x-shared-env: &shared-api-worker-env
|
|||||||
ORACLE_WALLET_LOCATION: ${ORACLE_WALLET_LOCATION:-/app/api/storage/wallet}
|
ORACLE_WALLET_LOCATION: ${ORACLE_WALLET_LOCATION:-/app/api/storage/wallet}
|
||||||
ORACLE_WALLET_PASSWORD: ${ORACLE_WALLET_PASSWORD:-dify}
|
ORACLE_WALLET_PASSWORD: ${ORACLE_WALLET_PASSWORD:-dify}
|
||||||
ORACLE_IS_AUTONOMOUS: ${ORACLE_IS_AUTONOMOUS:-false}
|
ORACLE_IS_AUTONOMOUS: ${ORACLE_IS_AUTONOMOUS:-false}
|
||||||
|
ALIBABACLOUD_MYSQL_HOST: ${ALIBABACLOUD_MYSQL_HOST:-127.0.0.1}
|
||||||
|
ALIBABACLOUD_MYSQL_PORT: ${ALIBABACLOUD_MYSQL_PORT:-3306}
|
||||||
|
ALIBABACLOUD_MYSQL_USER: ${ALIBABACLOUD_MYSQL_USER:-root}
|
||||||
|
ALIBABACLOUD_MYSQL_PASSWORD: ${ALIBABACLOUD_MYSQL_PASSWORD:-difyai123456}
|
||||||
|
ALIBABACLOUD_MYSQL_DATABASE: ${ALIBABACLOUD_MYSQL_DATABASE:-dify}
|
||||||
|
ALIBABACLOUD_MYSQL_MAX_CONNECTION: ${ALIBABACLOUD_MYSQL_MAX_CONNECTION:-5}
|
||||||
|
ALIBABACLOUD_MYSQL_HNSW_M: ${ALIBABACLOUD_MYSQL_HNSW_M:-6}
|
||||||
RELYT_HOST: ${RELYT_HOST:-db}
|
RELYT_HOST: ${RELYT_HOST:-db}
|
||||||
RELYT_PORT: ${RELYT_PORT:-5432}
|
RELYT_PORT: ${RELYT_PORT:-5432}
|
||||||
RELYT_USER: ${RELYT_USER:-postgres}
|
RELYT_USER: ${RELYT_USER:-postgres}
|
||||||
|
|||||||
Reference in New Issue
Block a user