diff --git a/api/config.py b/api/config.py index d13882d141..efba84641c 100644 --- a/api/config.py +++ b/api/config.py @@ -67,6 +67,7 @@ DEFAULTS = { 'CODE_EXECUTION_ENDPOINT': '', 'CODE_EXECUTION_API_KEY': '', 'TOOL_ICON_CACHE_MAX_AGE': 3600, + 'MILVUS_DATABASE': 'default', 'KEYWORD_DATA_SOURCE_TYPE': 'database', } @@ -212,6 +213,7 @@ class Config: self.MILVUS_USER = get_env('MILVUS_USER') self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD') self.MILVUS_SECURE = get_env('MILVUS_SECURE') + self.MILVUS_DATABASE = get_env('MILVUS_DATABASE') # weaviate settings self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT') diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index dcb37ccbe6..cff490176c 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -20,16 +20,17 @@ class MilvusConfig(BaseModel): password: str secure: bool = False batch_size: int = 100 + database: str = "default" @root_validator() def validate_config(cls, values: dict) -> dict: - if not values['host']: + if not values.get('host'): raise ValueError("config MILVUS_HOST is required") - if not values['port']: + if not values.get('port'): raise ValueError("config MILVUS_PORT is required") - if not values['user']: + if not values.get('user'): raise ValueError("config MILVUS_USER is required") - if not values['password']: + if not values.get('password'): raise ValueError("config MILVUS_PASSWORD is required") return values @@ -39,7 +40,8 @@ class MilvusConfig(BaseModel): 'port': self.port, 'user': self.user, 'password': self.password, - 'secure': self.secure + 'secure': self.secure, + 'db_name': self.database, } @@ -192,7 +194,7 @@ class MilvusVector(BaseVector): else: uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) connections.connect(alias=alias, uri=uri, user=self._client_config.user, - password=self._client_config.password) + password=self._client_config.password, db_name=self._client_config.database) if not utility.has_collection(self._collection_name, using=alias): from pymilvus import CollectionSchema, DataType, FieldSchema from pymilvus.orm.types import infer_dtype_bydata diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 71fc07967c..057f06d297 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -110,6 +110,7 @@ class Vector: user=config.get('MILVUS_USER'), password=config.get('MILVUS_PASSWORD'), secure=config.get('MILVUS_SECURE'), + database=config.get('MILVUS_DATABASE'), ) ) else: diff --git a/api/tests/unittests/test_model.py b/api/tests/unittests/test_model.py new file mode 100644 index 0000000000..73257dd338 --- /dev/null +++ b/api/tests/unittests/test_model.py @@ -0,0 +1,24 @@ +import pytest +from pydantic.error_wrappers import ValidationError + +from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig + + +def test_default_value(): + valid_config = { + 'host': 'localhost', + 'port': 19530, + 'user': 'root', + 'password': 'Milvus' + } + + for key in valid_config: + config = valid_config.copy() + del config[key] + with pytest.raises(ValidationError) as e: + MilvusConfig(**config) + assert e.value.errors()[1]['msg'] == f'config MILVUS_{key.upper()} is required' + + config = MilvusConfig(**valid_config) + assert config.secure is False + assert config.database == 'default'