feat:support baidu vector db (#9185)

This commit is contained in:
Shili Cao 2024-10-12 23:24:17 +08:00 committed by GitHub
parent 793205afc5
commit 2ec6ffe478
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 582 additions and 13 deletions

View File

@ -208,6 +208,15 @@ OPENSEARCH_USER=admin
OPENSEARCH_PASSWORD=admin
OPENSEARCH_SECURE=true
# Baidu configuration
BAIDU_VECTOR_DB_ENDPOINT=http://127.0.0.1:5287
BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS=30000
BAIDU_VECTOR_DB_ACCOUNT=root
BAIDU_VECTOR_DB_API_KEY=dify
BAIDU_VECTOR_DB_DATABASE=dify
BAIDU_VECTOR_DB_SHARD=1
BAIDU_VECTOR_DB_REPLICAS=3
# Upload configuration
UPLOAD_FILE_SIZE_LIMIT=15
UPLOAD_FILE_BATCH_LIMIT=5

View File

@ -347,6 +347,14 @@ def migrate_knowledge_vector_database():
index_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.BAIDU:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.BAIDU,
"vector_store": {"class_prefix": collection_name},
}
dataset.index_struct = json.dumps(index_struct_dict)
else:
raise ValueError(f"Vector store {vector_type} is not supported.")

View File

@ -0,0 +1,45 @@
from typing import Optional
from pydantic import Field, NonNegativeInt, PositiveInt
from pydantic_settings import BaseSettings
class BaiduVectorDBConfig(BaseSettings):
"""
Configuration settings for Baidu Vector Database
"""
BAIDU_VECTOR_DB_ENDPOINT: Optional[str] = Field(
description="URL of the Baidu Vector Database service (e.g., 'http://vdb.bj.baidubce.com')",
default=None,
)
BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: PositiveInt = Field(
description="Timeout in milliseconds for Baidu Vector Database operations (default is 30000 milliseconds)",
default=30000,
)
BAIDU_VECTOR_DB_ACCOUNT: Optional[str] = Field(
description="Account for authenticating with the Baidu Vector Database",
default=None,
)
BAIDU_VECTOR_DB_API_KEY: Optional[str] = Field(
description="API key for authenticating with the Baidu Vector Database service",
default=None,
)
BAIDU_VECTOR_DB_DATABASE: Optional[str] = Field(
description="Name of the specific Baidu Vector Database to connect to",
default=None,
)
BAIDU_VECTOR_DB_SHARD: PositiveInt = Field(
description="Number of shards for the Baidu Vector Database (default is 1)",
default=1,
)
BAIDU_VECTOR_DB_REPLICAS: NonNegativeInt = Field(
description="Number of replicas for the Baidu Vector Database (default is 3)",
default=3,
)

View File

@ -617,6 +617,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.CHROMA
| VectorType.TENCENT
| VectorType.PGVECTO_RS
| VectorType.BAIDU
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (
@ -653,6 +654,7 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.CHROMA
| VectorType.TENCENT
| VectorType.PGVECTO_RS
| VectorType.BAIDU
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (

View File

@ -0,0 +1,272 @@
import json
import time
import uuid
from typing import Any
from pydantic import BaseModel, model_validator
from pymochow import MochowClient
from pymochow.auth.bce_credentials import BceCredentials
from pymochow.configuration import Configuration
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, TableState
from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex
from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
class BaiduConfig(BaseModel):
endpoint: str
connection_timeout_in_mills: int = 30 * 1000
account: str
api_key: str
database: str
index_type: str = "HNSW"
metric_type: str = "L2"
shard: int = 1
replicas: int = 3
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["endpoint"]:
raise ValueError("config BAIDU_VECTOR_DB_ENDPOINT is required")
if not values["account"]:
raise ValueError("config BAIDU_VECTOR_DB_ACCOUNT is required")
if not values["api_key"]:
raise ValueError("config BAIDU_VECTOR_DB_API_KEY is required")
if not values["database"]:
raise ValueError("config BAIDU_VECTOR_DB_DATABASE is required")
return values
class BaiduVector(BaseVector):
field_id: str = "id"
field_vector: str = "vector"
field_text: str = "text"
field_metadata: str = "metadata"
field_app_id: str = "app_id"
field_annotation_id: str = "annotation_id"
index_vector: str = "vector_idx"
def __init__(self, collection_name: str, config: BaiduConfig):
super().__init__(collection_name)
self._client_config = config
self._client = self._init_client(config)
self._db = self._init_database()
def get_type(self) -> str:
return VectorType.BAIDU
def to_index_struct(self) -> dict:
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self._create_table(len(embeddings[0]))
self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
total_count = len(documents)
batch_size = 1000
# upsert texts and embeddings batch by batch
table = self._db.table(self._collection_name)
for start in range(0, total_count, batch_size):
end = min(start + batch_size, total_count)
rows = []
for i in range(start, end, 1):
row = Row(
id=metadatas[i].get("doc_id", str(uuid.uuid4())),
vector=embeddings[i],
text=texts[i],
metadata=json.dumps(metadatas[i]),
app_id=metadatas[i].get("app_id", ""),
annotation_id=metadatas[i].get("annotation_id", ""),
)
rows.append(row)
table.upsert(rows=rows)
# rebuild vector index after upsert finished
table.rebuild_index(self.index_vector)
while True:
time.sleep(1)
index = table.describe_index(self.index_vector)
if index.state == IndexState.NORMAL:
break
def text_exists(self, id: str) -> bool:
res = self._db.table(self._collection_name).query(primary_key={self.field_id: id})
if res and res.code == 0:
return True
return False
def delete_by_ids(self, ids: list[str]) -> None:
quoted_ids = [f"'{id}'" for id in ids]
self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})")
def delete_by_metadata_field(self, key: str, value: str) -> None:
self._db.table(self._collection_name).delete(filter=f"{key} = '{value}'")
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
anns = AnnSearch(
vector_field=self.field_vector,
vector_floats=query_vector,
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
)
res = self._db.table(self._collection_name).search(
anns=anns,
projections=[self.field_id, self.field_text, self.field_metadata],
retrieve_vector=True,
)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._get_search_res(res, score_threshold)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# baidu vector database doesn't support bm25 search on current version
return []
def _get_search_res(self, res, score_threshold):
docs = []
for row in res.rows:
row_data = row.get("row", {})
meta = row_data.get(self.field_metadata)
if meta is not None:
meta = json.loads(meta)
score = row.get("score", 0.0)
if score > score_threshold:
meta["score"] = score
doc = Document(page_content=row_data.get(self.field_text), metadata=meta)
docs.append(doc)
return docs
def delete(self) -> None:
self._db.drop_table(table_name=self._collection_name)
def _init_client(self, config) -> MochowClient:
config = Configuration(credentials=BceCredentials(config.account, config.api_key), endpoint=config.endpoint)
client = MochowClient(config)
return client
def _init_database(self):
exists = False
for db in self._client.list_databases():
if db.database_name == self._client_config.database:
exists = True
break
# Create database if not existed
if exists:
return self._client.database(self._client_config.database)
else:
return self._client.create_database(database_name=self._client_config.database)
def _table_existed(self) -> bool:
tables = self._db.list_table()
return any(table.table_name == self._collection_name for table in tables)
def _create_table(self, dimension: int) -> None:
# Try to grab distributed lock and create table
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
table_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(table_exist_cache_key):
return
if self._table_existed():
return
self.delete()
# check IndexType and MetricType
index_type = None
for k, v in IndexType.__members__.items():
if k == self._client_config.index_type:
index_type = v
if index_type is None:
raise ValueError("unsupported index_type")
metric_type = None
for k, v in MetricType.__members__.items():
if k == self._client_config.metric_type:
metric_type = v
if metric_type is None:
raise ValueError("unsupported metric_type")
# Construct field schema
fields = []
fields.append(
Field(
self.field_id,
FieldType.STRING,
primary_key=True,
partition_key=True,
auto_increment=False,
not_null=True,
)
)
fields.append(Field(self.field_metadata, FieldType.STRING, not_null=True))
fields.append(Field(self.field_app_id, FieldType.STRING))
fields.append(Field(self.field_annotation_id, FieldType.STRING))
fields.append(Field(self.field_text, FieldType.TEXT, not_null=True))
fields.append(Field(self.field_vector, FieldType.FLOAT_VECTOR, not_null=True, dimension=dimension))
# Construct vector index params
indexes = []
indexes.append(
VectorIndex(
index_name="vector_idx",
index_type=index_type,
field="vector",
metric_type=metric_type,
params=HNSWParams(m=16, efconstruction=200),
)
)
# Create table
self._db.create_table(
table_name=self._collection_name,
replication=self._client_config.replicas,
partition=Partition(partition_num=self._client_config.shard),
schema=Schema(fields=fields, indexes=indexes),
description="Table for Dify",
)
redis_client.set(table_exist_cache_key, 1, ex=3600)
# Wait for table created
while True:
time.sleep(1)
table = self._db.describe_table(self._collection_name)
if table.state == TableState.NORMAL:
break
class BaiduVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaiduVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.BAIDU, collection_name))
return BaiduVector(
collection_name=collection_name,
config=BaiduConfig(
endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT,
connection_timeout_in_mills=dify_config.BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS,
account=dify_config.BAIDU_VECTOR_DB_ACCOUNT,
api_key=dify_config.BAIDU_VECTOR_DB_API_KEY,
database=dify_config.BAIDU_VECTOR_DB_DATABASE,
shard=dify_config.BAIDU_VECTOR_DB_SHARD,
replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS,
),
)

View File

@ -103,6 +103,10 @@ class Vector:
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory
return AnalyticdbVectorFactory
case VectorType.BAIDU:
from core.rag.datasource.vdb.baidu.baidu_vector import BaiduVectorFactory
return BaiduVectorFactory
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")

View File

@ -16,3 +16,4 @@ class VectorType(str, Enum):
TENCENT = "tencent"
ORACLE = "oracle"
ELASTICSEARCH = "elasticsearch"
BAIDU = "baidu"

47
api/poetry.lock generated
View File

@ -732,7 +732,7 @@ name = "bce-python-sdk"
version = "0.9.23"
description = "BCE SDK for python"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, <4"
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,<4,>=2.7"
files = [
{file = "bce_python_sdk-0.9.23-py3-none-any.whl", hash = "sha256:8debe21a040e00060f6044877d594765ed7b18bc765c6bf16b878bca864140a3"},
{file = "bce_python_sdk-0.9.23.tar.gz", hash = "sha256:19739fed5cd0725356fc5ffa2acbdd8fb23f2a81edb91db21a03174551d0cf41"},
@ -847,7 +847,7 @@ name = "botocore"
version = "1.35.38"
description = "Low-level, data-driven core of boto 3."
optional = false
python-versions = ">= 3.8"
python-versions = ">=3.8"
files = [
{file = "botocore-1.35.38-py3-none-any.whl", hash = "sha256:2eb17d32fa2d3bb5d475132a83564d28e3acc2161534f24b75a54418a1d51359"},
{file = "botocore-1.35.38.tar.gz", hash = "sha256:55d9305c44e5ba29476df456120fa4fb919f03f066afa82f2ae400485e7465f4"},
@ -1068,7 +1068,7 @@ name = "build"
version = "1.2.2.post1"
description = "A simple, correct Python build frontend"
optional = false
python-versions = ">= 3.8"
python-versions = ">=3.8"
files = [
{file = "build-1.2.2.post1-py3-none-any.whl", hash = "sha256:1d61c0887fa860c01971625baae8bdd338e517b836a2f70dd1f7aa3a6b2fc5b5"},
{file = "build-1.2.2.post1.tar.gz", hash = "sha256:b36993e92ca9375a219c99e606a122ff365a760a2d4bba0caa09bd5278b608b7"},
@ -3385,7 +3385,7 @@ name = "gotrue"
version = "2.9.2"
description = "Python Client Library for Supabase Auth"
optional = false
python-versions = ">=3.8,<4.0"
python-versions = "<4.0,>=3.8"
files = [
{file = "gotrue-2.9.2-py3-none-any.whl", hash = "sha256:fcd5279e8f1cc630f3ac35af5485fe39f8030b23906776920d2c32a4e308cff4"},
{file = "gotrue-2.9.2.tar.gz", hash = "sha256:57b3245e916c5efbf19a21b1181011a903c1276bb1df2d847558f2f24f29abb2"},
@ -4415,7 +4415,7 @@ name = "langfuse"
version = "2.51.5"
description = "A client library for accessing langfuse"
optional = false
python-versions = ">=3.8.1,<4.0"
python-versions = "<4.0,>=3.8.1"
files = [
{file = "langfuse-2.51.5-py3-none-any.whl", hash = "sha256:b95401ca710ef94b521afa6541933b6f93d7cfd4a97523c8fc75bca4d6d219fb"},
{file = "langfuse-2.51.5.tar.gz", hash = "sha256:55bc37b5c5d3ae133c1a95db09117cfb3117add110ba02ebbf2ce45ac4395c5b"},
@ -4440,7 +4440,7 @@ name = "langsmith"
version = "0.1.134"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
optional = false
python-versions = ">=3.8.1,<4.0"
python-versions = "<4.0,>=3.8.1"
files = [
{file = "langsmith-0.1.134-py3-none-any.whl", hash = "sha256:ada98ad80ef38807725f32441a472da3dd28394010877751f48f458d3289da04"},
{file = "langsmith-0.1.134.tar.gz", hash = "sha256:23abee3b508875a0e63c602afafffc02442a19cfd88f9daae05b3e9054fd6b61"},
@ -6429,7 +6429,7 @@ name = "postgrest"
version = "0.17.1"
description = "PostgREST client for Python. This library provides an ORM interface to PostgREST."
optional = false
python-versions = ">=3.8,<4.0"
python-versions = "<4.0,>=3.8"
files = [
{file = "postgrest-0.17.1-py3-none-any.whl", hash = "sha256:ec1d00dc8532fe5ffb342cfc7c4e610a1e0e2272eb14f78f9b2b61094f9be510"},
{file = "postgrest-0.17.1.tar.gz", hash = "sha256:e31d9977dbb80dc5f9fdd4d444014686606692dc4ddb9adc85639e56c6d54c92"},
@ -7047,6 +7047,22 @@ bulk-writer = ["azure-storage-blob", "minio (>=7.0.0)", "pyarrow (>=12.0.0)", "r
dev = ["black", "grpcio (==1.62.2)", "grpcio-testing (==1.62.2)", "grpcio-tools (==1.62.2)", "pytest (>=5.3.4)", "pytest-cov (>=2.8.1)", "pytest-timeout (>=1.3.4)", "ruff (>0.4.0)"]
model = ["milvus-model (>=0.1.0)"]
[[package]]
name = "pymochow"
version = "1.3.1"
description = "Python SDK for mochow"
optional = false
python-versions = ">=3.7"
files = [
{file = "pymochow-1.3.1-py3-none-any.whl", hash = "sha256:a7f3b34fd6ea5d1d8413650bb6678365aa148fc396ae945e4ccb4f2365a52327"},
{file = "pymochow-1.3.1.tar.gz", hash = "sha256:1693d10cd0bb7bce45327890a90adafb503155922ccc029acb257699a73a20ba"},
]
[package.dependencies]
future = "*"
orjson = "*"
requests = "*"
[[package]]
name = "pymysql"
version = "1.1.1"
@ -7746,7 +7762,7 @@ name = "realtime"
version = "2.0.2"
description = ""
optional = false
python-versions = ">=3.9,<4.0"
python-versions = "<4.0,>=3.9"
files = [
{file = "realtime-2.0.2-py3-none-any.whl", hash = "sha256:2634c915bc38807f2013f21e8bcc4d2f79870dfd81460ddb9393883d0489928a"},
{file = "realtime-2.0.2.tar.gz", hash = "sha256:519da9325b3b8102139d51785013d592f6b2403d81fa21d838a0b0234723ed7d"},
@ -8173,7 +8189,7 @@ name = "s3transfer"
version = "0.10.3"
description = "An Amazon S3 Transfer Manager"
optional = false
python-versions = ">= 3.8"
python-versions = ">=3.8"
files = [
{file = "s3transfer-0.10.3-py3-none-any.whl", hash = "sha256:263ed587a5803c6c708d3ce44dc4dfedaab4c1a32e8329bab818933d79ddcf5d"},
{file = "s3transfer-0.10.3.tar.gz", hash = "sha256:4f50ed74ab84d474ce614475e0b8d5047ff080810aac5d01ea25231cfc944b0c"},
@ -8417,6 +8433,11 @@ files = [
{file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"},
{file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"},
{file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"},
{file = "scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5"},
{file = "scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908"},
{file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3"},
{file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12"},
{file = "scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f"},
{file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"},
{file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"},
{file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"},
@ -8836,7 +8857,7 @@ name = "storage3"
version = "0.8.1"
description = "Supabase Storage client for Python."
optional = false
python-versions = ">=3.8,<4.0"
python-versions = "<4.0,>=3.8"
files = [
{file = "storage3-0.8.1-py3-none-any.whl", hash = "sha256:0b21205f43eaf0d1dd33bde6c6d0612f88524b7865f017d2ae9827e3f63d9cdc"},
{file = "storage3-0.8.1.tar.gz", hash = "sha256:ea60b68b2221b3868ccc1a7f1294d57d0d9c51642cdc639d8115fe5d0adc8892"},
@ -8882,7 +8903,7 @@ name = "supabase"
version = "2.8.1"
description = "Supabase client for Python."
optional = false
python-versions = ">=3.9,<4.0"
python-versions = "<4.0,>=3.9"
files = [
{file = "supabase-2.8.1-py3-none-any.whl", hash = "sha256:dfa8bef89b54129093521d5bba2136ff765baf67cd76d8ad0aa4984d61a7815c"},
{file = "supabase-2.8.1.tar.gz", hash = "sha256:711c70e6acd9e2ff48ca0dc0b1bb70c01c25378cc5189ec9f5ed9655b30bc41d"},
@ -8902,7 +8923,7 @@ name = "supafunc"
version = "0.6.1"
description = "Library for Supabase Functions"
optional = false
python-versions = ">=3.8,<4.0"
python-versions = "<4.0,>=3.8"
files = [
{file = "supafunc-0.6.1-py3-none-any.whl", hash = "sha256:01aeeeb4bf429977664454a32c86418345140faf6d2e6eb0636d52e4547c5fbb"},
{file = "supafunc-0.6.1.tar.gz", hash = "sha256:3c8761e3999336ccdb7550498a395fd08afc8469382f55ea56f7f640e5a909aa"},
@ -10615,4 +10636,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.13"
content-hash = "cc10ee218369eb5576d1e5ac8aeeb72e8927bbcb8bd1ac1594167c45aa9d9a21"
content-hash = "375ac3a91760513924647e67376cb6018505ec61d967651b254c68af9808d774"

View File

@ -242,6 +242,7 @@ oracledb = "~2.2.1"
pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] }
pgvector = "0.2.5"
pymilvus = "~2.4.4"
pymochow = "1.3.1"
qdrant-client = "1.7.3"
tcvectordb = "1.3.2"
tidb-vector = "0.0.9"

View File

@ -0,0 +1,154 @@
import os
import pytest
from _pytest.monkeypatch import MonkeyPatch
from pymochow import MochowClient
from pymochow.model.database import Database
from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState
from pymochow.model.schema import HNSWParams, VectorIndex
from pymochow.model.table import Table
from requests.adapters import HTTPAdapter
class MockBaiduVectorDBClass:
def mock_vector_db_client(
self,
config=None,
adapter: HTTPAdapter = None,
):
self._conn = None
self._config = None
def list_databases(self, config=None) -> list[Database]:
return [
Database(
conn=self._conn,
database_name="dify",
config=self._config,
)
]
def create_database(self, database_name: str, config=None) -> Database:
return Database(conn=self._conn, database_name=database_name, config=config)
def list_table(self, config=None) -> list[Table]:
return []
def drop_table(self, table_name: str, config=None):
return {"code": 0, "msg": "Success"}
def create_table(
self,
table_name: str,
replication: int,
partition: int,
schema,
enable_dynamic_field=False,
description: str = "",
config=None,
) -> Table:
return Table(self, table_name, replication, partition, schema, enable_dynamic_field, description, config)
def describe_table(self, table_name: str, config=None) -> Table:
return Table(
self,
table_name,
3,
1,
None,
enable_dynamic_field=False,
description="table for dify",
config=config,
state=TableState.NORMAL,
)
def upsert(self, rows, config=None):
return {"code": 0, "msg": "operation success", "affectedCount": 1}
def rebuild_index(self, index_name: str, config=None):
return {"code": 0, "msg": "Success"}
def describe_index(self, index_name: str, config=None):
return VectorIndex(
index_name=index_name,
index_type=IndexType.HNSW,
field="vector",
metric_type=MetricType.L2,
params=HNSWParams(m=16, efconstruction=200),
auto_build=False,
state=IndexState.NORMAL,
)
def query(
self,
primary_key,
partition_key=None,
projections=None,
retrieve_vector=False,
read_consistency=ReadConsistency.EVENTUAL,
config=None,
):
return {
"row": {
"id": "doc_id_001",
"vector": [0.23432432, 0.8923744, 0.89238432],
"text": "text",
"metadata": {"doc_id": "doc_id_001"},
},
"code": 0,
"msg": "Success",
}
def delete(self, primary_key=None, partition_key=None, filter=None, config=None):
return {"code": 0, "msg": "Success"}
def search(
self,
anns,
partition_key=None,
projections=None,
retrieve_vector=False,
read_consistency=ReadConsistency.EVENTUAL,
config=None,
):
return {
"rows": [
{
"row": {
"id": "doc_id_001",
"vector": [0.23432432, 0.8923744, 0.89238432],
"text": "text",
"metadata": {"doc_id": "doc_id_001"},
},
"distance": 0.1,
"score": 0.5,
}
],
"code": 0,
"msg": "Success",
}
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_baiduvectordb_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(MochowClient, "__init__", MockBaiduVectorDBClass.mock_vector_db_client)
monkeypatch.setattr(MochowClient, "list_databases", MockBaiduVectorDBClass.list_databases)
monkeypatch.setattr(MochowClient, "create_database", MockBaiduVectorDBClass.create_database)
monkeypatch.setattr(Database, "table", MockBaiduVectorDBClass.describe_table)
monkeypatch.setattr(Database, "list_table", MockBaiduVectorDBClass.list_table)
monkeypatch.setattr(Database, "create_table", MockBaiduVectorDBClass.create_table)
monkeypatch.setattr(Database, "drop_table", MockBaiduVectorDBClass.drop_table)
monkeypatch.setattr(Database, "describe_table", MockBaiduVectorDBClass.describe_table)
monkeypatch.setattr(Table, "rebuild_index", MockBaiduVectorDBClass.rebuild_index)
monkeypatch.setattr(Table, "describe_index", MockBaiduVectorDBClass.describe_index)
monkeypatch.setattr(Table, "delete", MockBaiduVectorDBClass.delete)
monkeypatch.setattr(Table, "search", MockBaiduVectorDBClass.search)
yield
if MOCK:
monkeypatch.undo()

View File

@ -0,0 +1,36 @@
from unittest.mock import MagicMock
from core.rag.datasource.vdb.baidu.baidu_vector import BaiduConfig, BaiduVector
from tests.integration_tests.vdb.__mock.baiduvectordb import setup_baiduvectordb_mock
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
mock_client = MagicMock()
mock_client.list_databases.return_value = [{"name": "test"}]
class BaiduVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = BaiduVector(
"dify",
BaiduConfig(
endpoint="http://127.0.0.1:5287",
account="root",
api_key="dify",
database="dify",
shard=1,
replicas=3,
),
)
def search_by_vector(self):
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
assert len(hits_by_vector) == 1
def search_by_full_text(self):
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 0
def test_baidu_vector(setup_mock_redis, setup_baiduvectordb_mock):
BaiduVectorTest().run_all_tests()

View File

@ -462,6 +462,15 @@ ELASTICSEARCH_PORT=9200
ELASTICSEARCH_USERNAME=elastic
ELASTICSEARCH_PASSWORD=elastic
# baidu vector configurations, only available when VECTOR_STORE is `baidu`
BAIDU_VECTOR_DB_ENDPOINT=http://127.0.0.1:5287
BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS=30000
BAIDU_VECTOR_DB_ACCOUNT=root
BAIDU_VECTOR_DB_API_KEY=dify
BAIDU_VECTOR_DB_DATABASE=dify
BAIDU_VECTOR_DB_SHARD=1
BAIDU_VECTOR_DB_REPLICAS=3
# ------------------------------
# Knowledge Configuration
# ------------------------------

View File

@ -165,6 +165,13 @@ x-shared-env: &shared-api-worker-env
TENCENT_VECTOR_DB_DATABASE: ${TENCENT_VECTOR_DB_DATABASE:-dify}
TENCENT_VECTOR_DB_SHARD: ${TENCENT_VECTOR_DB_SHARD:-1}
TENCENT_VECTOR_DB_REPLICAS: ${TENCENT_VECTOR_DB_REPLICAS:-2}
BAIDU_VECTOR_DB_ENDPOINT: ${BAIDU_VECTOR_DB_ENDPOINT:-http://127.0.0.1:5287}
BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: ${BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS:-30000}
BAIDU_VECTOR_DB_ACCOUNT: ${BAIDU_VECTOR_DB_ACCOUNT:-root}
BAIDU_VECTOR_DB_API_KEY: ${BAIDU_VECTOR_DB_API_KEY:-dify}
BAIDU_VECTOR_DB_DATABASE: ${BAIDU_VECTOR_DB_DATABASE:-dify}
BAIDU_VECTOR_DB_SHARD: ${BAIDU_VECTOR_DB_SHARD:-1}
BAIDU_VECTOR_DB_REPLICAS: ${BAIDU_VECTOR_DB_REPLICAS:-3}
UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15}
UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5}
ETL_TYPE: ${ETL_TYPE:-dify}