fix: got unknown type of prompt message in multi-round ReAct agent chat (#5245)

This commit is contained in:
sino 2024-06-17 21:20:17 +08:00 committed by GitHub
parent 54756cd3b2
commit edffa5666d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -32,9 +32,9 @@ class CotAgentRunner(BaseAgentRunner, ABC):
_prompt_messages_tools: list[PromptMessage] = None _prompt_messages_tools: list[PromptMessage] = None
def run(self, message: Message, def run(self, message: Message,
query: str, query: str,
inputs: dict[str, str], inputs: dict[str, str],
) -> Union[Generator, LLMResult]: ) -> Union[Generator, LLMResult]:
""" """
Run Cot agent application Run Cot agent application
""" """
@ -52,7 +52,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# init instruction # init instruction
inputs = inputs or {} inputs = inputs or {}
instruction = app_config.prompt_template.simple_prompt_template instruction = app_config.prompt_template.simple_prompt_template
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) self._instruction = self._fill_in_inputs_from_external_data_tools(
instruction, inputs)
iteration_step = 1 iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
@ -61,7 +62,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
tool_instances, self._prompt_messages_tools = self._init_prompt_tools() tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
prompt_messages = self._organize_prompt_messages() prompt_messages = self._organize_prompt_messages()
function_call_state = True function_call_state = True
llm_usage = { llm_usage = {
'usage': None 'usage': None
@ -120,9 +121,10 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# check llm result # check llm result
if not chunks: if not chunks:
raise ValueError("failed to invoke llm") raise ValueError("failed to invoke llm")
usage_dict = {} usage_dict = {}
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict) react_chunks = CotAgentOutputParser.handle_react_stream_output(
chunks, usage_dict)
scratchpad = AgentScratchpadUnit( scratchpad = AgentScratchpadUnit(
agent_response='', agent_response='',
thought='', thought='',
@ -160,15 +162,16 @@ class CotAgentRunner(BaseAgentRunner, ABC):
) )
) )
scratchpad.thought = scratchpad.thought.strip() or 'I am thinking about how to help you' scratchpad.thought = scratchpad.thought.strip(
) or 'I am thinking about how to help you'
self._agent_scratchpad.append(scratchpad) self._agent_scratchpad.append(scratchpad)
# get llm usage # get llm usage
if 'usage' in usage_dict: if 'usage' in usage_dict:
increase_usage(llm_usage, usage_dict['usage']) increase_usage(llm_usage, usage_dict['usage'])
else: else:
usage_dict['usage'] = LLMUsage.empty_usage() usage_dict['usage'] = LLMUsage.empty_usage()
self.save_agent_thought( self.save_agent_thought(
agent_thought=agent_thought, agent_thought=agent_thought,
tool_name=scratchpad.action.action_name if scratchpad.action else '', tool_name=scratchpad.action.action_name if scratchpad.action else '',
@ -182,7 +185,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
messages_ids=[], messages_ids=[],
llm_usage=usage_dict['usage'] llm_usage=usage_dict['usage']
) )
if not scratchpad.is_final(): if not scratchpad.is_final():
self.queue_manager.publish(QueueAgentThoughtEvent( self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id agent_thought_id=agent_thought.id
@ -196,7 +199,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# action is final answer, return final answer directly # action is final answer, return final answer directly
try: try:
if isinstance(scratchpad.action.action_input, dict): if isinstance(scratchpad.action.action_input, dict):
final_answer = json.dumps(scratchpad.action.action_input) final_answer = json.dumps(
scratchpad.action.action_input)
elif isinstance(scratchpad.action.action_input, str): elif isinstance(scratchpad.action.action_input, str):
final_answer = scratchpad.action.action_input final_answer = scratchpad.action.action_input
else: else:
@ -207,7 +211,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
function_call_state = True function_call_state = True
# action is tool call, invoke tool # action is tool call, invoke tool
tool_invoke_response, tool_invoke_meta = self._handle_invoke_action( tool_invoke_response, tool_invoke_meta = self._handle_invoke_action(
action=scratchpad.action, action=scratchpad.action,
tool_instances=tool_instances, tool_instances=tool_instances,
message_file_ids=message_file_ids message_file_ids=message_file_ids
) )
@ -217,10 +221,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
self.save_agent_thought( self.save_agent_thought(
agent_thought=agent_thought, agent_thought=agent_thought,
tool_name=scratchpad.action.action_name, tool_name=scratchpad.action.action_name,
tool_input={scratchpad.action.action_name: scratchpad.action.action_input}, tool_input={
scratchpad.action.action_name: scratchpad.action.action_input},
thought=scratchpad.thought, thought=scratchpad.thought,
observation={scratchpad.action.action_name: tool_invoke_response}, observation={
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()}, scratchpad.action.action_name: tool_invoke_response},
tool_invoke_meta={
scratchpad.action.action_name: tool_invoke_meta.to_dict()},
answer=scratchpad.agent_response, answer=scratchpad.agent_response,
messages_ids=message_file_ids, messages_ids=message_file_ids,
llm_usage=usage_dict['usage'] llm_usage=usage_dict['usage']
@ -232,7 +239,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# update prompt tool message # update prompt tool message
for prompt_tool in self._prompt_messages_tools: for prompt_tool in self._prompt_messages_tools:
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool) self.update_prompt_message_tool(
tool_instances[prompt_tool.name], prompt_tool)
iteration_step += 1 iteration_step += 1
@ -251,12 +259,12 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# save agent thought # save agent thought
self.save_agent_thought( self.save_agent_thought(
agent_thought=agent_thought, agent_thought=agent_thought,
tool_name='', tool_name='',
tool_input={}, tool_input={},
tool_invoke_meta={}, tool_invoke_meta={},
thought=final_answer, thought=final_answer,
observation={}, observation={},
answer=final_answer, answer=final_answer,
messages_ids=[] messages_ids=[]
) )
@ -269,11 +277,12 @@ class CotAgentRunner(BaseAgentRunner, ABC):
message=AssistantPromptMessage( message=AssistantPromptMessage(
content=final_answer content=final_answer
), ),
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(), usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(
),
system_fingerprint='' system_fingerprint=''
)), PublishFrom.APPLICATION_MANAGER) )), PublishFrom.APPLICATION_MANAGER)
def _handle_invoke_action(self, action: AgentScratchpadUnit.Action, def _handle_invoke_action(self, action: AgentScratchpadUnit.Action,
tool_instances: dict[str, Tool], tool_instances: dict[str, Tool],
message_file_ids: list[str]) -> tuple[str, ToolInvokeMeta]: message_file_ids: list[str]) -> tuple[str, ToolInvokeMeta]:
""" """
@ -290,7 +299,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
if not tool_instance: if not tool_instance:
answer = f"there is not a tool named {tool_call_name}" answer = f"there is not a tool named {tool_call_name}"
return answer, ToolInvokeMeta.error_instance(answer) return answer, ToolInvokeMeta.error_instance(answer)
if isinstance(tool_call_args, str): if isinstance(tool_call_args, str):
try: try:
tool_call_args = json.loads(tool_call_args) tool_call_args = json.loads(tool_call_args)
@ -311,7 +320,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# publish files # publish files
for message_file, save_as in message_files: for message_file, save_as in message_files:
if save_as: if save_as:
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as) self.variables_pool.set_file(
tool_name=tool_call_name, value=message_file.id, name=save_as)
# publish message file # publish message file
self.queue_manager.publish(QueueMessageFileEvent( self.queue_manager.publish(QueueMessageFileEvent(
@ -342,7 +352,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
continue continue
return instruction return instruction
def _init_react_state(self, query) -> None: def _init_react_state(self, query) -> None:
""" """
init agent scratchpad init agent scratchpad
@ -350,7 +360,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
self._query = query self._query = query
self._agent_scratchpad = [] self._agent_scratchpad = []
self._historic_prompt_messages = self._organize_historic_prompt_messages() self._historic_prompt_messages = self._organize_historic_prompt_messages()
@abstractmethod @abstractmethod
def _organize_prompt_messages(self) -> list[PromptMessage]: def _organize_prompt_messages(self) -> list[PromptMessage]:
""" """
@ -382,13 +392,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
scratchpads: list[AgentScratchpadUnit] = [] scratchpads: list[AgentScratchpadUnit] = []
current_scratchpad: AgentScratchpadUnit = None current_scratchpad: AgentScratchpadUnit = None
self.history_prompt_messages = AgentHistoryPromptTransform(
model_config=self.model_config,
prompt_messages=current_session_messages or [],
history_messages=self.history_prompt_messages,
memory=self.memory
).get_prompt()
for message in self.history_prompt_messages: for message in self.history_prompt_messages:
if isinstance(message, AssistantPromptMessage): if isinstance(message, AssistantPromptMessage):
if not current_scratchpad: if not current_scratchpad:
@ -404,7 +407,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
try: try:
current_scratchpad.action = AgentScratchpadUnit.Action( current_scratchpad.action = AgentScratchpadUnit.Action(
action_name=message.tool_calls[0].function.name, action_name=message.tool_calls[0].function.name,
action_input=json.loads(message.tool_calls[0].function.arguments) action_input=json.loads(
message.tool_calls[0].function.arguments)
) )
current_scratchpad.action_str = json.dumps( current_scratchpad.action_str = json.dumps(
current_scratchpad.action.to_dict() current_scratchpad.action.to_dict()
@ -424,10 +428,15 @@ class CotAgentRunner(BaseAgentRunner, ABC):
result.append(message) result.append(message)
if scratchpads: if scratchpads:
result.append(AssistantPromptMessage( result.append(AssistantPromptMessage(
content=self._format_assistant_message(scratchpads) content=self._format_assistant_message(scratchpads)
)) ))
return result historic_prompts = AgentHistoryPromptTransform(
model_config=self.model_config,
prompt_messages=current_session_messages or [],
history_messages=result,
memory=self.memory
).get_prompt()
return historic_prompts