mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-11 13:59:04 +08:00
feat: support setting database used in Milvus (#3003)
This commit is contained in:
parent
a2c068d949
commit
9c01bcb3e5
@ -67,6 +67,7 @@ DEFAULTS = {
|
||||
'CODE_EXECUTION_ENDPOINT': '',
|
||||
'CODE_EXECUTION_API_KEY': '',
|
||||
'TOOL_ICON_CACHE_MAX_AGE': 3600,
|
||||
'MILVUS_DATABASE': 'default',
|
||||
'KEYWORD_DATA_SOURCE_TYPE': 'database',
|
||||
}
|
||||
|
||||
@ -212,6 +213,7 @@ class Config:
|
||||
self.MILVUS_USER = get_env('MILVUS_USER')
|
||||
self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD')
|
||||
self.MILVUS_SECURE = get_env('MILVUS_SECURE')
|
||||
self.MILVUS_DATABASE = get_env('MILVUS_DATABASE')
|
||||
|
||||
# weaviate settings
|
||||
self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT')
|
||||
|
@ -20,16 +20,17 @@ class MilvusConfig(BaseModel):
|
||||
password: str
|
||||
secure: bool = False
|
||||
batch_size: int = 100
|
||||
database: str = "default"
|
||||
|
||||
@root_validator()
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['host']:
|
||||
if not values.get('host'):
|
||||
raise ValueError("config MILVUS_HOST is required")
|
||||
if not values['port']:
|
||||
if not values.get('port'):
|
||||
raise ValueError("config MILVUS_PORT is required")
|
||||
if not values['user']:
|
||||
if not values.get('user'):
|
||||
raise ValueError("config MILVUS_USER is required")
|
||||
if not values['password']:
|
||||
if not values.get('password'):
|
||||
raise ValueError("config MILVUS_PASSWORD is required")
|
||||
return values
|
||||
|
||||
@ -39,7 +40,8 @@ class MilvusConfig(BaseModel):
|
||||
'port': self.port,
|
||||
'user': self.user,
|
||||
'password': self.password,
|
||||
'secure': self.secure
|
||||
'secure': self.secure,
|
||||
'db_name': self.database,
|
||||
}
|
||||
|
||||
|
||||
@ -192,7 +194,7 @@ class MilvusVector(BaseVector):
|
||||
else:
|
||||
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
||||
connections.connect(alias=alias, uri=uri, user=self._client_config.user,
|
||||
password=self._client_config.password)
|
||||
password=self._client_config.password, db_name=self._client_config.database)
|
||||
if not utility.has_collection(self._collection_name, using=alias):
|
||||
from pymilvus import CollectionSchema, DataType, FieldSchema
|
||||
from pymilvus.orm.types import infer_dtype_bydata
|
||||
|
@ -110,6 +110,7 @@ class Vector:
|
||||
user=config.get('MILVUS_USER'),
|
||||
password=config.get('MILVUS_PASSWORD'),
|
||||
secure=config.get('MILVUS_SECURE'),
|
||||
database=config.get('MILVUS_DATABASE'),
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
24
api/tests/unittests/test_model.py
Normal file
24
api/tests/unittests/test_model.py
Normal file
@ -0,0 +1,24 @@
|
||||
import pytest
|
||||
from pydantic.error_wrappers import ValidationError
|
||||
|
||||
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig
|
||||
|
||||
|
||||
def test_default_value():
|
||||
valid_config = {
|
||||
'host': 'localhost',
|
||||
'port': 19530,
|
||||
'user': 'root',
|
||||
'password': 'Milvus'
|
||||
}
|
||||
|
||||
for key in valid_config:
|
||||
config = valid_config.copy()
|
||||
del config[key]
|
||||
with pytest.raises(ValidationError) as e:
|
||||
MilvusConfig(**config)
|
||||
assert e.value.errors()[1]['msg'] == f'config MILVUS_{key.upper()} is required'
|
||||
|
||||
config = MilvusConfig(**valid_config)
|
||||
assert config.secure is False
|
||||
assert config.database == 'default'
|
Loading…
x
Reference in New Issue
Block a user