From 45dd1683fd24d9183b40426888c713413fb1b4ee Mon Sep 17 00:00:00 2001 From: Bowen Liang Date: Thu, 25 Apr 2024 22:27:30 +0800 Subject: [PATCH] test: add tests covering all methods of vector store (#3849) --- api/core/rag/datasource/vdb/vector_base.py | 6 ++++ .../vdb/milvus/test_milvus.py | 9 ++++- .../vdb/qdrant/test_qdrant.py | 2 +- .../vdb/test_vector_store.py | 33 ++++++++++++++++++- .../vdb/weaviate/test_weaviate.py | 2 +- 5 files changed, 48 insertions(+), 4 deletions(-) diff --git a/api/core/rag/datasource/vdb/vector_base.py b/api/core/rag/datasource/vdb/vector_base.py index 69ed4ed51c..c5212aa8f2 100644 --- a/api/core/rag/datasource/vdb/vector_base.py +++ b/api/core/rag/datasource/vdb/vector_base.py @@ -27,6 +27,12 @@ class BaseVector(ABC): def delete_by_ids(self, ids: list[str]) -> None: raise NotImplementedError + def delete_by_document_id(self, document_id: str): + raise NotImplementedError + + def get_ids_by_metadata_field(self, key: str, value: str): + raise NotImplementedError + @abstractmethod def delete_by_metadata_field(self, key: str, value: str) -> None: raise NotImplementedError diff --git a/api/tests/integration_tests/vdb/milvus/test_milvus.py b/api/tests/integration_tests/vdb/milvus/test_milvus.py index e829a8e4d0..09b1866b79 100644 --- a/api/tests/integration_tests/vdb/milvus/test_milvus.py +++ b/api/tests/integration_tests/vdb/milvus/test_milvus.py @@ -24,6 +24,13 @@ class TestMilvusVector(AbstractTestVector): hits_by_full_text = self.vector.search_by_full_text(query=get_sample_text()) assert len(hits_by_full_text) == 0 + def delete_document_by_id(self): + self.vector.delete_by_document_id(self.dataset_id) + + def get_ids_by_metadata_field(self): + ids = self.vector.get_ids_by_metadata_field('document_id', self.dataset_id) + assert len(ids) >= 1 + def test_milvus_vector(setup_mock_redis): - TestMilvusVector().run_all_test() + TestMilvusVector().run_all_tests() diff --git a/api/tests/integration_tests/vdb/qdrant/test_qdrant.py b/api/tests/integration_tests/vdb/qdrant/test_qdrant.py index 0ef3a253b8..ba69206601 100644 --- a/api/tests/integration_tests/vdb/qdrant/test_qdrant.py +++ b/api/tests/integration_tests/vdb/qdrant/test_qdrant.py @@ -20,4 +20,4 @@ class TestQdrantVector(AbstractTestVector): def test_qdrant_vector(setup_mock_redis): - TestQdrantVector().run_all_test() + TestQdrantVector().run_all_tests() diff --git a/api/tests/integration_tests/vdb/test_vector_store.py b/api/tests/integration_tests/vdb/test_vector_store.py index ab770be228..ccfae8b64b 100644 --- a/api/tests/integration_tests/vdb/test_vector_store.py +++ b/api/tests/integration_tests/vdb/test_vector_store.py @@ -71,8 +71,39 @@ class AbstractTestVector: def delete_vector(self): self.vector.delete() - def run_all_test(self): + def delete_by_ids(self): + self.vector.delete_by_ids([self.dataset_id]) + + def add_texts(self): + self.vector.add_texts( + documents=[ + get_sample_document(str(uuid.uuid4())), + get_sample_document(str(uuid.uuid4())), + ], + embeddings=[ + get_sample_embedding(), + get_sample_embedding(), + ], + ) + + def text_exists(self): + self.vector.text_exists(self.dataset_id) + + def delete_document_by_id(self): + with pytest.raises(NotImplementedError): + self.vector.delete_by_document_id(self.dataset_id) + + def get_ids_by_metadata_field(self): + with pytest.raises(NotImplementedError): + self.vector.get_ids_by_metadata_field('key', 'value') + + def run_all_tests(self): self.create_vector() self.search_by_vector() self.search_by_full_text() + self.text_exists() + self.get_ids_by_metadata_field() + self.add_texts() + self.delete_document_by_id() + self.delete_by_ids() self.delete_vector() diff --git a/api/tests/integration_tests/vdb/weaviate/test_weaviate.py b/api/tests/integration_tests/vdb/weaviate/test_weaviate.py index 338e331454..625444c3c0 100644 --- a/api/tests/integration_tests/vdb/weaviate/test_weaviate.py +++ b/api/tests/integration_tests/vdb/weaviate/test_weaviate.py @@ -21,4 +21,4 @@ class TestWeaviateVector(AbstractTestVector): def test_weaviate_vector(setup_mock_redis): - TestWeaviateVector().run_all_test() + TestWeaviateVector().run_all_tests()