diff --git a/api/core/agent/agent/multi_dataset_router_agent.py b/api/core/agent/agent/multi_dataset_router_agent.py index 958183f089..ebb9531b26 100644 --- a/api/core/agent/agent/multi_dataset_router_agent.py +++ b/api/core/agent/agent/multi_dataset_router_agent.py @@ -52,7 +52,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): elif len(self.tools) == 1: tool = next(iter(self.tools)) tool = cast(DatasetRetrieverTool, tool) - rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']}) + rst = tool.run(tool_input={'query': kwargs['input']}) return AgentFinish(return_values={"output": rst}, log=rst) if intermediate_steps: diff --git a/api/core/agent/agent/openai_function_call.py b/api/core/agent/agent/openai_function_call.py index 359078607f..addf0831ce 100644 --- a/api/core/agent/agent/openai_function_call.py +++ b/api/core/agent/agent/openai_function_call.py @@ -45,7 +45,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio :return: """ original_max_tokens = self.llm.max_tokens - self.llm.max_tokens = 15 + self.llm.max_tokens = 40 prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[]) messages = prompt.to_messages() diff --git a/api/core/agent/agent/structed_multi_dataset_router_agent.py b/api/core/agent/agent/structed_multi_dataset_router_agent.py index 67522418d3..8d682b59d5 100644 --- a/api/core/agent/agent/structed_multi_dataset_router_agent.py +++ b/api/core/agent/agent/structed_multi_dataset_router_agent.py @@ -90,7 +90,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent): elif len(self.dataset_tools) == 1: tool = next(iter(self.dataset_tools)) tool = cast(DatasetRetrieverTool, tool) - rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']}) + rst = tool.run(tool_input={'query': kwargs['input']}) return AgentFinish(return_values={"output": rst}, log=rst) full_inputs = self.get_full_inputs(intermediate_steps, **kwargs) diff --git a/api/core/callback_handler/dataset_tool_callback_handler.py b/api/core/callback_handler/dataset_tool_callback_handler.py index 2d863487d4..7d2ba4de1f 100644 --- a/api/core/callback_handler/dataset_tool_callback_handler.py +++ b/api/core/callback_handler/dataset_tool_callback_handler.py @@ -1,5 +1,6 @@ import json import logging +from json import JSONDecodeError from typing import Any, Dict, List, Union, Optional @@ -44,10 +45,15 @@ class DatasetToolCallbackHandler(BaseCallbackHandler): input_str: str, **kwargs: Any, ) -> None: - # tool_name = serialized.get('name') - input_dict = json.loads(input_str.replace("'", "\"")) - dataset_id = input_dict.get('dataset_id') - query = input_dict.get('query') + tool_name: str = serialized.get('name') + dataset_id = tool_name.removeprefix('dataset-') + + try: + input_dict = json.loads(input_str.replace("'", "\"")) + query = input_dict.get('query') + except JSONDecodeError: + query = input_str + self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=query)) def on_tool_end( diff --git a/api/core/tool/dataset_retriever_tool.py b/api/core/tool/dataset_retriever_tool.py index a0fe89fe83..ea63fd6be9 100644 --- a/api/core/tool/dataset_retriever_tool.py +++ b/api/core/tool/dataset_retriever_tool.py @@ -1,4 +1,3 @@ -import re from typing import Type from flask import current_app @@ -16,7 +15,6 @@ from models.dataset import Dataset, DocumentSegment class DatasetRetrieverToolInput(BaseModel): - dataset_id: str = Field(..., description="ID of dataset to be queried. MUST be UUID format.") query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.") @@ -37,27 +35,22 @@ class DatasetRetrieverTool(BaseTool): description = 'useful for when you want to answer queries about the ' + dataset.name description = description.replace('\n', '').replace('\r', '') - description += '\nID of dataset MUST be ' + dataset.id return cls( + name=f'dataset-{dataset.id}', tenant_id=dataset.tenant_id, dataset_id=dataset.id, description=description, **kwargs ) - def _run(self, dataset_id: str, query: str) -> str: - pattern = r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b' - match = re.search(pattern, dataset_id, re.IGNORECASE) - if match: - dataset_id = match.group() - + def _run(self, query: str) -> str: dataset = db.session.query(Dataset).filter( Dataset.tenant_id == self.tenant_id, - Dataset.id == dataset_id + Dataset.id == self.dataset_id ).first() if not dataset: - return f'[{self.name} failed to find dataset with id {dataset_id}.]' + return f'[{self.name} failed to find dataset with id {self.dataset_id}.]' if dataset.indexing_technique == "economy": # use keyword table query