diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 0f9c753056..dd74406f30 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -33,6 +33,7 @@ class RetrievalService: return [] all_documents = [] threads = [] + exceptions = [] # retrieval_model source with keyword if retrival_method == 'keyword_search': keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={ @@ -40,7 +41,8 @@ class RetrievalService: 'dataset_id': dataset_id, 'query': query, 'top_k': top_k, - 'all_documents': all_documents + 'all_documents': all_documents, + 'exceptions': exceptions, }) threads.append(keyword_thread) keyword_thread.start() @@ -54,7 +56,8 @@ class RetrievalService: 'score_threshold': score_threshold, 'reranking_model': reranking_model, 'all_documents': all_documents, - 'retrival_method': retrival_method + 'retrival_method': retrival_method, + 'exceptions': exceptions, }) threads.append(embedding_thread) embedding_thread.start() @@ -69,7 +72,8 @@ class RetrievalService: 'score_threshold': score_threshold, 'top_k': top_k, 'reranking_model': reranking_model, - 'all_documents': all_documents + 'all_documents': all_documents, + 'exceptions': exceptions, }) threads.append(full_text_index_thread) full_text_index_thread.start() @@ -77,6 +81,10 @@ class RetrievalService: for thread in threads: thread.join() + if exceptions: + exception_message = ';\n'.join(exceptions) + raise Exception(exception_message) + if retrival_method == 'hybrid_search': data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) all_documents = data_post_processor.invoke( @@ -89,82 +97,91 @@ class RetrievalService: @classmethod def keyword_search(cls, flask_app: Flask, dataset_id: str, query: str, - top_k: int, all_documents: list): + top_k: int, all_documents: list, exceptions: list): with flask_app.app_context(): - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + try: + dataset = db.session.query(Dataset).filter( + Dataset.id == dataset_id + ).first() - keyword = Keyword( - dataset=dataset - ) + keyword = Keyword( + dataset=dataset + ) - documents = keyword.search( - query, - top_k=top_k - ) - all_documents.extend(documents) + documents = keyword.search( + query, + top_k=top_k + ) + all_documents.extend(documents) + except Exception as e: + exceptions.append(str(e)) @classmethod def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], - all_documents: list, retrival_method: str): + all_documents: list, retrival_method: str, exceptions: list): with flask_app.app_context(): - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + try: + dataset = db.session.query(Dataset).filter( + Dataset.id == dataset_id + ).first() - vector = Vector( - dataset=dataset - ) + vector = Vector( + dataset=dataset + ) - documents = vector.search_by_vector( - query, - search_type='similarity_score_threshold', - top_k=top_k, - score_threshold=score_threshold, - filter={ - 'group_id': [dataset.id] - } - ) + documents = vector.search_by_vector( + query, + search_type='similarity_score_threshold', + top_k=top_k, + score_threshold=score_threshold, + filter={ + 'group_id': [dataset.id] + } + ) - if documents: - if reranking_model and retrival_method == 'semantic_search': - data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) - all_documents.extend(data_post_processor.invoke( - query=query, - documents=documents, - score_threshold=score_threshold, - top_n=len(documents) - )) - else: - all_documents.extend(documents) + if documents: + if reranking_model and retrival_method == 'semantic_search': + data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) + all_documents.extend(data_post_processor.invoke( + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=len(documents) + )) + else: + all_documents.extend(documents) + except Exception as e: + exceptions.append(str(e)) @classmethod def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], - all_documents: list, retrival_method: str): + all_documents: list, retrival_method: str, exceptions: list): with flask_app.app_context(): - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + try: + dataset = db.session.query(Dataset).filter( + Dataset.id == dataset_id + ).first() - vector_processor = Vector( - dataset=dataset, - ) + vector_processor = Vector( + dataset=dataset, + ) - documents = vector_processor.search_by_full_text( - query, - top_k=top_k - ) - if documents: - if reranking_model and retrival_method == 'full_text_search': - data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) - all_documents.extend(data_post_processor.invoke( - query=query, - documents=documents, - score_threshold=score_threshold, - top_n=len(documents) - )) - else: - all_documents.extend(documents) + documents = vector_processor.search_by_full_text( + query, + top_k=top_k + ) + if documents: + if reranking_model and retrival_method == 'full_text_search': + data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) + all_documents.extend(data_post_processor.invoke( + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=len(documents) + )) + else: + all_documents.extend(documents) + except Exception as e: + exceptions.append(str(e))