diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 54aa0c9906..982477138b 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -61,8 +61,6 @@ class CotAgentRunner(BaseAgentRunner, ABC): # convert tools into ModelRuntime Tool format tool_instances, self._prompt_messages_tools = self._init_prompt_tools() - prompt_messages = self._organize_prompt_messages() - function_call_state = True llm_usage = { 'usage': None diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index e8b05373ab..8debbe5c5d 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -5,6 +5,7 @@ from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, SystemPromptMessage, + TextPromptMessageContent, UserPromptMessage, ) from core.model_runtime.utils.encoders import jsonable_encoder @@ -25,6 +26,21 @@ class CotChatAgentRunner(CotAgentRunner): return SystemPromptMessage(content=system_prompt) + def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: + """ + Organize user query + """ + if self.files: + prompt_message_contents = [TextPromptMessageContent(data=query)] + for file_obj in self.files: + prompt_message_contents.append(file_obj.prompt_message_content) + + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + else: + prompt_messages.append(UserPromptMessage(content=query)) + + return prompt_messages + def _organize_prompt_messages(self) -> list[PromptMessage]: """ Organize @@ -51,27 +67,27 @@ class CotChatAgentRunner(CotAgentRunner): assistant_messages = [assistant_message] # query messages - query_messages = UserPromptMessage(content=self._query) + query_messages = self._organize_user_query(self._query, []) if assistant_messages: # organize historic prompt messages historic_messages = self._organize_historic_prompt_messages([ system_message, - query_messages, + *query_messages, *assistant_messages, UserPromptMessage(content='continue') - ]) + ]) messages = [ system_message, *historic_messages, - query_messages, + *query_messages, *assistant_messages, UserPromptMessage(content='continue') ] else: # organize historic prompt messages - historic_messages = self._organize_historic_prompt_messages([system_message, query_messages]) - messages = [system_message, *historic_messages, query_messages] + historic_messages = self._organize_historic_prompt_messages([system_message, *query_messages]) + messages = [system_message, *historic_messages, *query_messages] # join all messages - return messages \ No newline at end of file + return messages