fix: got unknown type of prompt message in multi-round ReAct agent chat (#5245)

This commit is contained in:
sino 2024-06-17 21:20:17 +08:00 committed by GitHub
parent 54756cd3b2
commit edffa5666d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -32,9 +32,9 @@ class CotAgentRunner(BaseAgentRunner, ABC):
_prompt_messages_tools: list[PromptMessage] = None
def run(self, message: Message,
query: str,
inputs: dict[str, str],
) -> Union[Generator, LLMResult]:
query: str,
inputs: dict[str, str],
) -> Union[Generator, LLMResult]:
"""
Run Cot agent application
"""
@ -52,7 +52,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# init instruction
inputs = inputs or {}
instruction = app_config.prompt_template.simple_prompt_template
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
self._instruction = self._fill_in_inputs_from_external_data_tools(
instruction, inputs)
iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
@ -122,7 +123,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
raise ValueError("failed to invoke llm")
usage_dict = {}
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
react_chunks = CotAgentOutputParser.handle_react_stream_output(
chunks, usage_dict)
scratchpad = AgentScratchpadUnit(
agent_response='',
thought='',
@ -160,7 +162,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
)
)
scratchpad.thought = scratchpad.thought.strip() or 'I am thinking about how to help you'
scratchpad.thought = scratchpad.thought.strip(
) or 'I am thinking about how to help you'
self._agent_scratchpad.append(scratchpad)
# get llm usage
@ -196,7 +199,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# action is final answer, return final answer directly
try:
if isinstance(scratchpad.action.action_input, dict):
final_answer = json.dumps(scratchpad.action.action_input)
final_answer = json.dumps(
scratchpad.action.action_input)
elif isinstance(scratchpad.action.action_input, str):
final_answer = scratchpad.action.action_input
else:
@ -217,10 +221,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
self.save_agent_thought(
agent_thought=agent_thought,
tool_name=scratchpad.action.action_name,
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
tool_input={
scratchpad.action.action_name: scratchpad.action.action_input},
thought=scratchpad.thought,
observation={scratchpad.action.action_name: tool_invoke_response},
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
observation={
scratchpad.action.action_name: tool_invoke_response},
tool_invoke_meta={
scratchpad.action.action_name: tool_invoke_meta.to_dict()},
answer=scratchpad.agent_response,
messages_ids=message_file_ids,
llm_usage=usage_dict['usage']
@ -232,7 +239,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# update prompt tool message
for prompt_tool in self._prompt_messages_tools:
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
self.update_prompt_message_tool(
tool_instances[prompt_tool.name], prompt_tool)
iteration_step += 1
@ -269,7 +277,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
message=AssistantPromptMessage(
content=final_answer
),
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(
),
system_fingerprint=''
)), PublishFrom.APPLICATION_MANAGER)
@ -311,7 +320,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# publish files
for message_file, save_as in message_files:
if save_as:
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as)
self.variables_pool.set_file(
tool_name=tool_call_name, value=message_file.id, name=save_as)
# publish message file
self.queue_manager.publish(QueueMessageFileEvent(
@ -382,13 +392,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
scratchpads: list[AgentScratchpadUnit] = []
current_scratchpad: AgentScratchpadUnit = None
self.history_prompt_messages = AgentHistoryPromptTransform(
model_config=self.model_config,
prompt_messages=current_session_messages or [],
history_messages=self.history_prompt_messages,
memory=self.memory
).get_prompt()
for message in self.history_prompt_messages:
if isinstance(message, AssistantPromptMessage):
if not current_scratchpad:
@ -404,7 +407,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
try:
current_scratchpad.action = AgentScratchpadUnit.Action(
action_name=message.tool_calls[0].function.name,
action_input=json.loads(message.tool_calls[0].function.arguments)
action_input=json.loads(
message.tool_calls[0].function.arguments)
)
current_scratchpad.action_str = json.dumps(
current_scratchpad.action.to_dict()
@ -424,10 +428,15 @@ class CotAgentRunner(BaseAgentRunner, ABC):
result.append(message)
if scratchpads:
result.append(AssistantPromptMessage(
content=self._format_assistant_message(scratchpads)
))
return result
historic_prompts = AgentHistoryPromptTransform(
model_config=self.model_config,
prompt_messages=current_session_messages or [],
history_messages=result,
memory=self.memory
).get_prompt()
return historic_prompts