import json from typing import Type from flask import current_app from langchain.tools import BaseTool from pydantic import Field, BaseModel from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.conversation_message_task import ConversationMessageTask from core.embedding.cached_embedding import CacheEmbedding from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig from core.index.vector_index.vector_index import VectorIndex from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_providers.model_factory import ModelFactory from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment, Document class DatasetRetrieverToolInput(BaseModel): query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.") class DatasetRetrieverTool(BaseTool): """Tool for querying a Dataset.""" name: str = "dataset" args_schema: Type[BaseModel] = DatasetRetrieverToolInput description: str = "use this to retrieve a dataset. " tenant_id: str dataset_id: str k: int = 3 conversation_message_task: ConversationMessageTask return_resource: str retriever_from: str @classmethod def from_dataset(cls, dataset: Dataset, **kwargs): description = dataset.description if not description: description = 'useful for when you want to answer queries about the ' + dataset.name description = description.replace('\n', '').replace('\r', '') return cls( name=f'dataset-{dataset.id}', tenant_id=dataset.tenant_id, dataset_id=dataset.id, description=description, **kwargs ) def _run(self, query: str) -> str: dataset = db.session.query(Dataset).filter( Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id ).first() if not dataset: return f'[{self.name} failed to find dataset with id {self.dataset_id}.]' if dataset.indexing_technique == "economy": # use keyword table query kw_table_index = KeywordTableIndex( dataset=dataset, config=KeywordTableConfig( max_keywords_per_chunk=5 ) ) documents = kw_table_index.search(query, search_kwargs={'k': self.k}) return str("\n".join([document.page_content for document in documents])) else: try: embedding_model = ModelFactory.get_embedding_model( tenant_id=dataset.tenant_id, model_provider_name=dataset.embedding_model_provider, model_name=dataset.embedding_model ) except LLMBadRequestError: return '' except ProviderTokenNotInitError: return '' embeddings = CacheEmbedding(embedding_model) vector_index = VectorIndex( dataset=dataset, config=current_app.config, embeddings=embeddings ) if self.k > 0: documents = vector_index.search( query, search_type='similarity_score_threshold', search_kwargs={ 'k': self.k, 'filter': { 'group_id': [dataset.id] } } ) else: documents = [] hit_callback = DatasetIndexToolCallbackHandler(dataset.id, self.conversation_message_task) hit_callback.on_tool_end(documents) document_score_list = {} if dataset.indexing_technique != "economy": for item in documents: document_score_list[item.metadata['doc_id']] = item.metadata['score'] document_context_list = [] index_node_ids = [document.metadata['doc_id'] for document in documents] segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id, DocumentSegment.completed_at.isnot(None), DocumentSegment.status == 'completed', DocumentSegment.enabled == True, DocumentSegment.index_node_id.in_(index_node_ids) ).all() if segments: index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} sorted_segments = sorted(segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float('inf'))) for segment in sorted_segments: if segment.answer: document_context_list.append(f'question:{segment.content} answer:{segment.answer}') else: document_context_list.append(segment.content) if self.return_resource: context_list = [] resource_number = 1 for segment in sorted_segments: context = {} document = Document.query.filter(Document.id == segment.document_id, Document.enabled == True, Document.archived == False, ).first() if dataset and document: source = { 'position': resource_number, 'dataset_id': dataset.id, 'dataset_name': dataset.name, 'document_id': document.id, 'document_name': document.name, 'data_source_type': document.data_source_type, 'segment_id': segment.id, 'retriever_from': self.retriever_from } if dataset.indexing_technique != "economy": source['score'] = document_score_list.get(segment.index_node_id) if self.retriever_from == 'dev': source['hit_count'] = segment.hit_count source['word_count'] = segment.word_count source['segment_position'] = segment.position source['index_node_hash'] = segment.index_node_hash if segment.answer: source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' else: source['content'] = segment.content context_list.append(source) resource_number += 1 hit_callback.return_retriever_resource_info(context_list) return str("\n".join(document_context_list)) async def _arun(self, tool_input: str) -> str: raise NotImplementedError()