From cea107b165b33d926274996de2cb8ebb21aa7547 Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Thu, 11 Apr 2024 18:34:17 +0800 Subject: [PATCH] Refactor/react agent (#3355) --- api/core/agent/base_agent_runner.py | 30 +- api/core/agent/cot_agent_runner.py | 619 +++++------------- api/core/agent/cot_chat_agent_runner.py | 71 ++ api/core/agent/cot_completion_agent_runner.py | 69 ++ api/core/agent/entities.py | 17 + api/core/agent/fc_agent_runner.py | 26 +- .../agent/output_parser/cot_output_parser.py | 183 ++++++ api/core/app/apps/agent_chat/app_runner.py | 81 ++- api/core/tools/prompt/template.py | 4 +- 9 files changed, 589 insertions(+), 511 deletions(-) create mode 100644 api/core/agent/cot_chat_agent_runner.py create mode 100644 api/core/agent/cot_completion_agent_runner.py create mode 100644 api/core/agent/output_parser/cot_output_parser.py diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index e9d208330b..dabc13374a 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -238,6 +238,34 @@ class BaseAgentRunner(AppRunner): return prompt_tool + def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]: + """ + Init tools + """ + tool_instances = {} + prompt_messages_tools = [] + + for tool in self.app_config.agent.tools if self.app_config.agent else []: + try: + prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool) + except Exception: + # api tool may be deleted + continue + # save tool entity + tool_instances[tool.tool_name] = tool_entity + # save prompt tool + prompt_messages_tools.append(prompt_tool) + + # convert dataset tools into ModelRuntime Tool format + for dataset_tool in self.dataset_tools: + prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool) + # save prompt tool + prompt_messages_tools.append(prompt_tool) + # save tool entity + tool_instances[dataset_tool.identity.name] = dataset_tool + + return tool_instances, prompt_messages_tools + def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool: """ update prompt message tool @@ -325,7 +353,7 @@ class BaseAgentRunner(AppRunner): tool_name: str, tool_input: Union[str, dict], thought: str, - observation: Union[str, str], + observation: Union[str, dict], tool_invoke_meta: Union[str, dict], answer: str, messages_ids: list[str], diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 3b39bb1951..ed55d1b022 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -1,30 +1,34 @@ import json -import re +from abc import ABC, abstractmethod from collections.abc import Generator -from typing import Literal, Union +from typing import Union from core.agent.base_agent_runner import BaseAgentRunner -from core.agent.entities import AgentPromptEntity, AgentScratchpadUnit +from core.agent.entities import AgentScratchpadUnit +from core.agent.output_parser.cot_output_parser import CotAgentOutputParser from core.app.apps.base_app_queue_manager import PublishFrom from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, - PromptMessageTool, - SystemPromptMessage, ToolPromptMessage, UserPromptMessage, ) -from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.entities.tool_entities import ToolInvokeMeta +from core.tools.tool.tool import Tool from core.tools.tool_engine import ToolEngine from models.model import Message -class CotAgentRunner(BaseAgentRunner): +class CotAgentRunner(BaseAgentRunner, ABC): _is_first_iteration = True _ignore_observation_providers = ['wenxin'] + _historic_prompt_messages: list[PromptMessage] = None + _agent_scratchpad: list[AgentScratchpadUnit] = None + _instruction: str = None + _query: str = None + _prompt_messages_tools: list[PromptMessage] = None def run(self, message: Message, query: str, @@ -35,9 +39,7 @@ class CotAgentRunner(BaseAgentRunner): """ app_generate_entity = self.application_generate_entity self._repack_app_generate_entity(app_generate_entity) - - agent_scratchpad: list[AgentScratchpadUnit] = [] - self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages) + self._init_react_state(query) # check model mode if 'Observation' not in app_generate_entity.model_config.stop: @@ -46,38 +48,19 @@ class CotAgentRunner(BaseAgentRunner): app_config = self.app_config - # override inputs + # init instruction inputs = inputs or {} instruction = app_config.prompt_template.simple_prompt_template - instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) + self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) iteration_step = 1 max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 - prompt_messages = self.history_prompt_messages - # convert tools into ModelRuntime Tool format - prompt_messages_tools: list[PromptMessageTool] = [] - tool_instances = {} - for tool in app_config.agent.tools if app_config.agent else []: - try: - prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool) - except Exception: - # api tool may be deleted - continue - # save tool entity - tool_instances[tool.tool_name] = tool_entity - # save prompt tool - prompt_messages_tools.append(prompt_tool) - - # convert dataset tools into ModelRuntime Tool format - for dataset_tool in self.dataset_tools: - prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool) - # save prompt tool - prompt_messages_tools.append(prompt_tool) - # save tool entity - tool_instances[dataset_tool.identity.name] = dataset_tool + tool_instances, self._prompt_messages_tools = self._init_prompt_tools() + prompt_messages = self._organize_prompt_messages() + function_call_state = True llm_usage = { 'usage': None @@ -102,7 +85,7 @@ class CotAgentRunner(BaseAgentRunner): if iteration_step == max_iteration_steps: # the last iteration, remove all tools - prompt_messages_tools = [] + self._prompt_messages_tools = [] message_file_ids = [] @@ -119,18 +102,8 @@ class CotAgentRunner(BaseAgentRunner): agent_thought_id=agent_thought.id ), PublishFrom.APPLICATION_MANAGER) - # update prompt messages - prompt_messages = self._organize_cot_prompt_messages( - mode=app_generate_entity.model_config.mode, - prompt_messages=prompt_messages, - tools=prompt_messages_tools, - agent_scratchpad=agent_scratchpad, - agent_prompt_message=app_config.agent.prompt, - instruction=instruction, - input=query - ) - # recalc llm max tokens + prompt_messages = self._organize_prompt_messages() self.recalc_llm_max_tokens(self.model_config, prompt_messages) # invoke model chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm( @@ -148,7 +121,7 @@ class CotAgentRunner(BaseAgentRunner): raise ValueError("failed to invoke llm") usage_dict = {} - react_chunks = self._handle_stream_react(chunks, usage_dict) + react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks) scratchpad = AgentScratchpadUnit( agent_response='', thought='', @@ -164,30 +137,12 @@ class CotAgentRunner(BaseAgentRunner): ), PublishFrom.APPLICATION_MANAGER) for chunk in react_chunks: - if isinstance(chunk, dict): - scratchpad.agent_response += json.dumps(chunk) - try: - if scratchpad.action: - raise Exception("") - scratchpad.action_str = json.dumps(chunk) - scratchpad.action = AgentScratchpadUnit.Action( - action_name=chunk['action'], - action_input=chunk['action_input'] - ) - except: - scratchpad.thought += json.dumps(chunk) - yield LLMResultChunk( - model=self.model_config.model, - prompt_messages=prompt_messages, - system_fingerprint='', - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage( - content=json.dumps(chunk, ensure_ascii=False) # if ensure_ascii=True, the text in webui maybe garbled text - ), - usage=None - ) - ) + if isinstance(chunk, AgentScratchpadUnit.Action): + action = chunk + # detect action + scratchpad.agent_response += json.dumps(chunk.dict()) + scratchpad.action_str = json.dumps(chunk.dict()) + scratchpad.action = action else: scratchpad.agent_response += chunk scratchpad.thought += chunk @@ -205,27 +160,29 @@ class CotAgentRunner(BaseAgentRunner): ) scratchpad.thought = scratchpad.thought.strip() or 'I am thinking about how to help you' - agent_scratchpad.append(scratchpad) - + self._agent_scratchpad.append(scratchpad) + # get llm usage if 'usage' in usage_dict: increase_usage(llm_usage, usage_dict['usage']) else: usage_dict['usage'] = LLMUsage.empty_usage() - self.save_agent_thought(agent_thought=agent_thought, - tool_name=scratchpad.action.action_name if scratchpad.action else '', - tool_input={ - scratchpad.action.action_name: scratchpad.action.action_input - } if scratchpad.action else '', - tool_invoke_meta={}, - thought=scratchpad.thought, - observation='', - answer=scratchpad.agent_response, - messages_ids=[], - llm_usage=usage_dict['usage']) + self.save_agent_thought( + agent_thought=agent_thought, + tool_name=scratchpad.action.action_name if scratchpad.action else '', + tool_input={ + scratchpad.action.action_name: scratchpad.action.action_input + } if scratchpad.action else {}, + tool_invoke_meta={}, + thought=scratchpad.thought, + observation='', + answer=scratchpad.agent_response, + messages_ids=[], + llm_usage=usage_dict['usage'] + ) - if scratchpad.action and scratchpad.action.action_name.lower() != "final answer": + if not scratchpad.is_final(): self.queue_manager.publish(QueueAgentThoughtEvent( agent_thought_id=agent_thought.id ), PublishFrom.APPLICATION_MANAGER) @@ -237,106 +194,43 @@ class CotAgentRunner(BaseAgentRunner): if scratchpad.action.action_name.lower() == "final answer": # action is final answer, return final answer directly try: - final_answer = scratchpad.action.action_input if \ - isinstance(scratchpad.action.action_input, str) else \ - json.dumps(scratchpad.action.action_input) + if isinstance(scratchpad.action.action_input, dict): + final_answer = json.dumps(scratchpad.action.action_input) + elif isinstance(scratchpad.action.action_input, str): + final_answer = scratchpad.action.action_input + else: + final_answer = f'{scratchpad.action.action_input}' except json.JSONDecodeError: final_answer = f'{scratchpad.action.action_input}' else: function_call_state = True - # action is tool call, invoke tool - tool_call_name = scratchpad.action.action_name - tool_call_args = scratchpad.action.action_input - tool_instance = tool_instances.get(tool_call_name) - if not tool_instance: - answer = f"there is not a tool named {tool_call_name}" - self.save_agent_thought( - agent_thought=agent_thought, - tool_name='', - tool_input='', - tool_invoke_meta=ToolInvokeMeta.error_instance( - f"there is not a tool named {tool_call_name}" - ).to_dict(), - thought=None, - observation={ - tool_call_name: answer - }, - answer=answer, - messages_ids=[] - ) - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) - else: - if isinstance(tool_call_args, str): - try: - tool_call_args = json.loads(tool_call_args) - except json.JSONDecodeError: - pass + tool_invoke_response, tool_invoke_meta = self._handle_invoke_action( + action=scratchpad.action, + tool_instances=tool_instances, + message_file_ids=message_file_ids + ) + scratchpad.observation = tool_invoke_response + scratchpad.agent_response = tool_invoke_response - # invoke tool - tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke( - tool=tool_instance, - tool_parameters=tool_call_args, - user_id=self.user_id, - tenant_id=self.tenant_id, - message=self.message, - invoke_from=self.application_generate_entity.invoke_from, - agent_tool_callback=self.agent_callback - ) - # publish files - for message_file, save_as in message_files: - if save_as: - self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as) + self.save_agent_thought( + agent_thought=agent_thought, + tool_name=scratchpad.action.action_name, + tool_input={scratchpad.action.action_name: scratchpad.action.action_input}, + thought=scratchpad.thought, + observation={scratchpad.action.action_name: tool_invoke_response}, + tool_invoke_meta=tool_invoke_meta.to_dict(), + answer=scratchpad.agent_response, + messages_ids=message_file_ids, + llm_usage=usage_dict['usage'] + ) - # publish message file - self.queue_manager.publish(QueueMessageFileEvent( - message_file_id=message_file.id - ), PublishFrom.APPLICATION_MANAGER) - # add message file ids - message_file_ids.append(message_file.id) - - # publish files - for message_file, save_as in message_files: - if save_as: - self.variables_pool.set_file(tool_name=tool_call_name, - value=message_file.id, - name=save_as) - self.queue_manager.publish(QueueMessageFileEvent( - message_file_id=message_file.id - ), PublishFrom.APPLICATION_MANAGER) - - message_file_ids = [message_file.id for message_file, _ in message_files] - - observation = tool_invoke_response - - # save scratchpad - scratchpad.observation = observation - - # save agent thought - self.save_agent_thought( - agent_thought=agent_thought, - tool_name=tool_call_name, - tool_input={ - tool_call_name: tool_call_args - }, - tool_invoke_meta={ - tool_call_name: tool_invoke_meta.to_dict() - }, - thought=None, - observation={ - tool_call_name: observation - }, - answer=scratchpad.agent_response, - messages_ids=message_file_ids, - ) - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueAgentThoughtEvent( + agent_thought_id=agent_thought.id + ), PublishFrom.APPLICATION_MANAGER) # update prompt tool message - for prompt_tool in prompt_messages_tools: + for prompt_tool in self._prompt_messages_tools: self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool) iteration_step += 1 @@ -378,96 +272,63 @@ class CotAgentRunner(BaseAgentRunner): system_fingerprint='' )), PublishFrom.APPLICATION_MANAGER) - def _handle_stream_react(self, llm_response: Generator[LLMResultChunk, None, None], usage: dict) \ - -> Generator[Union[str, dict], None, None]: - def parse_json(json_str): + def _handle_invoke_action(self, action: AgentScratchpadUnit.Action, + tool_instances: dict[str, Tool], + message_file_ids: list[str]) -> tuple[str, ToolInvokeMeta]: + """ + handle invoke action + :param action: action + :param tool_instances: tool instances + :return: observation, meta + """ + # action is tool call, invoke tool + tool_call_name = action.action_name + tool_call_args = action.action_input + tool_instance = tool_instances.get(tool_call_name) + + if not tool_instance: + answer = f"there is not a tool named {tool_call_name}" + return answer, ToolInvokeMeta.error_instance(answer) + + if isinstance(tool_call_args, str): try: - return json.loads(json_str.strip()) - except: - return json_str - - def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]: - code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL) - if not code_blocks: - return - for block in code_blocks: - json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE) - yield parse_json(json_text) - - code_block_cache = '' - code_block_delimiter_count = 0 - in_code_block = False - json_cache = '' - json_quote_count = 0 - in_json = False - got_json = False - - for response in llm_response: - response = response.delta.message.content - if not isinstance(response, str): - continue + tool_call_args = json.loads(tool_call_args) + except json.JSONDecodeError: + pass - # stream - index = 0 - while index < len(response): - steps = 1 - delta = response[index:index+steps] - if delta == '`': - code_block_cache += delta - code_block_delimiter_count += 1 - else: - if not in_code_block: - if code_block_delimiter_count > 0: - yield code_block_cache - code_block_cache = '' - else: - code_block_cache += delta - code_block_delimiter_count = 0 + # invoke tool + tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke( + tool=tool_instance, + tool_parameters=tool_call_args, + user_id=self.user_id, + tenant_id=self.tenant_id, + message=self.message, + invoke_from=self.application_generate_entity.invoke_from, + agent_tool_callback=self.agent_callback + ) - if code_block_delimiter_count == 3: - if in_code_block: - yield from extra_json_from_code_block(code_block_cache) - code_block_cache = '' - - in_code_block = not in_code_block - code_block_delimiter_count = 0 + # publish files + for message_file, save_as in message_files: + if save_as: + self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as) - if not in_code_block: - # handle single json - if delta == '{': - json_quote_count += 1 - in_json = True - json_cache += delta - elif delta == '}': - json_cache += delta - if json_quote_count > 0: - json_quote_count -= 1 - if json_quote_count == 0: - in_json = False - got_json = True - index += steps - continue - else: - if in_json: - json_cache += delta + # publish message file + self.queue_manager.publish(QueueMessageFileEvent( + message_file_id=message_file.id + ), PublishFrom.APPLICATION_MANAGER) + # add message file ids + message_file_ids.append(message_file.id) - if got_json: - got_json = False - yield parse_json(json_cache) - json_cache = '' - json_quote_count = 0 - in_json = False - - if not in_code_block and not in_json: - yield delta.replace('`', '') + return tool_invoke_response, tool_invoke_meta - index += steps - - if code_block_cache: - yield code_block_cache - - if json_cache: - yield parse_json(json_cache) + def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action: + """ + convert dict to action + """ + return AgentScratchpadUnit.Action( + action_name=action['action'], + action_input=action['action_input'] + ) def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str: """ @@ -481,15 +342,46 @@ class CotAgentRunner(BaseAgentRunner): return instruction - def _init_agent_scratchpad(self, - agent_scratchpad: list[AgentScratchpadUnit], - messages: list[PromptMessage] - ) -> list[AgentScratchpadUnit]: + def _init_react_state(self, query) -> None: """ init agent scratchpad """ + self._query = query + self._agent_scratchpad = [] + self._historic_prompt_messages = self._organize_historic_prompt_messages() + + @abstractmethod + def _organize_prompt_messages(self) -> list[PromptMessage]: + """ + organize prompt messages + """ + + def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str: + """ + format assistant message + """ + message = '' + for scratchpad in agent_scratchpad: + if scratchpad.is_final(): + message += f"Final Answer: {scratchpad.agent_response}" + else: + message += f"Thought: {scratchpad.thought}\n\n" + if scratchpad.action_str: + message += f"Action: {scratchpad.action_str}\n\n" + if scratchpad.observation: + message += f"Observation: {scratchpad.observation}\n\n" + + return message + + def _organize_historic_prompt_messages(self) -> list[PromptMessage]: + """ + organize historic prompt messages + """ + result: list[PromptMessage] = [] + scratchpad: list[AgentScratchpadUnit] = [] current_scratchpad: AgentScratchpadUnit = None - for message in messages: + + for message in self.history_prompt_messages: if isinstance(message, AssistantPromptMessage): current_scratchpad = AgentScratchpadUnit( agent_response=message.content, @@ -504,186 +396,29 @@ class CotAgentRunner(BaseAgentRunner): action_name=message.tool_calls[0].function.name, action_input=json.loads(message.tool_calls[0].function.arguments) ) + current_scratchpad.action_str = json.dumps( + current_scratchpad.action.to_dict() + ) except: pass - - agent_scratchpad.append(current_scratchpad) + + scratchpad.append(current_scratchpad) elif isinstance(message, ToolPromptMessage): if current_scratchpad: current_scratchpad.observation = message.content + elif isinstance(message, UserPromptMessage): + result.append(message) + + if scratchpad: + result.append(AssistantPromptMessage( + content=self._format_assistant_message(scratchpad) + )) + + scratchpad = [] + + if scratchpad: + result.append(AssistantPromptMessage( + content=self._format_assistant_message(scratchpad) + )) - return agent_scratchpad - - def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"], - agent_prompt_message: AgentPromptEntity, - ): - """ - check chain of thought prompt messages, a standard prompt message is like: - Respond to the human as helpfully and accurately as possible. - - {{instruction}} - - You have access to the following tools: - - {{tools}} - - Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). - Valid action values: "Final Answer" or {{tool_names}} - - Provide only ONE action per $JSON_BLOB, as shown: - - ``` - { - "action": $TOOL_NAME, - "action_input": $ACTION_INPUT - } - ``` - """ - - # parse agent prompt message - first_prompt = agent_prompt_message.first_prompt - next_iteration = agent_prompt_message.next_iteration - - if not isinstance(first_prompt, str) or not isinstance(next_iteration, str): - raise ValueError("first_prompt or next_iteration is required in CoT agent mode") - - # check instruction, tools, and tool_names slots - if not first_prompt.find("{{instruction}}") >= 0: - raise ValueError("{{instruction}} is required in first_prompt") - if not first_prompt.find("{{tools}}") >= 0: - raise ValueError("{{tools}} is required in first_prompt") - if not first_prompt.find("{{tool_names}}") >= 0: - raise ValueError("{{tool_names}} is required in first_prompt") - - if mode == "completion": - if not first_prompt.find("{{query}}") >= 0: - raise ValueError("{{query}} is required in first_prompt") - if not first_prompt.find("{{agent_scratchpad}}") >= 0: - raise ValueError("{{agent_scratchpad}} is required in first_prompt") - - if mode == "completion": - if not next_iteration.find("{{observation}}") >= 0: - raise ValueError("{{observation}} is required in next_iteration") - - def _convert_scratchpad_list_to_str(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str: - """ - convert agent scratchpad list to str - """ - next_iteration = self.app_config.agent.prompt.next_iteration - - result = '' - for scratchpad in agent_scratchpad: - result += (scratchpad.thought or '') + (scratchpad.action_str or '') + \ - next_iteration.replace("{{observation}}", scratchpad.observation or 'It seems that no response is available') - - return result - - def _organize_cot_prompt_messages(self, mode: Literal["completion", "chat"], - prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - agent_scratchpad: list[AgentScratchpadUnit], - agent_prompt_message: AgentPromptEntity, - instruction: str, - input: str, - ) -> list[PromptMessage]: - """ - organize chain of thought prompt messages, a standard prompt message is like: - Respond to the human as helpfully and accurately as possible. - - {{instruction}} - - You have access to the following tools: - - {{tools}} - - Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). - Valid action values: "Final Answer" or {{tool_names}} - - Provide only ONE action per $JSON_BLOB, as shown: - - ``` - {{{{ - "action": $TOOL_NAME, - "action_input": $ACTION_INPUT - }}}} - ``` - """ - - self._check_cot_prompt_messages(mode, agent_prompt_message) - - # parse agent prompt message - first_prompt = agent_prompt_message.first_prompt - - # parse tools - tools_str = self._jsonify_tool_prompt_messages(tools) - - # parse tools name - tool_names = '"' + '","'.join([tool.name for tool in tools]) + '"' - - # get system message - system_message = first_prompt.replace("{{instruction}}", instruction) \ - .replace("{{tools}}", tools_str) \ - .replace("{{tool_names}}", tool_names) - - # organize prompt messages - if mode == "chat": - # override system message - overridden = False - prompt_messages = prompt_messages.copy() - for prompt_message in prompt_messages: - if isinstance(prompt_message, SystemPromptMessage): - prompt_message.content = system_message - overridden = True - break - - # convert tool prompt messages to user prompt messages - for idx, prompt_message in enumerate(prompt_messages): - if isinstance(prompt_message, ToolPromptMessage): - prompt_messages[idx] = UserPromptMessage( - content=prompt_message.content - ) - - if not overridden: - prompt_messages.insert(0, SystemPromptMessage( - content=system_message, - )) - - # add assistant message - if len(agent_scratchpad) > 0 and not self._is_first_iteration: - prompt_messages.append(AssistantPromptMessage( - content=(agent_scratchpad[-1].thought or '') + (agent_scratchpad[-1].action_str or ''), - )) - - # add user message - if len(agent_scratchpad) > 0 and not self._is_first_iteration: - prompt_messages.append(UserPromptMessage( - content=(agent_scratchpad[-1].observation or 'It seems that no response is available'), - )) - - self._is_first_iteration = False - - return prompt_messages - elif mode == "completion": - # parse agent scratchpad - agent_scratchpad_str = self._convert_scratchpad_list_to_str(agent_scratchpad) - self._is_first_iteration = False - # parse prompt messages - return [UserPromptMessage( - content=first_prompt.replace("{{instruction}}", instruction) - .replace("{{tools}}", tools_str) - .replace("{{tool_names}}", tool_names) - .replace("{{query}}", input) - .replace("{{agent_scratchpad}}", agent_scratchpad_str), - )] - else: - raise ValueError(f"mode {mode} is not supported") - - def _jsonify_tool_prompt_messages(self, tools: list[PromptMessageTool]) -> str: - """ - jsonify tool prompt messages - """ - tools = jsonable_encoder(tools) - try: - return json.dumps(tools, ensure_ascii=False) - except json.JSONDecodeError: - return json.dumps(tools) \ No newline at end of file + return result \ No newline at end of file diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py new file mode 100644 index 0000000000..a904f3e641 --- /dev/null +++ b/api/core/agent/cot_chat_agent_runner.py @@ -0,0 +1,71 @@ +import json + +from core.agent.cot_agent_runner import CotAgentRunner +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.utils.encoders import jsonable_encoder + + +class CotChatAgentRunner(CotAgentRunner): + def _organize_system_prompt(self) -> SystemPromptMessage: + """ + Organize system prompt + """ + prompt_entity = self.app_config.agent.prompt + first_prompt = prompt_entity.first_prompt + + system_prompt = first_prompt \ + .replace("{{instruction}}", self._instruction) \ + .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \ + .replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools])) + + return SystemPromptMessage(content=system_prompt) + + def _organize_prompt_messages(self) -> list[PromptMessage]: + """ + Organize + """ + # organize system prompt + system_message = self._organize_system_prompt() + + # organize historic prompt messages + historic_messages = self._historic_prompt_messages + + # organize current assistant messages + agent_scratchpad = self._agent_scratchpad + if not agent_scratchpad: + assistant_messages = [] + else: + assistant_message = AssistantPromptMessage(content='') + for unit in agent_scratchpad: + if unit.is_final(): + assistant_message.content += f"Final Answer: {unit.agent_response}" + else: + assistant_message.content += f"Thought: {unit.thought}\n\n" + if unit.action_str: + assistant_message.content += f"Action: {unit.action_str}\n\n" + if unit.observation: + assistant_message.content += f"Observation: {unit.observation}\n\n" + + assistant_messages = [assistant_message] + + # query messages + query_messages = UserPromptMessage(content=self._query) + + if assistant_messages: + messages = [ + system_message, + *historic_messages, + query_messages, + *assistant_messages, + UserPromptMessage(content='continue') + ] + else: + messages = [system_message, *historic_messages, query_messages] + + # join all messages + return messages \ No newline at end of file diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py new file mode 100644 index 0000000000..3f0298d5a3 --- /dev/null +++ b/api/core/agent/cot_completion_agent_runner.py @@ -0,0 +1,69 @@ +import json + +from core.agent.cot_agent_runner import CotAgentRunner +from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, UserPromptMessage +from core.model_runtime.utils.encoders import jsonable_encoder + + +class CotCompletionAgentRunner(CotAgentRunner): + def _organize_instruction_prompt(self) -> str: + """ + Organize instruction prompt + """ + prompt_entity = self.app_config.agent.prompt + first_prompt = prompt_entity.first_prompt + + system_prompt = first_prompt.replace("{{instruction}}", self._instruction) \ + .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \ + .replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools])) + + return system_prompt + + def _organize_historic_prompt(self) -> str: + """ + Organize historic prompt + """ + historic_prompt_messages = self._historic_prompt_messages + historic_prompt = "" + + for message in historic_prompt_messages: + if isinstance(message, UserPromptMessage): + historic_prompt += f"Question: {message.content}\n\n" + elif isinstance(message, AssistantPromptMessage): + historic_prompt += message.content + "\n\n" + + return historic_prompt + + def _organize_prompt_messages(self) -> list[PromptMessage]: + """ + Organize prompt messages + """ + # organize system prompt + system_prompt = self._organize_instruction_prompt() + + # organize historic prompt messages + historic_prompt = self._organize_historic_prompt() + + # organize current assistant messages + agent_scratchpad = self._agent_scratchpad + assistant_prompt = '' + for unit in agent_scratchpad: + if unit.is_final(): + assistant_prompt += f"Final Answer: {unit.agent_response}" + else: + assistant_prompt += f"Thought: {unit.thought}\n\n" + if unit.action_str: + assistant_prompt += f"Action: {unit.action_str}\n\n" + if unit.observation: + assistant_prompt += f"Observation: {unit.observation}\n\n" + + # query messages + query_prompt = f"Question: {self._query}" + + # join all messages + prompt = system_prompt \ + .replace("{{historic_messages}}", historic_prompt) \ + .replace("{{agent_scratchpad}}", assistant_prompt) \ + .replace("{{query}}", query_prompt) + + return [UserPromptMessage(content=prompt)] \ No newline at end of file diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index e7016d6030..5284faa02e 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -34,12 +34,29 @@ class AgentScratchpadUnit(BaseModel): action_name: str action_input: Union[dict, str] + def to_dict(self) -> dict: + """ + Convert to dictionary. + """ + return { + 'action': self.action_name, + 'action_input': self.action_input, + } + agent_response: Optional[str] = None thought: Optional[str] = None action_str: Optional[str] = None observation: Optional[str] = None action: Optional[Action] = None + def is_final(self) -> bool: + """ + Check if the scratchpad unit is final. + """ + return self.action is None or ( + 'final' in self.action.action_name.lower() and + 'answer' in self.action.action_name.lower() + ) class AgentEntity(BaseModel): """ diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index ea5d31293d..a9b3a80073 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -12,7 +12,6 @@ from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, PromptMessageContentType, - PromptMessageTool, SystemPromptMessage, TextPromptMessageContent, ToolPromptMessage, @@ -25,8 +24,8 @@ from models.model import Message logger = logging.getLogger(__name__) class FunctionCallAgentRunner(BaseAgentRunner): - def run(self, message: Message, - query: str, + def run(self, + message: Message, query: str, **kwargs: Any ) -> Generator[LLMResultChunk, None, None]: """ Run FunctionCall agent application @@ -41,26 +40,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): prompt_messages = self._organize_user_query(query, prompt_messages) # convert tools into ModelRuntime Tool format - prompt_messages_tools: list[PromptMessageTool] = [] - tool_instances = {} - for tool in app_config.agent.tools if app_config.agent else []: - try: - prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool) - except Exception: - # api tool may be deleted - continue - # save tool entity - tool_instances[tool.tool_name] = tool_entity - # save prompt tool - prompt_messages_tools.append(prompt_tool) - - # convert dataset tools into ModelRuntime Tool format - for dataset_tool in self.dataset_tools: - prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool) - # save prompt tool - prompt_messages_tools.append(prompt_tool) - # save tool entity - tool_instances[dataset_tool.identity.name] = dataset_tool + tool_instances, prompt_messages_tools = self._init_prompt_tools() iteration_step = 1 max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py new file mode 100644 index 0000000000..91ac41143b --- /dev/null +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -0,0 +1,183 @@ +import json +import re +from collections.abc import Generator +from typing import Union + +from core.agent.entities import AgentScratchpadUnit +from core.model_runtime.entities.llm_entities import LLMResultChunk + + +class CotAgentOutputParser: + @classmethod + def handle_react_stream_output(cls, llm_response: Generator[LLMResultChunk, None, None]) -> \ + Generator[Union[str, AgentScratchpadUnit.Action], None, None]: + def parse_action(json_str): + try: + action = json.loads(json_str) + action_name = None + action_input = None + + for key, value in action.items(): + if 'input' in key.lower(): + action_input = value + else: + action_name = value + + if action_name is not None and action_input is not None: + return AgentScratchpadUnit.Action( + action_name=action_name, + action_input=action_input, + ) + else: + return json_str or '' + except: + return json_str or '' + + def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]: + code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL) + if not code_blocks: + return + for block in code_blocks: + json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE) + yield parse_action(json_text) + + code_block_cache = '' + code_block_delimiter_count = 0 + in_code_block = False + json_cache = '' + json_quote_count = 0 + in_json = False + got_json = False + + action_cache = '' + action_str = 'action:' + action_idx = 0 + + thought_cache = '' + thought_str = 'thought:' + thought_idx = 0 + + for response in llm_response: + response = response.delta.message.content + if not isinstance(response, str): + continue + + # stream + index = 0 + while index < len(response): + steps = 1 + delta = response[index:index+steps] + last_character = response[index-1] if index > 0 else '' + + if delta == '`': + code_block_cache += delta + code_block_delimiter_count += 1 + else: + if not in_code_block: + if code_block_delimiter_count > 0: + yield code_block_cache + code_block_cache = '' + else: + code_block_cache += delta + code_block_delimiter_count = 0 + + if not in_code_block and not in_json: + if delta.lower() == action_str[action_idx] and action_idx == 0: + if last_character not in ['\n', ' ', '']: + index += steps + yield delta + continue + + action_cache += delta + action_idx += 1 + if action_idx == len(action_str): + action_cache = '' + action_idx = 0 + index += steps + continue + elif delta.lower() == action_str[action_idx] and action_idx > 0: + action_cache += delta + action_idx += 1 + if action_idx == len(action_str): + action_cache = '' + action_idx = 0 + index += steps + continue + else: + if action_cache: + yield action_cache + action_cache = '' + action_idx = 0 + + if delta.lower() == thought_str[thought_idx] and thought_idx == 0: + if last_character not in ['\n', ' ', '']: + index += steps + yield delta + continue + + thought_cache += delta + thought_idx += 1 + if thought_idx == len(thought_str): + thought_cache = '' + thought_idx = 0 + index += steps + continue + elif delta.lower() == thought_str[thought_idx] and thought_idx > 0: + thought_cache += delta + thought_idx += 1 + if thought_idx == len(thought_str): + thought_cache = '' + thought_idx = 0 + index += steps + continue + else: + if thought_cache: + yield thought_cache + thought_cache = '' + thought_idx = 0 + + if code_block_delimiter_count == 3: + if in_code_block: + yield from extra_json_from_code_block(code_block_cache) + code_block_cache = '' + + in_code_block = not in_code_block + code_block_delimiter_count = 0 + + if not in_code_block: + # handle single json + if delta == '{': + json_quote_count += 1 + in_json = True + json_cache += delta + elif delta == '}': + json_cache += delta + if json_quote_count > 0: + json_quote_count -= 1 + if json_quote_count == 0: + in_json = False + got_json = True + index += steps + continue + else: + if in_json: + json_cache += delta + + if got_json: + got_json = False + yield parse_action(json_cache) + json_cache = '' + json_quote_count = 0 + in_json = False + + if not in_code_block and not in_json: + yield delta.replace('`', '') + + index += steps + + if code_block_cache: + yield code_block_cache + + if json_cache: + yield parse_action(json_cache) + diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index f42b146e51..dfa5d4591b 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -1,7 +1,8 @@ import logging from typing import cast -from core.agent.cot_agent_runner import CotAgentRunner +from core.agent.cot_chat_agent_runner import CotChatAgentRunner +from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner from core.agent.entities import AgentEntity from core.agent.fc_agent_runner import FunctionCallAgentRunner from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig @@ -11,8 +12,8 @@ from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, Mo from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMUsage -from core.model_runtime.entities.model_entities import ModelFeature +from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage +from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.moderation.base import ModerationException from core.tools.entities.tool_entities import ToolRuntimeVariablePool @@ -207,48 +208,40 @@ class AgentChatAppRunner(AppRunner): # start agent runner if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: - assistant_cot_runner = CotAgentRunner( - tenant_id=app_config.tenant_id, - application_generate_entity=application_generate_entity, - conversation=conversation, - app_config=app_config, - model_config=application_generate_entity.model_config, - config=agent_entity, - queue_manager=queue_manager, - message=message, - user_id=application_generate_entity.user_id, - memory=memory, - prompt_messages=prompt_message, - variables_pool=tool_variables, - db_variables=tool_conversation_variables, - model_instance=model_instance - ) - invoke_result = assistant_cot_runner.run( - message=message, - query=query, - inputs=inputs, - ) + # check LLM mode + if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: + runner_cls = CotChatAgentRunner + elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION.value: + runner_cls = CotCompletionAgentRunner + else: + raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}") elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING: - assistant_fc_runner = FunctionCallAgentRunner( - tenant_id=app_config.tenant_id, - application_generate_entity=application_generate_entity, - conversation=conversation, - app_config=app_config, - model_config=application_generate_entity.model_config, - config=agent_entity, - queue_manager=queue_manager, - message=message, - user_id=application_generate_entity.user_id, - memory=memory, - prompt_messages=prompt_message, - variables_pool=tool_variables, - db_variables=tool_conversation_variables, - model_instance=model_instance - ) - invoke_result = assistant_fc_runner.run( - message=message, - query=query, - ) + runner_cls = FunctionCallAgentRunner + else: + raise ValueError(f"Invalid agent strategy: {agent_entity.strategy}") + + runner = runner_cls( + tenant_id=app_config.tenant_id, + application_generate_entity=application_generate_entity, + conversation=conversation, + app_config=app_config, + model_config=application_generate_entity.model_config, + config=agent_entity, + queue_manager=queue_manager, + message=message, + user_id=application_generate_entity.user_id, + memory=memory, + prompt_messages=prompt_message, + variables_pool=tool_variables, + db_variables=tool_conversation_variables, + model_instance=model_instance + ) + + invoke_result = runner.run( + message=message, + query=query, + inputs=inputs, + ) # handle invoke result self._handle_invoke_result( diff --git a/api/core/tools/prompt/template.py b/api/core/tools/prompt/template.py index 3d35592279..b0cf1a77fb 100644 --- a/api/core/tools/prompt/template.py +++ b/api/core/tools/prompt/template.py @@ -38,8 +38,10 @@ Action: ``` Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. +{{historic_messages}} Question: {{query}} -Thought: {{agent_scratchpad}}""" +{{agent_scratchpad}} +Thought:""" ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES = """Observation: {{observation}} Thought:"""