Merge pull request #13670 from HarrisonConsulting/fix/milvus-standalone-index

fix: enhance MilvusClient with dynamic index type and improved logging
This commit is contained in:
Tim Jaeryang Baek 2025-05-08 22:08:23 +04:00 committed by GitHub
commit 1fea4f794f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 115 additions and 49 deletions

View File

@ -1765,6 +1765,12 @@ MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db")
MILVUS_DB = os.environ.get("MILVUS_DB", "default") MILVUS_DB = os.environ.get("MILVUS_DB", "default")
MILVUS_TOKEN = os.environ.get("MILVUS_TOKEN", None) MILVUS_TOKEN = os.environ.get("MILVUS_TOKEN", None)
MILVUS_INDEX_TYPE = os.environ.get("MILVUS_INDEX_TYPE", "HNSW")
MILVUS_METRIC_TYPE = os.environ.get("MILVUS_METRIC_TYPE", "COSINE")
MILVUS_HNSW_M = int(os.environ.get("MILVUS_HNSW_M", "16"))
MILVUS_HNSW_EFCONSTRUCTION = int(os.environ.get("MILVUS_HNSW_EFCONSTRUCTION", "100"))
MILVUS_IVF_FLAT_NLIST = int(os.environ.get("MILVUS_IVF_FLAT_NLIST", "128"))
# Qdrant # Qdrant
QDRANT_URI = os.environ.get("QDRANT_URI", None) QDRANT_URI = os.environ.get("QDRANT_URI", None)
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", None) QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", None)

View File

@ -3,7 +3,6 @@ from pymilvus import FieldSchema, DataType
import json import json
import logging import logging
from typing import Optional from typing import Optional
from open_webui.retrieval.vector.main import ( from open_webui.retrieval.vector.main import (
VectorDBBase, VectorDBBase,
VectorItem, VectorItem,
@ -14,13 +13,17 @@ from open_webui.config import (
MILVUS_URI, MILVUS_URI,
MILVUS_DB, MILVUS_DB,
MILVUS_TOKEN, MILVUS_TOKEN,
MILVUS_INDEX_TYPE,
MILVUS_METRIC_TYPE,
MILVUS_HNSW_M,
MILVUS_HNSW_EFCONSTRUCTION,
MILVUS_IVF_FLAT_NLIST,
) )
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
class MilvusClient(VectorDBBase): class MilvusClient(VectorDBBase):
def __init__(self): def __init__(self):
self.collection_prefix = "open_webui" self.collection_prefix = "open_webui"
@ -33,7 +36,6 @@ class MilvusClient(VectorDBBase):
ids = [] ids = []
documents = [] documents = []
metadatas = [] metadatas = []
for match in result: for match in result:
_ids = [] _ids = []
_documents = [] _documents = []
@ -42,11 +44,9 @@ class MilvusClient(VectorDBBase):
_ids.append(item.get("id")) _ids.append(item.get("id"))
_documents.append(item.get("data", {}).get("text")) _documents.append(item.get("data", {}).get("text"))
_metadatas.append(item.get("metadata")) _metadatas.append(item.get("metadata"))
ids.append(_ids) ids.append(_ids)
documents.append(_documents) documents.append(_documents)
metadatas.append(_metadatas) metadatas.append(_metadatas)
return GetResult( return GetResult(
**{ **{
"ids": ids, "ids": ids,
@ -60,13 +60,11 @@ class MilvusClient(VectorDBBase):
distances = [] distances = []
documents = [] documents = []
metadatas = [] metadatas = []
for match in result: for match in result:
_ids = [] _ids = []
_distances = [] _distances = []
_documents = [] _documents = []
_metadatas = [] _metadatas = []
for item in match: for item in match:
_ids.append(item.get("id")) _ids.append(item.get("id"))
# normalize milvus score from [-1, 1] to [0, 1] range # normalize milvus score from [-1, 1] to [0, 1] range
@ -75,12 +73,10 @@ class MilvusClient(VectorDBBase):
_distances.append(_dist) _distances.append(_dist)
_documents.append(item.get("entity", {}).get("data", {}).get("text")) _documents.append(item.get("entity", {}).get("data", {}).get("text"))
_metadatas.append(item.get("entity", {}).get("metadata")) _metadatas.append(item.get("entity", {}).get("metadata"))
ids.append(_ids) ids.append(_ids)
distances.append(_distances) distances.append(_distances)
documents.append(_documents) documents.append(_documents)
metadatas.append(_metadatas) metadatas.append(_metadatas)
return SearchResult( return SearchResult(
**{ **{
"ids": ids, "ids": ids,
@ -113,11 +109,36 @@ class MilvusClient(VectorDBBase):
) )
index_params = self.client.prepare_index_params() index_params = self.client.prepare_index_params()
# Use configurations from config.py
index_type = MILVUS_INDEX_TYPE.upper()
metric_type = MILVUS_METRIC_TYPE.upper()
log.info(f"Using Milvus index type: {index_type}, metric type: {metric_type}")
index_creation_params = {}
if index_type == "HNSW":
index_creation_params = {"M": MILVUS_HNSW_M, "efConstruction": MILVUS_HNSW_EFCONSTRUCTION}
log.info(f"HNSW params: {index_creation_params}")
elif index_type == "IVF_FLAT":
index_creation_params = {"nlist": MILVUS_IVF_FLAT_NLIST}
log.info(f"IVF_FLAT params: {index_creation_params}")
elif index_type in ["FLAT", "AUTOINDEX"]:
log.info(f"Using {index_type} index with no specific build-time params.")
else:
log.warning(
f"Unsupported MILVUS_INDEX_TYPE: '{index_type}'. "
f"Supported types: HNSW, IVF_FLAT, FLAT, AUTOINDEX. "
f"Milvus will use its default for the collection if this type is not directly supported for index creation."
)
# For unsupported types, pass the type directly to Milvus; it might handle it or use a default.
# If Milvus errors out, the user needs to correct the MILVUS_INDEX_TYPE env var.
index_params.add_index( index_params.add_index(
field_name="vector", field_name="vector",
index_type="HNSW", index_type=index_type,
metric_type="COSINE", metric_type=metric_type,
params={"M": 16, "efConstruction": 100}, params=index_creation_params,
) )
self.client.create_collection( self.client.create_collection(
@ -125,6 +146,8 @@ class MilvusClient(VectorDBBase):
schema=schema, schema=schema,
index_params=index_params, index_params=index_params,
) )
log.info(f"Successfully created collection '{self.collection_prefix}_{collection_name}' with index type '{index_type}' and metric '{metric_type}'.")
def has_collection(self, collection_name: str) -> bool: def has_collection(self, collection_name: str) -> bool:
# Check if the collection exists based on the collection name. # Check if the collection exists based on the collection name.
@ -145,84 +168,95 @@ class MilvusClient(VectorDBBase):
) -> Optional[SearchResult]: ) -> Optional[SearchResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results. # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
collection_name = collection_name.replace("-", "_") collection_name = collection_name.replace("-", "_")
# For some index types like IVF_FLAT, search params like nprobe can be set.
# Example: search_params = {"nprobe": 10} if using IVF_FLAT
# For simplicity, not adding configurable search_params here, but could be extended.
result = self.client.search( result = self.client.search(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f"{self.collection_prefix}_{collection_name}",
data=vectors, data=vectors,
limit=limit, limit=limit,
output_fields=["data", "metadata"], output_fields=["data", "metadata"],
# search_params=search_params # Potentially add later if needed
) )
return self._result_to_search_result(result) return self._result_to_search_result(result)
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None): def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
# Construct the filter string for querying # Construct the filter string for querying
collection_name = collection_name.replace("-", "_") collection_name = collection_name.replace("-", "_")
if not self.has_collection(collection_name): if not self.has_collection(collection_name):
log.warning(f"Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}")
return None return None
filter_string = " && ".join( filter_string = " && ".join(
[ [
f'metadata["{key}"] == {json.dumps(value)}' f'metadata["{key}"] == {json.dumps(value)}'
for key, value in filter.items() for key, value in filter.items()
] ]
) )
max_limit = 16383 # The maximum number of records per request max_limit = 16383 # The maximum number of records per request
all_results = [] all_results = []
if limit is None: if limit is None:
limit = float("inf") # Use infinity as a placeholder for no limit # Milvus default limit for query if not specified is 16384, but docs mention iteration.
# Let's set a practical high number if "all" is intended, or handle true pagination.
# For now, if limit is None, we'll fetch in batches up to a very large number.
# This part could be refined based on expected use cases for "get all".
# For this function signature, None implies "as many as possible" up to Milvus limits.
limit = 16384 * 10 # A large number to signify fetching many, will be capped by actual data or max_limit per call.
log.info(f"Limit not specified for query, fetching up to {limit} results in batches.")
# Initialize offset and remaining to handle pagination # Initialize offset and remaining to handle pagination
offset = 0 offset = 0
remaining = limit remaining = limit
try: try:
log.info(f"Querying collection {self.collection_prefix}_{collection_name} with filter: '{filter_string}', limit: {limit}")
# Loop until there are no more items to fetch or the desired limit is reached # Loop until there are no more items to fetch or the desired limit is reached
while remaining > 0: while remaining > 0:
log.info(f"remaining: {remaining}") current_fetch = min(max_limit, remaining if isinstance(remaining, int) else max_limit)
current_fetch = min( log.debug(f"Querying with offset: {offset}, current_fetch: {current_fetch}")
max_limit, remaining
) # Determine how many items to fetch in this iteration
results = self.client.query( results = self.client.query(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f"{self.collection_prefix}_{collection_name}",
filter=filter_string, filter=filter_string,
output_fields=["*"], output_fields=["id", "data", "metadata"], # Explicitly list needed fields. Vector not usually needed in query.
limit=current_fetch, limit=current_fetch,
offset=offset, offset=offset,
) )
if not results: if not results:
log.debug("No more results from query.")
break break
all_results.extend(results) all_results.extend(results)
results_count = len(results) results_count = len(results)
remaining -= ( log.debug(f"Fetched {results_count} results in this batch.")
results_count # Decrease remaining by the number of items fetched
) if isinstance(remaining, int):
remaining -= results_count
offset += results_count offset += results_count
# Break the loop if the results returned are less than the requested fetch count # Break the loop if the results returned are less than the requested fetch count (means end of data)
if results_count < current_fetch: if results_count < current_fetch:
log.debug("Fetched less than requested, assuming end of results for this query.")
break break
log.debug(all_results) log.info(f"Total results from query: {len(all_results)}")
return self._result_to_get_result([all_results]) return self._result_to_get_result([all_results])
except Exception as e: except Exception as e:
log.exception( log.exception(
f"Error querying collection {collection_name} with limit {limit}: {e}" f"Error querying collection {self.collection_prefix}_{collection_name} with filter '{filter_string}' and limit {limit}: {e}"
) )
return None return None
def get(self, collection_name: str) -> Optional[GetResult]: def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection. # Get all the items in the collection. This can be very resource-intensive for large collections.
collection_name = collection_name.replace("-", "_") collection_name = collection_name.replace("-", "_")
result = self.client.query( log.warning(f"Fetching ALL items from collection '{self.collection_prefix}_{collection_name}'. This might be slow for large collections.")
collection_name=f"{self.collection_prefix}_{collection_name}", # Using query with a trivial filter to get all items.
filter='id != ""', # This will use the paginated query logic.
) return self.query(collection_name=collection_name, filter={}, limit=None)
return self._result_to_get_result([result])
def insert(self, collection_name: str, items: list[VectorItem]): def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created. # Insert the items into the collection, if the collection does not exist, it will be created.
@ -230,10 +264,15 @@ class MilvusClient(VectorDBBase):
if not self.client.has_collection( if not self.client.has_collection(
collection_name=f"{self.collection_prefix}_{collection_name}" collection_name=f"{self.collection_prefix}_{collection_name}"
): ):
log.info(f"Collection {self.collection_prefix}_{collection_name} does not exist. Creating now.")
if not items:
log.error(f"Cannot create collection {self.collection_prefix}_{collection_name} without items to determine dimension.")
raise ValueError("Cannot create Milvus collection without items to determine vector dimension.")
self._create_collection( self._create_collection(
collection_name=collection_name, dimension=len(items[0]["vector"]) collection_name=collection_name, dimension=len(items[0]["vector"])
) )
log.info(f"Inserting {len(items)} items into collection {self.collection_prefix}_{collection_name}.")
return self.client.insert( return self.client.insert(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f"{self.collection_prefix}_{collection_name}",
data=[ data=[
@ -253,10 +292,15 @@ class MilvusClient(VectorDBBase):
if not self.client.has_collection( if not self.client.has_collection(
collection_name=f"{self.collection_prefix}_{collection_name}" collection_name=f"{self.collection_prefix}_{collection_name}"
): ):
log.info(f"Collection {self.collection_prefix}_{collection_name} does not exist for upsert. Creating now.")
if not items:
log.error(f"Cannot create collection {self.collection_prefix}_{collection_name} for upsert without items to determine dimension.")
raise ValueError("Cannot create Milvus collection for upsert without items to determine vector dimension.")
self._create_collection( self._create_collection(
collection_name=collection_name, dimension=len(items[0]["vector"]) collection_name=collection_name, dimension=len(items[0]["vector"])
) )
log.info(f"Upserting {len(items)} items into collection {self.collection_prefix}_{collection_name}.")
return self.client.upsert( return self.client.upsert(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f"{self.collection_prefix}_{collection_name}",
data=[ data=[
@ -276,30 +320,46 @@ class MilvusClient(VectorDBBase):
ids: Optional[list[str]] = None, ids: Optional[list[str]] = None,
filter: Optional[dict] = None, filter: Optional[dict] = None,
): ):
# Delete the items from the collection based on the ids. # Delete the items from the collection based on the ids or filter.
collection_name = collection_name.replace("-", "_") collection_name = collection_name.replace("-", "_")
if not self.has_collection(collection_name):
log.warning(f"Delete attempted on non-existent collection: {self.collection_prefix}_{collection_name}")
return None
if ids: if ids:
log.info(f"Deleting items by IDs from {self.collection_prefix}_{collection_name}. IDs: {ids}")
return self.client.delete( return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f"{self.collection_prefix}_{collection_name}",
ids=ids, ids=ids,
) )
elif filter: elif filter:
# Convert the filter dictionary to a string using JSON_CONTAINS.
filter_string = " && ".join( filter_string = " && ".join(
[ [
f'metadata["{key}"] == {json.dumps(value)}' f'metadata["{key}"] == {json.dumps(value)}'
for key, value in filter.items() for key, value in filter.items()
] ]
) )
log.info(f"Deleting items by filter from {self.collection_prefix}_{collection_name}. Filter: {filter_string}")
return self.client.delete( return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f"{self.collection_prefix}_{collection_name}",
filter=filter_string, filter=filter_string,
) )
else:
log.warning(f"Delete operation on {self.collection_prefix}_{collection_name} called without IDs or filter. No action taken.")
return None
def reset(self): def reset(self):
# Resets the database. This will delete all collections and item entries. # Resets the database. This will delete all collections and item entries that match the prefix.
log.warning(f"Resetting Milvus: Deleting all collections with prefix '{self.collection_prefix}'.")
collection_names = self.client.list_collections() collection_names = self.client.list_collections()
for collection_name in collection_names: deleted_collections = []
if collection_name.startswith(self.collection_prefix): for collection_name_full in collection_names:
self.client.drop_collection(collection_name=collection_name) if collection_name_full.startswith(self.collection_prefix):
try:
self.client.drop_collection(collection_name=collection_name_full)
deleted_collections.append(collection_name_full)
log.info(f"Deleted collection: {collection_name_full}")
except Exception as e:
log.error(f"Error deleting collection {collection_name_full}: {e}")
log.info(f"Milvus reset complete. Deleted collections: {deleted_collections}")