feat: support setting database used in Milvus (#3003)

This commit is contained in:
Leo Q 2024-04-09 15:39:36 +08:00 committed by GitHub
parent a2c068d949
commit 9c01bcb3e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 35 additions and 6 deletions

View File

@ -67,6 +67,7 @@ DEFAULTS = {
'CODE_EXECUTION_ENDPOINT': '', 'CODE_EXECUTION_ENDPOINT': '',
'CODE_EXECUTION_API_KEY': '', 'CODE_EXECUTION_API_KEY': '',
'TOOL_ICON_CACHE_MAX_AGE': 3600, 'TOOL_ICON_CACHE_MAX_AGE': 3600,
'MILVUS_DATABASE': 'default',
'KEYWORD_DATA_SOURCE_TYPE': 'database', 'KEYWORD_DATA_SOURCE_TYPE': 'database',
} }
@ -212,6 +213,7 @@ class Config:
self.MILVUS_USER = get_env('MILVUS_USER') self.MILVUS_USER = get_env('MILVUS_USER')
self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD') self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD')
self.MILVUS_SECURE = get_env('MILVUS_SECURE') self.MILVUS_SECURE = get_env('MILVUS_SECURE')
self.MILVUS_DATABASE = get_env('MILVUS_DATABASE')
# weaviate settings # weaviate settings
self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT') self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT')

View File

@ -20,16 +20,17 @@ class MilvusConfig(BaseModel):
password: str password: str
secure: bool = False secure: bool = False
batch_size: int = 100 batch_size: int = 100
database: str = "default"
@root_validator() @root_validator()
def validate_config(cls, values: dict) -> dict: def validate_config(cls, values: dict) -> dict:
if not values['host']: if not values.get('host'):
raise ValueError("config MILVUS_HOST is required") raise ValueError("config MILVUS_HOST is required")
if not values['port']: if not values.get('port'):
raise ValueError("config MILVUS_PORT is required") raise ValueError("config MILVUS_PORT is required")
if not values['user']: if not values.get('user'):
raise ValueError("config MILVUS_USER is required") raise ValueError("config MILVUS_USER is required")
if not values['password']: if not values.get('password'):
raise ValueError("config MILVUS_PASSWORD is required") raise ValueError("config MILVUS_PASSWORD is required")
return values return values
@ -39,7 +40,8 @@ class MilvusConfig(BaseModel):
'port': self.port, 'port': self.port,
'user': self.user, 'user': self.user,
'password': self.password, 'password': self.password,
'secure': self.secure 'secure': self.secure,
'db_name': self.database,
} }
@ -192,7 +194,7 @@ class MilvusVector(BaseVector):
else: else:
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
connections.connect(alias=alias, uri=uri, user=self._client_config.user, connections.connect(alias=alias, uri=uri, user=self._client_config.user,
password=self._client_config.password) password=self._client_config.password, db_name=self._client_config.database)
if not utility.has_collection(self._collection_name, using=alias): if not utility.has_collection(self._collection_name, using=alias):
from pymilvus import CollectionSchema, DataType, FieldSchema from pymilvus import CollectionSchema, DataType, FieldSchema
from pymilvus.orm.types import infer_dtype_bydata from pymilvus.orm.types import infer_dtype_bydata

View File

@ -110,6 +110,7 @@ class Vector:
user=config.get('MILVUS_USER'), user=config.get('MILVUS_USER'),
password=config.get('MILVUS_PASSWORD'), password=config.get('MILVUS_PASSWORD'),
secure=config.get('MILVUS_SECURE'), secure=config.get('MILVUS_SECURE'),
database=config.get('MILVUS_DATABASE'),
) )
) )
else: else:

View File

@ -0,0 +1,24 @@
import pytest
from pydantic.error_wrappers import ValidationError
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig
def test_default_value():
valid_config = {
'host': 'localhost',
'port': 19530,
'user': 'root',
'password': 'Milvus'
}
for key in valid_config:
config = valid_config.copy()
del config[key]
with pytest.raises(ValidationError) as e:
MilvusConfig(**config)
assert e.value.errors()[1]['msg'] == f'config MILVUS_{key.upper()} is required'
config = MilvusConfig(**valid_config)
assert config.secure is False
assert config.database == 'default'