mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-05-31 10:35:04 +08:00
194 lines
7.0 KiB
Python
194 lines
7.0 KiB
Python
import os
|
|
from typing import Optional, Union
|
|
|
|
import pytest
|
|
from _pytest.monkeypatch import MonkeyPatch
|
|
from requests.adapters import HTTPAdapter
|
|
from tcvectordb import RPCVectorDBClient # type: ignore
|
|
from tcvectordb.model import enum
|
|
from tcvectordb.model.collection import FilterIndexConfig
|
|
from tcvectordb.model.document import AnnSearch, Document, Filter, KeywordSearch, Rerank # type: ignore
|
|
from tcvectordb.model.enum import ReadConsistency # type: ignore
|
|
from tcvectordb.model.index import FilterIndex, HNSWParams, Index, IndexField, VectorIndex # type: ignore
|
|
from tcvectordb.rpc.model.collection import RPCCollection
|
|
from tcvectordb.rpc.model.database import RPCDatabase
|
|
from xinference_client.types import Embedding # type: ignore
|
|
|
|
|
|
class MockTcvectordbClass:
|
|
def mock_vector_db_client(
|
|
self,
|
|
url: str,
|
|
username="",
|
|
key="",
|
|
read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY,
|
|
timeout=10,
|
|
adapter: HTTPAdapter = None,
|
|
pool_size: int = 2,
|
|
proxies: Optional[dict] = None,
|
|
password: Optional[str] = None,
|
|
**kwargs,
|
|
):
|
|
self._conn = None
|
|
self._read_consistency = read_consistency
|
|
|
|
def create_database_if_not_exists(self, database_name: str, timeout: Optional[float] = None) -> RPCDatabase:
|
|
return RPCDatabase(
|
|
name="dify",
|
|
read_consistency=self._read_consistency,
|
|
)
|
|
|
|
def exists_collection(self, database_name: str, collection_name: str) -> bool:
|
|
return True
|
|
|
|
def describe_collection(
|
|
self, database_name: str, collection_name: str, timeout: Optional[float] = None
|
|
) -> RPCCollection:
|
|
index = Index(
|
|
FilterIndex("id", enum.FieldType.String, enum.IndexType.PRIMARY_KEY),
|
|
VectorIndex(
|
|
"vector",
|
|
128,
|
|
enum.IndexType.HNSW,
|
|
enum.MetricType.IP,
|
|
HNSWParams(m=16, efconstruction=200),
|
|
),
|
|
FilterIndex("text", enum.FieldType.String, enum.IndexType.FILTER),
|
|
FilterIndex("metadata", enum.FieldType.String, enum.IndexType.FILTER),
|
|
)
|
|
return RPCCollection(
|
|
RPCDatabase(
|
|
name=database_name,
|
|
read_consistency=self._read_consistency,
|
|
),
|
|
collection_name,
|
|
index=index,
|
|
)
|
|
|
|
def create_collection(
|
|
self,
|
|
database_name: str,
|
|
collection_name: str,
|
|
shard: int,
|
|
replicas: int,
|
|
description: Optional[str] = None,
|
|
index: Index = None,
|
|
embedding: Embedding = None,
|
|
timeout: Optional[float] = None,
|
|
ttl_config: Optional[dict] = None,
|
|
filter_index_config: FilterIndexConfig = None,
|
|
indexes: Optional[list[IndexField]] = None,
|
|
) -> RPCCollection:
|
|
return RPCCollection(
|
|
RPCDatabase(
|
|
name="dify",
|
|
read_consistency=self._read_consistency,
|
|
),
|
|
collection_name,
|
|
shard,
|
|
replicas,
|
|
description,
|
|
index,
|
|
embedding=embedding,
|
|
read_consistency=self._read_consistency,
|
|
timeout=timeout,
|
|
ttl_config=ttl_config,
|
|
filter_index_config=filter_index_config,
|
|
indexes=indexes,
|
|
)
|
|
|
|
def collection_upsert(
|
|
self,
|
|
database_name: str,
|
|
collection_name: str,
|
|
documents: list[Union[Document, dict]],
|
|
timeout: Optional[float] = None,
|
|
build_index: bool = True,
|
|
**kwargs,
|
|
):
|
|
return {"code": 0, "msg": "operation success"}
|
|
|
|
def collection_search(
|
|
self,
|
|
database_name: str,
|
|
collection_name: str,
|
|
vectors: list[list[float]],
|
|
filter: Filter = None,
|
|
params=None,
|
|
retrieve_vector: bool = False,
|
|
limit: int = 10,
|
|
output_fields: Optional[list[str]] = None,
|
|
timeout: Optional[float] = None,
|
|
) -> list[list[dict]]:
|
|
return [[{"metadata": {"doc_id": "foo1"}, "text": "text", "doc_id": "foo1", "score": 0.1}]]
|
|
|
|
def collection_hybrid_search(
|
|
self,
|
|
database_name: str,
|
|
collection_name: str,
|
|
ann: Optional[Union[list[AnnSearch], AnnSearch]] = None,
|
|
match: Optional[Union[list[KeywordSearch], KeywordSearch]] = None,
|
|
filter: Union[Filter, str] = None,
|
|
rerank: Optional[Rerank] = None,
|
|
retrieve_vector: Optional[bool] = None,
|
|
output_fields: Optional[list[str]] = None,
|
|
limit: Optional[int] = None,
|
|
timeout: Optional[float] = None,
|
|
return_pd_object=False,
|
|
**kwargs,
|
|
) -> list[list[dict]]:
|
|
return [[{"metadata": {"doc_id": "foo1"}, "text": "text", "doc_id": "foo1", "score": 0.1}]]
|
|
|
|
def collection_query(
|
|
self,
|
|
database_name: str,
|
|
collection_name: str,
|
|
document_ids: Optional[list] = None,
|
|
retrieve_vector: bool = False,
|
|
limit: Optional[int] = None,
|
|
offset: Optional[int] = None,
|
|
filter: Optional[Filter] = None,
|
|
output_fields: Optional[list[str]] = None,
|
|
timeout: Optional[float] = None,
|
|
) -> list[dict]:
|
|
return [{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}]
|
|
|
|
def collection_delete(
|
|
self,
|
|
database_name: str,
|
|
collection_name: str,
|
|
document_ids: Optional[list[str]] = None,
|
|
filter: Filter = None,
|
|
timeout: Optional[float] = None,
|
|
):
|
|
return {"code": 0, "msg": "operation success"}
|
|
|
|
def drop_collection(self, database_name: str, collection_name: str, timeout: Optional[float] = None) -> dict:
|
|
return {"code": 0, "msg": "operation success"}
|
|
|
|
|
|
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
|
|
|
|
|
@pytest.fixture
|
|
def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch):
|
|
if MOCK:
|
|
monkeypatch.setattr(RPCVectorDBClient, "__init__", MockTcvectordbClass.mock_vector_db_client)
|
|
monkeypatch.setattr(
|
|
RPCVectorDBClient, "create_database_if_not_exists", MockTcvectordbClass.create_database_if_not_exists
|
|
)
|
|
monkeypatch.setattr(RPCVectorDBClient, "exists_collection", MockTcvectordbClass.exists_collection)
|
|
monkeypatch.setattr(RPCVectorDBClient, "create_collection", MockTcvectordbClass.create_collection)
|
|
monkeypatch.setattr(RPCVectorDBClient, "describe_collection", MockTcvectordbClass.describe_collection)
|
|
monkeypatch.setattr(RPCVectorDBClient, "upsert", MockTcvectordbClass.collection_upsert)
|
|
monkeypatch.setattr(RPCVectorDBClient, "search", MockTcvectordbClass.collection_search)
|
|
monkeypatch.setattr(RPCVectorDBClient, "hybrid_search", MockTcvectordbClass.collection_hybrid_search)
|
|
monkeypatch.setattr(RPCVectorDBClient, "query", MockTcvectordbClass.collection_query)
|
|
monkeypatch.setattr(RPCVectorDBClient, "delete", MockTcvectordbClass.collection_delete)
|
|
monkeypatch.setattr(RPCVectorDBClient, "drop_collection", MockTcvectordbClass.drop_collection)
|
|
|
|
yield
|
|
|
|
if MOCK:
|
|
monkeypatch.undo()
|