mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 07:39:00 +08:00
feat: support setting database used in Milvus (#3003)
This commit is contained in:
parent
a2c068d949
commit
9c01bcb3e5
@ -67,6 +67,7 @@ DEFAULTS = {
|
|||||||
'CODE_EXECUTION_ENDPOINT': '',
|
'CODE_EXECUTION_ENDPOINT': '',
|
||||||
'CODE_EXECUTION_API_KEY': '',
|
'CODE_EXECUTION_API_KEY': '',
|
||||||
'TOOL_ICON_CACHE_MAX_AGE': 3600,
|
'TOOL_ICON_CACHE_MAX_AGE': 3600,
|
||||||
|
'MILVUS_DATABASE': 'default',
|
||||||
'KEYWORD_DATA_SOURCE_TYPE': 'database',
|
'KEYWORD_DATA_SOURCE_TYPE': 'database',
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -212,6 +213,7 @@ class Config:
|
|||||||
self.MILVUS_USER = get_env('MILVUS_USER')
|
self.MILVUS_USER = get_env('MILVUS_USER')
|
||||||
self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD')
|
self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD')
|
||||||
self.MILVUS_SECURE = get_env('MILVUS_SECURE')
|
self.MILVUS_SECURE = get_env('MILVUS_SECURE')
|
||||||
|
self.MILVUS_DATABASE = get_env('MILVUS_DATABASE')
|
||||||
|
|
||||||
# weaviate settings
|
# weaviate settings
|
||||||
self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT')
|
self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT')
|
||||||
|
@ -20,16 +20,17 @@ class MilvusConfig(BaseModel):
|
|||||||
password: str
|
password: str
|
||||||
secure: bool = False
|
secure: bool = False
|
||||||
batch_size: int = 100
|
batch_size: int = 100
|
||||||
|
database: str = "default"
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_config(cls, values: dict) -> dict:
|
def validate_config(cls, values: dict) -> dict:
|
||||||
if not values['host']:
|
if not values.get('host'):
|
||||||
raise ValueError("config MILVUS_HOST is required")
|
raise ValueError("config MILVUS_HOST is required")
|
||||||
if not values['port']:
|
if not values.get('port'):
|
||||||
raise ValueError("config MILVUS_PORT is required")
|
raise ValueError("config MILVUS_PORT is required")
|
||||||
if not values['user']:
|
if not values.get('user'):
|
||||||
raise ValueError("config MILVUS_USER is required")
|
raise ValueError("config MILVUS_USER is required")
|
||||||
if not values['password']:
|
if not values.get('password'):
|
||||||
raise ValueError("config MILVUS_PASSWORD is required")
|
raise ValueError("config MILVUS_PASSWORD is required")
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@ -39,7 +40,8 @@ class MilvusConfig(BaseModel):
|
|||||||
'port': self.port,
|
'port': self.port,
|
||||||
'user': self.user,
|
'user': self.user,
|
||||||
'password': self.password,
|
'password': self.password,
|
||||||
'secure': self.secure
|
'secure': self.secure,
|
||||||
|
'db_name': self.database,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -192,7 +194,7 @@ class MilvusVector(BaseVector):
|
|||||||
else:
|
else:
|
||||||
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
||||||
connections.connect(alias=alias, uri=uri, user=self._client_config.user,
|
connections.connect(alias=alias, uri=uri, user=self._client_config.user,
|
||||||
password=self._client_config.password)
|
password=self._client_config.password, db_name=self._client_config.database)
|
||||||
if not utility.has_collection(self._collection_name, using=alias):
|
if not utility.has_collection(self._collection_name, using=alias):
|
||||||
from pymilvus import CollectionSchema, DataType, FieldSchema
|
from pymilvus import CollectionSchema, DataType, FieldSchema
|
||||||
from pymilvus.orm.types import infer_dtype_bydata
|
from pymilvus.orm.types import infer_dtype_bydata
|
||||||
|
@ -110,6 +110,7 @@ class Vector:
|
|||||||
user=config.get('MILVUS_USER'),
|
user=config.get('MILVUS_USER'),
|
||||||
password=config.get('MILVUS_PASSWORD'),
|
password=config.get('MILVUS_PASSWORD'),
|
||||||
secure=config.get('MILVUS_SECURE'),
|
secure=config.get('MILVUS_SECURE'),
|
||||||
|
database=config.get('MILVUS_DATABASE'),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
24
api/tests/unittests/test_model.py
Normal file
24
api/tests/unittests/test_model.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
import pytest
|
||||||
|
from pydantic.error_wrappers import ValidationError
|
||||||
|
|
||||||
|
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_value():
|
||||||
|
valid_config = {
|
||||||
|
'host': 'localhost',
|
||||||
|
'port': 19530,
|
||||||
|
'user': 'root',
|
||||||
|
'password': 'Milvus'
|
||||||
|
}
|
||||||
|
|
||||||
|
for key in valid_config:
|
||||||
|
config = valid_config.copy()
|
||||||
|
del config[key]
|
||||||
|
with pytest.raises(ValidationError) as e:
|
||||||
|
MilvusConfig(**config)
|
||||||
|
assert e.value.errors()[1]['msg'] == f'config MILVUS_{key.upper()} is required'
|
||||||
|
|
||||||
|
config = MilvusConfig(**valid_config)
|
||||||
|
assert config.secure is False
|
||||||
|
assert config.database == 'default'
|
Loading…
x
Reference in New Issue
Block a user