diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py index b382016473..8d6e821f4c 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py @@ -3,11 +3,13 @@ from typing import Any from pydantic import BaseModel, Field from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.entities.context_entities import DocumentContext from core.rag.models.document import Document as RetrievalDocument from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db -from models.dataset import Dataset, Document, DocumentSegment +from models.dataset import Dataset +from models.dataset import Document as DatasetDocument from services.external_knowledge_service import ExternalDatasetService default_retrieval_model = { @@ -54,7 +56,6 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): if not dataset: return "" - for hit_callback in self.hit_callbacks: hit_callback.on_query(query, dataset.id) if dataset.provider == "external": @@ -125,7 +126,6 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): ) else: documents = [] - for hit_callback in self.hit_callbacks: hit_callback.on_tool_end(documents) document_score_list = {} @@ -134,50 +134,46 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): if item.metadata is not None and item.metadata.get("score"): document_score_list[item.metadata["doc_id"]] = item.metadata["score"] document_context_list = [] - index_node_ids = [document.metadata["doc_id"] for document in documents] - segments = DocumentSegment.query.filter( - DocumentSegment.dataset_id == self.dataset_id, - DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == "completed", - DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids), - ).all() - - if segments: - index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted( - segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) - ) - for segment in sorted_segments: + records = RetrievalService.format_retrieval_documents(documents) + if records: + for record in records: + segment = record.segment if segment.answer: document_context_list.append( - f"question:{segment.get_sign_content()} answer:{segment.answer}" + DocumentContext( + content=f"question:{segment.get_sign_content()} answer:{segment.answer}", + score=record.score, + ) ) else: - document_context_list.append(segment.get_sign_content()) + document_context_list.append( + DocumentContext( + content=segment.get_sign_content(), + score=record.score, + ) + ) + retrieval_resource_list = [] if self.return_resource: - context_list = [] - resource_number = 1 - for segment in sorted_segments: - document_segment = Document.query.filter( - Document.id == segment.document_id, - Document.enabled == True, - Document.archived == False, + for record in records: + segment = record.segment + dataset = Dataset.query.filter_by(id=segment.dataset_id).first() + document = DatasetDocument.query.filter( + DatasetDocument.id == segment.document_id, + DatasetDocument.enabled == True, + DatasetDocument.archived == False, ).first() - if not document_segment: - continue - if dataset and document_segment: + if dataset and document: source = { - "position": resource_number, "dataset_id": dataset.id, "dataset_name": dataset.name, - "document_id": document_segment.id, - "document_name": document_segment.name, - "data_source_type": document_segment.data_source_type, + "document_id": document.id, # type: ignore + "document_name": document.name, # type: ignore + "data_source_type": document.data_source_type, # type: ignore "segment_id": segment.id, "retriever_from": self.retriever_from, - "score": document_score_list.get(segment.index_node_id, None), + "score": record.score or 0.0, } + if self.retriever_from == "dev": source["hit_count"] = segment.hit_count source["word_count"] = segment.word_count @@ -187,10 +183,19 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" else: source["content"] = segment.content - context_list.append(source) - resource_number += 1 + retrieval_resource_list.append(source) - for hit_callback in self.hit_callbacks: - hit_callback.return_retriever_resource_info(context_list) - - return str("\n".join(document_context_list)) + if self.return_resource and retrieval_resource_list: + retrieval_resource_list = sorted( + retrieval_resource_list, + key=lambda x: x.get("score") or 0.0, + reverse=True, + ) + for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore + item["position"] = position # type: ignore + for hit_callback in self.hit_callbacks: + hit_callback.return_retriever_resource_info(retrieval_resource_list) + if document_context_list: + document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True) + return str("\n".join([document_context.content for document_context in document_context_list])) + return ""