feat: support tencent vector db (#3568)

This commit is contained in:
quicksand 2024-06-14 19:25:17 +08:00 committed by GitHub
parent 9ed21737d5
commit 4080f7b8ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 481 additions and 5 deletions

View File

@ -99,6 +99,15 @@ RELYT_USER=postgres
RELYT_PASSWORD=postgres
RELYT_DATABASE=postgres
# Tencent configuration
TENCENT_VECTOR_DB_URL=http://127.0.0.1
TENCENT_VECTOR_DB_API_KEY=dify
TENCENT_VECTOR_DB_TIMEOUT=30
TENCENT_VECTOR_DB_USERNAME=dify
TENCENT_VECTOR_DB_DATABASE=dify
TENCENT_VECTOR_DB_SHARD=1
TENCENT_VECTOR_DB_REPLICAS=2
# PGVECTO_RS configuration
PGVECTO_RS_HOST=localhost
PGVECTO_RS_PORT=5431

View File

@ -309,6 +309,14 @@ def migrate_knowledge_vector_database():
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.TENCENT:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.TENCENT,
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.PGVECTOR:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)

View File

@ -288,6 +288,16 @@ class Config:
self.RELYT_PASSWORD = get_env('RELYT_PASSWORD')
self.RELYT_DATABASE = get_env('RELYT_DATABASE')
# tencent settings
self.TENCENT_VECTOR_DB_URL = get_env('TENCENT_VECTOR_DB_URL')
self.TENCENT_VECTOR_DB_API_KEY = get_env('TENCENT_VECTOR_DB_API_KEY')
self.TENCENT_VECTOR_DB_TIMEOUT = get_env('TENCENT_VECTOR_DB_TIMEOUT')
self.TENCENT_VECTOR_DB_USERNAME = get_env('TENCENT_VECTOR_DB_USERNAME')
self.TENCENT_VECTOR_DB_DATABASE = get_env('TENCENT_VECTOR_DB_DATABASE')
self.TENCENT_VECTOR_DB_SHARD = get_env('TENCENT_VECTOR_DB_SHARD')
self.TENCENT_VECTOR_DB_REPLICAS = get_env('TENCENT_VECTOR_DB_REPLICAS')
# pgvecto rs settings
self.PGVECTO_RS_HOST = get_env('PGVECTO_RS_HOST')
self.PGVECTO_RS_PORT = get_env('PGVECTO_RS_PORT')

View File

@ -480,9 +480,8 @@ class DatasetRetrievalSettingApi(Resource):
@account_initialization_required
def get(self):
vector_type = current_app.config['VECTOR_STORE']
match vector_type:
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA:
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT:
return {
'retrieval_method': [
'semantic_search'
@ -504,7 +503,7 @@ class DatasetRetrievalSettingMockApi(Resource):
@account_initialization_required
def get(self, vector_type):
match vector_type:
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA:
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCEN:
return {
'retrieval_method': [
'semantic_search'

View File

@ -0,0 +1,227 @@
import json
from typing import Any, Optional
from flask import current_app
from pydantic import BaseModel
from tcvectordb import VectorDBClient
from tcvectordb.model import document, enum
from tcvectordb.model import index as vdb_index
from tcvectordb.model.document import Filter
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
class TencentConfig(BaseModel):
url: str
api_key: Optional[str]
timeout: float = 30
username: Optional[str]
database: Optional[str]
index_type: str = "HNSW"
metric_type: str = "L2"
shard: int = 1,
replicas: int = 2,
def to_tencent_params(self):
return {
'url': self.url,
'username': self.username,
'key': self.api_key,
'timeout': self.timeout
}
class TencentVector(BaseVector):
field_id: str = "id"
field_vector: str = "vector"
field_text: str = "text"
field_metadata: str = "metadata"
def __init__(self, collection_name: str, config: TencentConfig):
super().__init__(collection_name)
self._client_config = config
self._client = VectorDBClient(**self._client_config.to_tencent_params())
self._db = self._init_database()
def _init_database(self):
exists = False
for db in self._client.list_databases():
if db.database_name == self._client_config.database:
exists = True
break
if exists:
return self._client.database(self._client_config.database)
else:
return self._client.create_database(database_name=self._client_config.database)
def get_type(self) -> str:
return 'tencent'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name}
}
def _has_collection(self) -> bool:
collections = self._db.list_collections()
for collection in collections:
if collection.collection_name == self._collection_name:
return True
return False
def _create_collection(self, dimension: int) -> None:
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
if self._has_collection():
return
self.delete()
index_type = None
for k, v in enum.IndexType.__members__.items():
if k == self._client_config.index_type:
index_type = v
if index_type is None:
raise ValueError("unsupported index_type")
metric_type = None
for k, v in enum.MetricType.__members__.items():
if k == self._client_config.metric_type:
metric_type = v
if metric_type is None:
raise ValueError("unsupported metric_type")
params = vdb_index.HNSWParams(m=16, efconstruction=200)
index = vdb_index.Index(
vdb_index.FilterIndex(
self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY
),
vdb_index.VectorIndex(
self.field_vector,
dimension,
index_type,
metric_type,
params,
),
vdb_index.FilterIndex(
self.field_text, enum.FieldType.String, enum.IndexType.FILTER
),
vdb_index.FilterIndex(
self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER
),
)
self._db.create_collection(
name=self._collection_name,
shard=self._client_config.shard,
replicas=self._client_config.replicas,
description="Collection for Dify",
index=index,
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self._create_collection(len(embeddings[0]))
self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
total_count = len(embeddings)
docs = []
for id in range(0, total_count):
if metadatas is None:
continue
metadata = json.dumps(metadatas[id])
doc = document.Document(
id=metadatas[id]["doc_id"],
vector=embeddings[id],
text=texts[id],
metadata=metadata,
)
docs.append(doc)
self._db.collection(self._collection_name).upsert(docs, self._client_config.timeout)
def text_exists(self, id: str) -> bool:
docs = self._db.collection(self._collection_name).query(document_ids=[id])
if docs and len(docs) > 0:
return True
return False
def delete_by_ids(self, ids: list[str]) -> None:
self._db.collection(self._collection_name).delete(document_ids=ids)
def delete_by_metadata_field(self, key: str, value: str) -> None:
self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(key, [value])))
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
res = self._db.collection(self._collection_name).search(vectors=[query_vector],
params=document.HNSWSearchParams(
ef=kwargs.get("ef", 10)),
retrieve_vector=False,
limit=kwargs.get('top_k', 4),
timeout=self._client_config.timeout,
)
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
return self._get_search_res(res, score_threshold)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return []
def _get_search_res(self, res, score_threshold):
docs = []
if res is None or len(res) == 0:
return docs
for result in res[0]:
meta = result.get(self.field_metadata)
if meta is not None:
meta = json.loads(meta)
score = 1 - result.get("score", 0.0)
if score > score_threshold:
meta["score"] = score
doc = Document(page_content=result.get(self.field_text), metadata=meta)
docs.append(doc)
return docs
def delete(self) -> None:
self._db.drop_collection(name=self._collection_name)
class TencentVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TencentVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name))
config = current_app.config
return TencentVector(
collection_name=collection_name,
config=TencentConfig(
url=config.get('TENCENT_VECTOR_DB_URL'),
api_key=config.get('TENCENT_VECTOR_DB_API_KEY'),
timeout=config.get('TENCENT_VECTOR_DB_TIMEOUT'),
username=config.get('TENCENT_VECTOR_DB_USERNAME'),
database=config.get('TENCENT_VECTOR_DB_DATABASE'),
shard=config.get('TENCENT_VECTOR_DB_SHARD'),
replicas=config.get('TENCENT_VECTOR_DB_REPLICAS'),
)
)

View File

@ -39,7 +39,6 @@ class Vector:
def _init_vector(self) -> BaseVector:
config = current_app.config
vector_type = config.get('VECTOR_STORE')
if self._dataset.index_struct_dict:
vector_type = self._dataset.index_struct_dict['type']
@ -76,6 +75,9 @@ class Vector:
case VectorType.WEAVIATE:
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory
return WeaviateVectorFactory
case VectorType.TENCENT:
from core.rag.datasource.vdb.tencent.tencent_vector import TencentVectorFactory
return TencentVectorFactory
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")

View File

@ -10,3 +10,4 @@ class VectorType(str, Enum):
RELYT = 'relyt'
TIDB_VECTOR = 'tidb_vector'
WEAVIATE = 'weaviate'
TENCENT = 'tencent'

45
api/poetry.lock generated
View File

@ -1439,6 +1439,23 @@ mypy = ["contourpy[bokeh,docs]", "docutils-stubs", "mypy (==1.8.0)", "types-Pill
test = ["Pillow", "contourpy[test-no-images]", "matplotlib"]
test-no-images = ["pytest", "pytest-cov", "pytest-xdist", "wurlitzer"]
[[package]]
name = "cos-python-sdk-v5"
version = "1.9.29"
description = "cos-python-sdk-v5"
optional = false
python-versions = "*"
files = [
{file = "cos-python-sdk-v5-1.9.29.tar.gz", hash = "sha256:1bb07022368d178e7a50a3cc42e0d6cbf4b0bef2af12a3bb8436904339cdec8e"},
]
[package.dependencies]
crcmod = "*"
pycryptodome = "*"
requests = ">=2.8"
six = "*"
xmltodict = "*"
[[package]]
name = "coverage"
version = "7.2.7"
@ -7411,6 +7428,21 @@ files = [
[package.extras]
widechars = ["wcwidth"]
[[package]]
name = "tcvectordb"
version = "1.3.2"
description = "Tencent VectorDB Python SDK"
optional = false
python-versions = ">=3"
files = [
{file = "tcvectordb-1.3.2-py3-none-any.whl", hash = "sha256:c4b6922d5df4cf14fcd3e61220d9374d1d53ec7270c254216ae35f8a752908f3"},
{file = "tcvectordb-1.3.2.tar.gz", hash = "sha256:2772f5871a69744ffc7c970b321312d626078533a721de3c744059a81aab419e"},
]
[package.dependencies]
cos-python-sdk-v5 = ">=1.9.26"
requests = "*"
[[package]]
name = "tenacity"
version = "8.3.0"
@ -8641,6 +8673,17 @@ files = [
{file = "XlsxWriter-3.2.0.tar.gz", hash = "sha256:9977d0c661a72866a61f9f7a809e25ebbb0fb7036baa3b9fe74afcfca6b3cb8c"},
]
[[package]]
name = "xmltodict"
version = "0.13.0"
description = "Makes working with XML feel like you are working with JSON"
optional = false
python-versions = ">=3.4"
files = [
{file = "xmltodict-0.13.0-py2.py3-none-any.whl", hash = "sha256:aa89e8fd76320154a40d19a0df04a4695fb9dc5ba977cbb68ab3e4eb225e7852"},
{file = "xmltodict-0.13.0.tar.gz", hash = "sha256:341595a488e3e01a85a9d8911d8912fd922ede5fecc4dce437eb4b6c8d037e56"},
]
[[package]]
name = "yarl"
version = "1.9.4"
@ -8878,4 +8921,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "32a9ac027beabdb863fb33886bbf6f0000cbddf4d6089cbdb5c5dbfba23b29b4"
content-hash = "e967aa4b61dc7c40f2f50eb325038da1dc0ff633d8f778e7a7560bdabce744dc"

View File

@ -179,6 +179,7 @@ google-cloud-aiplatform = "1.49.0"
vanna = {version = "0.5.5", extras = ["postgres", "mysql", "clickhouse", "duckdb"]}
kaleido = "0.2.1"
tencentcloud-sdk-python-hunyuan = "~3.0.1158"
tcvectordb = "1.3.2"
chromadb = "~0.5.0"
[tool.poetry.group.dev]

View File

@ -78,6 +78,7 @@ lxml==5.1.0
pydantic~=2.7.4
pydantic_extra_types~=2.8.1
pgvecto-rs==0.1.4
tcvectordb==1.3.2
firecrawl-py==0.0.5
oss2==2.18.5
pgvector==0.2.5

View File

@ -0,0 +1,132 @@
import os
from typing import Optional
import pytest
from _pytest.monkeypatch import MonkeyPatch
from requests.adapters import HTTPAdapter
from tcvectordb import VectorDBClient
from tcvectordb.model.database import Collection, Database
from tcvectordb.model.document import Document, Filter
from tcvectordb.model.enum import ReadConsistency
from tcvectordb.model.index import Index
from xinference_client.types import Embedding
class MockTcvectordbClass:
def VectorDBClient(self, url=None, username='', key='',
read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY,
timeout=5,
adapter: HTTPAdapter = None):
self._conn = None
self._read_consistency = read_consistency
def list_databases(self) -> list[Database]:
return [
Database(
conn=self._conn,
read_consistency=self._read_consistency,
name='dify',
)]
def list_collections(self, timeout: Optional[float] = None) -> list[Collection]:
return []
def drop_collection(self, name: str, timeout: Optional[float] = None):
return {
"code": 0,
"msg": "operation success"
}
def create_collection(
self,
name: str,
shard: int,
replicas: int,
description: str,
index: Index,
embedding: Embedding = None,
timeout: float = None,
) -> Collection:
return Collection(self, name, shard, replicas, description, index, embedding=embedding,
read_consistency=self._read_consistency, timeout=timeout)
def describe_collection(self, name: str, timeout: Optional[float] = None) -> Collection:
collection = Collection(
self,
name,
shard=1,
replicas=2,
description=name,
timeout=timeout
)
return collection
def collection_upsert(
self,
documents: list[Document],
timeout: Optional[float] = None,
build_index: bool = True,
**kwargs
):
return {
"code": 0,
"msg": "operation success"
}
def collection_search(
self,
vectors: list[list[float]],
filter: Filter = None,
params=None,
retrieve_vector: bool = False,
limit: int = 10,
output_fields: Optional[list[str]] = None,
timeout: Optional[float] = None,
) -> list[list[dict]]:
return [[{'metadata': '{"doc_id":"foo1"}', 'text': 'text', 'doc_id': 'foo1', 'score': 0.1}]]
def collection_query(
self,
document_ids: Optional[list] = None,
retrieve_vector: bool = False,
limit: Optional[int] = None,
offset: Optional[int] = None,
filter: Optional[Filter] = None,
output_fields: Optional[list[str]] = None,
timeout: Optional[float] = None,
) -> list[dict]:
return [{'metadata': '{"doc_id":"foo1"}', 'text': 'text', 'doc_id': 'foo1', 'score': 0.1}]
def collection_delete(
self,
document_ids: list[str] = None,
filter: Filter = None,
timeout: float = None,
):
return {
"code": 0,
"msg": "operation success"
}
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
@pytest.fixture
def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(VectorDBClient, '__init__', MockTcvectordbClass.VectorDBClient)
monkeypatch.setattr(VectorDBClient, 'list_databases', MockTcvectordbClass.list_databases)
monkeypatch.setattr(Database, 'collection', MockTcvectordbClass.describe_collection)
monkeypatch.setattr(Database, 'list_collections', MockTcvectordbClass.list_collections)
monkeypatch.setattr(Database, 'drop_collection', MockTcvectordbClass.drop_collection)
monkeypatch.setattr(Database, 'create_collection', MockTcvectordbClass.create_collection)
monkeypatch.setattr(Collection, 'upsert', MockTcvectordbClass.collection_upsert)
monkeypatch.setattr(Collection, 'search', MockTcvectordbClass.collection_search)
monkeypatch.setattr(Collection, 'query', MockTcvectordbClass.collection_query)
monkeypatch.setattr(Collection, 'delete', MockTcvectordbClass.collection_delete)
yield
if MOCK:
monkeypatch.undo()

View File

@ -0,0 +1,35 @@
from unittest.mock import MagicMock
from core.rag.datasource.vdb.tencent.tencent_vector import TencentConfig, TencentVector
from tests.integration_tests.vdb.__mock.tcvectordb import setup_tcvectordb_mock
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
mock_client = MagicMock()
mock_client.list_databases.return_value = [{"name": "test"}]
class TencentVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = TencentVector("dify", TencentConfig(
url="http://127.0.0.1",
api_key="dify",
timeout=30,
username="dify",
database="dify",
shard=1,
replicas=2,
))
def search_by_vector(self):
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
assert len(hits_by_vector) == 1
def search_by_full_text(self):
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 0
def test_tencent_vector(setup_mock_redis,setup_tcvectordb_mock):
TencentVectorTest().run_all_tests()

View File

@ -298,6 +298,14 @@ services:
RELYT_USER: postgres
RELYT_PASSWORD: difyai123456
RELYT_DATABASE: postgres
# tencent configurations
TENCENT_VECTOR_DB_URL: http://127.0.0.1
TENCENT_VECTOR_DB_API_KEY: dify
TENCENT_VECTOR_DB_TIMEOUT: 30
TENCENT_VECTOR_DB_USERNAME: dify
TENCENT_VECTOR_DB_DATABASE: dify
TENCENT_VECTOR_DB_SHARD: 1
TENCENT_VECTOR_DB_REPLICAS: 2
# pgvector configurations
PGVECTOR_HOST: pgvector
PGVECTOR_PORT: 5432