From 5e46c278067d48911e47d77b1fd6a87ec7a26919 Mon Sep 17 00:00:00 2001 From: Matt Harrison Date: Wed, 7 May 2025 21:51:28 -0400 Subject: [PATCH] refac: enhance MilvusClient with dynamic index type and improved logging --- .../open_webui/retrieval/vector/dbs/milvus.py | 187 +++++++++++++----- 1 file changed, 138 insertions(+), 49 deletions(-) diff --git a/backend/open_webui/retrieval/vector/dbs/milvus.py b/backend/open_webui/retrieval/vector/dbs/milvus.py index f116c57f7..0f8481c10 100644 --- a/backend/open_webui/retrieval/vector/dbs/milvus.py +++ b/backend/open_webui/retrieval/vector/dbs/milvus.py @@ -1,9 +1,9 @@ +import os # Added import from pymilvus import MilvusClient as Client from pymilvus import FieldSchema, DataType import json import logging from typing import Optional - from open_webui.retrieval.vector.main import ( VectorDBBase, VectorItem, @@ -20,7 +20,6 @@ from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) - class MilvusClient(VectorDBBase): def __init__(self): self.collection_prefix = "open_webui" @@ -33,7 +32,6 @@ class MilvusClient(VectorDBBase): ids = [] documents = [] metadatas = [] - for match in result: _ids = [] _documents = [] @@ -42,11 +40,9 @@ class MilvusClient(VectorDBBase): _ids.append(item.get("id")) _documents.append(item.get("data", {}).get("text")) _metadatas.append(item.get("metadata")) - ids.append(_ids) documents.append(_documents) metadatas.append(_metadatas) - return GetResult( **{ "ids": ids, @@ -60,13 +56,11 @@ class MilvusClient(VectorDBBase): distances = [] documents = [] metadatas = [] - for match in result: _ids = [] _distances = [] _documents = [] _metadatas = [] - for item in match: _ids.append(item.get("id")) # normalize milvus score from [-1, 1] to [0, 1] range @@ -75,12 +69,10 @@ class MilvusClient(VectorDBBase): _distances.append(_dist) _documents.append(item.get("entity", {}).get("data", {}).get("text")) _metadatas.append(item.get("entity", {}).get("metadata")) - ids.append(_ids) distances.append(_distances) documents.append(_documents) metadatas.append(_metadatas) - return SearchResult( **{ "ids": ids, @@ -113,11 +105,69 @@ class MilvusClient(VectorDBBase): ) index_params = self.client.prepare_index_params() + + # Get index type from environment variable. + # Milvus standalone (local mode) supports: FLAT, IVF_FLAT, AUTOINDEX. + # HNSW is often preferred for performance but may require a clustered Milvus setup. + # Defaulting to AUTOINDEX for broader compatibility, especially with Milvus standalone. + default_index_type = "AUTOINDEX" + milvus_index_type_env = os.getenv("MILVUS_INDEX_TYPE") + + if milvus_index_type_env: + milvus_index_type = milvus_index_type_env.upper() + log.info(f"Milvus index type from MILVUS_INDEX_TYPE env var: {milvus_index_type}") + else: + milvus_index_type = default_index_type + log.info(f"MILVUS_INDEX_TYPE env var not set, defaulting to: {milvus_index_type}") + + index_creation_params = {} + metric_type = os.getenv("MILVUS_METRIC_TYPE", "COSINE").upper() # Default to COSINE + + if milvus_index_type == "HNSW": + # Parameters for HNSW + m_env = os.getenv("MILVUS_HNSW_M", "16") + ef_construction_env = os.getenv("MILVUS_HNSW_EFCONSTRUCTION", "100") + try: + m_val = int(m_env) + ef_val = int(ef_construction_env) + except ValueError: + log.warning(f"Invalid HNSW params M='{m_env}' or efConstruction='{ef_construction_env}'. Defaulting to M=16, efConstruction=100.") + m_val = 16 + ef_val = 100 + index_creation_params = {"M": m_val, "efConstruction": ef_val} + log.info(f"Using HNSW index with metric {metric_type}, params: {index_creation_params}") + elif milvus_index_type == "IVF_FLAT": + # Parameters for IVF_FLAT + nlist_env = os.getenv("MILVUS_IVF_FLAT_NLIST", "128") + try: + nlist = int(nlist_env) + except ValueError: + log.warning(f"Invalid MILVUS_IVF_FLAT_NLIST value '{nlist_env}'. Defaulting to 128.") + nlist = 128 + index_creation_params = {"nlist": nlist} + log.info(f"Using IVF_FLAT index with metric {metric_type}, params: {index_creation_params}") + elif milvus_index_type == "FLAT": + log.info(f"Using FLAT index with metric {metric_type} (no specific build-time params).") + # No specific build-time parameters needed for FLAT + elif milvus_index_type == "AUTOINDEX": + log.info(f"Using AUTOINDEX with metric {metric_type} (params managed by Milvus).") + # No specific build-time parameters needed for AUTOINDEX + else: + log.warning( + f"Unsupported or unrecognized MILVUS_INDEX_TYPE: '{milvus_index_type}'. " + f"Falling back to '{default_index_type}'. " + f"Supported types: HNSW, IVF_FLAT, FLAT, AUTOINDEX." + ) + milvus_index_type = default_index_type # Fallback to a safe default + # index_creation_params remains {} which is fine for AUTOINDEX/FLAT + log.info(f"Fell back to {default_index_type} index with metric {metric_type}.") + + index_params.add_index( field_name="vector", - index_type="HNSW", - metric_type="COSINE", - params={"M": 16, "efConstruction": 100}, + index_type=milvus_index_type, + metric_type=metric_type, + params=index_creation_params, ) self.client.create_collection( @@ -125,6 +175,8 @@ class MilvusClient(VectorDBBase): schema=schema, index_params=index_params, ) + log.info(f"Successfully created collection '{self.collection_prefix}_{collection_name}' with index type '{milvus_index_type}' and metric '{metric_type}'.") + def has_collection(self, collection_name: str) -> bool: # Check if the collection exists based on the collection name. @@ -145,84 +197,95 @@ class MilvusClient(VectorDBBase): ) -> Optional[SearchResult]: # Search for the nearest neighbor items based on the vectors and return 'limit' number of results. collection_name = collection_name.replace("-", "_") + # For some index types like IVF_FLAT, search params like nprobe can be set. + # Example: search_params = {"nprobe": 10} if using IVF_FLAT + # For simplicity, not adding configurable search_params here, but could be extended. result = self.client.search( collection_name=f"{self.collection_prefix}_{collection_name}", data=vectors, limit=limit, output_fields=["data", "metadata"], + # search_params=search_params # Potentially add later if needed ) - return self._result_to_search_result(result) def query(self, collection_name: str, filter: dict, limit: Optional[int] = None): # Construct the filter string for querying collection_name = collection_name.replace("-", "_") if not self.has_collection(collection_name): + log.warning(f"Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}") return None - filter_string = " && ".join( [ f'metadata["{key}"] == {json.dumps(value)}' for key, value in filter.items() ] ) - max_limit = 16383 # The maximum number of records per request all_results = [] - if limit is None: - limit = float("inf") # Use infinity as a placeholder for no limit + # Milvus default limit for query if not specified is 16384, but docs mention iteration. + # Let's set a practical high number if "all" is intended, or handle true pagination. + # For now, if limit is None, we'll fetch in batches up to a very large number. + # This part could be refined based on expected use cases for "get all". + # For this function signature, None implies "as many as possible" up to Milvus limits. + limit = 16384 * 10 # A large number to signify fetching many, will be capped by actual data or max_limit per call. + log.info(f"Limit not specified for query, fetching up to {limit} results in batches.") + # Initialize offset and remaining to handle pagination offset = 0 remaining = limit - + try: + log.info(f"Querying collection {self.collection_prefix}_{collection_name} with filter: '{filter_string}', limit: {limit}") # Loop until there are no more items to fetch or the desired limit is reached while remaining > 0: - log.info(f"remaining: {remaining}") - current_fetch = min( - max_limit, remaining - ) # Determine how many items to fetch in this iteration - + current_fetch = min(max_limit, remaining if isinstance(remaining, int) else max_limit) + log.debug(f"Querying with offset: {offset}, current_fetch: {current_fetch}") + results = self.client.query( collection_name=f"{self.collection_prefix}_{collection_name}", filter=filter_string, - output_fields=["*"], + output_fields=["id", "data", "metadata"], # Explicitly list needed fields. Vector not usually needed in query. limit=current_fetch, offset=offset, ) - + if not results: + log.debug("No more results from query.") break - + all_results.extend(results) results_count = len(results) - remaining -= ( - results_count # Decrease remaining by the number of items fetched - ) + log.debug(f"Fetched {results_count} results in this batch.") + + if isinstance(remaining, int): + remaining -= results_count + offset += results_count - - # Break the loop if the results returned are less than the requested fetch count + + # Break the loop if the results returned are less than the requested fetch count (means end of data) if results_count < current_fetch: + log.debug("Fetched less than requested, assuming end of results for this query.") break - - log.debug(all_results) + + log.info(f"Total results from query: {len(all_results)}") return self._result_to_get_result([all_results]) except Exception as e: log.exception( - f"Error querying collection {collection_name} with limit {limit}: {e}" + f"Error querying collection {self.collection_prefix}_{collection_name} with filter '{filter_string}' and limit {limit}: {e}" ) return None def get(self, collection_name: str) -> Optional[GetResult]: - # Get all the items in the collection. + # Get all the items in the collection. This can be very resource-intensive for large collections. collection_name = collection_name.replace("-", "_") - result = self.client.query( - collection_name=f"{self.collection_prefix}_{collection_name}", - filter='id != ""', - ) - return self._result_to_get_result([result]) + log.warning(f"Fetching ALL items from collection '{self.collection_prefix}_{collection_name}'. This might be slow for large collections.") + # Using query with a trivial filter to get all items. + # This will use the paginated query logic. + return self.query(collection_name=collection_name, filter={}, limit=None) + def insert(self, collection_name: str, items: list[VectorItem]): # Insert the items into the collection, if the collection does not exist, it will be created. @@ -230,10 +293,15 @@ class MilvusClient(VectorDBBase): if not self.client.has_collection( collection_name=f"{self.collection_prefix}_{collection_name}" ): + log.info(f"Collection {self.collection_prefix}_{collection_name} does not exist. Creating now.") + if not items: + log.error(f"Cannot create collection {self.collection_prefix}_{collection_name} without items to determine dimension.") + raise ValueError("Cannot create Milvus collection without items to determine vector dimension.") self._create_collection( collection_name=collection_name, dimension=len(items[0]["vector"]) ) - + + log.info(f"Inserting {len(items)} items into collection {self.collection_prefix}_{collection_name}.") return self.client.insert( collection_name=f"{self.collection_prefix}_{collection_name}", data=[ @@ -253,10 +321,15 @@ class MilvusClient(VectorDBBase): if not self.client.has_collection( collection_name=f"{self.collection_prefix}_{collection_name}" ): + log.info(f"Collection {self.collection_prefix}_{collection_name} does not exist for upsert. Creating now.") + if not items: + log.error(f"Cannot create collection {self.collection_prefix}_{collection_name} for upsert without items to determine dimension.") + raise ValueError("Cannot create Milvus collection for upsert without items to determine vector dimension.") self._create_collection( collection_name=collection_name, dimension=len(items[0]["vector"]) ) - + + log.info(f"Upserting {len(items)} items into collection {self.collection_prefix}_{collection_name}.") return self.client.upsert( collection_name=f"{self.collection_prefix}_{collection_name}", data=[ @@ -276,30 +349,46 @@ class MilvusClient(VectorDBBase): ids: Optional[list[str]] = None, filter: Optional[dict] = None, ): - # Delete the items from the collection based on the ids. + # Delete the items from the collection based on the ids or filter. collection_name = collection_name.replace("-", "_") + if not self.has_collection(collection_name): + log.warning(f"Delete attempted on non-existent collection: {self.collection_prefix}_{collection_name}") + return None + if ids: + log.info(f"Deleting items by IDs from {self.collection_prefix}_{collection_name}. IDs: {ids}") return self.client.delete( collection_name=f"{self.collection_prefix}_{collection_name}", ids=ids, ) elif filter: - # Convert the filter dictionary to a string using JSON_CONTAINS. filter_string = " && ".join( [ f'metadata["{key}"] == {json.dumps(value)}' for key, value in filter.items() ] ) - + log.info(f"Deleting items by filter from {self.collection_prefix}_{collection_name}. Filter: {filter_string}") return self.client.delete( collection_name=f"{self.collection_prefix}_{collection_name}", filter=filter_string, ) + else: + log.warning(f"Delete operation on {self.collection_prefix}_{collection_name} called without IDs or filter. No action taken.") + return None + def reset(self): - # Resets the database. This will delete all collections and item entries. + # Resets the database. This will delete all collections and item entries that match the prefix. + log.warning(f"Resetting Milvus: Deleting all collections with prefix '{self.collection_prefix}'.") collection_names = self.client.list_collections() - for collection_name in collection_names: - if collection_name.startswith(self.collection_prefix): - self.client.drop_collection(collection_name=collection_name) + deleted_collections = [] + for collection_name_full in collection_names: + if collection_name_full.startswith(self.collection_prefix): + try: + self.client.drop_collection(collection_name=collection_name_full) + deleted_collections.append(collection_name_full) + log.info(f"Deleted collection: {collection_name_full}") + except Exception as e: + log.error(f"Error deleting collection {collection_name_full}: {e}") + log.info(f"Milvus reset complete. Deleted collections: {deleted_collections}")