mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-11 21:09:05 +08:00
feat: support tencent vector db (#3568)
This commit is contained in:
parent
9ed21737d5
commit
4080f7b8ad
@ -99,6 +99,15 @@ RELYT_USER=postgres
|
||||
RELYT_PASSWORD=postgres
|
||||
RELYT_DATABASE=postgres
|
||||
|
||||
# Tencent configuration
|
||||
TENCENT_VECTOR_DB_URL=http://127.0.0.1
|
||||
TENCENT_VECTOR_DB_API_KEY=dify
|
||||
TENCENT_VECTOR_DB_TIMEOUT=30
|
||||
TENCENT_VECTOR_DB_USERNAME=dify
|
||||
TENCENT_VECTOR_DB_DATABASE=dify
|
||||
TENCENT_VECTOR_DB_SHARD=1
|
||||
TENCENT_VECTOR_DB_REPLICAS=2
|
||||
|
||||
# PGVECTO_RS configuration
|
||||
PGVECTO_RS_HOST=localhost
|
||||
PGVECTO_RS_PORT=5431
|
||||
|
@ -309,6 +309,14 @@ def migrate_knowledge_vector_database():
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == VectorType.TENCENT:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": VectorType.TENCENT,
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == VectorType.PGVECTOR:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
|
@ -288,6 +288,16 @@ class Config:
|
||||
self.RELYT_PASSWORD = get_env('RELYT_PASSWORD')
|
||||
self.RELYT_DATABASE = get_env('RELYT_DATABASE')
|
||||
|
||||
|
||||
# tencent settings
|
||||
self.TENCENT_VECTOR_DB_URL = get_env('TENCENT_VECTOR_DB_URL')
|
||||
self.TENCENT_VECTOR_DB_API_KEY = get_env('TENCENT_VECTOR_DB_API_KEY')
|
||||
self.TENCENT_VECTOR_DB_TIMEOUT = get_env('TENCENT_VECTOR_DB_TIMEOUT')
|
||||
self.TENCENT_VECTOR_DB_USERNAME = get_env('TENCENT_VECTOR_DB_USERNAME')
|
||||
self.TENCENT_VECTOR_DB_DATABASE = get_env('TENCENT_VECTOR_DB_DATABASE')
|
||||
self.TENCENT_VECTOR_DB_SHARD = get_env('TENCENT_VECTOR_DB_SHARD')
|
||||
self.TENCENT_VECTOR_DB_REPLICAS = get_env('TENCENT_VECTOR_DB_REPLICAS')
|
||||
|
||||
# pgvecto rs settings
|
||||
self.PGVECTO_RS_HOST = get_env('PGVECTO_RS_HOST')
|
||||
self.PGVECTO_RS_PORT = get_env('PGVECTO_RS_PORT')
|
||||
|
@ -480,9 +480,8 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
vector_type = current_app.config['VECTOR_STORE']
|
||||
|
||||
match vector_type:
|
||||
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA:
|
||||
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
'semantic_search'
|
||||
@ -504,7 +503,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, vector_type):
|
||||
match vector_type:
|
||||
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA:
|
||||
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCEN:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
'semantic_search'
|
||||
|
0
api/core/rag/datasource/vdb/tencent/__init__.py
Normal file
0
api/core/rag/datasource/vdb/tencent/__init__.py
Normal file
227
api/core/rag/datasource/vdb/tencent/tencent_vector.py
Normal file
227
api/core/rag/datasource/vdb/tencent/tencent_vector.py
Normal file
@ -0,0 +1,227 @@
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel
|
||||
from tcvectordb import VectorDBClient
|
||||
from tcvectordb.model import document, enum
|
||||
from tcvectordb.model import index as vdb_index
|
||||
from tcvectordb.model.document import Filter
|
||||
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class TencentConfig(BaseModel):
|
||||
url: str
|
||||
api_key: Optional[str]
|
||||
timeout: float = 30
|
||||
username: Optional[str]
|
||||
database: Optional[str]
|
||||
index_type: str = "HNSW"
|
||||
metric_type: str = "L2"
|
||||
shard: int = 1,
|
||||
replicas: int = 2,
|
||||
|
||||
def to_tencent_params(self):
|
||||
return {
|
||||
'url': self.url,
|
||||
'username': self.username,
|
||||
'key': self.api_key,
|
||||
'timeout': self.timeout
|
||||
}
|
||||
|
||||
|
||||
class TencentVector(BaseVector):
|
||||
field_id: str = "id"
|
||||
field_vector: str = "vector"
|
||||
field_text: str = "text"
|
||||
field_metadata: str = "metadata"
|
||||
|
||||
def __init__(self, collection_name: str, config: TencentConfig):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
self._client = VectorDBClient(**self._client_config.to_tencent_params())
|
||||
self._db = self._init_database()
|
||||
|
||||
def _init_database(self):
|
||||
exists = False
|
||||
for db in self._client.list_databases():
|
||||
if db.database_name == self._client_config.database:
|
||||
exists = True
|
||||
break
|
||||
if exists:
|
||||
return self._client.database(self._client_config.database)
|
||||
else:
|
||||
return self._client.create_database(database_name=self._client_config.database)
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'tencent'
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"class_prefix": self._collection_name}
|
||||
}
|
||||
|
||||
def _has_collection(self) -> bool:
|
||||
collections = self._db.list_collections()
|
||||
for collection in collections:
|
||||
if collection.collection_name == self._collection_name:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _create_collection(self, dimension: int) -> None:
|
||||
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
|
||||
if self._has_collection():
|
||||
return
|
||||
|
||||
self.delete()
|
||||
index_type = None
|
||||
for k, v in enum.IndexType.__members__.items():
|
||||
if k == self._client_config.index_type:
|
||||
index_type = v
|
||||
if index_type is None:
|
||||
raise ValueError("unsupported index_type")
|
||||
metric_type = None
|
||||
for k, v in enum.MetricType.__members__.items():
|
||||
if k == self._client_config.metric_type:
|
||||
metric_type = v
|
||||
if metric_type is None:
|
||||
raise ValueError("unsupported metric_type")
|
||||
params = vdb_index.HNSWParams(m=16, efconstruction=200)
|
||||
index = vdb_index.Index(
|
||||
vdb_index.FilterIndex(
|
||||
self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY
|
||||
),
|
||||
vdb_index.VectorIndex(
|
||||
self.field_vector,
|
||||
dimension,
|
||||
index_type,
|
||||
metric_type,
|
||||
params,
|
||||
),
|
||||
vdb_index.FilterIndex(
|
||||
self.field_text, enum.FieldType.String, enum.IndexType.FILTER
|
||||
),
|
||||
vdb_index.FilterIndex(
|
||||
self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER
|
||||
),
|
||||
)
|
||||
|
||||
self._db.create_collection(
|
||||
name=self._collection_name,
|
||||
shard=self._client_config.shard,
|
||||
replicas=self._client_config.replicas,
|
||||
description="Collection for Dify",
|
||||
index=index,
|
||||
)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
self._create_collection(len(embeddings[0]))
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
texts = [doc.page_content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
total_count = len(embeddings)
|
||||
docs = []
|
||||
for id in range(0, total_count):
|
||||
if metadatas is None:
|
||||
continue
|
||||
metadata = json.dumps(metadatas[id])
|
||||
doc = document.Document(
|
||||
id=metadatas[id]["doc_id"],
|
||||
vector=embeddings[id],
|
||||
text=texts[id],
|
||||
metadata=metadata,
|
||||
)
|
||||
docs.append(doc)
|
||||
self._db.collection(self._collection_name).upsert(docs, self._client_config.timeout)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
docs = self._db.collection(self._collection_name).query(document_ids=[id])
|
||||
if docs and len(docs) > 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
self._db.collection(self._collection_name).delete(document_ids=ids)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(key, [value])))
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
|
||||
res = self._db.collection(self._collection_name).search(vectors=[query_vector],
|
||||
params=document.HNSWSearchParams(
|
||||
ef=kwargs.get("ef", 10)),
|
||||
retrieve_vector=False,
|
||||
limit=kwargs.get('top_k', 4),
|
||||
timeout=self._client_config.timeout,
|
||||
)
|
||||
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
|
||||
return self._get_search_res(res, score_threshold)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return []
|
||||
|
||||
def _get_search_res(self, res, score_threshold):
|
||||
docs = []
|
||||
if res is None or len(res) == 0:
|
||||
return docs
|
||||
|
||||
for result in res[0]:
|
||||
meta = result.get(self.field_metadata)
|
||||
if meta is not None:
|
||||
meta = json.loads(meta)
|
||||
score = 1 - result.get("score", 0.0)
|
||||
if score > score_threshold:
|
||||
meta["score"] = score
|
||||
doc = Document(page_content=result.get(self.field_text), metadata=meta)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
def delete(self) -> None:
|
||||
self._db.drop_collection(name=self._collection_name)
|
||||
|
||||
|
||||
|
||||
|
||||
class TencentVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TencentVector:
|
||||
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
return TencentVector(
|
||||
collection_name=collection_name,
|
||||
config=TencentConfig(
|
||||
url=config.get('TENCENT_VECTOR_DB_URL'),
|
||||
api_key=config.get('TENCENT_VECTOR_DB_API_KEY'),
|
||||
timeout=config.get('TENCENT_VECTOR_DB_TIMEOUT'),
|
||||
username=config.get('TENCENT_VECTOR_DB_USERNAME'),
|
||||
database=config.get('TENCENT_VECTOR_DB_DATABASE'),
|
||||
shard=config.get('TENCENT_VECTOR_DB_SHARD'),
|
||||
replicas=config.get('TENCENT_VECTOR_DB_REPLICAS'),
|
||||
)
|
||||
)
|
@ -39,7 +39,6 @@ class Vector:
|
||||
def _init_vector(self) -> BaseVector:
|
||||
config = current_app.config
|
||||
vector_type = config.get('VECTOR_STORE')
|
||||
|
||||
if self._dataset.index_struct_dict:
|
||||
vector_type = self._dataset.index_struct_dict['type']
|
||||
|
||||
@ -76,6 +75,9 @@ class Vector:
|
||||
case VectorType.WEAVIATE:
|
||||
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory
|
||||
return WeaviateVectorFactory
|
||||
case VectorType.TENCENT:
|
||||
from core.rag.datasource.vdb.tencent.tencent_vector import TencentVectorFactory
|
||||
return TencentVectorFactory
|
||||
case _:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
|
@ -10,3 +10,4 @@ class VectorType(str, Enum):
|
||||
RELYT = 'relyt'
|
||||
TIDB_VECTOR = 'tidb_vector'
|
||||
WEAVIATE = 'weaviate'
|
||||
TENCENT = 'tencent'
|
||||
|
45
api/poetry.lock
generated
45
api/poetry.lock
generated
@ -1439,6 +1439,23 @@ mypy = ["contourpy[bokeh,docs]", "docutils-stubs", "mypy (==1.8.0)", "types-Pill
|
||||
test = ["Pillow", "contourpy[test-no-images]", "matplotlib"]
|
||||
test-no-images = ["pytest", "pytest-cov", "pytest-xdist", "wurlitzer"]
|
||||
|
||||
[[package]]
|
||||
name = "cos-python-sdk-v5"
|
||||
version = "1.9.29"
|
||||
description = "cos-python-sdk-v5"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "cos-python-sdk-v5-1.9.29.tar.gz", hash = "sha256:1bb07022368d178e7a50a3cc42e0d6cbf4b0bef2af12a3bb8436904339cdec8e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
crcmod = "*"
|
||||
pycryptodome = "*"
|
||||
requests = ">=2.8"
|
||||
six = "*"
|
||||
xmltodict = "*"
|
||||
|
||||
[[package]]
|
||||
name = "coverage"
|
||||
version = "7.2.7"
|
||||
@ -7411,6 +7428,21 @@ files = [
|
||||
[package.extras]
|
||||
widechars = ["wcwidth"]
|
||||
|
||||
[[package]]
|
||||
name = "tcvectordb"
|
||||
version = "1.3.2"
|
||||
description = "Tencent VectorDB Python SDK"
|
||||
optional = false
|
||||
python-versions = ">=3"
|
||||
files = [
|
||||
{file = "tcvectordb-1.3.2-py3-none-any.whl", hash = "sha256:c4b6922d5df4cf14fcd3e61220d9374d1d53ec7270c254216ae35f8a752908f3"},
|
||||
{file = "tcvectordb-1.3.2.tar.gz", hash = "sha256:2772f5871a69744ffc7c970b321312d626078533a721de3c744059a81aab419e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
cos-python-sdk-v5 = ">=1.9.26"
|
||||
requests = "*"
|
||||
|
||||
[[package]]
|
||||
name = "tenacity"
|
||||
version = "8.3.0"
|
||||
@ -8641,6 +8673,17 @@ files = [
|
||||
{file = "XlsxWriter-3.2.0.tar.gz", hash = "sha256:9977d0c661a72866a61f9f7a809e25ebbb0fb7036baa3b9fe74afcfca6b3cb8c"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "xmltodict"
|
||||
version = "0.13.0"
|
||||
description = "Makes working with XML feel like you are working with JSON"
|
||||
optional = false
|
||||
python-versions = ">=3.4"
|
||||
files = [
|
||||
{file = "xmltodict-0.13.0-py2.py3-none-any.whl", hash = "sha256:aa89e8fd76320154a40d19a0df04a4695fb9dc5ba977cbb68ab3e4eb225e7852"},
|
||||
{file = "xmltodict-0.13.0.tar.gz", hash = "sha256:341595a488e3e01a85a9d8911d8912fd922ede5fecc4dce437eb4b6c8d037e56"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "yarl"
|
||||
version = "1.9.4"
|
||||
@ -8878,4 +8921,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "32a9ac027beabdb863fb33886bbf6f0000cbddf4d6089cbdb5c5dbfba23b29b4"
|
||||
content-hash = "e967aa4b61dc7c40f2f50eb325038da1dc0ff633d8f778e7a7560bdabce744dc"
|
||||
|
@ -179,6 +179,7 @@ google-cloud-aiplatform = "1.49.0"
|
||||
vanna = {version = "0.5.5", extras = ["postgres", "mysql", "clickhouse", "duckdb"]}
|
||||
kaleido = "0.2.1"
|
||||
tencentcloud-sdk-python-hunyuan = "~3.0.1158"
|
||||
tcvectordb = "1.3.2"
|
||||
chromadb = "~0.5.0"
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
|
@ -78,6 +78,7 @@ lxml==5.1.0
|
||||
pydantic~=2.7.4
|
||||
pydantic_extra_types~=2.8.1
|
||||
pgvecto-rs==0.1.4
|
||||
tcvectordb==1.3.2
|
||||
firecrawl-py==0.0.5
|
||||
oss2==2.18.5
|
||||
pgvector==0.2.5
|
||||
|
0
api/tests/integration_tests/vdb/__mock/__init__.py
Normal file
0
api/tests/integration_tests/vdb/__mock/__init__.py
Normal file
132
api/tests/integration_tests/vdb/__mock/tcvectordb.py
Normal file
132
api/tests/integration_tests/vdb/__mock/tcvectordb.py
Normal file
@ -0,0 +1,132 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from requests.adapters import HTTPAdapter
|
||||
from tcvectordb import VectorDBClient
|
||||
from tcvectordb.model.database import Collection, Database
|
||||
from tcvectordb.model.document import Document, Filter
|
||||
from tcvectordb.model.enum import ReadConsistency
|
||||
from tcvectordb.model.index import Index
|
||||
from xinference_client.types import Embedding
|
||||
|
||||
|
||||
class MockTcvectordbClass:
|
||||
|
||||
def VectorDBClient(self, url=None, username='', key='',
|
||||
read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY,
|
||||
timeout=5,
|
||||
adapter: HTTPAdapter = None):
|
||||
self._conn = None
|
||||
self._read_consistency = read_consistency
|
||||
|
||||
def list_databases(self) -> list[Database]:
|
||||
return [
|
||||
Database(
|
||||
conn=self._conn,
|
||||
read_consistency=self._read_consistency,
|
||||
name='dify',
|
||||
)]
|
||||
|
||||
def list_collections(self, timeout: Optional[float] = None) -> list[Collection]:
|
||||
return []
|
||||
|
||||
def drop_collection(self, name: str, timeout: Optional[float] = None):
|
||||
return {
|
||||
"code": 0,
|
||||
"msg": "operation success"
|
||||
}
|
||||
|
||||
def create_collection(
|
||||
self,
|
||||
name: str,
|
||||
shard: int,
|
||||
replicas: int,
|
||||
description: str,
|
||||
index: Index,
|
||||
embedding: Embedding = None,
|
||||
timeout: float = None,
|
||||
) -> Collection:
|
||||
return Collection(self, name, shard, replicas, description, index, embedding=embedding,
|
||||
read_consistency=self._read_consistency, timeout=timeout)
|
||||
|
||||
def describe_collection(self, name: str, timeout: Optional[float] = None) -> Collection:
|
||||
collection = Collection(
|
||||
self,
|
||||
name,
|
||||
shard=1,
|
||||
replicas=2,
|
||||
description=name,
|
||||
timeout=timeout
|
||||
)
|
||||
return collection
|
||||
|
||||
def collection_upsert(
|
||||
self,
|
||||
documents: list[Document],
|
||||
timeout: Optional[float] = None,
|
||||
build_index: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
return {
|
||||
"code": 0,
|
||||
"msg": "operation success"
|
||||
}
|
||||
|
||||
def collection_search(
|
||||
self,
|
||||
vectors: list[list[float]],
|
||||
filter: Filter = None,
|
||||
params=None,
|
||||
retrieve_vector: bool = False,
|
||||
limit: int = 10,
|
||||
output_fields: Optional[list[str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> list[list[dict]]:
|
||||
return [[{'metadata': '{"doc_id":"foo1"}', 'text': 'text', 'doc_id': 'foo1', 'score': 0.1}]]
|
||||
|
||||
def collection_query(
|
||||
self,
|
||||
document_ids: Optional[list] = None,
|
||||
retrieve_vector: bool = False,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
filter: Optional[Filter] = None,
|
||||
output_fields: Optional[list[str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> list[dict]:
|
||||
return [{'metadata': '{"doc_id":"foo1"}', 'text': 'text', 'doc_id': 'foo1', 'score': 0.1}]
|
||||
|
||||
def collection_delete(
|
||||
self,
|
||||
document_ids: list[str] = None,
|
||||
filter: Filter = None,
|
||||
timeout: float = None,
|
||||
):
|
||||
return {
|
||||
"code": 0,
|
||||
"msg": "operation success"
|
||||
}
|
||||
|
||||
|
||||
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
|
||||
|
||||
@pytest.fixture
|
||||
def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(VectorDBClient, '__init__', MockTcvectordbClass.VectorDBClient)
|
||||
monkeypatch.setattr(VectorDBClient, 'list_databases', MockTcvectordbClass.list_databases)
|
||||
monkeypatch.setattr(Database, 'collection', MockTcvectordbClass.describe_collection)
|
||||
monkeypatch.setattr(Database, 'list_collections', MockTcvectordbClass.list_collections)
|
||||
monkeypatch.setattr(Database, 'drop_collection', MockTcvectordbClass.drop_collection)
|
||||
monkeypatch.setattr(Database, 'create_collection', MockTcvectordbClass.create_collection)
|
||||
monkeypatch.setattr(Collection, 'upsert', MockTcvectordbClass.collection_upsert)
|
||||
monkeypatch.setattr(Collection, 'search', MockTcvectordbClass.collection_search)
|
||||
monkeypatch.setattr(Collection, 'query', MockTcvectordbClass.collection_query)
|
||||
monkeypatch.setattr(Collection, 'delete', MockTcvectordbClass.collection_delete)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
35
api/tests/integration_tests/vdb/tcvectordb/test_tencent.py
Normal file
35
api/tests/integration_tests/vdb/tcvectordb/test_tencent.py
Normal file
@ -0,0 +1,35 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.rag.datasource.vdb.tencent.tencent_vector import TencentConfig, TencentVector
|
||||
from tests.integration_tests.vdb.__mock.tcvectordb import setup_tcvectordb_mock
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.list_databases.return_value = [{"name": "test"}]
|
||||
|
||||
class TencentVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = TencentVector("dify", TencentConfig(
|
||||
url="http://127.0.0.1",
|
||||
api_key="dify",
|
||||
timeout=30,
|
||||
username="dify",
|
||||
database="dify",
|
||||
shard=1,
|
||||
replicas=2,
|
||||
))
|
||||
|
||||
def search_by_vector(self):
|
||||
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
|
||||
assert len(hits_by_vector) == 1
|
||||
|
||||
def search_by_full_text(self):
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 0
|
||||
|
||||
def test_tencent_vector(setup_mock_redis,setup_tcvectordb_mock):
|
||||
TencentVectorTest().run_all_tests()
|
||||
|
||||
|
||||
|
@ -298,6 +298,14 @@ services:
|
||||
RELYT_USER: postgres
|
||||
RELYT_PASSWORD: difyai123456
|
||||
RELYT_DATABASE: postgres
|
||||
# tencent configurations
|
||||
TENCENT_VECTOR_DB_URL: http://127.0.0.1
|
||||
TENCENT_VECTOR_DB_API_KEY: dify
|
||||
TENCENT_VECTOR_DB_TIMEOUT: 30
|
||||
TENCENT_VECTOR_DB_USERNAME: dify
|
||||
TENCENT_VECTOR_DB_DATABASE: dify
|
||||
TENCENT_VECTOR_DB_SHARD: 1
|
||||
TENCENT_VECTOR_DB_REPLICAS: 2
|
||||
# pgvector configurations
|
||||
PGVECTOR_HOST: pgvector
|
||||
PGVECTOR_PORT: 5432
|
||||
|
Loading…
x
Reference in New Issue
Block a user