diff --git a/api/core/features/assistant_base_runner.py b/api/core/features/assistant_base_runner.py index c4a5767b04..2a4ae7e135 100644 --- a/api/core/features/assistant_base_runner.py +++ b/api/core/features/assistant_base_runner.py @@ -606,36 +606,42 @@ class BaseAssistantApplicationRunner(AppRunner): for message in messages: result.append(UserPromptMessage(content=message.query)) agent_thoughts: list[MessageAgentThought] = message.agent_thoughts - for agent_thought in agent_thoughts: - tools = agent_thought.tool - if tools: - tools = tools.split(';') - tool_calls: list[AssistantPromptMessage.ToolCall] = [] - tool_call_response: list[ToolPromptMessage] = [] - tool_inputs = json.loads(agent_thought.tool_input) - for tool in tools: - # generate a uuid for tool call - tool_call_id = str(uuid.uuid4()) - tool_calls.append(AssistantPromptMessage.ToolCall( - id=tool_call_id, - type='function', - function=AssistantPromptMessage.ToolCall.ToolCallFunction( + if agent_thoughts: + for agent_thought in agent_thoughts: + tools = agent_thought.tool + if tools: + tools = tools.split(';') + tool_calls: list[AssistantPromptMessage.ToolCall] = [] + tool_call_response: list[ToolPromptMessage] = [] + tool_inputs = json.loads(agent_thought.tool_input) + for tool in tools: + # generate a uuid for tool call + tool_call_id = str(uuid.uuid4()) + tool_calls.append(AssistantPromptMessage.ToolCall( + id=tool_call_id, + type='function', + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=tool, + arguments=json.dumps(tool_inputs.get(tool, {})), + ) + )) + tool_call_response.append(ToolPromptMessage( + content=agent_thought.observation, name=tool, - arguments=json.dumps(tool_inputs.get(tool, {})), - ) - )) - tool_call_response.append(ToolPromptMessage( - content=agent_thought.observation, - name=tool, - tool_call_id=tool_call_id, - )) + tool_call_id=tool_call_id, + )) - result.extend([ - AssistantPromptMessage( - content=agent_thought.thought, - tool_calls=tool_calls, - ), - *tool_call_response - ]) + result.extend([ + AssistantPromptMessage( + content=agent_thought.thought, + tool_calls=tool_calls, + ), + *tool_call_response + ]) + if not tools: + result.append(AssistantPromptMessage(content=agent_thought.thought)) + else: + if message.answer: + result.append(AssistantPromptMessage(content=message.answer)) return result \ No newline at end of file diff --git a/api/core/features/assistant_cot_runner.py b/api/core/features/assistant_cot_runner.py index aa4a6797cd..809834c8cb 100644 --- a/api/core/features/assistant_cot_runner.py +++ b/api/core/features/assistant_cot_runner.py @@ -154,7 +154,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): thought='', action_str='', observation='', - action=None + action=None, ) # publish agent thought if it's first iteration @@ -469,7 +469,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): thought=message.content, action_str='', action=None, - observation=None + observation=None, ) if message.tool_calls: try: @@ -484,7 +484,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): elif isinstance(message, ToolPromptMessage): if current_scratchpad: current_scratchpad.observation = message.content - + return agent_scratchpad def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"], @@ -607,6 +607,13 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): prompt_message.content = system_message overridden = True break + + # convert tool prompt messages to user prompt messages + for idx, prompt_message in enumerate(prompt_messages): + if isinstance(prompt_message, ToolPromptMessage): + prompt_messages[idx] = UserPromptMessage( + content=prompt_message.content + ) if not overridden: prompt_messages.insert(0, SystemPromptMessage(