refac: enhance MilvusClient with dynamic index type and improved logging

This commit is contained in:
Matt Harrison 2025-05-07 21:51:28 -04:00
parent b34401a087
commit 5e46c27806

View File

@ -1,9 +1,9 @@
import os # Added import
from pymilvus import MilvusClient as Client from pymilvus import MilvusClient as Client
from pymilvus import FieldSchema, DataType from pymilvus import FieldSchema, DataType
import json import json
import logging import logging
from typing import Optional from typing import Optional
from open_webui.retrieval.vector.main import ( from open_webui.retrieval.vector.main import (
VectorDBBase, VectorDBBase,
VectorItem, VectorItem,
@ -20,7 +20,6 @@ from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
class MilvusClient(VectorDBBase): class MilvusClient(VectorDBBase):
def __init__(self): def __init__(self):
self.collection_prefix = "open_webui" self.collection_prefix = "open_webui"
@ -33,7 +32,6 @@ class MilvusClient(VectorDBBase):
ids = [] ids = []
documents = [] documents = []
metadatas = [] metadatas = []
for match in result: for match in result:
_ids = [] _ids = []
_documents = [] _documents = []
@ -42,11 +40,9 @@ class MilvusClient(VectorDBBase):
_ids.append(item.get("id")) _ids.append(item.get("id"))
_documents.append(item.get("data", {}).get("text")) _documents.append(item.get("data", {}).get("text"))
_metadatas.append(item.get("metadata")) _metadatas.append(item.get("metadata"))
ids.append(_ids) ids.append(_ids)
documents.append(_documents) documents.append(_documents)
metadatas.append(_metadatas) metadatas.append(_metadatas)
return GetResult( return GetResult(
**{ **{
"ids": ids, "ids": ids,
@ -60,13 +56,11 @@ class MilvusClient(VectorDBBase):
distances = [] distances = []
documents = [] documents = []
metadatas = [] metadatas = []
for match in result: for match in result:
_ids = [] _ids = []
_distances = [] _distances = []
_documents = [] _documents = []
_metadatas = [] _metadatas = []
for item in match: for item in match:
_ids.append(item.get("id")) _ids.append(item.get("id"))
# normalize milvus score from [-1, 1] to [0, 1] range # normalize milvus score from [-1, 1] to [0, 1] range
@ -75,12 +69,10 @@ class MilvusClient(VectorDBBase):
_distances.append(_dist) _distances.append(_dist)
_documents.append(item.get("entity", {}).get("data", {}).get("text")) _documents.append(item.get("entity", {}).get("data", {}).get("text"))
_metadatas.append(item.get("entity", {}).get("metadata")) _metadatas.append(item.get("entity", {}).get("metadata"))
ids.append(_ids) ids.append(_ids)
distances.append(_distances) distances.append(_distances)
documents.append(_documents) documents.append(_documents)
metadatas.append(_metadatas) metadatas.append(_metadatas)
return SearchResult( return SearchResult(
**{ **{
"ids": ids, "ids": ids,
@ -113,11 +105,69 @@ class MilvusClient(VectorDBBase):
) )
index_params = self.client.prepare_index_params() index_params = self.client.prepare_index_params()
# Get index type from environment variable.
# Milvus standalone (local mode) supports: FLAT, IVF_FLAT, AUTOINDEX.
# HNSW is often preferred for performance but may require a clustered Milvus setup.
# Defaulting to AUTOINDEX for broader compatibility, especially with Milvus standalone.
default_index_type = "AUTOINDEX"
milvus_index_type_env = os.getenv("MILVUS_INDEX_TYPE")
if milvus_index_type_env:
milvus_index_type = milvus_index_type_env.upper()
log.info(f"Milvus index type from MILVUS_INDEX_TYPE env var: {milvus_index_type}")
else:
milvus_index_type = default_index_type
log.info(f"MILVUS_INDEX_TYPE env var not set, defaulting to: {milvus_index_type}")
index_creation_params = {}
metric_type = os.getenv("MILVUS_METRIC_TYPE", "COSINE").upper() # Default to COSINE
if milvus_index_type == "HNSW":
# Parameters for HNSW
m_env = os.getenv("MILVUS_HNSW_M", "16")
ef_construction_env = os.getenv("MILVUS_HNSW_EFCONSTRUCTION", "100")
try:
m_val = int(m_env)
ef_val = int(ef_construction_env)
except ValueError:
log.warning(f"Invalid HNSW params M='{m_env}' or efConstruction='{ef_construction_env}'. Defaulting to M=16, efConstruction=100.")
m_val = 16
ef_val = 100
index_creation_params = {"M": m_val, "efConstruction": ef_val}
log.info(f"Using HNSW index with metric {metric_type}, params: {index_creation_params}")
elif milvus_index_type == "IVF_FLAT":
# Parameters for IVF_FLAT
nlist_env = os.getenv("MILVUS_IVF_FLAT_NLIST", "128")
try:
nlist = int(nlist_env)
except ValueError:
log.warning(f"Invalid MILVUS_IVF_FLAT_NLIST value '{nlist_env}'. Defaulting to 128.")
nlist = 128
index_creation_params = {"nlist": nlist}
log.info(f"Using IVF_FLAT index with metric {metric_type}, params: {index_creation_params}")
elif milvus_index_type == "FLAT":
log.info(f"Using FLAT index with metric {metric_type} (no specific build-time params).")
# No specific build-time parameters needed for FLAT
elif milvus_index_type == "AUTOINDEX":
log.info(f"Using AUTOINDEX with metric {metric_type} (params managed by Milvus).")
# No specific build-time parameters needed for AUTOINDEX
else:
log.warning(
f"Unsupported or unrecognized MILVUS_INDEX_TYPE: '{milvus_index_type}'. "
f"Falling back to '{default_index_type}'. "
f"Supported types: HNSW, IVF_FLAT, FLAT, AUTOINDEX."
)
milvus_index_type = default_index_type # Fallback to a safe default
# index_creation_params remains {} which is fine for AUTOINDEX/FLAT
log.info(f"Fell back to {default_index_type} index with metric {metric_type}.")
index_params.add_index( index_params.add_index(
field_name="vector", field_name="vector",
index_type="HNSW", index_type=milvus_index_type,
metric_type="COSINE", metric_type=metric_type,
params={"M": 16, "efConstruction": 100}, params=index_creation_params,
) )
self.client.create_collection( self.client.create_collection(
@ -125,6 +175,8 @@ class MilvusClient(VectorDBBase):
schema=schema, schema=schema,
index_params=index_params, index_params=index_params,
) )
log.info(f"Successfully created collection '{self.collection_prefix}_{collection_name}' with index type '{milvus_index_type}' and metric '{metric_type}'.")
def has_collection(self, collection_name: str) -> bool: def has_collection(self, collection_name: str) -> bool:
# Check if the collection exists based on the collection name. # Check if the collection exists based on the collection name.
@ -145,84 +197,95 @@ class MilvusClient(VectorDBBase):
) -> Optional[SearchResult]: ) -> Optional[SearchResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results. # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
collection_name = collection_name.replace("-", "_") collection_name = collection_name.replace("-", "_")
# For some index types like IVF_FLAT, search params like nprobe can be set.
# Example: search_params = {"nprobe": 10} if using IVF_FLAT
# For simplicity, not adding configurable search_params here, but could be extended.
result = self.client.search( result = self.client.search(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f"{self.collection_prefix}_{collection_name}",
data=vectors, data=vectors,
limit=limit, limit=limit,
output_fields=["data", "metadata"], output_fields=["data", "metadata"],
# search_params=search_params # Potentially add later if needed
) )
return self._result_to_search_result(result) return self._result_to_search_result(result)
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None): def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
# Construct the filter string for querying # Construct the filter string for querying
collection_name = collection_name.replace("-", "_") collection_name = collection_name.replace("-", "_")
if not self.has_collection(collection_name): if not self.has_collection(collection_name):
log.warning(f"Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}")
return None return None
filter_string = " && ".join( filter_string = " && ".join(
[ [
f'metadata["{key}"] == {json.dumps(value)}' f'metadata["{key}"] == {json.dumps(value)}'
for key, value in filter.items() for key, value in filter.items()
] ]
) )
max_limit = 16383 # The maximum number of records per request max_limit = 16383 # The maximum number of records per request
all_results = [] all_results = []
if limit is None: if limit is None:
limit = float("inf") # Use infinity as a placeholder for no limit # Milvus default limit for query if not specified is 16384, but docs mention iteration.
# Let's set a practical high number if "all" is intended, or handle true pagination.
# For now, if limit is None, we'll fetch in batches up to a very large number.
# This part could be refined based on expected use cases for "get all".
# For this function signature, None implies "as many as possible" up to Milvus limits.
limit = 16384 * 10 # A large number to signify fetching many, will be capped by actual data or max_limit per call.
log.info(f"Limit not specified for query, fetching up to {limit} results in batches.")
# Initialize offset and remaining to handle pagination # Initialize offset and remaining to handle pagination
offset = 0 offset = 0
remaining = limit remaining = limit
try: try:
log.info(f"Querying collection {self.collection_prefix}_{collection_name} with filter: '{filter_string}', limit: {limit}")
# Loop until there are no more items to fetch or the desired limit is reached # Loop until there are no more items to fetch or the desired limit is reached
while remaining > 0: while remaining > 0:
log.info(f"remaining: {remaining}") current_fetch = min(max_limit, remaining if isinstance(remaining, int) else max_limit)
current_fetch = min( log.debug(f"Querying with offset: {offset}, current_fetch: {current_fetch}")
max_limit, remaining
) # Determine how many items to fetch in this iteration
results = self.client.query( results = self.client.query(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f"{self.collection_prefix}_{collection_name}",
filter=filter_string, filter=filter_string,
output_fields=["*"], output_fields=["id", "data", "metadata"], # Explicitly list needed fields. Vector not usually needed in query.
limit=current_fetch, limit=current_fetch,
offset=offset, offset=offset,
) )
if not results: if not results:
log.debug("No more results from query.")
break break
all_results.extend(results) all_results.extend(results)
results_count = len(results) results_count = len(results)
remaining -= ( log.debug(f"Fetched {results_count} results in this batch.")
results_count # Decrease remaining by the number of items fetched
) if isinstance(remaining, int):
remaining -= results_count
offset += results_count offset += results_count
# Break the loop if the results returned are less than the requested fetch count # Break the loop if the results returned are less than the requested fetch count (means end of data)
if results_count < current_fetch: if results_count < current_fetch:
log.debug("Fetched less than requested, assuming end of results for this query.")
break break
log.debug(all_results) log.info(f"Total results from query: {len(all_results)}")
return self._result_to_get_result([all_results]) return self._result_to_get_result([all_results])
except Exception as e: except Exception as e:
log.exception( log.exception(
f"Error querying collection {collection_name} with limit {limit}: {e}" f"Error querying collection {self.collection_prefix}_{collection_name} with filter '{filter_string}' and limit {limit}: {e}"
) )
return None return None
def get(self, collection_name: str) -> Optional[GetResult]: def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection. # Get all the items in the collection. This can be very resource-intensive for large collections.
collection_name = collection_name.replace("-", "_") collection_name = collection_name.replace("-", "_")
result = self.client.query( log.warning(f"Fetching ALL items from collection '{self.collection_prefix}_{collection_name}'. This might be slow for large collections.")
collection_name=f"{self.collection_prefix}_{collection_name}", # Using query with a trivial filter to get all items.
filter='id != ""', # This will use the paginated query logic.
) return self.query(collection_name=collection_name, filter={}, limit=None)
return self._result_to_get_result([result])
def insert(self, collection_name: str, items: list[VectorItem]): def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created. # Insert the items into the collection, if the collection does not exist, it will be created.
@ -230,10 +293,15 @@ class MilvusClient(VectorDBBase):
if not self.client.has_collection( if not self.client.has_collection(
collection_name=f"{self.collection_prefix}_{collection_name}" collection_name=f"{self.collection_prefix}_{collection_name}"
): ):
log.info(f"Collection {self.collection_prefix}_{collection_name} does not exist. Creating now.")
if not items:
log.error(f"Cannot create collection {self.collection_prefix}_{collection_name} without items to determine dimension.")
raise ValueError("Cannot create Milvus collection without items to determine vector dimension.")
self._create_collection( self._create_collection(
collection_name=collection_name, dimension=len(items[0]["vector"]) collection_name=collection_name, dimension=len(items[0]["vector"])
) )
log.info(f"Inserting {len(items)} items into collection {self.collection_prefix}_{collection_name}.")
return self.client.insert( return self.client.insert(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f"{self.collection_prefix}_{collection_name}",
data=[ data=[
@ -253,10 +321,15 @@ class MilvusClient(VectorDBBase):
if not self.client.has_collection( if not self.client.has_collection(
collection_name=f"{self.collection_prefix}_{collection_name}" collection_name=f"{self.collection_prefix}_{collection_name}"
): ):
log.info(f"Collection {self.collection_prefix}_{collection_name} does not exist for upsert. Creating now.")
if not items:
log.error(f"Cannot create collection {self.collection_prefix}_{collection_name} for upsert without items to determine dimension.")
raise ValueError("Cannot create Milvus collection for upsert without items to determine vector dimension.")
self._create_collection( self._create_collection(
collection_name=collection_name, dimension=len(items[0]["vector"]) collection_name=collection_name, dimension=len(items[0]["vector"])
) )
log.info(f"Upserting {len(items)} items into collection {self.collection_prefix}_{collection_name}.")
return self.client.upsert( return self.client.upsert(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f"{self.collection_prefix}_{collection_name}",
data=[ data=[
@ -276,30 +349,46 @@ class MilvusClient(VectorDBBase):
ids: Optional[list[str]] = None, ids: Optional[list[str]] = None,
filter: Optional[dict] = None, filter: Optional[dict] = None,
): ):
# Delete the items from the collection based on the ids. # Delete the items from the collection based on the ids or filter.
collection_name = collection_name.replace("-", "_") collection_name = collection_name.replace("-", "_")
if not self.has_collection(collection_name):
log.warning(f"Delete attempted on non-existent collection: {self.collection_prefix}_{collection_name}")
return None
if ids: if ids:
log.info(f"Deleting items by IDs from {self.collection_prefix}_{collection_name}. IDs: {ids}")
return self.client.delete( return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f"{self.collection_prefix}_{collection_name}",
ids=ids, ids=ids,
) )
elif filter: elif filter:
# Convert the filter dictionary to a string using JSON_CONTAINS.
filter_string = " && ".join( filter_string = " && ".join(
[ [
f'metadata["{key}"] == {json.dumps(value)}' f'metadata["{key}"] == {json.dumps(value)}'
for key, value in filter.items() for key, value in filter.items()
] ]
) )
log.info(f"Deleting items by filter from {self.collection_prefix}_{collection_name}. Filter: {filter_string}")
return self.client.delete( return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f"{self.collection_prefix}_{collection_name}",
filter=filter_string, filter=filter_string,
) )
else:
log.warning(f"Delete operation on {self.collection_prefix}_{collection_name} called without IDs or filter. No action taken.")
return None
def reset(self): def reset(self):
# Resets the database. This will delete all collections and item entries. # Resets the database. This will delete all collections and item entries that match the prefix.
log.warning(f"Resetting Milvus: Deleting all collections with prefix '{self.collection_prefix}'.")
collection_names = self.client.list_collections() collection_names = self.client.list_collections()
for collection_name in collection_names: deleted_collections = []
if collection_name.startswith(self.collection_prefix): for collection_name_full in collection_names:
self.client.drop_collection(collection_name=collection_name) if collection_name_full.startswith(self.collection_prefix):
try:
self.client.drop_collection(collection_name=collection_name_full)
deleted_collections.append(collection_name_full)
log.info(f"Deleted collection: {collection_name_full}")
except Exception as e:
log.error(f"Error deleting collection {collection_name_full}: {e}")
log.info(f"Milvus reset complete. Deleted collections: {deleted_collections}")