test: improve vector store tests (#3855)

This commit is contained in:
Bowen Liang 2024-04-26 19:18:42 +08:00 committed by GitHub
parent 4d66a86579
commit 045827043d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 58 additions and 74 deletions

View File

@ -37,27 +37,6 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Set up Weaviate
uses: hoverkraft-tech/compose-action@v2.0.0
with:
compose-file: docker/docker-compose.middleware.yaml
services: weaviate
- name: Set up Qdrant
uses: hoverkraft-tech/compose-action@v2.0.0
with:
compose-file: docker/docker-compose.qdrant.yaml
services: qdrant
- name: Set up Milvus
uses: hoverkraft-tech/compose-action@v2.0.0
with:
compose-file: docker/docker-compose.milvus.yaml
services: |
etcd
minio
milvus-standalone
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
@ -82,5 +61,19 @@ jobs:
- name: Run Workflow - name: Run Workflow
run: dev/pytest/pytest_workflow.sh run: dev/pytest/pytest_workflow.sh
- name: Run Vector Stores - name: Set up Vector Stores (Weaviate, Qdrant and Milvus)
uses: hoverkraft-tech/compose-action@v2.0.0
with:
compose-file: |
docker/docker-compose.middleware.yaml
docker/docker-compose.qdrant.yaml
docker/docker-compose.milvus.yaml
services: |
weaviate
qdrant
etcd
minio
milvus-standalone
- name: Test Vector Stores
run: dev/pytest/pytest_vdb.sh run: dev/pytest/pytest_vdb.sh

View File

@ -124,7 +124,7 @@ class MilvusVector(BaseVector):
if ids: if ids:
self._client.delete(collection_name=self._collection_name, pks=ids) self._client.delete(collection_name=self._collection_name, pks=ids)
def delete_by_ids(self, doc_ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
alias = uuid4().hex alias = uuid4().hex
if self._client_config.secure: if self._client_config.secure:
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
@ -136,7 +136,7 @@ class MilvusVector(BaseVector):
if utility.has_collection(self._collection_name, using=alias): if utility.has_collection(self._collection_name, using=alias):
result = self._client.query(collection_name=self._collection_name, result = self._client.query(collection_name=self._collection_name,
filter=f'metadata["doc_id"] in {doc_ids}', filter=f'metadata["doc_id"] in {ids}',
output_fields=["id"]) output_fields=["id"])
if result: if result:
ids = [item["id"] for item in result] ids = [item["id"] for item in result]

View File

@ -199,10 +199,10 @@ class RelytVector(BaseVector):
if ids: if ids:
self.delete_by_uuids(ids) self.delete_by_uuids(ids)
def delete_by_ids(self, doc_ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
with Session(self.client) as session: with Session(self.client) as session:
ids_str = ','.join(f"'{doc_id}'" for doc_id in doc_ids) ids_str = ','.join(f"'{doc_id}'" for doc_id in ids)
select_statement = sql_text( select_statement = sql_text(
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' in ({ids_str}); """ f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' in ({ids_str}); """
) )

View File

@ -1,7 +1,7 @@
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector
from tests.integration_tests.vdb.test_vector_store import ( from tests.integration_tests.vdb.test_vector_store import (
AbstractTestVector, AbstractTestVector,
get_sample_text, get_example_text,
setup_mock_redis, setup_mock_redis,
) )
@ -21,15 +21,15 @@ class TestMilvusVector(AbstractTestVector):
def search_by_full_text(self): def search_by_full_text(self):
# milvus dos not support full text searching yet in < 2.3.x # milvus dos not support full text searching yet in < 2.3.x
hits_by_full_text = self.vector.search_by_full_text(query=get_sample_text()) hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 0 assert len(hits_by_full_text) == 0
def delete_document_by_id(self): def delete_by_document_id(self):
self.vector.delete_by_document_id(self.dataset_id) self.vector.delete_by_document_id(document_id=self.example_doc_id)
def get_ids_by_metadata_field(self): def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field('document_id', self.dataset_id) ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id)
assert len(ids) >= 1 assert len(ids) == 1
def test_milvus_vector(setup_mock_redis): def test_milvus_vector(setup_mock_redis):

View File

@ -1,3 +1,4 @@
import random
import uuid import uuid
from unittest.mock import MagicMock from unittest.mock import MagicMock
@ -8,26 +9,18 @@ from extensions import ext_redis
from models.dataset import Dataset from models.dataset import Dataset
def get_sample_text() -> str: def get_example_text() -> str:
return 'test_text' return 'test_text'
def get_sample_embedding() -> list[float]: def get_example_document(doc_id: str) -> Document:
return [1.1, 2.2, 3.3]
def get_sample_query_vector() -> list[float]:
return get_sample_embedding()
def get_sample_document(sample_dataset_id: str) -> Document:
doc = Document( doc = Document(
page_content=get_sample_text(), page_content=get_example_text(),
metadata={ metadata={
"doc_id": sample_dataset_id, "doc_id": doc_id,
"doc_hash": sample_dataset_id, "doc_hash": doc_id,
"document_id": sample_dataset_id, "document_id": doc_id,
"dataset_id": sample_dataset_id, "dataset_id": doc_id,
} }
) )
return doc return doc
@ -53,49 +46,48 @@ class AbstractTestVector:
self.vector = None self.vector = None
self.dataset_id = str(uuid.uuid4()) self.dataset_id = str(uuid.uuid4())
self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id)
self.example_doc_id = str(uuid.uuid4())
self.example_embedding = [1.001 * i for i in range(128)]
def create_vector(self) -> None: def create_vector(self) -> None:
self.vector.create( self.vector.create(
texts=[get_sample_document(self.dataset_id)], texts=[get_example_document(doc_id=self.example_doc_id)],
embeddings=[get_sample_embedding()], embeddings=[self.example_embedding],
) )
def search_by_vector(self): def search_by_vector(self):
hits_by_vector = self.vector.search_by_vector(query_vector=get_sample_query_vector()) hits_by_vector: list[Document] = self.vector.search_by_vector(query_vector=self.example_embedding)
assert len(hits_by_vector) >= 1 assert len(hits_by_vector) == 1
assert hits_by_vector[0].metadata['doc_id'] == self.example_doc_id
def search_by_full_text(self): def search_by_full_text(self):
hits_by_full_text = self.vector.search_by_full_text(query=get_sample_text()) hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) >= 1 assert len(hits_by_full_text) == 1
assert hits_by_full_text[0].metadata['doc_id'] == self.example_doc_id
def delete_vector(self): def delete_vector(self):
self.vector.delete() self.vector.delete()
def delete_by_ids(self): def delete_by_ids(self, ids: list[str]):
self.vector.delete_by_ids([self.dataset_id]) self.vector.delete_by_ids(ids=ids)
def add_texts(self): def add_texts(self) -> list[str]:
self.vector.add_texts( batch_size = 100
documents=[ documents = [get_example_document(doc_id=str(uuid.uuid4())) for _ in range(batch_size)]
get_sample_document(str(uuid.uuid4())), embeddings = [self.example_embedding] * batch_size
get_sample_document(str(uuid.uuid4())), self.vector.add_texts(documents=documents, embeddings=embeddings)
], return [doc.metadata['doc_id'] for doc in documents]
embeddings=[
get_sample_embedding(),
get_sample_embedding(),
],
)
def text_exists(self): def text_exists(self):
self.vector.text_exists(self.dataset_id) assert self.vector.text_exists(self.example_doc_id)
def delete_document_by_id(self): def delete_by_document_id(self):
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
self.vector.delete_by_document_id(self.dataset_id) self.vector.delete_by_document_id(document_id=self.example_doc_id)
def get_ids_by_metadata_field(self): def get_ids_by_metadata_field(self):
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
self.vector.get_ids_by_metadata_field('key', 'value') self.vector.get_ids_by_metadata_field(key='key', value='value')
def run_all_tests(self): def run_all_tests(self):
self.create_vector() self.create_vector()
@ -103,7 +95,7 @@ class AbstractTestVector:
self.search_by_full_text() self.search_by_full_text()
self.text_exists() self.text_exists()
self.get_ids_by_metadata_field() self.get_ids_by_metadata_field()
self.add_texts() self.delete_by_document_id()
self.delete_document_by_id() added_doc_ids = self.add_texts()
self.delete_by_ids() self.delete_by_ids(added_doc_ids)
self.delete_vector() self.delete_vector()

View File

@ -1,5 +1,4 @@
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
from models.dataset import Dataset
from tests.integration_tests.vdb.test_vector_store import ( from tests.integration_tests.vdb.test_vector_store import (
AbstractTestVector, AbstractTestVector,
setup_mock_redis, setup_mock_redis,