mirror of
https://git.mirrors.martin98.com/https://github.com/open-webui/open-webui
synced 2025-08-18 04:25:52 +08:00
Merge pull request #5378 from thiswillbeyourgithub/fix_RAG_and_web
fix: RAG and Web Search + RAG enhancements
This commit is contained in:
commit
7dc4cb30b2
@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import uuid
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
@ -91,7 +92,7 @@ def query_doc_with_hybrid_search(
|
|||||||
k: int,
|
k: int,
|
||||||
reranking_function,
|
reranking_function,
|
||||||
r: float,
|
r: float,
|
||||||
):
|
) -> dict:
|
||||||
try:
|
try:
|
||||||
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
|
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
|
||||||
|
|
||||||
@ -134,7 +135,7 @@ def query_doc_with_hybrid_search(
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def merge_and_sort_query_results(query_results, k, reverse=False):
|
def merge_and_sort_query_results(query_results: list[dict], k: int, reverse: bool = False) -> list[dict]:
|
||||||
# Initialize lists to store combined data
|
# Initialize lists to store combined data
|
||||||
combined_distances = []
|
combined_distances = []
|
||||||
combined_documents = []
|
combined_documents = []
|
||||||
@ -180,7 +181,7 @@ def query_collection(
|
|||||||
query: str,
|
query: str,
|
||||||
embedding_function,
|
embedding_function,
|
||||||
k: int,
|
k: int,
|
||||||
):
|
) -> dict:
|
||||||
results = []
|
results = []
|
||||||
for collection_name in collection_names:
|
for collection_name in collection_names:
|
||||||
if collection_name:
|
if collection_name:
|
||||||
@ -192,8 +193,8 @@ def query_collection(
|
|||||||
embedding_function=embedding_function,
|
embedding_function=embedding_function,
|
||||||
)
|
)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
log.exception(f"Error when querying the collection: {e}")
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -207,8 +208,9 @@ def query_collection_with_hybrid_search(
|
|||||||
k: int,
|
k: int,
|
||||||
reranking_function,
|
reranking_function,
|
||||||
r: float,
|
r: float,
|
||||||
):
|
) -> dict:
|
||||||
results = []
|
results = []
|
||||||
|
failed = 0
|
||||||
for collection_name in collection_names:
|
for collection_name in collection_names:
|
||||||
try:
|
try:
|
||||||
result = query_doc_with_hybrid_search(
|
result = query_doc_with_hybrid_search(
|
||||||
@ -220,14 +222,39 @@ def query_collection_with_hybrid_search(
|
|||||||
r=r,
|
r=r,
|
||||||
)
|
)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
log.exception(
|
||||||
|
"Error when querying the collection with "
|
||||||
|
f"hybrid_search: {e}"
|
||||||
|
)
|
||||||
|
failed += 1
|
||||||
|
if failed == len(collection_names):
|
||||||
|
raise Exception("Hybrid search failed for all collections. Using "
|
||||||
|
"Non hybrid search as fallback.")
|
||||||
return merge_and_sort_query_results(results, k=k, reverse=True)
|
return merge_and_sort_query_results(results, k=k, reverse=True)
|
||||||
|
|
||||||
|
|
||||||
def rag_template(template: str, context: str, query: str):
|
def rag_template(template: str, context: str, query: str):
|
||||||
template = template.replace("[context]", context)
|
count = template.count("[context]")
|
||||||
template = template.replace("[query]", query)
|
assert count == 1, (
|
||||||
|
f"RAG template contains an unexpected number of '[context]' : {count}"
|
||||||
|
)
|
||||||
|
assert "[context]" in template, "RAG template does not contain '[context]'"
|
||||||
|
if "<context>" in context and "</context>" in context:
|
||||||
|
log.debug(
|
||||||
|
"WARNING: Potential prompt injection attack: the RAG "
|
||||||
|
"context contains '<context>' and '</context>'. This might be "
|
||||||
|
"nothing, or the user might be trying to hack something."
|
||||||
|
)
|
||||||
|
|
||||||
|
if "[query]" in context:
|
||||||
|
query_placeholder = str(uuid.uuid4())
|
||||||
|
template = template.replace("[QUERY]", query_placeholder)
|
||||||
|
template = template.replace("[context]", context)
|
||||||
|
template = template.replace(query_placeholder, query)
|
||||||
|
else:
|
||||||
|
template = template.replace("[context]", context)
|
||||||
|
template = template.replace("[query]", query)
|
||||||
return template
|
return template
|
||||||
|
|
||||||
|
|
||||||
@ -304,19 +331,25 @@ def get_rag_context(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
context = None
|
||||||
if file["type"] == "text":
|
if file["type"] == "text":
|
||||||
context = file["content"]
|
context = file["content"]
|
||||||
else:
|
else:
|
||||||
if hybrid_search:
|
if hybrid_search:
|
||||||
context = query_collection_with_hybrid_search(
|
try:
|
||||||
collection_names=collection_names,
|
context = query_collection_with_hybrid_search(
|
||||||
query=query,
|
collection_names=collection_names,
|
||||||
embedding_function=embedding_function,
|
query=query,
|
||||||
k=k,
|
embedding_function=embedding_function,
|
||||||
reranking_function=reranking_function,
|
k=k,
|
||||||
r=r,
|
reranking_function=reranking_function,
|
||||||
)
|
r=r,
|
||||||
else:
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.debug("Error when using hybrid search, using"
|
||||||
|
" non hybrid search as fallback.")
|
||||||
|
|
||||||
|
if (not hybrid_search) or (context is None):
|
||||||
context = query_collection(
|
context = query_collection(
|
||||||
collection_names=collection_names,
|
collection_names=collection_names,
|
||||||
query=query,
|
query=query,
|
||||||
@ -325,7 +358,6 @@ def get_rag_context(
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
context = None
|
|
||||||
|
|
||||||
if context:
|
if context:
|
||||||
relevant_contexts.append({**context, "source": file})
|
relevant_contexts.append({**context, "source": file})
|
||||||
|
@ -1030,19 +1030,25 @@ CHUNK_OVERLAP = PersistentConfig(
|
|||||||
int(os.environ.get("CHUNK_OVERLAP", "100")),
|
int(os.environ.get("CHUNK_OVERLAP", "100")),
|
||||||
)
|
)
|
||||||
|
|
||||||
DEFAULT_RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags.
|
DEFAULT_RAG_TEMPLATE = """You are given a user query, some textual context and rules, all inside xml tags. You have to answer the query based on the context while respecting the rules.
|
||||||
|
|
||||||
<context>
|
<context>
|
||||||
[context]
|
[context]
|
||||||
</context>
|
</context>
|
||||||
|
|
||||||
When answer to user:
|
<rules>
|
||||||
- If you don't know, just say that you don't know.
|
- If you don't know, just say so.
|
||||||
- If you don't know when you are not sure, ask for clarification.
|
- If you are not sure, ask for clarification.
|
||||||
Avoid mentioning that you obtained the information from the context.
|
- Answer in the same language as the user query.
|
||||||
And answer according to the language of the user's question.
|
- If the context appears unreadable or of poor quality, tell the user then answer as best as you can.
|
||||||
|
- If the answer is not in the context but you think you know the answer, explain that to the user then answer with your own knowledge.
|
||||||
|
- Answer directly and without using xml tags.
|
||||||
|
</rules>
|
||||||
|
|
||||||
Given the context information, answer the query.
|
<user_query>
|
||||||
Query: [query]"""
|
[query]
|
||||||
|
</user_query>
|
||||||
|
"""
|
||||||
|
|
||||||
RAG_TEMPLATE = PersistentConfig(
|
RAG_TEMPLATE = PersistentConfig(
|
||||||
"RAG_TEMPLATE",
|
"RAG_TEMPLATE",
|
||||||
|
@ -588,6 +588,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
prompt = get_last_user_message(body["messages"])
|
prompt = get_last_user_message(body["messages"])
|
||||||
if prompt is None:
|
if prompt is None:
|
||||||
raise Exception("No user message found")
|
raise Exception("No user message found")
|
||||||
|
if rag_app.state.config.RELEVANCE_THRESHOLD == 0:
|
||||||
|
assert context_string.strip(), (
|
||||||
|
"With a 0 relevancy threshold for RAG, the context cannot "
|
||||||
|
"be empty"
|
||||||
|
)
|
||||||
# Workaround for Ollama 2.0+ system prompt issue
|
# Workaround for Ollama 2.0+ system prompt issue
|
||||||
# TODO: replace with add_or_update_system_message
|
# TODO: replace with add_or_update_system_message
|
||||||
if model["owned_by"] == "ollama":
|
if model["owned_by"] == "ollama":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user