mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-11 20:59:02 +08:00
feat:support baidu vector db (#9185)
This commit is contained in:
parent
793205afc5
commit
2ec6ffe478
@ -208,6 +208,15 @@ OPENSEARCH_USER=admin
|
||||
OPENSEARCH_PASSWORD=admin
|
||||
OPENSEARCH_SECURE=true
|
||||
|
||||
# Baidu configuration
|
||||
BAIDU_VECTOR_DB_ENDPOINT=http://127.0.0.1:5287
|
||||
BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS=30000
|
||||
BAIDU_VECTOR_DB_ACCOUNT=root
|
||||
BAIDU_VECTOR_DB_API_KEY=dify
|
||||
BAIDU_VECTOR_DB_DATABASE=dify
|
||||
BAIDU_VECTOR_DB_SHARD=1
|
||||
BAIDU_VECTOR_DB_REPLICAS=3
|
||||
|
||||
# Upload configuration
|
||||
UPLOAD_FILE_SIZE_LIMIT=15
|
||||
UPLOAD_FILE_BATCH_LIMIT=5
|
||||
|
@ -347,6 +347,14 @@ def migrate_knowledge_vector_database():
|
||||
index_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == VectorType.BAIDU:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": VectorType.BAIDU,
|
||||
"vector_store": {"class_prefix": collection_name},
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
else:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
|
45
api/configs/middleware/vdb/baidu_vector_config.py
Normal file
45
api/configs/middleware/vdb/baidu_vector_config.py
Normal file
@ -0,0 +1,45 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, NonNegativeInt, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class BaiduVectorDBConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for Baidu Vector Database
|
||||
"""
|
||||
|
||||
BAIDU_VECTOR_DB_ENDPOINT: Optional[str] = Field(
|
||||
description="URL of the Baidu Vector Database service (e.g., 'http://vdb.bj.baidubce.com')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: PositiveInt = Field(
|
||||
description="Timeout in milliseconds for Baidu Vector Database operations (default is 30000 milliseconds)",
|
||||
default=30000,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_ACCOUNT: Optional[str] = Field(
|
||||
description="Account for authenticating with the Baidu Vector Database",
|
||||
default=None,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_API_KEY: Optional[str] = Field(
|
||||
description="API key for authenticating with the Baidu Vector Database service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_DATABASE: Optional[str] = Field(
|
||||
description="Name of the specific Baidu Vector Database to connect to",
|
||||
default=None,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_SHARD: PositiveInt = Field(
|
||||
description="Number of shards for the Baidu Vector Database (default is 1)",
|
||||
default=1,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_REPLICAS: NonNegativeInt = Field(
|
||||
description="Number of replicas for the Baidu Vector Database (default is 3)",
|
||||
default=3,
|
||||
)
|
@ -617,6 +617,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.CHROMA
|
||||
| VectorType.TENCENT
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.BAIDU
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
@ -653,6 +654,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
| VectorType.CHROMA
|
||||
| VectorType.TENCENT
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.BAIDU
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
|
0
api/core/rag/datasource/vdb/baidu/__init__.py
Normal file
0
api/core/rag/datasource/vdb/baidu/__init__.py
Normal file
272
api/core/rag/datasource/vdb/baidu/baidu_vector.py
Normal file
272
api/core/rag/datasource/vdb/baidu/baidu_vector.py
Normal file
@ -0,0 +1,272 @@
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pymochow import MochowClient
|
||||
from pymochow.auth.bce_credentials import BceCredentials
|
||||
from pymochow.configuration import Configuration
|
||||
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, TableState
|
||||
from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex
|
||||
from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class BaiduConfig(BaseModel):
|
||||
endpoint: str
|
||||
connection_timeout_in_mills: int = 30 * 1000
|
||||
account: str
|
||||
api_key: str
|
||||
database: str
|
||||
index_type: str = "HNSW"
|
||||
metric_type: str = "L2"
|
||||
shard: int = 1
|
||||
replicas: int = 3
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["endpoint"]:
|
||||
raise ValueError("config BAIDU_VECTOR_DB_ENDPOINT is required")
|
||||
if not values["account"]:
|
||||
raise ValueError("config BAIDU_VECTOR_DB_ACCOUNT is required")
|
||||
if not values["api_key"]:
|
||||
raise ValueError("config BAIDU_VECTOR_DB_API_KEY is required")
|
||||
if not values["database"]:
|
||||
raise ValueError("config BAIDU_VECTOR_DB_DATABASE is required")
|
||||
return values
|
||||
|
||||
|
||||
class BaiduVector(BaseVector):
|
||||
field_id: str = "id"
|
||||
field_vector: str = "vector"
|
||||
field_text: str = "text"
|
||||
field_metadata: str = "metadata"
|
||||
field_app_id: str = "app_id"
|
||||
field_annotation_id: str = "annotation_id"
|
||||
index_vector: str = "vector_idx"
|
||||
|
||||
def __init__(self, collection_name: str, config: BaiduConfig):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
self._client = self._init_client(config)
|
||||
self._db = self._init_database()
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.BAIDU
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
self._create_table(len(embeddings[0]))
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
texts = [doc.page_content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
total_count = len(documents)
|
||||
batch_size = 1000
|
||||
|
||||
# upsert texts and embeddings batch by batch
|
||||
table = self._db.table(self._collection_name)
|
||||
for start in range(0, total_count, batch_size):
|
||||
end = min(start + batch_size, total_count)
|
||||
rows = []
|
||||
for i in range(start, end, 1):
|
||||
row = Row(
|
||||
id=metadatas[i].get("doc_id", str(uuid.uuid4())),
|
||||
vector=embeddings[i],
|
||||
text=texts[i],
|
||||
metadata=json.dumps(metadatas[i]),
|
||||
app_id=metadatas[i].get("app_id", ""),
|
||||
annotation_id=metadatas[i].get("annotation_id", ""),
|
||||
)
|
||||
rows.append(row)
|
||||
table.upsert(rows=rows)
|
||||
|
||||
# rebuild vector index after upsert finished
|
||||
table.rebuild_index(self.index_vector)
|
||||
while True:
|
||||
time.sleep(1)
|
||||
index = table.describe_index(self.index_vector)
|
||||
if index.state == IndexState.NORMAL:
|
||||
break
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
res = self._db.table(self._collection_name).query(primary_key={self.field_id: id})
|
||||
if res and res.code == 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
quoted_ids = [f"'{id}'" for id in ids]
|
||||
self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})")
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
self._db.table(self._collection_name).delete(filter=f"{key} = '{value}'")
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
anns = AnnSearch(
|
||||
vector_field=self.field_vector,
|
||||
vector_floats=query_vector,
|
||||
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
|
||||
)
|
||||
res = self._db.table(self._collection_name).search(
|
||||
anns=anns,
|
||||
projections=[self.field_id, self.field_text, self.field_metadata],
|
||||
retrieve_vector=True,
|
||||
)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
return self._get_search_res(res, score_threshold)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
# baidu vector database doesn't support bm25 search on current version
|
||||
return []
|
||||
|
||||
def _get_search_res(self, res, score_threshold):
|
||||
docs = []
|
||||
for row in res.rows:
|
||||
row_data = row.get("row", {})
|
||||
meta = row_data.get(self.field_metadata)
|
||||
if meta is not None:
|
||||
meta = json.loads(meta)
|
||||
score = row.get("score", 0.0)
|
||||
if score > score_threshold:
|
||||
meta["score"] = score
|
||||
doc = Document(page_content=row_data.get(self.field_text), metadata=meta)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
def delete(self) -> None:
|
||||
self._db.drop_table(table_name=self._collection_name)
|
||||
|
||||
def _init_client(self, config) -> MochowClient:
|
||||
config = Configuration(credentials=BceCredentials(config.account, config.api_key), endpoint=config.endpoint)
|
||||
client = MochowClient(config)
|
||||
return client
|
||||
|
||||
def _init_database(self):
|
||||
exists = False
|
||||
for db in self._client.list_databases():
|
||||
if db.database_name == self._client_config.database:
|
||||
exists = True
|
||||
break
|
||||
# Create database if not existed
|
||||
if exists:
|
||||
return self._client.database(self._client_config.database)
|
||||
else:
|
||||
return self._client.create_database(database_name=self._client_config.database)
|
||||
|
||||
def _table_existed(self) -> bool:
|
||||
tables = self._db.list_table()
|
||||
return any(table.table_name == self._collection_name for table in tables)
|
||||
|
||||
def _create_table(self, dimension: int) -> None:
|
||||
# Try to grab distributed lock and create table
|
||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
table_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
if redis_client.get(table_exist_cache_key):
|
||||
return
|
||||
|
||||
if self._table_existed():
|
||||
return
|
||||
|
||||
self.delete()
|
||||
|
||||
# check IndexType and MetricType
|
||||
index_type = None
|
||||
for k, v in IndexType.__members__.items():
|
||||
if k == self._client_config.index_type:
|
||||
index_type = v
|
||||
if index_type is None:
|
||||
raise ValueError("unsupported index_type")
|
||||
metric_type = None
|
||||
for k, v in MetricType.__members__.items():
|
||||
if k == self._client_config.metric_type:
|
||||
metric_type = v
|
||||
if metric_type is None:
|
||||
raise ValueError("unsupported metric_type")
|
||||
|
||||
# Construct field schema
|
||||
fields = []
|
||||
fields.append(
|
||||
Field(
|
||||
self.field_id,
|
||||
FieldType.STRING,
|
||||
primary_key=True,
|
||||
partition_key=True,
|
||||
auto_increment=False,
|
||||
not_null=True,
|
||||
)
|
||||
)
|
||||
fields.append(Field(self.field_metadata, FieldType.STRING, not_null=True))
|
||||
fields.append(Field(self.field_app_id, FieldType.STRING))
|
||||
fields.append(Field(self.field_annotation_id, FieldType.STRING))
|
||||
fields.append(Field(self.field_text, FieldType.TEXT, not_null=True))
|
||||
fields.append(Field(self.field_vector, FieldType.FLOAT_VECTOR, not_null=True, dimension=dimension))
|
||||
|
||||
# Construct vector index params
|
||||
indexes = []
|
||||
indexes.append(
|
||||
VectorIndex(
|
||||
index_name="vector_idx",
|
||||
index_type=index_type,
|
||||
field="vector",
|
||||
metric_type=metric_type,
|
||||
params=HNSWParams(m=16, efconstruction=200),
|
||||
)
|
||||
)
|
||||
|
||||
# Create table
|
||||
self._db.create_table(
|
||||
table_name=self._collection_name,
|
||||
replication=self._client_config.replicas,
|
||||
partition=Partition(partition_num=self._client_config.shard),
|
||||
schema=Schema(fields=fields, indexes=indexes),
|
||||
description="Table for Dify",
|
||||
)
|
||||
|
||||
redis_client.set(table_exist_cache_key, 1, ex=3600)
|
||||
|
||||
# Wait for table created
|
||||
while True:
|
||||
time.sleep(1)
|
||||
table = self._db.describe_table(self._collection_name)
|
||||
if table.state == TableState.NORMAL:
|
||||
break
|
||||
|
||||
|
||||
class BaiduVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaiduVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.BAIDU, collection_name))
|
||||
|
||||
return BaiduVector(
|
||||
collection_name=collection_name,
|
||||
config=BaiduConfig(
|
||||
endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT,
|
||||
connection_timeout_in_mills=dify_config.BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS,
|
||||
account=dify_config.BAIDU_VECTOR_DB_ACCOUNT,
|
||||
api_key=dify_config.BAIDU_VECTOR_DB_API_KEY,
|
||||
database=dify_config.BAIDU_VECTOR_DB_DATABASE,
|
||||
shard=dify_config.BAIDU_VECTOR_DB_SHARD,
|
||||
replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS,
|
||||
),
|
||||
)
|
@ -103,6 +103,10 @@ class Vector:
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory
|
||||
|
||||
return AnalyticdbVectorFactory
|
||||
case VectorType.BAIDU:
|
||||
from core.rag.datasource.vdb.baidu.baidu_vector import BaiduVectorFactory
|
||||
|
||||
return BaiduVectorFactory
|
||||
case _:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
|
@ -16,3 +16,4 @@ class VectorType(str, Enum):
|
||||
TENCENT = "tencent"
|
||||
ORACLE = "oracle"
|
||||
ELASTICSEARCH = "elasticsearch"
|
||||
BAIDU = "baidu"
|
||||
|
47
api/poetry.lock
generated
47
api/poetry.lock
generated
@ -732,7 +732,7 @@ name = "bce-python-sdk"
|
||||
version = "0.9.23"
|
||||
description = "BCE SDK for python"
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, <4"
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,<4,>=2.7"
|
||||
files = [
|
||||
{file = "bce_python_sdk-0.9.23-py3-none-any.whl", hash = "sha256:8debe21a040e00060f6044877d594765ed7b18bc765c6bf16b878bca864140a3"},
|
||||
{file = "bce_python_sdk-0.9.23.tar.gz", hash = "sha256:19739fed5cd0725356fc5ffa2acbdd8fb23f2a81edb91db21a03174551d0cf41"},
|
||||
@ -847,7 +847,7 @@ name = "botocore"
|
||||
version = "1.35.38"
|
||||
description = "Low-level, data-driven core of boto 3."
|
||||
optional = false
|
||||
python-versions = ">= 3.8"
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "botocore-1.35.38-py3-none-any.whl", hash = "sha256:2eb17d32fa2d3bb5d475132a83564d28e3acc2161534f24b75a54418a1d51359"},
|
||||
{file = "botocore-1.35.38.tar.gz", hash = "sha256:55d9305c44e5ba29476df456120fa4fb919f03f066afa82f2ae400485e7465f4"},
|
||||
@ -1068,7 +1068,7 @@ name = "build"
|
||||
version = "1.2.2.post1"
|
||||
description = "A simple, correct Python build frontend"
|
||||
optional = false
|
||||
python-versions = ">= 3.8"
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "build-1.2.2.post1-py3-none-any.whl", hash = "sha256:1d61c0887fa860c01971625baae8bdd338e517b836a2f70dd1f7aa3a6b2fc5b5"},
|
||||
{file = "build-1.2.2.post1.tar.gz", hash = "sha256:b36993e92ca9375a219c99e606a122ff365a760a2d4bba0caa09bd5278b608b7"},
|
||||
@ -3385,7 +3385,7 @@ name = "gotrue"
|
||||
version = "2.9.2"
|
||||
description = "Python Client Library for Supabase Auth"
|
||||
optional = false
|
||||
python-versions = ">=3.8,<4.0"
|
||||
python-versions = "<4.0,>=3.8"
|
||||
files = [
|
||||
{file = "gotrue-2.9.2-py3-none-any.whl", hash = "sha256:fcd5279e8f1cc630f3ac35af5485fe39f8030b23906776920d2c32a4e308cff4"},
|
||||
{file = "gotrue-2.9.2.tar.gz", hash = "sha256:57b3245e916c5efbf19a21b1181011a903c1276bb1df2d847558f2f24f29abb2"},
|
||||
@ -4415,7 +4415,7 @@ name = "langfuse"
|
||||
version = "2.51.5"
|
||||
description = "A client library for accessing langfuse"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
python-versions = "<4.0,>=3.8.1"
|
||||
files = [
|
||||
{file = "langfuse-2.51.5-py3-none-any.whl", hash = "sha256:b95401ca710ef94b521afa6541933b6f93d7cfd4a97523c8fc75bca4d6d219fb"},
|
||||
{file = "langfuse-2.51.5.tar.gz", hash = "sha256:55bc37b5c5d3ae133c1a95db09117cfb3117add110ba02ebbf2ce45ac4395c5b"},
|
||||
@ -4440,7 +4440,7 @@ name = "langsmith"
|
||||
version = "0.1.134"
|
||||
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
python-versions = "<4.0,>=3.8.1"
|
||||
files = [
|
||||
{file = "langsmith-0.1.134-py3-none-any.whl", hash = "sha256:ada98ad80ef38807725f32441a472da3dd28394010877751f48f458d3289da04"},
|
||||
{file = "langsmith-0.1.134.tar.gz", hash = "sha256:23abee3b508875a0e63c602afafffc02442a19cfd88f9daae05b3e9054fd6b61"},
|
||||
@ -6429,7 +6429,7 @@ name = "postgrest"
|
||||
version = "0.17.1"
|
||||
description = "PostgREST client for Python. This library provides an ORM interface to PostgREST."
|
||||
optional = false
|
||||
python-versions = ">=3.8,<4.0"
|
||||
python-versions = "<4.0,>=3.8"
|
||||
files = [
|
||||
{file = "postgrest-0.17.1-py3-none-any.whl", hash = "sha256:ec1d00dc8532fe5ffb342cfc7c4e610a1e0e2272eb14f78f9b2b61094f9be510"},
|
||||
{file = "postgrest-0.17.1.tar.gz", hash = "sha256:e31d9977dbb80dc5f9fdd4d444014686606692dc4ddb9adc85639e56c6d54c92"},
|
||||
@ -7047,6 +7047,22 @@ bulk-writer = ["azure-storage-blob", "minio (>=7.0.0)", "pyarrow (>=12.0.0)", "r
|
||||
dev = ["black", "grpcio (==1.62.2)", "grpcio-testing (==1.62.2)", "grpcio-tools (==1.62.2)", "pytest (>=5.3.4)", "pytest-cov (>=2.8.1)", "pytest-timeout (>=1.3.4)", "ruff (>0.4.0)"]
|
||||
model = ["milvus-model (>=0.1.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "pymochow"
|
||||
version = "1.3.1"
|
||||
description = "Python SDK for mochow"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "pymochow-1.3.1-py3-none-any.whl", hash = "sha256:a7f3b34fd6ea5d1d8413650bb6678365aa148fc396ae945e4ccb4f2365a52327"},
|
||||
{file = "pymochow-1.3.1.tar.gz", hash = "sha256:1693d10cd0bb7bce45327890a90adafb503155922ccc029acb257699a73a20ba"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
future = "*"
|
||||
orjson = "*"
|
||||
requests = "*"
|
||||
|
||||
[[package]]
|
||||
name = "pymysql"
|
||||
version = "1.1.1"
|
||||
@ -7746,7 +7762,7 @@ name = "realtime"
|
||||
version = "2.0.2"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = ">=3.9,<4.0"
|
||||
python-versions = "<4.0,>=3.9"
|
||||
files = [
|
||||
{file = "realtime-2.0.2-py3-none-any.whl", hash = "sha256:2634c915bc38807f2013f21e8bcc4d2f79870dfd81460ddb9393883d0489928a"},
|
||||
{file = "realtime-2.0.2.tar.gz", hash = "sha256:519da9325b3b8102139d51785013d592f6b2403d81fa21d838a0b0234723ed7d"},
|
||||
@ -8173,7 +8189,7 @@ name = "s3transfer"
|
||||
version = "0.10.3"
|
||||
description = "An Amazon S3 Transfer Manager"
|
||||
optional = false
|
||||
python-versions = ">= 3.8"
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "s3transfer-0.10.3-py3-none-any.whl", hash = "sha256:263ed587a5803c6c708d3ce44dc4dfedaab4c1a32e8329bab818933d79ddcf5d"},
|
||||
{file = "s3transfer-0.10.3.tar.gz", hash = "sha256:4f50ed74ab84d474ce614475e0b8d5047ff080810aac5d01ea25231cfc944b0c"},
|
||||
@ -8417,6 +8433,11 @@ files = [
|
||||
{file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"},
|
||||
{file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"},
|
||||
{file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"},
|
||||
{file = "scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5"},
|
||||
{file = "scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908"},
|
||||
{file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3"},
|
||||
{file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12"},
|
||||
{file = "scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f"},
|
||||
{file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"},
|
||||
{file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"},
|
||||
{file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"},
|
||||
@ -8836,7 +8857,7 @@ name = "storage3"
|
||||
version = "0.8.1"
|
||||
description = "Supabase Storage client for Python."
|
||||
optional = false
|
||||
python-versions = ">=3.8,<4.0"
|
||||
python-versions = "<4.0,>=3.8"
|
||||
files = [
|
||||
{file = "storage3-0.8.1-py3-none-any.whl", hash = "sha256:0b21205f43eaf0d1dd33bde6c6d0612f88524b7865f017d2ae9827e3f63d9cdc"},
|
||||
{file = "storage3-0.8.1.tar.gz", hash = "sha256:ea60b68b2221b3868ccc1a7f1294d57d0d9c51642cdc639d8115fe5d0adc8892"},
|
||||
@ -8882,7 +8903,7 @@ name = "supabase"
|
||||
version = "2.8.1"
|
||||
description = "Supabase client for Python."
|
||||
optional = false
|
||||
python-versions = ">=3.9,<4.0"
|
||||
python-versions = "<4.0,>=3.9"
|
||||
files = [
|
||||
{file = "supabase-2.8.1-py3-none-any.whl", hash = "sha256:dfa8bef89b54129093521d5bba2136ff765baf67cd76d8ad0aa4984d61a7815c"},
|
||||
{file = "supabase-2.8.1.tar.gz", hash = "sha256:711c70e6acd9e2ff48ca0dc0b1bb70c01c25378cc5189ec9f5ed9655b30bc41d"},
|
||||
@ -8902,7 +8923,7 @@ name = "supafunc"
|
||||
version = "0.6.1"
|
||||
description = "Library for Supabase Functions"
|
||||
optional = false
|
||||
python-versions = ">=3.8,<4.0"
|
||||
python-versions = "<4.0,>=3.8"
|
||||
files = [
|
||||
{file = "supafunc-0.6.1-py3-none-any.whl", hash = "sha256:01aeeeb4bf429977664454a32c86418345140faf6d2e6eb0636d52e4547c5fbb"},
|
||||
{file = "supafunc-0.6.1.tar.gz", hash = "sha256:3c8761e3999336ccdb7550498a395fd08afc8469382f55ea56f7f640e5a909aa"},
|
||||
@ -10615,4 +10636,4 @@ cffi = ["cffi (>=1.11)"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<3.13"
|
||||
content-hash = "cc10ee218369eb5576d1e5ac8aeeb72e8927bbcb8bd1ac1594167c45aa9d9a21"
|
||||
content-hash = "375ac3a91760513924647e67376cb6018505ec61d967651b254c68af9808d774"
|
||||
|
@ -242,6 +242,7 @@ oracledb = "~2.2.1"
|
||||
pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] }
|
||||
pgvector = "0.2.5"
|
||||
pymilvus = "~2.4.4"
|
||||
pymochow = "1.3.1"
|
||||
qdrant-client = "1.7.3"
|
||||
tcvectordb = "1.3.2"
|
||||
tidb-vector = "0.0.9"
|
||||
|
154
api/tests/integration_tests/vdb/__mock/baiduvectordb.py
Normal file
154
api/tests/integration_tests/vdb/__mock/baiduvectordb.py
Normal file
@ -0,0 +1,154 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from pymochow import MochowClient
|
||||
from pymochow.model.database import Database
|
||||
from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState
|
||||
from pymochow.model.schema import HNSWParams, VectorIndex
|
||||
from pymochow.model.table import Table
|
||||
from requests.adapters import HTTPAdapter
|
||||
|
||||
|
||||
class MockBaiduVectorDBClass:
|
||||
def mock_vector_db_client(
|
||||
self,
|
||||
config=None,
|
||||
adapter: HTTPAdapter = None,
|
||||
):
|
||||
self._conn = None
|
||||
self._config = None
|
||||
|
||||
def list_databases(self, config=None) -> list[Database]:
|
||||
return [
|
||||
Database(
|
||||
conn=self._conn,
|
||||
database_name="dify",
|
||||
config=self._config,
|
||||
)
|
||||
]
|
||||
|
||||
def create_database(self, database_name: str, config=None) -> Database:
|
||||
return Database(conn=self._conn, database_name=database_name, config=config)
|
||||
|
||||
def list_table(self, config=None) -> list[Table]:
|
||||
return []
|
||||
|
||||
def drop_table(self, table_name: str, config=None):
|
||||
return {"code": 0, "msg": "Success"}
|
||||
|
||||
def create_table(
|
||||
self,
|
||||
table_name: str,
|
||||
replication: int,
|
||||
partition: int,
|
||||
schema,
|
||||
enable_dynamic_field=False,
|
||||
description: str = "",
|
||||
config=None,
|
||||
) -> Table:
|
||||
return Table(self, table_name, replication, partition, schema, enable_dynamic_field, description, config)
|
||||
|
||||
def describe_table(self, table_name: str, config=None) -> Table:
|
||||
return Table(
|
||||
self,
|
||||
table_name,
|
||||
3,
|
||||
1,
|
||||
None,
|
||||
enable_dynamic_field=False,
|
||||
description="table for dify",
|
||||
config=config,
|
||||
state=TableState.NORMAL,
|
||||
)
|
||||
|
||||
def upsert(self, rows, config=None):
|
||||
return {"code": 0, "msg": "operation success", "affectedCount": 1}
|
||||
|
||||
def rebuild_index(self, index_name: str, config=None):
|
||||
return {"code": 0, "msg": "Success"}
|
||||
|
||||
def describe_index(self, index_name: str, config=None):
|
||||
return VectorIndex(
|
||||
index_name=index_name,
|
||||
index_type=IndexType.HNSW,
|
||||
field="vector",
|
||||
metric_type=MetricType.L2,
|
||||
params=HNSWParams(m=16, efconstruction=200),
|
||||
auto_build=False,
|
||||
state=IndexState.NORMAL,
|
||||
)
|
||||
|
||||
def query(
|
||||
self,
|
||||
primary_key,
|
||||
partition_key=None,
|
||||
projections=None,
|
||||
retrieve_vector=False,
|
||||
read_consistency=ReadConsistency.EVENTUAL,
|
||||
config=None,
|
||||
):
|
||||
return {
|
||||
"row": {
|
||||
"id": "doc_id_001",
|
||||
"vector": [0.23432432, 0.8923744, 0.89238432],
|
||||
"text": "text",
|
||||
"metadata": {"doc_id": "doc_id_001"},
|
||||
},
|
||||
"code": 0,
|
||||
"msg": "Success",
|
||||
}
|
||||
|
||||
def delete(self, primary_key=None, partition_key=None, filter=None, config=None):
|
||||
return {"code": 0, "msg": "Success"}
|
||||
|
||||
def search(
|
||||
self,
|
||||
anns,
|
||||
partition_key=None,
|
||||
projections=None,
|
||||
retrieve_vector=False,
|
||||
read_consistency=ReadConsistency.EVENTUAL,
|
||||
config=None,
|
||||
):
|
||||
return {
|
||||
"rows": [
|
||||
{
|
||||
"row": {
|
||||
"id": "doc_id_001",
|
||||
"vector": [0.23432432, 0.8923744, 0.89238432],
|
||||
"text": "text",
|
||||
"metadata": {"doc_id": "doc_id_001"},
|
||||
},
|
||||
"distance": 0.1,
|
||||
"score": 0.5,
|
||||
}
|
||||
],
|
||||
"code": 0,
|
||||
"msg": "Success",
|
||||
}
|
||||
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_baiduvectordb_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(MochowClient, "__init__", MockBaiduVectorDBClass.mock_vector_db_client)
|
||||
monkeypatch.setattr(MochowClient, "list_databases", MockBaiduVectorDBClass.list_databases)
|
||||
monkeypatch.setattr(MochowClient, "create_database", MockBaiduVectorDBClass.create_database)
|
||||
monkeypatch.setattr(Database, "table", MockBaiduVectorDBClass.describe_table)
|
||||
monkeypatch.setattr(Database, "list_table", MockBaiduVectorDBClass.list_table)
|
||||
monkeypatch.setattr(Database, "create_table", MockBaiduVectorDBClass.create_table)
|
||||
monkeypatch.setattr(Database, "drop_table", MockBaiduVectorDBClass.drop_table)
|
||||
monkeypatch.setattr(Database, "describe_table", MockBaiduVectorDBClass.describe_table)
|
||||
monkeypatch.setattr(Table, "rebuild_index", MockBaiduVectorDBClass.rebuild_index)
|
||||
monkeypatch.setattr(Table, "describe_index", MockBaiduVectorDBClass.describe_index)
|
||||
monkeypatch.setattr(Table, "delete", MockBaiduVectorDBClass.delete)
|
||||
monkeypatch.setattr(Table, "search", MockBaiduVectorDBClass.search)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
0
api/tests/integration_tests/vdb/baidu/__init__.py
Normal file
0
api/tests/integration_tests/vdb/baidu/__init__.py
Normal file
36
api/tests/integration_tests/vdb/baidu/test_baidu.py
Normal file
36
api/tests/integration_tests/vdb/baidu/test_baidu.py
Normal file
@ -0,0 +1,36 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.rag.datasource.vdb.baidu.baidu_vector import BaiduConfig, BaiduVector
|
||||
from tests.integration_tests.vdb.__mock.baiduvectordb import setup_baiduvectordb_mock
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.list_databases.return_value = [{"name": "test"}]
|
||||
|
||||
|
||||
class BaiduVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = BaiduVector(
|
||||
"dify",
|
||||
BaiduConfig(
|
||||
endpoint="http://127.0.0.1:5287",
|
||||
account="root",
|
||||
api_key="dify",
|
||||
database="dify",
|
||||
shard=1,
|
||||
replicas=3,
|
||||
),
|
||||
)
|
||||
|
||||
def search_by_vector(self):
|
||||
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
|
||||
assert len(hits_by_vector) == 1
|
||||
|
||||
def search_by_full_text(self):
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 0
|
||||
|
||||
|
||||
def test_baidu_vector(setup_mock_redis, setup_baiduvectordb_mock):
|
||||
BaiduVectorTest().run_all_tests()
|
@ -462,6 +462,15 @@ ELASTICSEARCH_PORT=9200
|
||||
ELASTICSEARCH_USERNAME=elastic
|
||||
ELASTICSEARCH_PASSWORD=elastic
|
||||
|
||||
# baidu vector configurations, only available when VECTOR_STORE is `baidu`
|
||||
BAIDU_VECTOR_DB_ENDPOINT=http://127.0.0.1:5287
|
||||
BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS=30000
|
||||
BAIDU_VECTOR_DB_ACCOUNT=root
|
||||
BAIDU_VECTOR_DB_API_KEY=dify
|
||||
BAIDU_VECTOR_DB_DATABASE=dify
|
||||
BAIDU_VECTOR_DB_SHARD=1
|
||||
BAIDU_VECTOR_DB_REPLICAS=3
|
||||
|
||||
# ------------------------------
|
||||
# Knowledge Configuration
|
||||
# ------------------------------
|
||||
|
@ -165,6 +165,13 @@ x-shared-env: &shared-api-worker-env
|
||||
TENCENT_VECTOR_DB_DATABASE: ${TENCENT_VECTOR_DB_DATABASE:-dify}
|
||||
TENCENT_VECTOR_DB_SHARD: ${TENCENT_VECTOR_DB_SHARD:-1}
|
||||
TENCENT_VECTOR_DB_REPLICAS: ${TENCENT_VECTOR_DB_REPLICAS:-2}
|
||||
BAIDU_VECTOR_DB_ENDPOINT: ${BAIDU_VECTOR_DB_ENDPOINT:-http://127.0.0.1:5287}
|
||||
BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: ${BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS:-30000}
|
||||
BAIDU_VECTOR_DB_ACCOUNT: ${BAIDU_VECTOR_DB_ACCOUNT:-root}
|
||||
BAIDU_VECTOR_DB_API_KEY: ${BAIDU_VECTOR_DB_API_KEY:-dify}
|
||||
BAIDU_VECTOR_DB_DATABASE: ${BAIDU_VECTOR_DB_DATABASE:-dify}
|
||||
BAIDU_VECTOR_DB_SHARD: ${BAIDU_VECTOR_DB_SHARD:-1}
|
||||
BAIDU_VECTOR_DB_REPLICAS: ${BAIDU_VECTOR_DB_REPLICAS:-3}
|
||||
UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15}
|
||||
UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5}
|
||||
ETL_TYPE: ${ETL_TYPE:-dify}
|
||||
|
Loading…
x
Reference in New Issue
Block a user