fix: got unknown type of prompt message in multi-round ReAct agent chat (#5245)

This commit is contained in:
sino 2024-06-17 21:20:17 +08:00 committed by GitHub
parent 54756cd3b2
commit edffa5666d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -32,9 +32,9 @@ class CotAgentRunner(BaseAgentRunner, ABC):
_prompt_messages_tools: list[PromptMessage] = None
def run(self, message: Message,
query: str,
inputs: dict[str, str],
) -> Union[Generator, LLMResult]:
query: str,
inputs: dict[str, str],
) -> Union[Generator, LLMResult]:
"""
Run Cot agent application
"""
@ -52,7 +52,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# init instruction
inputs = inputs or {}
instruction = app_config.prompt_template.simple_prompt_template
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
self._instruction = self._fill_in_inputs_from_external_data_tools(
instruction, inputs)
iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
@ -61,7 +62,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
prompt_messages = self._organize_prompt_messages()
function_call_state = True
llm_usage = {
'usage': None
@ -120,9 +121,10 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# check llm result
if not chunks:
raise ValueError("failed to invoke llm")
usage_dict = {}
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
react_chunks = CotAgentOutputParser.handle_react_stream_output(
chunks, usage_dict)
scratchpad = AgentScratchpadUnit(
agent_response='',
thought='',
@ -160,15 +162,16 @@ class CotAgentRunner(BaseAgentRunner, ABC):
)
)
scratchpad.thought = scratchpad.thought.strip() or 'I am thinking about how to help you'
scratchpad.thought = scratchpad.thought.strip(
) or 'I am thinking about how to help you'
self._agent_scratchpad.append(scratchpad)
# get llm usage
if 'usage' in usage_dict:
increase_usage(llm_usage, usage_dict['usage'])
else:
usage_dict['usage'] = LLMUsage.empty_usage()
self.save_agent_thought(
agent_thought=agent_thought,
tool_name=scratchpad.action.action_name if scratchpad.action else '',
@ -182,7 +185,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
messages_ids=[],
llm_usage=usage_dict['usage']
)
if not scratchpad.is_final():
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
@ -196,7 +199,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# action is final answer, return final answer directly
try:
if isinstance(scratchpad.action.action_input, dict):
final_answer = json.dumps(scratchpad.action.action_input)
final_answer = json.dumps(
scratchpad.action.action_input)
elif isinstance(scratchpad.action.action_input, str):
final_answer = scratchpad.action.action_input
else:
@ -207,7 +211,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
function_call_state = True
# action is tool call, invoke tool
tool_invoke_response, tool_invoke_meta = self._handle_invoke_action(
action=scratchpad.action,
action=scratchpad.action,
tool_instances=tool_instances,
message_file_ids=message_file_ids
)
@ -217,10 +221,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
self.save_agent_thought(
agent_thought=agent_thought,
tool_name=scratchpad.action.action_name,
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
tool_input={
scratchpad.action.action_name: scratchpad.action.action_input},
thought=scratchpad.thought,
observation={scratchpad.action.action_name: tool_invoke_response},
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
observation={
scratchpad.action.action_name: tool_invoke_response},
tool_invoke_meta={
scratchpad.action.action_name: tool_invoke_meta.to_dict()},
answer=scratchpad.agent_response,
messages_ids=message_file_ids,
llm_usage=usage_dict['usage']
@ -232,7 +239,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# update prompt tool message
for prompt_tool in self._prompt_messages_tools:
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
self.update_prompt_message_tool(
tool_instances[prompt_tool.name], prompt_tool)
iteration_step += 1
@ -251,12 +259,12 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# save agent thought
self.save_agent_thought(
agent_thought=agent_thought,
agent_thought=agent_thought,
tool_name='',
tool_input={},
tool_invoke_meta={},
thought=final_answer,
observation={},
observation={},
answer=final_answer,
messages_ids=[]
)
@ -269,11 +277,12 @@ class CotAgentRunner(BaseAgentRunner, ABC):
message=AssistantPromptMessage(
content=final_answer
),
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(
),
system_fingerprint=''
)), PublishFrom.APPLICATION_MANAGER)
def _handle_invoke_action(self, action: AgentScratchpadUnit.Action,
def _handle_invoke_action(self, action: AgentScratchpadUnit.Action,
tool_instances: dict[str, Tool],
message_file_ids: list[str]) -> tuple[str, ToolInvokeMeta]:
"""
@ -290,7 +299,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
if not tool_instance:
answer = f"there is not a tool named {tool_call_name}"
return answer, ToolInvokeMeta.error_instance(answer)
if isinstance(tool_call_args, str):
try:
tool_call_args = json.loads(tool_call_args)
@ -311,7 +320,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# publish files
for message_file, save_as in message_files:
if save_as:
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as)
self.variables_pool.set_file(
tool_name=tool_call_name, value=message_file.id, name=save_as)
# publish message file
self.queue_manager.publish(QueueMessageFileEvent(
@ -342,7 +352,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
continue
return instruction
def _init_react_state(self, query) -> None:
"""
init agent scratchpad
@ -350,7 +360,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
self._query = query
self._agent_scratchpad = []
self._historic_prompt_messages = self._organize_historic_prompt_messages()
@abstractmethod
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
@ -382,13 +392,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
scratchpads: list[AgentScratchpadUnit] = []
current_scratchpad: AgentScratchpadUnit = None
self.history_prompt_messages = AgentHistoryPromptTransform(
model_config=self.model_config,
prompt_messages=current_session_messages or [],
history_messages=self.history_prompt_messages,
memory=self.memory
).get_prompt()
for message in self.history_prompt_messages:
if isinstance(message, AssistantPromptMessage):
if not current_scratchpad:
@ -404,7 +407,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
try:
current_scratchpad.action = AgentScratchpadUnit.Action(
action_name=message.tool_calls[0].function.name,
action_input=json.loads(message.tool_calls[0].function.arguments)
action_input=json.loads(
message.tool_calls[0].function.arguments)
)
current_scratchpad.action_str = json.dumps(
current_scratchpad.action.to_dict()
@ -424,10 +428,15 @@ class CotAgentRunner(BaseAgentRunner, ABC):
result.append(message)
if scratchpads:
result.append(AssistantPromptMessage(
content=self._format_assistant_message(scratchpads)
))
return result
historic_prompts = AgentHistoryPromptTransform(
model_config=self.model_config,
prompt_messages=current_session_messages or [],
history_messages=result,
memory=self.memory
).get_prompt()
return historic_prompts