fix: convert tool messages into user messages in react mode and fill … (#2584)

This commit is contained in:
Yeuoly 2024-02-27 19:15:07 +08:00 committed by GitHub
parent 29ab244de6
commit 3a34370422
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 45 additions and 32 deletions

View File

@ -606,36 +606,42 @@ class BaseAssistantApplicationRunner(AppRunner):
for message in messages: for message in messages:
result.append(UserPromptMessage(content=message.query)) result.append(UserPromptMessage(content=message.query))
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
for agent_thought in agent_thoughts: if agent_thoughts:
tools = agent_thought.tool for agent_thought in agent_thoughts:
if tools: tools = agent_thought.tool
tools = tools.split(';') if tools:
tool_calls: list[AssistantPromptMessage.ToolCall] = [] tools = tools.split(';')
tool_call_response: list[ToolPromptMessage] = [] tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_inputs = json.loads(agent_thought.tool_input) tool_call_response: list[ToolPromptMessage] = []
for tool in tools: tool_inputs = json.loads(agent_thought.tool_input)
# generate a uuid for tool call for tool in tools:
tool_call_id = str(uuid.uuid4()) # generate a uuid for tool call
tool_calls.append(AssistantPromptMessage.ToolCall( tool_call_id = str(uuid.uuid4())
id=tool_call_id, tool_calls.append(AssistantPromptMessage.ToolCall(
type='function', id=tool_call_id,
function=AssistantPromptMessage.ToolCall.ToolCallFunction( type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool,
arguments=json.dumps(tool_inputs.get(tool, {})),
)
))
tool_call_response.append(ToolPromptMessage(
content=agent_thought.observation,
name=tool, name=tool,
arguments=json.dumps(tool_inputs.get(tool, {})), tool_call_id=tool_call_id,
) ))
))
tool_call_response.append(ToolPromptMessage(
content=agent_thought.observation,
name=tool,
tool_call_id=tool_call_id,
))
result.extend([ result.extend([
AssistantPromptMessage( AssistantPromptMessage(
content=agent_thought.thought, content=agent_thought.thought,
tool_calls=tool_calls, tool_calls=tool_calls,
), ),
*tool_call_response *tool_call_response
]) ])
if not tools:
result.append(AssistantPromptMessage(content=agent_thought.thought))
else:
if message.answer:
result.append(AssistantPromptMessage(content=message.answer))
return result return result

View File

@ -154,7 +154,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
thought='', thought='',
action_str='', action_str='',
observation='', observation='',
action=None action=None,
) )
# publish agent thought if it's first iteration # publish agent thought if it's first iteration
@ -469,7 +469,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
thought=message.content, thought=message.content,
action_str='', action_str='',
action=None, action=None,
observation=None observation=None,
) )
if message.tool_calls: if message.tool_calls:
try: try:
@ -608,6 +608,13 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
overridden = True overridden = True
break break
# convert tool prompt messages to user prompt messages
for idx, prompt_message in enumerate(prompt_messages):
if isinstance(prompt_message, ToolPromptMessage):
prompt_messages[idx] = UserPromptMessage(
content=prompt_message.content
)
if not overridden: if not overridden:
prompt_messages.insert(0, SystemPromptMessage( prompt_messages.insert(0, SystemPromptMessage(
content=system_message, content=system_message,