From edffa5666d66d9b9c43734ef8d8dfa401b7cdc01 Mon Sep 17 00:00:00 2001 From: sino Date: Mon, 17 Jun 2024 21:20:17 +0800 Subject: [PATCH] fix: got unknown type of prompt message in multi-round ReAct agent chat (#5245) --- api/core/agent/cot_agent_runner.py | 81 +++++++++++++++++------------- 1 file changed, 45 insertions(+), 36 deletions(-) diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 20d924ddc7..54aa0c9906 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -32,9 +32,9 @@ class CotAgentRunner(BaseAgentRunner, ABC): _prompt_messages_tools: list[PromptMessage] = None def run(self, message: Message, - query: str, - inputs: dict[str, str], - ) -> Union[Generator, LLMResult]: + query: str, + inputs: dict[str, str], + ) -> Union[Generator, LLMResult]: """ Run Cot agent application """ @@ -52,7 +52,8 @@ class CotAgentRunner(BaseAgentRunner, ABC): # init instruction inputs = inputs or {} instruction = app_config.prompt_template.simple_prompt_template - self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) + self._instruction = self._fill_in_inputs_from_external_data_tools( + instruction, inputs) iteration_step = 1 max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 @@ -61,7 +62,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): tool_instances, self._prompt_messages_tools = self._init_prompt_tools() prompt_messages = self._organize_prompt_messages() - + function_call_state = True llm_usage = { 'usage': None @@ -120,9 +121,10 @@ class CotAgentRunner(BaseAgentRunner, ABC): # check llm result if not chunks: raise ValueError("failed to invoke llm") - + usage_dict = {} - react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict) + react_chunks = CotAgentOutputParser.handle_react_stream_output( + chunks, usage_dict) scratchpad = AgentScratchpadUnit( agent_response='', thought='', @@ -160,15 +162,16 @@ class CotAgentRunner(BaseAgentRunner, ABC): ) ) - scratchpad.thought = scratchpad.thought.strip() or 'I am thinking about how to help you' + scratchpad.thought = scratchpad.thought.strip( + ) or 'I am thinking about how to help you' self._agent_scratchpad.append(scratchpad) - + # get llm usage if 'usage' in usage_dict: increase_usage(llm_usage, usage_dict['usage']) else: usage_dict['usage'] = LLMUsage.empty_usage() - + self.save_agent_thought( agent_thought=agent_thought, tool_name=scratchpad.action.action_name if scratchpad.action else '', @@ -182,7 +185,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): messages_ids=[], llm_usage=usage_dict['usage'] ) - + if not scratchpad.is_final(): self.queue_manager.publish(QueueAgentThoughtEvent( agent_thought_id=agent_thought.id @@ -196,7 +199,8 @@ class CotAgentRunner(BaseAgentRunner, ABC): # action is final answer, return final answer directly try: if isinstance(scratchpad.action.action_input, dict): - final_answer = json.dumps(scratchpad.action.action_input) + final_answer = json.dumps( + scratchpad.action.action_input) elif isinstance(scratchpad.action.action_input, str): final_answer = scratchpad.action.action_input else: @@ -207,7 +211,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): function_call_state = True # action is tool call, invoke tool tool_invoke_response, tool_invoke_meta = self._handle_invoke_action( - action=scratchpad.action, + action=scratchpad.action, tool_instances=tool_instances, message_file_ids=message_file_ids ) @@ -217,10 +221,13 @@ class CotAgentRunner(BaseAgentRunner, ABC): self.save_agent_thought( agent_thought=agent_thought, tool_name=scratchpad.action.action_name, - tool_input={scratchpad.action.action_name: scratchpad.action.action_input}, + tool_input={ + scratchpad.action.action_name: scratchpad.action.action_input}, thought=scratchpad.thought, - observation={scratchpad.action.action_name: tool_invoke_response}, - tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()}, + observation={ + scratchpad.action.action_name: tool_invoke_response}, + tool_invoke_meta={ + scratchpad.action.action_name: tool_invoke_meta.to_dict()}, answer=scratchpad.agent_response, messages_ids=message_file_ids, llm_usage=usage_dict['usage'] @@ -232,7 +239,8 @@ class CotAgentRunner(BaseAgentRunner, ABC): # update prompt tool message for prompt_tool in self._prompt_messages_tools: - self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool) + self.update_prompt_message_tool( + tool_instances[prompt_tool.name], prompt_tool) iteration_step += 1 @@ -251,12 +259,12 @@ class CotAgentRunner(BaseAgentRunner, ABC): # save agent thought self.save_agent_thought( - agent_thought=agent_thought, + agent_thought=agent_thought, tool_name='', tool_input={}, tool_invoke_meta={}, thought=final_answer, - observation={}, + observation={}, answer=final_answer, messages_ids=[] ) @@ -269,11 +277,12 @@ class CotAgentRunner(BaseAgentRunner, ABC): message=AssistantPromptMessage( content=final_answer ), - usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(), + usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage( + ), system_fingerprint='' )), PublishFrom.APPLICATION_MANAGER) - def _handle_invoke_action(self, action: AgentScratchpadUnit.Action, + def _handle_invoke_action(self, action: AgentScratchpadUnit.Action, tool_instances: dict[str, Tool], message_file_ids: list[str]) -> tuple[str, ToolInvokeMeta]: """ @@ -290,7 +299,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): if not tool_instance: answer = f"there is not a tool named {tool_call_name}" return answer, ToolInvokeMeta.error_instance(answer) - + if isinstance(tool_call_args, str): try: tool_call_args = json.loads(tool_call_args) @@ -311,7 +320,8 @@ class CotAgentRunner(BaseAgentRunner, ABC): # publish files for message_file, save_as in message_files: if save_as: - self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as) + self.variables_pool.set_file( + tool_name=tool_call_name, value=message_file.id, name=save_as) # publish message file self.queue_manager.publish(QueueMessageFileEvent( @@ -342,7 +352,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): continue return instruction - + def _init_react_state(self, query) -> None: """ init agent scratchpad @@ -350,7 +360,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): self._query = query self._agent_scratchpad = [] self._historic_prompt_messages = self._organize_historic_prompt_messages() - + @abstractmethod def _organize_prompt_messages(self) -> list[PromptMessage]: """ @@ -382,13 +392,6 @@ class CotAgentRunner(BaseAgentRunner, ABC): scratchpads: list[AgentScratchpadUnit] = [] current_scratchpad: AgentScratchpadUnit = None - self.history_prompt_messages = AgentHistoryPromptTransform( - model_config=self.model_config, - prompt_messages=current_session_messages or [], - history_messages=self.history_prompt_messages, - memory=self.memory - ).get_prompt() - for message in self.history_prompt_messages: if isinstance(message, AssistantPromptMessage): if not current_scratchpad: @@ -404,7 +407,8 @@ class CotAgentRunner(BaseAgentRunner, ABC): try: current_scratchpad.action = AgentScratchpadUnit.Action( action_name=message.tool_calls[0].function.name, - action_input=json.loads(message.tool_calls[0].function.arguments) + action_input=json.loads( + message.tool_calls[0].function.arguments) ) current_scratchpad.action_str = json.dumps( current_scratchpad.action.to_dict() @@ -424,10 +428,15 @@ class CotAgentRunner(BaseAgentRunner, ABC): result.append(message) - if scratchpads: result.append(AssistantPromptMessage( content=self._format_assistant_message(scratchpads) )) - - return result \ No newline at end of file + + historic_prompts = AgentHistoryPromptTransform( + model_config=self.model_config, + prompt_messages=current_session_messages or [], + history_messages=result, + memory=self.memory + ).get_prompt() + return historic_prompts