From 093b8ca475f99623f0ede726581935bdc08fcd54 Mon Sep 17 00:00:00 2001 From: Sangmin Ahn Date: Tue, 23 Jul 2024 16:02:25 +0900 Subject: [PATCH] fix: escape double quotation marks in the vector DB search query (#6506) --- api/core/rag/datasource/retrieval_service.py | 10 +++++++--- api/services/hit_testing_service.py | 6 +++++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 8814c61433..702dcec314 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -110,7 +110,7 @@ class RetrievalService: ) documents = keyword.search( - query, + cls.escape_query_for_search(query), top_k=top_k ) all_documents.extend(documents) @@ -132,7 +132,7 @@ class RetrievalService: ) documents = vector.search_by_vector( - query, + cls.escape_query_for_search(query), search_type='similarity_score_threshold', top_k=top_k, score_threshold=score_threshold, @@ -170,7 +170,7 @@ class RetrievalService: ) documents = vector_processor.search_by_full_text( - query, + cls.escape_query_for_search(query), top_k=top_k ) if documents: @@ -186,3 +186,7 @@ class RetrievalService: all_documents.extend(documents) except Exception as e: exceptions.append(str(e)) + + @staticmethod + def escape_query_for_search(query: str) -> str: + return query.replace('"', '\\"') \ No newline at end of file diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index b83e1d8cb7..a04ffdfbbe 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -40,7 +40,7 @@ class HitTestingService: all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], dataset_id=dataset.id, - query=query, + query=cls.escape_query_for_search(query), top_k=retrieval_model['top_k'], score_threshold=retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None, @@ -104,3 +104,7 @@ class HitTestingService: if not query or len(query) > 250: raise ValueError('Query is required and cannot exceed 250 characters') + + @staticmethod + def escape_query_for_search(query: str) -> str: + return query.replace('"', '\\"')