diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index 6dde4d71c6..6bdc0e726d 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -37,27 +37,6 @@ jobs: - name: Checkout code uses: actions/checkout@v4 - - name: Set up Weaviate - uses: hoverkraft-tech/compose-action@v2.0.0 - with: - compose-file: docker/docker-compose.middleware.yaml - services: weaviate - - - name: Set up Qdrant - uses: hoverkraft-tech/compose-action@v2.0.0 - with: - compose-file: docker/docker-compose.qdrant.yaml - services: qdrant - - - name: Set up Milvus - uses: hoverkraft-tech/compose-action@v2.0.0 - with: - compose-file: docker/docker-compose.milvus.yaml - services: | - etcd - minio - milvus-standalone - - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: @@ -82,5 +61,19 @@ jobs: - name: Run Workflow run: dev/pytest/pytest_workflow.sh - - name: Run Vector Stores + - name: Set up Vector Stores (Weaviate, Qdrant and Milvus) + uses: hoverkraft-tech/compose-action@v2.0.0 + with: + compose-file: | + docker/docker-compose.middleware.yaml + docker/docker-compose.qdrant.yaml + docker/docker-compose.milvus.yaml + services: | + weaviate + qdrant + etcd + minio + milvus-standalone + + - name: Test Vector Stores run: dev/pytest/pytest_vdb.sh diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index 63cf502149..c90fe3b188 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -124,7 +124,7 @@ class MilvusVector(BaseVector): if ids: self._client.delete(collection_name=self._collection_name, pks=ids) - def delete_by_ids(self, doc_ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]) -> None: alias = uuid4().hex if self._client_config.secure: uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) @@ -136,7 +136,7 @@ class MilvusVector(BaseVector): if utility.has_collection(self._collection_name, using=alias): result = self._client.query(collection_name=self._collection_name, - filter=f'metadata["doc_id"] in {doc_ids}', + filter=f'metadata["doc_id"] in {ids}', output_fields=["id"]) if result: ids = [item["id"] for item in result] diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index b9e87d0c40..c7d3575352 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -199,10 +199,10 @@ class RelytVector(BaseVector): if ids: self.delete_by_uuids(ids) - def delete_by_ids(self, doc_ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]) -> None: with Session(self.client) as session: - ids_str = ','.join(f"'{doc_id}'" for doc_id in doc_ids) + ids_str = ','.join(f"'{doc_id}'" for doc_id in ids) select_statement = sql_text( f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' in ({ids_str}); """ ) diff --git a/api/tests/integration_tests/vdb/milvus/test_milvus.py b/api/tests/integration_tests/vdb/milvus/test_milvus.py index 09b1866b79..00c4140003 100644 --- a/api/tests/integration_tests/vdb/milvus/test_milvus.py +++ b/api/tests/integration_tests/vdb/milvus/test_milvus.py @@ -1,7 +1,7 @@ from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector from tests.integration_tests.vdb.test_vector_store import ( AbstractTestVector, - get_sample_text, + get_example_text, setup_mock_redis, ) @@ -21,15 +21,15 @@ class TestMilvusVector(AbstractTestVector): def search_by_full_text(self): # milvus dos not support full text searching yet in < 2.3.x - hits_by_full_text = self.vector.search_by_full_text(query=get_sample_text()) + hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) assert len(hits_by_full_text) == 0 - def delete_document_by_id(self): - self.vector.delete_by_document_id(self.dataset_id) + def delete_by_document_id(self): + self.vector.delete_by_document_id(document_id=self.example_doc_id) def get_ids_by_metadata_field(self): - ids = self.vector.get_ids_by_metadata_field('document_id', self.dataset_id) - assert len(ids) >= 1 + ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) + assert len(ids) == 1 def test_milvus_vector(setup_mock_redis): diff --git a/api/tests/integration_tests/vdb/test_vector_store.py b/api/tests/integration_tests/vdb/test_vector_store.py index ccfae8b64b..fd64d445ed 100644 --- a/api/tests/integration_tests/vdb/test_vector_store.py +++ b/api/tests/integration_tests/vdb/test_vector_store.py @@ -1,3 +1,4 @@ +import random import uuid from unittest.mock import MagicMock @@ -8,26 +9,18 @@ from extensions import ext_redis from models.dataset import Dataset -def get_sample_text() -> str: +def get_example_text() -> str: return 'test_text' -def get_sample_embedding() -> list[float]: - return [1.1, 2.2, 3.3] - - -def get_sample_query_vector() -> list[float]: - return get_sample_embedding() - - -def get_sample_document(sample_dataset_id: str) -> Document: +def get_example_document(doc_id: str) -> Document: doc = Document( - page_content=get_sample_text(), + page_content=get_example_text(), metadata={ - "doc_id": sample_dataset_id, - "doc_hash": sample_dataset_id, - "document_id": sample_dataset_id, - "dataset_id": sample_dataset_id, + "doc_id": doc_id, + "doc_hash": doc_id, + "document_id": doc_id, + "dataset_id": doc_id, } ) return doc @@ -53,49 +46,48 @@ class AbstractTestVector: self.vector = None self.dataset_id = str(uuid.uuid4()) self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + self.example_doc_id = str(uuid.uuid4()) + self.example_embedding = [1.001 * i for i in range(128)] def create_vector(self) -> None: self.vector.create( - texts=[get_sample_document(self.dataset_id)], - embeddings=[get_sample_embedding()], + texts=[get_example_document(doc_id=self.example_doc_id)], + embeddings=[self.example_embedding], ) def search_by_vector(self): - hits_by_vector = self.vector.search_by_vector(query_vector=get_sample_query_vector()) - assert len(hits_by_vector) >= 1 + hits_by_vector: list[Document] = self.vector.search_by_vector(query_vector=self.example_embedding) + assert len(hits_by_vector) == 1 + assert hits_by_vector[0].metadata['doc_id'] == self.example_doc_id def search_by_full_text(self): - hits_by_full_text = self.vector.search_by_full_text(query=get_sample_text()) - assert len(hits_by_full_text) >= 1 + hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text()) + assert len(hits_by_full_text) == 1 + assert hits_by_full_text[0].metadata['doc_id'] == self.example_doc_id def delete_vector(self): self.vector.delete() - def delete_by_ids(self): - self.vector.delete_by_ids([self.dataset_id]) + def delete_by_ids(self, ids: list[str]): + self.vector.delete_by_ids(ids=ids) - def add_texts(self): - self.vector.add_texts( - documents=[ - get_sample_document(str(uuid.uuid4())), - get_sample_document(str(uuid.uuid4())), - ], - embeddings=[ - get_sample_embedding(), - get_sample_embedding(), - ], - ) + def add_texts(self) -> list[str]: + batch_size = 100 + documents = [get_example_document(doc_id=str(uuid.uuid4())) for _ in range(batch_size)] + embeddings = [self.example_embedding] * batch_size + self.vector.add_texts(documents=documents, embeddings=embeddings) + return [doc.metadata['doc_id'] for doc in documents] def text_exists(self): - self.vector.text_exists(self.dataset_id) + assert self.vector.text_exists(self.example_doc_id) - def delete_document_by_id(self): + def delete_by_document_id(self): with pytest.raises(NotImplementedError): - self.vector.delete_by_document_id(self.dataset_id) + self.vector.delete_by_document_id(document_id=self.example_doc_id) def get_ids_by_metadata_field(self): with pytest.raises(NotImplementedError): - self.vector.get_ids_by_metadata_field('key', 'value') + self.vector.get_ids_by_metadata_field(key='key', value='value') def run_all_tests(self): self.create_vector() @@ -103,7 +95,7 @@ class AbstractTestVector: self.search_by_full_text() self.text_exists() self.get_ids_by_metadata_field() - self.add_texts() - self.delete_document_by_id() - self.delete_by_ids() + self.delete_by_document_id() + added_doc_ids = self.add_texts() + self.delete_by_ids(added_doc_ids) self.delete_vector() diff --git a/api/tests/integration_tests/vdb/weaviate/test_weaviate.py b/api/tests/integration_tests/vdb/weaviate/test_weaviate.py index 625444c3c0..20a2d2be06 100644 --- a/api/tests/integration_tests/vdb/weaviate/test_weaviate.py +++ b/api/tests/integration_tests/vdb/weaviate/test_weaviate.py @@ -1,5 +1,4 @@ from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector -from models.dataset import Dataset from tests.integration_tests.vdb.test_vector_store import ( AbstractTestVector, setup_mock_redis,