feat: Enable baiduvector intergration test (#9369)

This commit is contained in:
ice yao 2024-10-16 09:41:28 +08:00 committed by GitHub
parent da25b91980
commit 568d5c46ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 43 additions and 34 deletions

View File

@ -1,4 +1,5 @@
import os import os
from unittest.mock import MagicMock
import pytest import pytest
from _pytest.monkeypatch import MonkeyPatch from _pytest.monkeypatch import MonkeyPatch
@ -10,26 +11,31 @@ from pymochow.model.table import Table
from requests.adapters import HTTPAdapter from requests.adapters import HTTPAdapter
class AttrDict(dict):
def __getattr__(self, item):
return self.get(item)
class MockBaiduVectorDBClass: class MockBaiduVectorDBClass:
def mock_vector_db_client( def mock_vector_db_client(
self, self,
config=None, config=None,
adapter: HTTPAdapter = None, adapter: HTTPAdapter = None,
): ):
self._conn = None self.conn = MagicMock()
self._config = None self._config = MagicMock()
def list_databases(self, config=None) -> list[Database]: def list_databases(self, config=None) -> list[Database]:
return [ return [
Database( Database(
conn=self._conn, conn=self.conn,
database_name="dify", database_name="dify",
config=self._config, config=self._config,
) )
] ]
def create_database(self, database_name: str, config=None) -> Database: def create_database(self, database_name: str, config=None) -> Database:
return Database(conn=self._conn, database_name=database_name, config=config) return Database(conn=self.conn, database_name=database_name, config=config)
def list_table(self, config=None) -> list[Table]: def list_table(self, config=None) -> list[Table]:
return [] return []
@ -88,16 +94,18 @@ class MockBaiduVectorDBClass:
read_consistency=ReadConsistency.EVENTUAL, read_consistency=ReadConsistency.EVENTUAL,
config=None, config=None,
): ):
return { return AttrDict(
"row": { {
"id": "doc_id_001", "row": {
"vector": [0.23432432, 0.8923744, 0.89238432], "id": primary_key.get("id"),
"text": "text", "vector": [0.23432432, 0.8923744, 0.89238432],
"metadata": {"doc_id": "doc_id_001"}, "text": "text",
}, "metadata": '{"doc_id": "doc_id_001"}',
"code": 0, },
"msg": "Success", "code": 0,
} "msg": "Success",
}
)
def delete(self, primary_key=None, partition_key=None, filter=None, config=None): def delete(self, primary_key=None, partition_key=None, filter=None, config=None):
return {"code": 0, "msg": "Success"} return {"code": 0, "msg": "Success"}
@ -111,22 +119,24 @@ class MockBaiduVectorDBClass:
read_consistency=ReadConsistency.EVENTUAL, read_consistency=ReadConsistency.EVENTUAL,
config=None, config=None,
): ):
return { return AttrDict(
"rows": [ {
{ "rows": [
"row": { {
"id": "doc_id_001", "row": {
"vector": [0.23432432, 0.8923744, 0.89238432], "id": "doc_id_001",
"text": "text", "vector": [0.23432432, 0.8923744, 0.89238432],
"metadata": {"doc_id": "doc_id_001"}, "text": "text",
}, "metadata": '{"doc_id": "doc_id_001"}',
"distance": 0.1, },
"score": 0.5, "distance": 0.1,
} "score": 0.5,
], }
"code": 0, ],
"msg": "Success", "code": 0,
} "msg": "Success",
}
)
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@ -146,6 +156,7 @@ def setup_baiduvectordb_mock(request, monkeypatch: MonkeyPatch):
monkeypatch.setattr(Table, "rebuild_index", MockBaiduVectorDBClass.rebuild_index) monkeypatch.setattr(Table, "rebuild_index", MockBaiduVectorDBClass.rebuild_index)
monkeypatch.setattr(Table, "describe_index", MockBaiduVectorDBClass.describe_index) monkeypatch.setattr(Table, "describe_index", MockBaiduVectorDBClass.describe_index)
monkeypatch.setattr(Table, "delete", MockBaiduVectorDBClass.delete) monkeypatch.setattr(Table, "delete", MockBaiduVectorDBClass.delete)
monkeypatch.setattr(Table, "query", MockBaiduVectorDBClass.query)
monkeypatch.setattr(Table, "search", MockBaiduVectorDBClass.search) monkeypatch.setattr(Table, "search", MockBaiduVectorDBClass.search)
yield yield

View File

@ -4,9 +4,6 @@ from core.rag.datasource.vdb.baidu.baidu_vector import BaiduConfig, BaiduVector
from tests.integration_tests.vdb.__mock.baiduvectordb import setup_baiduvectordb_mock from tests.integration_tests.vdb.__mock.baiduvectordb import setup_baiduvectordb_mock
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
mock_client = MagicMock()
mock_client.list_databases.return_value = [{"name": "test"}]
class BaiduVectorTest(AbstractVectorTest): class BaiduVectorTest(AbstractVectorTest):
def __init__(self): def __init__(self):

View File

@ -8,4 +8,5 @@ pytest api/tests/integration_tests/vdb/chroma \
api/tests/integration_tests/vdb/qdrant \ api/tests/integration_tests/vdb/qdrant \
api/tests/integration_tests/vdb/weaviate \ api/tests/integration_tests/vdb/weaviate \
api/tests/integration_tests/vdb/elasticsearch \ api/tests/integration_tests/vdb/elasticsearch \
api/tests/integration_tests/vdb/vikingdb api/tests/integration_tests/vdb/vikingdb \
api/tests/integration_tests/vdb/baidu