mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 01:39:04 +08:00
Add Volcengine VikingDB as new vector provider (#9287)
This commit is contained in:
parent
1ec83e4969
commit
d15ba3939d
@ -111,7 +111,7 @@ SUPABASE_URL=your-server-url
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, vikingdb
|
||||
VECTOR_STORE=weaviate
|
||||
|
||||
# Weaviate configuration
|
||||
@ -220,6 +220,15 @@ BAIDU_VECTOR_DB_DATABASE=dify
|
||||
BAIDU_VECTOR_DB_SHARD=1
|
||||
BAIDU_VECTOR_DB_REPLICAS=3
|
||||
|
||||
# ViKingDB configuration
|
||||
VIKINGDB_ACCESS_KEY=your-ak
|
||||
VIKINGDB_SECRET_KEY=your-sk
|
||||
VIKINGDB_REGION=cn-shanghai
|
||||
VIKINGDB_HOST=api-vikingdb.xxx.volces.com
|
||||
VIKINGDB_SCHEMA=http
|
||||
VIKINGDB_CONNECTION_TIMEOUT=30
|
||||
VIKINGDB_SOCKET_TIMEOUT=30
|
||||
|
||||
# Upload configuration
|
||||
UPLOAD_FILE_SIZE_LIMIT=15
|
||||
UPLOAD_FILE_BATCH_LIMIT=5
|
||||
|
@ -28,6 +28,7 @@ from configs.middleware.vdb.qdrant_config import QdrantConfig
|
||||
from configs.middleware.vdb.relyt_config import RelytConfig
|
||||
from configs.middleware.vdb.tencent_vector_config import TencentVectorDBConfig
|
||||
from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig
|
||||
from configs.middleware.vdb.vikingdb_config import VikingDBConfig
|
||||
from configs.middleware.vdb.weaviate_config import WeaviateConfig
|
||||
|
||||
|
||||
@ -243,5 +244,6 @@ class MiddlewareConfig(
|
||||
WeaviateConfig,
|
||||
ElasticsearchConfig,
|
||||
InternalTestConfig,
|
||||
VikingDBConfig,
|
||||
):
|
||||
pass
|
||||
|
37
api/configs/middleware/vdb/vikingdb_config.py
Normal file
37
api/configs/middleware/vdb/vikingdb_config.py
Normal file
@ -0,0 +1,37 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VikingDBConfig(BaseModel):
|
||||
"""
|
||||
Configuration for connecting to Volcengine VikingDB.
|
||||
Refer to the following documentation for details on obtaining credentials:
|
||||
https://www.volcengine.com/docs/6291/65568
|
||||
"""
|
||||
|
||||
VIKINGDB_ACCESS_KEY: Optional[str] = Field(
|
||||
default=None, description="The Access Key provided by Volcengine VikingDB for API authentication."
|
||||
)
|
||||
VIKINGDB_SECRET_KEY: Optional[str] = Field(
|
||||
default=None, description="The Secret Key provided by Volcengine VikingDB for API authentication."
|
||||
)
|
||||
VIKINGDB_REGION: Optional[str] = Field(
|
||||
default="cn-shanghai",
|
||||
description="The region of the Volcengine VikingDB service.(e.g., 'cn-shanghai', 'cn-beijing').",
|
||||
)
|
||||
VIKINGDB_HOST: Optional[str] = Field(
|
||||
default="api-vikingdb.mlp.cn-shanghai.volces.com",
|
||||
description="The host of the Volcengine VikingDB service.(e.g., 'api-vikingdb.volces.com', \
|
||||
'api-vikingdb.mlp.cn-shanghai.volces.com')",
|
||||
)
|
||||
VIKINGDB_SCHEME: Optional[str] = Field(
|
||||
default="http",
|
||||
description="The scheme of the Volcengine VikingDB service.(e.g., 'http', 'https').",
|
||||
)
|
||||
VIKINGDB_CONNECTION_TIMEOUT: Optional[int] = Field(
|
||||
default=30, description="The connection timeout of the Volcengine VikingDB service."
|
||||
)
|
||||
VIKINGDB_SOCKET_TIMEOUT: Optional[int] = Field(
|
||||
default=30, description="The socket timeout of the Volcengine VikingDB service."
|
||||
)
|
@ -618,6 +618,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.TENCENT
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.BAIDU
|
||||
| VectorType.VIKINGDB
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
@ -655,6 +656,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
| VectorType.TENCENT
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.BAIDU
|
||||
| VectorType.VIKINGDB
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
|
@ -107,6 +107,10 @@ class Vector:
|
||||
from core.rag.datasource.vdb.baidu.baidu_vector import BaiduVectorFactory
|
||||
|
||||
return BaiduVectorFactory
|
||||
case VectorType.VIKINGDB:
|
||||
from core.rag.datasource.vdb.vikingdb.vikingdb_vector import VikingDBVectorFactory
|
||||
|
||||
return VikingDBVectorFactory
|
||||
case _:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
|
@ -17,3 +17,4 @@ class VectorType(str, Enum):
|
||||
ORACLE = "oracle"
|
||||
ELASTICSEARCH = "elasticsearch"
|
||||
BAIDU = "baidu"
|
||||
VIKINGDB = "vikingdb"
|
||||
|
0
api/core/rag/datasource/vdb/vikingdb/__init__.py
Normal file
0
api/core/rag/datasource/vdb/vikingdb/__init__.py
Normal file
239
api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py
Normal file
239
api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py
Normal file
@ -0,0 +1,239 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from volcengine.viking_db import (
|
||||
Data,
|
||||
DistanceType,
|
||||
Field,
|
||||
FieldType,
|
||||
IndexType,
|
||||
QuantType,
|
||||
VectorIndexParams,
|
||||
VikingDBService,
|
||||
)
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.field import Field as vdb_Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class VikingDBConfig(BaseModel):
|
||||
access_key: str
|
||||
secret_key: str
|
||||
host: str
|
||||
region: str
|
||||
scheme: str
|
||||
connection_timeout: int
|
||||
socket_timeout: int
|
||||
index_type: str = IndexType.HNSW
|
||||
distance: str = DistanceType.L2
|
||||
quant: str = QuantType.Float
|
||||
|
||||
|
||||
class VikingDBVector(BaseVector):
|
||||
def __init__(self, collection_name: str, group_id: str, config: VikingDBConfig):
|
||||
super().__init__(collection_name)
|
||||
self._group_id = group_id
|
||||
self._client_config = config
|
||||
self._index_name = f"{self._collection_name}_idx"
|
||||
self._client = VikingDBService(
|
||||
host=config.host,
|
||||
region=config.region,
|
||||
scheme=config.scheme,
|
||||
connection_timeout=config.connection_timeout,
|
||||
socket_timeout=config.socket_timeout,
|
||||
ak=config.access_key,
|
||||
sk=config.secret_key,
|
||||
)
|
||||
|
||||
def _has_collection(self) -> bool:
|
||||
try:
|
||||
self._client.get_collection(self._collection_name)
|
||||
except Exception:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _has_index(self) -> bool:
|
||||
try:
|
||||
self._client.get_index(self._collection_name, self._index_name)
|
||||
except Exception:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _create_collection(self, dimension: int):
|
||||
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
|
||||
if not self._has_collection():
|
||||
fields = [
|
||||
Field(field_name=vdb_Field.PRIMARY_KEY.value, field_type=FieldType.String, is_primary_key=True),
|
||||
Field(field_name=vdb_Field.METADATA_KEY.value, field_type=FieldType.String),
|
||||
Field(field_name=vdb_Field.GROUP_KEY.value, field_type=FieldType.String),
|
||||
Field(field_name=vdb_Field.CONTENT_KEY.value, field_type=FieldType.Text),
|
||||
Field(field_name=vdb_Field.VECTOR.value, field_type=FieldType.Vector, dim=dimension),
|
||||
]
|
||||
|
||||
self._client.create_collection(
|
||||
collection_name=self._collection_name,
|
||||
fields=fields,
|
||||
description="Collection For Dify",
|
||||
)
|
||||
|
||||
if not self._has_index():
|
||||
vector_index = VectorIndexParams(
|
||||
distance=self._client_config.distance,
|
||||
index_type=self._client_config.index_type,
|
||||
quant=self._client_config.quant,
|
||||
)
|
||||
|
||||
self._client.create_index(
|
||||
collection_name=self._collection_name,
|
||||
index_name=self._index_name,
|
||||
vector_index=vector_index,
|
||||
partition_by=vdb_Field.GROUP_KEY.value,
|
||||
description="Index For Dify",
|
||||
)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.VIKINGDB
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection(dimension)
|
||||
self.add_texts(texts, embeddings, **kwargs)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
page_contents = [doc.page_content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
docs = []
|
||||
|
||||
for i, page_content in enumerate(page_contents):
|
||||
metadata = {}
|
||||
if metadatas is not None:
|
||||
for key, val in metadatas[i].items():
|
||||
metadata[key] = val
|
||||
doc = Data(
|
||||
{
|
||||
vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"],
|
||||
vdb_Field.VECTOR.value: embeddings[i] if embeddings else None,
|
||||
vdb_Field.CONTENT_KEY.value: page_content,
|
||||
vdb_Field.METADATA_KEY.value: json.dumps(metadata),
|
||||
vdb_Field.GROUP_KEY.value: self._group_id,
|
||||
}
|
||||
)
|
||||
docs.append(doc)
|
||||
|
||||
self._client.get_collection(self._collection_name).upsert_data(docs)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
docs = self._client.get_collection(self._collection_name).fetch_data(id)
|
||||
not_exists_str = "data does not exist"
|
||||
if docs is not None and not_exists_str not in docs.fields.get("message", ""):
|
||||
return True
|
||||
return False
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
self._client.get_collection(self._collection_name).delete_data(ids)
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
# Note: Metadata field value is an dict, but vikingdb field
|
||||
# not support json type
|
||||
results = self._client.get_index(self._collection_name, self._index_name).search(
|
||||
filter={"op": "must", "field": vdb_Field.GROUP_KEY.value, "conds": [self._group_id]},
|
||||
# max value is 5000
|
||||
limit=5000,
|
||||
)
|
||||
|
||||
if not results:
|
||||
return []
|
||||
|
||||
ids = []
|
||||
for result in results:
|
||||
metadata = result.fields.get(vdb_Field.METADATA_KEY.value)
|
||||
if metadata is not None:
|
||||
metadata = json.loads(metadata)
|
||||
if metadata.get(key) == value:
|
||||
ids.append(result.id)
|
||||
return ids
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
self.delete_by_ids(ids)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
results = self._client.get_index(self._collection_name, self._index_name).search_by_vector(
|
||||
query_vector, limit=kwargs.get("top_k", 50)
|
||||
)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
return self._get_search_res(results, score_threshold)
|
||||
|
||||
def _get_search_res(self, results, score_threshold):
|
||||
if len(results) == 0:
|
||||
return []
|
||||
|
||||
docs = []
|
||||
for result in results:
|
||||
metadata = result.fields.get(vdb_Field.METADATA_KEY.value)
|
||||
if metadata is not None:
|
||||
metadata = json.loads(metadata)
|
||||
if result.score > score_threshold:
|
||||
metadata["score"] = result.score
|
||||
doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata)
|
||||
docs.append(doc)
|
||||
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return []
|
||||
|
||||
def delete(self) -> None:
|
||||
if self._has_index():
|
||||
self._client.drop_index(self._collection_name, self._index_name)
|
||||
if self._has_collection():
|
||||
self._client.drop_collection(self._collection_name)
|
||||
|
||||
|
||||
class VikingDBVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> VikingDBVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.VIKINGDB, collection_name))
|
||||
|
||||
if dify_config.VIKINGDB_ACCESS_KEY is None:
|
||||
raise ValueError("VIKINGDB_ACCESS_KEY should not be None")
|
||||
if dify_config.VIKINGDB_SECRET_KEY is None:
|
||||
raise ValueError("VIKINGDB_SECRET_KEY should not be None")
|
||||
if dify_config.VIKINGDB_HOST is None:
|
||||
raise ValueError("VIKINGDB_HOST should not be None")
|
||||
if dify_config.VIKINGDB_REGION is None:
|
||||
raise ValueError("VIKINGDB_REGION should not be None")
|
||||
if dify_config.VIKINGDB_SCHEME is None:
|
||||
raise ValueError("VIKINGDB_SCHEME should not be None")
|
||||
return VikingDBVector(
|
||||
collection_name=collection_name,
|
||||
group_id=dataset.id,
|
||||
config=VikingDBConfig(
|
||||
access_key=dify_config.VIKINGDB_ACCESS_KEY,
|
||||
secret_key=dify_config.VIKINGDB_SECRET_KEY,
|
||||
host=dify_config.VIKINGDB_HOST,
|
||||
region=dify_config.VIKINGDB_REGION,
|
||||
scheme=dify_config.VIKINGDB_SCHEME,
|
||||
connection_timeout=dify_config.VIKINGDB_CONNECTION_TIMEOUT,
|
||||
socket_timeout=dify_config.VIKINGDB_SOCKET_TIMEOUT,
|
||||
),
|
||||
)
|
73
api/poetry.lock
generated
73
api/poetry.lock
generated
@ -2038,6 +2038,17 @@ packaging = ">=17.0"
|
||||
pandas = ">=0.24.2"
|
||||
pyarrow = ">=3.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "decorator"
|
||||
version = "5.1.1"
|
||||
description = "Decorators for Humans"
|
||||
optional = false
|
||||
python-versions = ">=3.5"
|
||||
files = [
|
||||
{file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"},
|
||||
{file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "defusedxml"
|
||||
version = "0.7.1"
|
||||
@ -3027,6 +3038,20 @@ files = [
|
||||
docs = ["sphinx (>=4)", "sphinx-rtd-theme (>=1)"]
|
||||
tests = ["cython", "hypothesis", "mpmath", "pytest", "setuptools"]
|
||||
|
||||
[[package]]
|
||||
name = "google"
|
||||
version = "3.0.0"
|
||||
description = "Python bindings to the Google search engine."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "google-3.0.0-py2.py3-none-any.whl", hash = "sha256:889cf695f84e4ae2c55fbc0cfdaf4c1e729417fa52ab1db0485202ba173e4935"},
|
||||
{file = "google-3.0.0.tar.gz", hash = "sha256:143530122ee5130509ad5e989f0512f7cb218b2d4eddbafbad40fd10e8d8ccbe"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
beautifulsoup4 = "*"
|
||||
|
||||
[[package]]
|
||||
name = "google-ai-generativelanguage"
|
||||
version = "0.6.9"
|
||||
@ -6670,6 +6695,17 @@ files = [
|
||||
{file = "psycopg2_binary-2.9.9-cp39-cp39-win_amd64.whl", hash = "sha256:f7ae5d65ccfbebdfa761585228eb4d0df3a8b15cfb53bd953e713e09fbb12957"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "py"
|
||||
version = "1.11.0"
|
||||
description = "library with cross-python path, ini-parsing, io, code, log facilities"
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
||||
files = [
|
||||
{file = "py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378"},
|
||||
{file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "py-cpuinfo"
|
||||
version = "9.0.0"
|
||||
@ -8012,6 +8048,21 @@ files = [
|
||||
[package.dependencies]
|
||||
requests = "2.31.0"
|
||||
|
||||
[[package]]
|
||||
name = "retry"
|
||||
version = "0.9.2"
|
||||
description = "Easy to use retry decorator."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "retry-0.9.2-py2.py3-none-any.whl", hash = "sha256:ccddf89761fa2c726ab29391837d4327f819ea14d244c232a1d24c67a2f98606"},
|
||||
{file = "retry-0.9.2.tar.gz", hash = "sha256:f8bfa8b99b69c4506d6f5bd3b0aabf77f98cdb17f3c9fc3f5ca820033336fba4"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
decorator = ">=3.4.2"
|
||||
py = ">=1.4.26,<2.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "rich"
|
||||
version = "13.9.2"
|
||||
@ -9829,6 +9880,26 @@ files = [
|
||||
{file = "vine-5.1.0.tar.gz", hash = "sha256:8b62e981d35c41049211cf62a0a1242d8c1ee9bd15bb196ce38aefd6799e61e0"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "volcengine-compat"
|
||||
version = "1.0.156"
|
||||
description = "Be Compatible with the Volcengine SDK for Python, The version of package dependencies has been modified. like pycryptodome, pytz."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "volcengine_compat-1.0.156-py3-none-any.whl", hash = "sha256:4abc149a7601ebad8fa2d28fab50c7945145cf74daecb71bca797b0bdc82c5a5"},
|
||||
{file = "volcengine_compat-1.0.156.tar.gz", hash = "sha256:e357d096828e31a202dc6047bbc5bf6fff3f54a98cd35a99ab5f965ea741a267"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
google = ">=3.0.0"
|
||||
protobuf = ">=3.18.3"
|
||||
pycryptodome = ">=3.9.9"
|
||||
pytz = ">=2020.5"
|
||||
requests = ">=2.25.1"
|
||||
retry = ">=0.9.2"
|
||||
six = ">=1.0"
|
||||
|
||||
[[package]]
|
||||
name = "volcengine-python-sdk"
|
||||
version = "1.0.103"
|
||||
@ -10636,4 +10707,4 @@ cffi = ["cffi (>=1.11)"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<3.13"
|
||||
content-hash = "375ac3a91760513924647e67376cb6018505ec61d967651b254c68af9808d774"
|
||||
content-hash = "edb5e3b0d50e84a239224cc77f3f615fdbdd6b504bce5b1075b29363f3054957"
|
||||
|
@ -246,6 +246,7 @@ pymochow = "1.3.1"
|
||||
qdrant-client = "1.7.3"
|
||||
tcvectordb = "1.3.2"
|
||||
tidb-vector = "0.0.9"
|
||||
volcengine-compat = "~1.0.156"
|
||||
weaviate-client = "~3.21.0"
|
||||
|
||||
############################################################
|
||||
|
215
api/tests/integration_tests/vdb/__mock/vikingdb.py
Normal file
215
api/tests/integration_tests/vdb/__mock/vikingdb.py
Normal file
@ -0,0 +1,215 @@
|
||||
import os
|
||||
from typing import Union
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from volcengine.viking_db import (
|
||||
Collection,
|
||||
Data,
|
||||
DistanceType,
|
||||
Field,
|
||||
FieldType,
|
||||
Index,
|
||||
IndexType,
|
||||
QuantType,
|
||||
VectorIndexParams,
|
||||
VikingDBService,
|
||||
)
|
||||
|
||||
from core.rag.datasource.vdb.field import Field as vdb_Field
|
||||
|
||||
|
||||
class MockVikingDBClass:
|
||||
def __init__(
|
||||
self,
|
||||
host="api-vikingdb.volces.com",
|
||||
region="cn-north-1",
|
||||
ak="",
|
||||
sk="",
|
||||
scheme="http",
|
||||
connection_timeout=30,
|
||||
socket_timeout=30,
|
||||
proxy=None,
|
||||
):
|
||||
self._viking_db_service = MagicMock()
|
||||
self._viking_db_service.get_exception = MagicMock(return_value='{"data": {"primary_key": "test_id"}}')
|
||||
|
||||
def get_collection(self, collection_name) -> Collection:
|
||||
return Collection(
|
||||
collection_name=collection_name,
|
||||
description="Collection For Dify",
|
||||
viking_db_service=self._viking_db_service,
|
||||
primary_key=vdb_Field.PRIMARY_KEY.value,
|
||||
fields=[
|
||||
Field(field_name=vdb_Field.PRIMARY_KEY.value, field_type=FieldType.String, is_primary_key=True),
|
||||
Field(field_name=vdb_Field.METADATA_KEY.value, field_type=FieldType.String),
|
||||
Field(field_name=vdb_Field.GROUP_KEY.value, field_type=FieldType.String),
|
||||
Field(field_name=vdb_Field.CONTENT_KEY.value, field_type=FieldType.Text),
|
||||
Field(field_name=vdb_Field.VECTOR.value, field_type=FieldType.Vector, dim=768),
|
||||
],
|
||||
indexes=[
|
||||
Index(
|
||||
collection_name=collection_name,
|
||||
index_name=f"{collection_name}_idx",
|
||||
vector_index=VectorIndexParams(
|
||||
distance=DistanceType.L2,
|
||||
index_type=IndexType.HNSW,
|
||||
quant=QuantType.Float,
|
||||
),
|
||||
scalar_index=None,
|
||||
stat=None,
|
||||
viking_db_service=self._viking_db_service,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def drop_collection(self, collection_name):
|
||||
assert collection_name != ""
|
||||
|
||||
def create_collection(self, collection_name, fields, description="") -> Collection:
|
||||
return Collection(
|
||||
collection_name=collection_name,
|
||||
description=description,
|
||||
primary_key=vdb_Field.PRIMARY_KEY.value,
|
||||
viking_db_service=self._viking_db_service,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
def get_index(self, collection_name, index_name) -> Index:
|
||||
return Index(
|
||||
collection_name=collection_name,
|
||||
index_name=index_name,
|
||||
viking_db_service=self._viking_db_service,
|
||||
stat=None,
|
||||
scalar_index=None,
|
||||
vector_index=VectorIndexParams(
|
||||
distance=DistanceType.L2,
|
||||
index_type=IndexType.HNSW,
|
||||
quant=QuantType.Float,
|
||||
),
|
||||
)
|
||||
|
||||
def create_index(
|
||||
self,
|
||||
collection_name,
|
||||
index_name,
|
||||
vector_index=None,
|
||||
cpu_quota=2,
|
||||
description="",
|
||||
partition_by="",
|
||||
scalar_index=None,
|
||||
shard_count=None,
|
||||
shard_policy=None,
|
||||
):
|
||||
return Index(
|
||||
collection_name=collection_name,
|
||||
index_name=index_name,
|
||||
vector_index=vector_index,
|
||||
cpu_quota=cpu_quota,
|
||||
description=description,
|
||||
partition_by=partition_by,
|
||||
scalar_index=scalar_index,
|
||||
shard_count=shard_count,
|
||||
shard_policy=shard_policy,
|
||||
viking_db_service=self._viking_db_service,
|
||||
stat=None,
|
||||
)
|
||||
|
||||
def drop_index(self, collection_name, index_name):
|
||||
assert collection_name != ""
|
||||
assert index_name != ""
|
||||
|
||||
def upsert_data(self, data: Union[Data, list[Data]]):
|
||||
assert data is not None
|
||||
|
||||
def fetch_data(self, id: Union[str, list[str], int, list[int]]):
|
||||
return Data(
|
||||
fields={
|
||||
vdb_Field.GROUP_KEY.value: "test_group",
|
||||
vdb_Field.METADATA_KEY.value: "{}",
|
||||
vdb_Field.CONTENT_KEY.value: "content",
|
||||
vdb_Field.PRIMARY_KEY.value: id,
|
||||
vdb_Field.VECTOR.value: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398],
|
||||
},
|
||||
id=id,
|
||||
)
|
||||
|
||||
def delete_data(self, id: Union[str, list[str], int, list[int]]):
|
||||
assert id is not None
|
||||
|
||||
def search_by_vector(
|
||||
self,
|
||||
vector,
|
||||
sparse_vectors=None,
|
||||
filter=None,
|
||||
limit=10,
|
||||
output_fields=None,
|
||||
partition="default",
|
||||
dense_weight=None,
|
||||
) -> list[Data]:
|
||||
return [
|
||||
Data(
|
||||
fields={
|
||||
vdb_Field.GROUP_KEY.value: "test_group",
|
||||
vdb_Field.METADATA_KEY.value: '\
|
||||
{"source": "/var/folders/ml/xxx/xxx.txt", \
|
||||
"document_id": "test_document_id", \
|
||||
"dataset_id": "test_dataset_id", \
|
||||
"doc_id": "test_id", \
|
||||
"doc_hash": "test_hash"}',
|
||||
vdb_Field.CONTENT_KEY.value: "content",
|
||||
vdb_Field.PRIMARY_KEY.value: "test_id",
|
||||
vdb_Field.VECTOR.value: vector,
|
||||
},
|
||||
id="test_id",
|
||||
score=0.10,
|
||||
)
|
||||
]
|
||||
|
||||
def search(
|
||||
self, order=None, filter=None, limit=10, output_fields=None, partition="default", dense_weight=None
|
||||
) -> list[Data]:
|
||||
return [
|
||||
Data(
|
||||
fields={
|
||||
vdb_Field.GROUP_KEY.value: "test_group",
|
||||
vdb_Field.METADATA_KEY.value: '\
|
||||
{"source": "/var/folders/ml/xxx/xxx.txt", \
|
||||
"document_id": "test_document_id", \
|
||||
"dataset_id": "test_dataset_id", \
|
||||
"doc_id": "test_id", \
|
||||
"doc_hash": "test_hash"}',
|
||||
vdb_Field.CONTENT_KEY.value: "content",
|
||||
vdb_Field.PRIMARY_KEY.value: "test_id",
|
||||
vdb_Field.VECTOR.value: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398],
|
||||
},
|
||||
id="test_id",
|
||||
score=0.10,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_vikingdb_mock(monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(VikingDBService, "__init__", MockVikingDBClass.__init__)
|
||||
monkeypatch.setattr(VikingDBService, "get_collection", MockVikingDBClass.get_collection)
|
||||
monkeypatch.setattr(VikingDBService, "create_collection", MockVikingDBClass.create_collection)
|
||||
monkeypatch.setattr(VikingDBService, "drop_collection", MockVikingDBClass.drop_collection)
|
||||
monkeypatch.setattr(VikingDBService, "get_index", MockVikingDBClass.get_index)
|
||||
monkeypatch.setattr(VikingDBService, "create_index", MockVikingDBClass.create_index)
|
||||
monkeypatch.setattr(VikingDBService, "drop_index", MockVikingDBClass.drop_index)
|
||||
monkeypatch.setattr(Collection, "upsert_data", MockVikingDBClass.upsert_data)
|
||||
monkeypatch.setattr(Collection, "fetch_data", MockVikingDBClass.fetch_data)
|
||||
monkeypatch.setattr(Collection, "delete_data", MockVikingDBClass.delete_data)
|
||||
monkeypatch.setattr(Index, "search_by_vector", MockVikingDBClass.search_by_vector)
|
||||
monkeypatch.setattr(Index, "search", MockVikingDBClass.search)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
37
api/tests/integration_tests/vdb/vikingdb/test_vikingdb.py
Normal file
37
api/tests/integration_tests/vdb/vikingdb/test_vikingdb.py
Normal file
@ -0,0 +1,37 @@
|
||||
from core.rag.datasource.vdb.vikingdb.vikingdb_vector import VikingDBConfig, VikingDBVector
|
||||
from tests.integration_tests.vdb.__mock.vikingdb import setup_vikingdb_mock
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
|
||||
|
||||
|
||||
class VikingDBVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = VikingDBVector(
|
||||
"test_collection",
|
||||
"test_group",
|
||||
config=VikingDBConfig(
|
||||
access_key="test_access_key",
|
||||
host="test_host",
|
||||
region="test_region",
|
||||
scheme="test_scheme",
|
||||
secret_key="test_secret_key",
|
||||
connection_timeout=30,
|
||||
socket_timeout=30,
|
||||
),
|
||||
)
|
||||
|
||||
def search_by_vector(self):
|
||||
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
|
||||
assert len(hits_by_vector) == 1
|
||||
|
||||
def search_by_full_text(self):
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 0
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value="test_document_id")
|
||||
assert len(ids) > 0
|
||||
|
||||
|
||||
def test_vikingdb_vector(setup_mock_redis, setup_vikingdb_mock):
|
||||
VikingDBVectorTest().run_all_tests()
|
@ -7,4 +7,5 @@ pytest api/tests/integration_tests/vdb/chroma \
|
||||
api/tests/integration_tests/vdb/pgvector \
|
||||
api/tests/integration_tests/vdb/qdrant \
|
||||
api/tests/integration_tests/vdb/weaviate \
|
||||
api/tests/integration_tests/vdb/elasticsearch
|
||||
api/tests/integration_tests/vdb/elasticsearch \
|
||||
api/tests/integration_tests/vdb/vikingdb
|
||||
|
@ -173,6 +173,11 @@ x-shared-env: &shared-api-worker-env
|
||||
BAIDU_VECTOR_DB_DATABASE: ${BAIDU_VECTOR_DB_DATABASE:-dify}
|
||||
BAIDU_VECTOR_DB_SHARD: ${BAIDU_VECTOR_DB_SHARD:-1}
|
||||
BAIDU_VECTOR_DB_REPLICAS: ${BAIDU_VECTOR_DB_REPLICAS:-3}
|
||||
VIKINGDB_ACCESS_KEY: ${VIKINGDB_ACCESS_KEY:-dify}
|
||||
VIKINGDB_SECRET_KEY: ${VIKINGDB_SECRET_KEY:-dify}
|
||||
VIKINGDB_REGION: ${VIKINGDB_REGION:-cn-shanghai}
|
||||
VIKINGDB_HOST: ${VIKINGDB_HOST:-api-vikingdb.xxx.volces.com}
|
||||
VIKINGDB_SCHEMA: ${VIKINGDB_SCHEMA:-http}
|
||||
UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15}
|
||||
UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5}
|
||||
ETL_TYPE: ${ETL_TYPE:-dify}
|
||||
|
Loading…
x
Reference in New Issue
Block a user