mirror of
https://git.mirrors.martin98.com/https://github.com/open-webui/open-webui
synced 2025-08-16 09:05:59 +08:00
refac: enhance MilvusClient with dynamic index type and improved logging
This commit is contained in:
parent
b34401a087
commit
5e46c27806
@ -1,9 +1,9 @@
|
||||
import os # Added import
|
||||
from pymilvus import MilvusClient as Client
|
||||
from pymilvus import FieldSchema, DataType
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
@ -20,7 +20,6 @@ from open_webui.env import SRC_LOG_LEVELS
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class MilvusClient(VectorDBBase):
|
||||
def __init__(self):
|
||||
self.collection_prefix = "open_webui"
|
||||
@ -33,7 +32,6 @@ class MilvusClient(VectorDBBase):
|
||||
ids = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for match in result:
|
||||
_ids = []
|
||||
_documents = []
|
||||
@ -42,11 +40,9 @@ class MilvusClient(VectorDBBase):
|
||||
_ids.append(item.get("id"))
|
||||
_documents.append(item.get("data", {}).get("text"))
|
||||
_metadatas.append(item.get("metadata"))
|
||||
|
||||
ids.append(_ids)
|
||||
documents.append(_documents)
|
||||
metadatas.append(_metadatas)
|
||||
|
||||
return GetResult(
|
||||
**{
|
||||
"ids": ids,
|
||||
@ -60,13 +56,11 @@ class MilvusClient(VectorDBBase):
|
||||
distances = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for match in result:
|
||||
_ids = []
|
||||
_distances = []
|
||||
_documents = []
|
||||
_metadatas = []
|
||||
|
||||
for item in match:
|
||||
_ids.append(item.get("id"))
|
||||
# normalize milvus score from [-1, 1] to [0, 1] range
|
||||
@ -75,12 +69,10 @@ class MilvusClient(VectorDBBase):
|
||||
_distances.append(_dist)
|
||||
_documents.append(item.get("entity", {}).get("data", {}).get("text"))
|
||||
_metadatas.append(item.get("entity", {}).get("metadata"))
|
||||
|
||||
ids.append(_ids)
|
||||
distances.append(_distances)
|
||||
documents.append(_documents)
|
||||
metadatas.append(_metadatas)
|
||||
|
||||
return SearchResult(
|
||||
**{
|
||||
"ids": ids,
|
||||
@ -113,11 +105,69 @@ class MilvusClient(VectorDBBase):
|
||||
)
|
||||
|
||||
index_params = self.client.prepare_index_params()
|
||||
|
||||
# Get index type from environment variable.
|
||||
# Milvus standalone (local mode) supports: FLAT, IVF_FLAT, AUTOINDEX.
|
||||
# HNSW is often preferred for performance but may require a clustered Milvus setup.
|
||||
# Defaulting to AUTOINDEX for broader compatibility, especially with Milvus standalone.
|
||||
default_index_type = "AUTOINDEX"
|
||||
milvus_index_type_env = os.getenv("MILVUS_INDEX_TYPE")
|
||||
|
||||
if milvus_index_type_env:
|
||||
milvus_index_type = milvus_index_type_env.upper()
|
||||
log.info(f"Milvus index type from MILVUS_INDEX_TYPE env var: {milvus_index_type}")
|
||||
else:
|
||||
milvus_index_type = default_index_type
|
||||
log.info(f"MILVUS_INDEX_TYPE env var not set, defaulting to: {milvus_index_type}")
|
||||
|
||||
index_creation_params = {}
|
||||
metric_type = os.getenv("MILVUS_METRIC_TYPE", "COSINE").upper() # Default to COSINE
|
||||
|
||||
if milvus_index_type == "HNSW":
|
||||
# Parameters for HNSW
|
||||
m_env = os.getenv("MILVUS_HNSW_M", "16")
|
||||
ef_construction_env = os.getenv("MILVUS_HNSW_EFCONSTRUCTION", "100")
|
||||
try:
|
||||
m_val = int(m_env)
|
||||
ef_val = int(ef_construction_env)
|
||||
except ValueError:
|
||||
log.warning(f"Invalid HNSW params M='{m_env}' or efConstruction='{ef_construction_env}'. Defaulting to M=16, efConstruction=100.")
|
||||
m_val = 16
|
||||
ef_val = 100
|
||||
index_creation_params = {"M": m_val, "efConstruction": ef_val}
|
||||
log.info(f"Using HNSW index with metric {metric_type}, params: {index_creation_params}")
|
||||
elif milvus_index_type == "IVF_FLAT":
|
||||
# Parameters for IVF_FLAT
|
||||
nlist_env = os.getenv("MILVUS_IVF_FLAT_NLIST", "128")
|
||||
try:
|
||||
nlist = int(nlist_env)
|
||||
except ValueError:
|
||||
log.warning(f"Invalid MILVUS_IVF_FLAT_NLIST value '{nlist_env}'. Defaulting to 128.")
|
||||
nlist = 128
|
||||
index_creation_params = {"nlist": nlist}
|
||||
log.info(f"Using IVF_FLAT index with metric {metric_type}, params: {index_creation_params}")
|
||||
elif milvus_index_type == "FLAT":
|
||||
log.info(f"Using FLAT index with metric {metric_type} (no specific build-time params).")
|
||||
# No specific build-time parameters needed for FLAT
|
||||
elif milvus_index_type == "AUTOINDEX":
|
||||
log.info(f"Using AUTOINDEX with metric {metric_type} (params managed by Milvus).")
|
||||
# No specific build-time parameters needed for AUTOINDEX
|
||||
else:
|
||||
log.warning(
|
||||
f"Unsupported or unrecognized MILVUS_INDEX_TYPE: '{milvus_index_type}'. "
|
||||
f"Falling back to '{default_index_type}'. "
|
||||
f"Supported types: HNSW, IVF_FLAT, FLAT, AUTOINDEX."
|
||||
)
|
||||
milvus_index_type = default_index_type # Fallback to a safe default
|
||||
# index_creation_params remains {} which is fine for AUTOINDEX/FLAT
|
||||
log.info(f"Fell back to {default_index_type} index with metric {metric_type}.")
|
||||
|
||||
|
||||
index_params.add_index(
|
||||
field_name="vector",
|
||||
index_type="HNSW",
|
||||
metric_type="COSINE",
|
||||
params={"M": 16, "efConstruction": 100},
|
||||
index_type=milvus_index_type,
|
||||
metric_type=metric_type,
|
||||
params=index_creation_params,
|
||||
)
|
||||
|
||||
self.client.create_collection(
|
||||
@ -125,6 +175,8 @@ class MilvusClient(VectorDBBase):
|
||||
schema=schema,
|
||||
index_params=index_params,
|
||||
)
|
||||
log.info(f"Successfully created collection '{self.collection_prefix}_{collection_name}' with index type '{milvus_index_type}' and metric '{metric_type}'.")
|
||||
|
||||
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
# Check if the collection exists based on the collection name.
|
||||
@ -145,84 +197,95 @@ class MilvusClient(VectorDBBase):
|
||||
) -> Optional[SearchResult]:
|
||||
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
# For some index types like IVF_FLAT, search params like nprobe can be set.
|
||||
# Example: search_params = {"nprobe": 10} if using IVF_FLAT
|
||||
# For simplicity, not adding configurable search_params here, but could be extended.
|
||||
result = self.client.search(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
data=vectors,
|
||||
limit=limit,
|
||||
output_fields=["data", "metadata"],
|
||||
# search_params=search_params # Potentially add later if needed
|
||||
)
|
||||
|
||||
return self._result_to_search_result(result)
|
||||
|
||||
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
|
||||
# Construct the filter string for querying
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
if not self.has_collection(collection_name):
|
||||
log.warning(f"Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}")
|
||||
return None
|
||||
|
||||
filter_string = " && ".join(
|
||||
[
|
||||
f'metadata["{key}"] == {json.dumps(value)}'
|
||||
for key, value in filter.items()
|
||||
]
|
||||
)
|
||||
|
||||
max_limit = 16383 # The maximum number of records per request
|
||||
all_results = []
|
||||
|
||||
if limit is None:
|
||||
limit = float("inf") # Use infinity as a placeholder for no limit
|
||||
# Milvus default limit for query if not specified is 16384, but docs mention iteration.
|
||||
# Let's set a practical high number if "all" is intended, or handle true pagination.
|
||||
# For now, if limit is None, we'll fetch in batches up to a very large number.
|
||||
# This part could be refined based on expected use cases for "get all".
|
||||
# For this function signature, None implies "as many as possible" up to Milvus limits.
|
||||
limit = 16384 * 10 # A large number to signify fetching many, will be capped by actual data or max_limit per call.
|
||||
log.info(f"Limit not specified for query, fetching up to {limit} results in batches.")
|
||||
|
||||
|
||||
# Initialize offset and remaining to handle pagination
|
||||
offset = 0
|
||||
remaining = limit
|
||||
|
||||
|
||||
try:
|
||||
log.info(f"Querying collection {self.collection_prefix}_{collection_name} with filter: '{filter_string}', limit: {limit}")
|
||||
# Loop until there are no more items to fetch or the desired limit is reached
|
||||
while remaining > 0:
|
||||
log.info(f"remaining: {remaining}")
|
||||
current_fetch = min(
|
||||
max_limit, remaining
|
||||
) # Determine how many items to fetch in this iteration
|
||||
|
||||
current_fetch = min(max_limit, remaining if isinstance(remaining, int) else max_limit)
|
||||
log.debug(f"Querying with offset: {offset}, current_fetch: {current_fetch}")
|
||||
|
||||
results = self.client.query(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
filter=filter_string,
|
||||
output_fields=["*"],
|
||||
output_fields=["id", "data", "metadata"], # Explicitly list needed fields. Vector not usually needed in query.
|
||||
limit=current_fetch,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
|
||||
if not results:
|
||||
log.debug("No more results from query.")
|
||||
break
|
||||
|
||||
|
||||
all_results.extend(results)
|
||||
results_count = len(results)
|
||||
remaining -= (
|
||||
results_count # Decrease remaining by the number of items fetched
|
||||
)
|
||||
log.debug(f"Fetched {results_count} results in this batch.")
|
||||
|
||||
if isinstance(remaining, int):
|
||||
remaining -= results_count
|
||||
|
||||
offset += results_count
|
||||
|
||||
# Break the loop if the results returned are less than the requested fetch count
|
||||
|
||||
# Break the loop if the results returned are less than the requested fetch count (means end of data)
|
||||
if results_count < current_fetch:
|
||||
log.debug("Fetched less than requested, assuming end of results for this query.")
|
||||
break
|
||||
|
||||
log.debug(all_results)
|
||||
|
||||
log.info(f"Total results from query: {len(all_results)}")
|
||||
return self._result_to_get_result([all_results])
|
||||
except Exception as e:
|
||||
log.exception(
|
||||
f"Error querying collection {collection_name} with limit {limit}: {e}"
|
||||
f"Error querying collection {self.collection_prefix}_{collection_name} with filter '{filter_string}' and limit {limit}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
# Get all the items in the collection.
|
||||
# Get all the items in the collection. This can be very resource-intensive for large collections.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
result = self.client.query(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
filter='id != ""',
|
||||
)
|
||||
return self._result_to_get_result([result])
|
||||
log.warning(f"Fetching ALL items from collection '{self.collection_prefix}_{collection_name}'. This might be slow for large collections.")
|
||||
# Using query with a trivial filter to get all items.
|
||||
# This will use the paginated query logic.
|
||||
return self.query(collection_name=collection_name, filter={}, limit=None)
|
||||
|
||||
|
||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Insert the items into the collection, if the collection does not exist, it will be created.
|
||||
@ -230,10 +293,15 @@ class MilvusClient(VectorDBBase):
|
||||
if not self.client.has_collection(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}"
|
||||
):
|
||||
log.info(f"Collection {self.collection_prefix}_{collection_name} does not exist. Creating now.")
|
||||
if not items:
|
||||
log.error(f"Cannot create collection {self.collection_prefix}_{collection_name} without items to determine dimension.")
|
||||
raise ValueError("Cannot create Milvus collection without items to determine vector dimension.")
|
||||
self._create_collection(
|
||||
collection_name=collection_name, dimension=len(items[0]["vector"])
|
||||
)
|
||||
|
||||
|
||||
log.info(f"Inserting {len(items)} items into collection {self.collection_prefix}_{collection_name}.")
|
||||
return self.client.insert(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
data=[
|
||||
@ -253,10 +321,15 @@ class MilvusClient(VectorDBBase):
|
||||
if not self.client.has_collection(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}"
|
||||
):
|
||||
log.info(f"Collection {self.collection_prefix}_{collection_name} does not exist for upsert. Creating now.")
|
||||
if not items:
|
||||
log.error(f"Cannot create collection {self.collection_prefix}_{collection_name} for upsert without items to determine dimension.")
|
||||
raise ValueError("Cannot create Milvus collection for upsert without items to determine vector dimension.")
|
||||
self._create_collection(
|
||||
collection_name=collection_name, dimension=len(items[0]["vector"])
|
||||
)
|
||||
|
||||
|
||||
log.info(f"Upserting {len(items)} items into collection {self.collection_prefix}_{collection_name}.")
|
||||
return self.client.upsert(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
data=[
|
||||
@ -276,30 +349,46 @@ class MilvusClient(VectorDBBase):
|
||||
ids: Optional[list[str]] = None,
|
||||
filter: Optional[dict] = None,
|
||||
):
|
||||
# Delete the items from the collection based on the ids.
|
||||
# Delete the items from the collection based on the ids or filter.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
if not self.has_collection(collection_name):
|
||||
log.warning(f"Delete attempted on non-existent collection: {self.collection_prefix}_{collection_name}")
|
||||
return None
|
||||
|
||||
if ids:
|
||||
log.info(f"Deleting items by IDs from {self.collection_prefix}_{collection_name}. IDs: {ids}")
|
||||
return self.client.delete(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
ids=ids,
|
||||
)
|
||||
elif filter:
|
||||
# Convert the filter dictionary to a string using JSON_CONTAINS.
|
||||
filter_string = " && ".join(
|
||||
[
|
||||
f'metadata["{key}"] == {json.dumps(value)}'
|
||||
for key, value in filter.items()
|
||||
]
|
||||
)
|
||||
|
||||
log.info(f"Deleting items by filter from {self.collection_prefix}_{collection_name}. Filter: {filter_string}")
|
||||
return self.client.delete(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
filter=filter_string,
|
||||
)
|
||||
else:
|
||||
log.warning(f"Delete operation on {self.collection_prefix}_{collection_name} called without IDs or filter. No action taken.")
|
||||
return None
|
||||
|
||||
|
||||
def reset(self):
|
||||
# Resets the database. This will delete all collections and item entries.
|
||||
# Resets the database. This will delete all collections and item entries that match the prefix.
|
||||
log.warning(f"Resetting Milvus: Deleting all collections with prefix '{self.collection_prefix}'.")
|
||||
collection_names = self.client.list_collections()
|
||||
for collection_name in collection_names:
|
||||
if collection_name.startswith(self.collection_prefix):
|
||||
self.client.drop_collection(collection_name=collection_name)
|
||||
deleted_collections = []
|
||||
for collection_name_full in collection_names:
|
||||
if collection_name_full.startswith(self.collection_prefix):
|
||||
try:
|
||||
self.client.drop_collection(collection_name=collection_name_full)
|
||||
deleted_collections.append(collection_name_full)
|
||||
log.info(f"Deleted collection: {collection_name_full}")
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting collection {collection_name_full}: {e}")
|
||||
log.info(f"Milvus reset complete. Deleted collections: {deleted_collections}")
|
||||
|
Loading…
x
Reference in New Issue
Block a user