diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index 2b23cad17..410945c81 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -260,23 +260,47 @@ def query_collection( k: int, ) -> dict: results = [] - for query in queries: - log.debug(f"query_collection:query {query}") - query_embedding = embedding_function(query, prefix=RAG_EMBEDDING_QUERY_PREFIX) - for collection_name in collection_names: + error = False + + def process_query_collection(collection_name, query_embedding): + try: if collection_name: - try: - result = query_doc( - collection_name=collection_name, - k=k, - query_embedding=query_embedding, - ) - if result is not None: - results.append(result.model_dump()) - except Exception as e: - log.exception(f"Error when querying the collection: {e}") - else: - pass + result = query_doc( + collection_name=collection_name, + k=k, + query_embedding=query_embedding, + ) + if result is not None: + return result.model_dump(), None + return None, None + except Exception as e: + log.exception(f"Error when querying the collection: {e}") + return None, e + + # Generate all query embeddings (in one call) + query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX) + log.debug( + f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections" + ) + + with ThreadPoolExecutor() as executor: + future_results = [] + for query_embedding in query_embeddings: + for collection_name in collection_names: + result = executor.submit( + process_query_collection, collection_name, query_embedding + ) + future_results.append(result) + task_results = [future.result() for future in future_results] + + for result, err in task_results: + if err is not None: + error = True + elif result is not None: + results.append(result) + + if error and not results: + log.warning("All collection queries failed. No results returned.") return merge_and_sort_query_results(results, k=k)