Merge pull request #5378 from thiswillbeyourgithub/fix_RAG_and_web

fix: RAG and Web Search + RAG enhancements
This commit is contained in:
Timothy Jaeryang Baek 2024-09-13 05:38:53 +01:00 committed by GitHub
commit 7dc4cb30b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 72 additions and 29 deletions

View File

@ -1,5 +1,6 @@
import logging import logging
import os import os
import uuid
from typing import Optional, Union from typing import Optional, Union
import requests import requests
@ -91,7 +92,7 @@ def query_doc_with_hybrid_search(
k: int, k: int,
reranking_function, reranking_function,
r: float, r: float,
): ) -> dict:
try: try:
result = VECTOR_DB_CLIENT.get(collection_name=collection_name) result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
@ -134,7 +135,7 @@ def query_doc_with_hybrid_search(
raise e raise e
def merge_and_sort_query_results(query_results, k, reverse=False): def merge_and_sort_query_results(query_results: list[dict], k: int, reverse: bool = False) -> list[dict]:
# Initialize lists to store combined data # Initialize lists to store combined data
combined_distances = [] combined_distances = []
combined_documents = [] combined_documents = []
@ -180,7 +181,7 @@ def query_collection(
query: str, query: str,
embedding_function, embedding_function,
k: int, k: int,
): ) -> dict:
results = [] results = []
for collection_name in collection_names: for collection_name in collection_names:
if collection_name: if collection_name:
@ -192,8 +193,8 @@ def query_collection(
embedding_function=embedding_function, embedding_function=embedding_function,
) )
results.append(result) results.append(result)
except Exception: except Exception as e:
pass log.exception(f"Error when querying the collection: {e}")
else: else:
pass pass
@ -207,8 +208,9 @@ def query_collection_with_hybrid_search(
k: int, k: int,
reranking_function, reranking_function,
r: float, r: float,
): ) -> dict:
results = [] results = []
failed = 0
for collection_name in collection_names: for collection_name in collection_names:
try: try:
result = query_doc_with_hybrid_search( result = query_doc_with_hybrid_search(
@ -220,14 +222,39 @@ def query_collection_with_hybrid_search(
r=r, r=r,
) )
results.append(result) results.append(result)
except Exception: except Exception as e:
pass log.exception(
"Error when querying the collection with "
f"hybrid_search: {e}"
)
failed += 1
if failed == len(collection_names):
raise Exception("Hybrid search failed for all collections. Using "
"Non hybrid search as fallback.")
return merge_and_sort_query_results(results, k=k, reverse=True) return merge_and_sort_query_results(results, k=k, reverse=True)
def rag_template(template: str, context: str, query: str): def rag_template(template: str, context: str, query: str):
template = template.replace("[context]", context) count = template.count("[context]")
template = template.replace("[query]", query) assert count == 1, (
f"RAG template contains an unexpected number of '[context]' : {count}"
)
assert "[context]" in template, "RAG template does not contain '[context]'"
if "<context>" in context and "</context>" in context:
log.debug(
"WARNING: Potential prompt injection attack: the RAG "
"context contains '<context>' and '</context>'. This might be "
"nothing, or the user might be trying to hack something."
)
if "[query]" in context:
query_placeholder = str(uuid.uuid4())
template = template.replace("[QUERY]", query_placeholder)
template = template.replace("[context]", context)
template = template.replace(query_placeholder, query)
else:
template = template.replace("[context]", context)
template = template.replace("[query]", query)
return template return template
@ -304,19 +331,25 @@ def get_rag_context(
continue continue
try: try:
context = None
if file["type"] == "text": if file["type"] == "text":
context = file["content"] context = file["content"]
else: else:
if hybrid_search: if hybrid_search:
context = query_collection_with_hybrid_search( try:
collection_names=collection_names, context = query_collection_with_hybrid_search(
query=query, collection_names=collection_names,
embedding_function=embedding_function, query=query,
k=k, embedding_function=embedding_function,
reranking_function=reranking_function, k=k,
r=r, reranking_function=reranking_function,
) r=r,
else: )
except Exception as e:
log.debug("Error when using hybrid search, using"
" non hybrid search as fallback.")
if (not hybrid_search) or (context is None):
context = query_collection( context = query_collection(
collection_names=collection_names, collection_names=collection_names,
query=query, query=query,
@ -325,7 +358,6 @@ def get_rag_context(
) )
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
context = None
if context: if context:
relevant_contexts.append({**context, "source": file}) relevant_contexts.append({**context, "source": file})

View File

@ -1030,19 +1030,25 @@ CHUNK_OVERLAP = PersistentConfig(
int(os.environ.get("CHUNK_OVERLAP", "100")), int(os.environ.get("CHUNK_OVERLAP", "100")),
) )
DEFAULT_RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags. DEFAULT_RAG_TEMPLATE = """You are given a user query, some textual context and rules, all inside xml tags. You have to answer the query based on the context while respecting the rules.
<context> <context>
[context] [context]
</context> </context>
When answer to user: <rules>
- If you don't know, just say that you don't know. - If you don't know, just say so.
- If you don't know when you are not sure, ask for clarification. - If you are not sure, ask for clarification.
Avoid mentioning that you obtained the information from the context. - Answer in the same language as the user query.
And answer according to the language of the user's question. - If the context appears unreadable or of poor quality, tell the user then answer as best as you can.
- If the answer is not in the context but you think you know the answer, explain that to the user then answer with your own knowledge.
- Answer directly and without using xml tags.
</rules>
Given the context information, answer the query. <user_query>
Query: [query]""" [query]
</user_query>
"""
RAG_TEMPLATE = PersistentConfig( RAG_TEMPLATE = PersistentConfig(
"RAG_TEMPLATE", "RAG_TEMPLATE",

View File

@ -588,6 +588,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
prompt = get_last_user_message(body["messages"]) prompt = get_last_user_message(body["messages"])
if prompt is None: if prompt is None:
raise Exception("No user message found") raise Exception("No user message found")
if rag_app.state.config.RELEVANCE_THRESHOLD == 0:
assert context_string.strip(), (
"With a 0 relevancy threshold for RAG, the context cannot "
"be empty"
)
# Workaround for Ollama 2.0+ system prompt issue # Workaround for Ollama 2.0+ system prompt issue
# TODO: replace with add_or_update_system_message # TODO: replace with add_or_update_system_message
if model["owned_by"] == "ollama": if model["owned_by"] == "ollama":