fix: escape double quotation marks in the vector DB search query (#6506)

This commit is contained in:
Sangmin Ahn 2024-07-23 16:02:25 +09:00 committed by GitHub
parent 5fcc2caeed
commit 093b8ca475
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 4 deletions

View File

@ -110,7 +110,7 @@ class RetrievalService:
) )
documents = keyword.search( documents = keyword.search(
query, cls.escape_query_for_search(query),
top_k=top_k top_k=top_k
) )
all_documents.extend(documents) all_documents.extend(documents)
@ -132,7 +132,7 @@ class RetrievalService:
) )
documents = vector.search_by_vector( documents = vector.search_by_vector(
query, cls.escape_query_for_search(query),
search_type='similarity_score_threshold', search_type='similarity_score_threshold',
top_k=top_k, top_k=top_k,
score_threshold=score_threshold, score_threshold=score_threshold,
@ -170,7 +170,7 @@ class RetrievalService:
) )
documents = vector_processor.search_by_full_text( documents = vector_processor.search_by_full_text(
query, cls.escape_query_for_search(query),
top_k=top_k top_k=top_k
) )
if documents: if documents:
@ -186,3 +186,7 @@ class RetrievalService:
all_documents.extend(documents) all_documents.extend(documents)
except Exception as e: except Exception as e:
exceptions.append(str(e)) exceptions.append(str(e))
@staticmethod
def escape_query_for_search(query: str) -> str:
return query.replace('"', '\\"')

View File

@ -40,7 +40,7 @@ class HitTestingService:
all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
dataset_id=dataset.id, dataset_id=dataset.id,
query=query, query=cls.escape_query_for_search(query),
top_k=retrieval_model['top_k'], top_k=retrieval_model['top_k'],
score_threshold=retrieval_model['score_threshold'] score_threshold=retrieval_model['score_threshold']
if retrieval_model['score_threshold_enabled'] else None, if retrieval_model['score_threshold_enabled'] else None,
@ -104,3 +104,7 @@ class HitTestingService:
if not query or len(query) > 250: if not query or len(query) > 250:
raise ValueError('Query is required and cannot exceed 250 characters') raise ValueError('Query is required and cannot exceed 250 characters')
@staticmethod
def escape_query_for_search(query: str) -> str:
return query.replace('"', '\\"')