mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 03:35:51 +08:00
test: improve vector store tests (#3855)
This commit is contained in:
parent
4d66a86579
commit
045827043d
37
.github/workflows/api-tests.yml
vendored
37
.github/workflows/api-tests.yml
vendored
@ -37,27 +37,6 @@ jobs:
|
|||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up Weaviate
|
|
||||||
uses: hoverkraft-tech/compose-action@v2.0.0
|
|
||||||
with:
|
|
||||||
compose-file: docker/docker-compose.middleware.yaml
|
|
||||||
services: weaviate
|
|
||||||
|
|
||||||
- name: Set up Qdrant
|
|
||||||
uses: hoverkraft-tech/compose-action@v2.0.0
|
|
||||||
with:
|
|
||||||
compose-file: docker/docker-compose.qdrant.yaml
|
|
||||||
services: qdrant
|
|
||||||
|
|
||||||
- name: Set up Milvus
|
|
||||||
uses: hoverkraft-tech/compose-action@v2.0.0
|
|
||||||
with:
|
|
||||||
compose-file: docker/docker-compose.milvus.yaml
|
|
||||||
services: |
|
|
||||||
etcd
|
|
||||||
minio
|
|
||||||
milvus-standalone
|
|
||||||
|
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
@ -82,5 +61,19 @@ jobs:
|
|||||||
- name: Run Workflow
|
- name: Run Workflow
|
||||||
run: dev/pytest/pytest_workflow.sh
|
run: dev/pytest/pytest_workflow.sh
|
||||||
|
|
||||||
- name: Run Vector Stores
|
- name: Set up Vector Stores (Weaviate, Qdrant and Milvus)
|
||||||
|
uses: hoverkraft-tech/compose-action@v2.0.0
|
||||||
|
with:
|
||||||
|
compose-file: |
|
||||||
|
docker/docker-compose.middleware.yaml
|
||||||
|
docker/docker-compose.qdrant.yaml
|
||||||
|
docker/docker-compose.milvus.yaml
|
||||||
|
services: |
|
||||||
|
weaviate
|
||||||
|
qdrant
|
||||||
|
etcd
|
||||||
|
minio
|
||||||
|
milvus-standalone
|
||||||
|
|
||||||
|
- name: Test Vector Stores
|
||||||
run: dev/pytest/pytest_vdb.sh
|
run: dev/pytest/pytest_vdb.sh
|
||||||
|
@ -124,7 +124,7 @@ class MilvusVector(BaseVector):
|
|||||||
if ids:
|
if ids:
|
||||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||||
|
|
||||||
def delete_by_ids(self, doc_ids: list[str]) -> None:
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
alias = uuid4().hex
|
alias = uuid4().hex
|
||||||
if self._client_config.secure:
|
if self._client_config.secure:
|
||||||
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
||||||
@ -136,7 +136,7 @@ class MilvusVector(BaseVector):
|
|||||||
if utility.has_collection(self._collection_name, using=alias):
|
if utility.has_collection(self._collection_name, using=alias):
|
||||||
|
|
||||||
result = self._client.query(collection_name=self._collection_name,
|
result = self._client.query(collection_name=self._collection_name,
|
||||||
filter=f'metadata["doc_id"] in {doc_ids}',
|
filter=f'metadata["doc_id"] in {ids}',
|
||||||
output_fields=["id"])
|
output_fields=["id"])
|
||||||
if result:
|
if result:
|
||||||
ids = [item["id"] for item in result]
|
ids = [item["id"] for item in result]
|
||||||
|
@ -199,10 +199,10 @@ class RelytVector(BaseVector):
|
|||||||
if ids:
|
if ids:
|
||||||
self.delete_by_uuids(ids)
|
self.delete_by_uuids(ids)
|
||||||
|
|
||||||
def delete_by_ids(self, doc_ids: list[str]) -> None:
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
|
|
||||||
with Session(self.client) as session:
|
with Session(self.client) as session:
|
||||||
ids_str = ','.join(f"'{doc_id}'" for doc_id in doc_ids)
|
ids_str = ','.join(f"'{doc_id}'" for doc_id in ids)
|
||||||
select_statement = sql_text(
|
select_statement = sql_text(
|
||||||
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' in ({ids_str}); """
|
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' in ({ids_str}); """
|
||||||
)
|
)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector
|
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector
|
||||||
from tests.integration_tests.vdb.test_vector_store import (
|
from tests.integration_tests.vdb.test_vector_store import (
|
||||||
AbstractTestVector,
|
AbstractTestVector,
|
||||||
get_sample_text,
|
get_example_text,
|
||||||
setup_mock_redis,
|
setup_mock_redis,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -21,15 +21,15 @@ class TestMilvusVector(AbstractTestVector):
|
|||||||
|
|
||||||
def search_by_full_text(self):
|
def search_by_full_text(self):
|
||||||
# milvus dos not support full text searching yet in < 2.3.x
|
# milvus dos not support full text searching yet in < 2.3.x
|
||||||
hits_by_full_text = self.vector.search_by_full_text(query=get_sample_text())
|
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||||
assert len(hits_by_full_text) == 0
|
assert len(hits_by_full_text) == 0
|
||||||
|
|
||||||
def delete_document_by_id(self):
|
def delete_by_document_id(self):
|
||||||
self.vector.delete_by_document_id(self.dataset_id)
|
self.vector.delete_by_document_id(document_id=self.example_doc_id)
|
||||||
|
|
||||||
def get_ids_by_metadata_field(self):
|
def get_ids_by_metadata_field(self):
|
||||||
ids = self.vector.get_ids_by_metadata_field('document_id', self.dataset_id)
|
ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id)
|
||||||
assert len(ids) >= 1
|
assert len(ids) == 1
|
||||||
|
|
||||||
|
|
||||||
def test_milvus_vector(setup_mock_redis):
|
def test_milvus_vector(setup_mock_redis):
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import random
|
||||||
import uuid
|
import uuid
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
@ -8,26 +9,18 @@ from extensions import ext_redis
|
|||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
|
|
||||||
|
|
||||||
def get_sample_text() -> str:
|
def get_example_text() -> str:
|
||||||
return 'test_text'
|
return 'test_text'
|
||||||
|
|
||||||
|
|
||||||
def get_sample_embedding() -> list[float]:
|
def get_example_document(doc_id: str) -> Document:
|
||||||
return [1.1, 2.2, 3.3]
|
|
||||||
|
|
||||||
|
|
||||||
def get_sample_query_vector() -> list[float]:
|
|
||||||
return get_sample_embedding()
|
|
||||||
|
|
||||||
|
|
||||||
def get_sample_document(sample_dataset_id: str) -> Document:
|
|
||||||
doc = Document(
|
doc = Document(
|
||||||
page_content=get_sample_text(),
|
page_content=get_example_text(),
|
||||||
metadata={
|
metadata={
|
||||||
"doc_id": sample_dataset_id,
|
"doc_id": doc_id,
|
||||||
"doc_hash": sample_dataset_id,
|
"doc_hash": doc_id,
|
||||||
"document_id": sample_dataset_id,
|
"document_id": doc_id,
|
||||||
"dataset_id": sample_dataset_id,
|
"dataset_id": doc_id,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return doc
|
return doc
|
||||||
@ -53,49 +46,48 @@ class AbstractTestVector:
|
|||||||
self.vector = None
|
self.vector = None
|
||||||
self.dataset_id = str(uuid.uuid4())
|
self.dataset_id = str(uuid.uuid4())
|
||||||
self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id)
|
self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id)
|
||||||
|
self.example_doc_id = str(uuid.uuid4())
|
||||||
|
self.example_embedding = [1.001 * i for i in range(128)]
|
||||||
|
|
||||||
def create_vector(self) -> None:
|
def create_vector(self) -> None:
|
||||||
self.vector.create(
|
self.vector.create(
|
||||||
texts=[get_sample_document(self.dataset_id)],
|
texts=[get_example_document(doc_id=self.example_doc_id)],
|
||||||
embeddings=[get_sample_embedding()],
|
embeddings=[self.example_embedding],
|
||||||
)
|
)
|
||||||
|
|
||||||
def search_by_vector(self):
|
def search_by_vector(self):
|
||||||
hits_by_vector = self.vector.search_by_vector(query_vector=get_sample_query_vector())
|
hits_by_vector: list[Document] = self.vector.search_by_vector(query_vector=self.example_embedding)
|
||||||
assert len(hits_by_vector) >= 1
|
assert len(hits_by_vector) == 1
|
||||||
|
assert hits_by_vector[0].metadata['doc_id'] == self.example_doc_id
|
||||||
|
|
||||||
def search_by_full_text(self):
|
def search_by_full_text(self):
|
||||||
hits_by_full_text = self.vector.search_by_full_text(query=get_sample_text())
|
hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
|
||||||
assert len(hits_by_full_text) >= 1
|
assert len(hits_by_full_text) == 1
|
||||||
|
assert hits_by_full_text[0].metadata['doc_id'] == self.example_doc_id
|
||||||
|
|
||||||
def delete_vector(self):
|
def delete_vector(self):
|
||||||
self.vector.delete()
|
self.vector.delete()
|
||||||
|
|
||||||
def delete_by_ids(self):
|
def delete_by_ids(self, ids: list[str]):
|
||||||
self.vector.delete_by_ids([self.dataset_id])
|
self.vector.delete_by_ids(ids=ids)
|
||||||
|
|
||||||
def add_texts(self):
|
def add_texts(self) -> list[str]:
|
||||||
self.vector.add_texts(
|
batch_size = 100
|
||||||
documents=[
|
documents = [get_example_document(doc_id=str(uuid.uuid4())) for _ in range(batch_size)]
|
||||||
get_sample_document(str(uuid.uuid4())),
|
embeddings = [self.example_embedding] * batch_size
|
||||||
get_sample_document(str(uuid.uuid4())),
|
self.vector.add_texts(documents=documents, embeddings=embeddings)
|
||||||
],
|
return [doc.metadata['doc_id'] for doc in documents]
|
||||||
embeddings=[
|
|
||||||
get_sample_embedding(),
|
|
||||||
get_sample_embedding(),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
def text_exists(self):
|
def text_exists(self):
|
||||||
self.vector.text_exists(self.dataset_id)
|
assert self.vector.text_exists(self.example_doc_id)
|
||||||
|
|
||||||
def delete_document_by_id(self):
|
def delete_by_document_id(self):
|
||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(NotImplementedError):
|
||||||
self.vector.delete_by_document_id(self.dataset_id)
|
self.vector.delete_by_document_id(document_id=self.example_doc_id)
|
||||||
|
|
||||||
def get_ids_by_metadata_field(self):
|
def get_ids_by_metadata_field(self):
|
||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(NotImplementedError):
|
||||||
self.vector.get_ids_by_metadata_field('key', 'value')
|
self.vector.get_ids_by_metadata_field(key='key', value='value')
|
||||||
|
|
||||||
def run_all_tests(self):
|
def run_all_tests(self):
|
||||||
self.create_vector()
|
self.create_vector()
|
||||||
@ -103,7 +95,7 @@ class AbstractTestVector:
|
|||||||
self.search_by_full_text()
|
self.search_by_full_text()
|
||||||
self.text_exists()
|
self.text_exists()
|
||||||
self.get_ids_by_metadata_field()
|
self.get_ids_by_metadata_field()
|
||||||
self.add_texts()
|
self.delete_by_document_id()
|
||||||
self.delete_document_by_id()
|
added_doc_ids = self.add_texts()
|
||||||
self.delete_by_ids()
|
self.delete_by_ids(added_doc_ids)
|
||||||
self.delete_vector()
|
self.delete_vector()
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
|
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
|
||||||
from models.dataset import Dataset
|
|
||||||
from tests.integration_tests.vdb.test_vector_store import (
|
from tests.integration_tests.vdb.test_vector_store import (
|
||||||
AbstractTestVector,
|
AbstractTestVector,
|
||||||
setup_mock_redis,
|
setup_mock_redis,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user