diff --git a/api/core/rag/extractor/csv_extractor.py b/api/core/rag/extractor/csv_extractor.py index 059bee5f6c..09a1cddd1e 100644 --- a/api/core/rag/extractor/csv_extractor.py +++ b/api/core/rag/extractor/csv_extractor.py @@ -34,6 +34,7 @@ class CSVExtractor(BaseExtractor): def extract(self) -> list[Document]: """Load data into document objects.""" + docs = [] try: with open(self._file_path, newline="", encoding=self._encoding) as csvfile: docs = self._read_from_file(csvfile) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 5c2d486656..02bd92c145 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -2,6 +2,7 @@ import threading from typing import Optional, cast from flask import Flask, current_app +from langchain.tools import BaseTool from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity @@ -17,6 +18,8 @@ from core.rag.models.document import Document from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter from core.rerank.rerank import RerankRunner +from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool +from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool from extensions.ext_database import db from models.dataset import Dataset, DatasetQuery, DocumentSegment from models.dataset import Document as DatasetDocument @@ -373,3 +376,92 @@ class DatasetRetrieval: ) all_documents.extend(documents) + + def to_dataset_retriever_tool(self, tenant_id: str, + dataset_ids: list[str], + retrieve_config: DatasetRetrieveConfigEntity, + return_resource: bool, + invoke_from: InvokeFrom, + hit_callback: DatasetIndexToolCallbackHandler) \ + -> Optional[list[BaseTool]]: + """ + A dataset tool is a tool that can be used to retrieve information from a dataset + :param tenant_id: tenant id + :param dataset_ids: dataset ids + :param retrieve_config: retrieve config + :param return_resource: return resource + :param invoke_from: invoke from + :param hit_callback: hit callback + """ + tools = [] + available_datasets = [] + for dataset_id in dataset_ids: + # get dataset from dataset id + dataset = db.session.query(Dataset).filter( + Dataset.tenant_id == tenant_id, + Dataset.id == dataset_id + ).first() + + # pass if dataset is not available + if not dataset: + continue + + # pass if dataset is not available + if (dataset and dataset.available_document_count == 0 + and dataset.available_document_count == 0): + continue + + available_datasets.append(dataset) + + if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: + # get retrieval model config + default_retrieval_model = { + 'search_method': 'semantic_search', + 'reranking_enable': False, + 'reranking_model': { + 'reranking_provider_name': '', + 'reranking_model_name': '' + }, + 'top_k': 2, + 'score_threshold_enabled': False + } + + for dataset in available_datasets: + retrieval_model_config = dataset.retrieval_model \ + if dataset.retrieval_model else default_retrieval_model + + # get top k + top_k = retrieval_model_config['top_k'] + + # get score threshold + score_threshold = None + score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") + if score_threshold_enabled: + score_threshold = retrieval_model_config.get("score_threshold") + + tool = DatasetRetrieverTool.from_dataset( + dataset=dataset, + top_k=top_k, + score_threshold=score_threshold, + hit_callbacks=[hit_callback], + return_resource=return_resource, + retriever_from=invoke_from.to_source() + ) + + tools.append(tool) + elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: + tool = DatasetMultiRetrieverTool.from_dataset( + dataset_ids=[dataset.id for dataset in available_datasets], + tenant_id=tenant_id, + top_k=retrieve_config.top_k or 2, + score_threshold=retrieve_config.score_threshold, + hit_callbacks=[hit_callback], + return_resource=return_resource, + retriever_from=invoke_from.to_source(), + reranking_provider_name=retrieve_config.reranking_model.get('reranking_provider_name'), + reranking_model_name=retrieve_config.reranking_model.get('reranking_model_name') + ) + + tools.append(tool) + + return tools