From ae3ad59b166c963429f42f6b858e77b12cae0fa6 Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Tue, 20 Feb 2024 19:03:43 +0800 Subject: [PATCH] =?UTF-8?q?Refactor=20agent=20history=20organization=20and?= =?UTF-8?q?=20initialization=20of=20agent=20scrat=E2=80=A6=20(#2495)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/core/features/assistant_base_runner.py | 77 ++++++++++++++++++---- api/core/features/assistant_cot_runner.py | 35 ++++++++++ 2 files changed, 98 insertions(+), 14 deletions(-) diff --git a/api/core/features/assistant_base_runner.py b/api/core/features/assistant_base_runner.py index c62028eaf0..c4a5767b04 100644 --- a/api/core/features/assistant_base_runner.py +++ b/api/core/features/assistant_base_runner.py @@ -1,5 +1,6 @@ import json import logging +import uuid from datetime import datetime from mimetypes import guess_extension from typing import Optional, Union, cast @@ -20,7 +21,14 @@ from core.file.message_file_parser import FileTransferMethod from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMUsage -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder @@ -77,7 +85,9 @@ class BaseAssistantApplicationRunner(AppRunner): self.message = message self.user_id = user_id self.memory = memory - self.history_prompt_messages = prompt_messages + self.history_prompt_messages = self.organize_agent_history( + prompt_messages=prompt_messages or [] + ) self.variables_pool = variables_pool self.db_variables_pool = db_variables self.model_instance = model_instance @@ -504,17 +514,6 @@ class BaseAssistantApplicationRunner(AppRunner): agent_thought.tool_labels_str = json.dumps(labels) db.session.commit() - - def get_history_prompt_messages(self) -> list[PromptMessage]: - """ - Get history prompt messages - """ - if self.history_prompt_messages is None: - self.history_prompt_messages = db.session.query(PromptMessage).filter( - PromptMessage.message_id == self.message.id, - ).order_by(PromptMessage.position.asc()).all() - - return self.history_prompt_messages def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]: """ @@ -589,4 +588,54 @@ class BaseAssistantApplicationRunner(AppRunner): """ db_variables.updated_at = datetime.utcnow() db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) - db.session.commit() \ No newline at end of file + db.session.commit() + + def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: + """ + Organize agent history + """ + result = [] + # check if there is a system message in the beginning of the conversation + if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage): + result.append(prompt_messages[0]) + + messages: list[Message] = db.session.query(Message).filter( + Message.conversation_id == self.message.conversation_id, + ).order_by(Message.created_at.asc()).all() + + for message in messages: + result.append(UserPromptMessage(content=message.query)) + agent_thoughts: list[MessageAgentThought] = message.agent_thoughts + for agent_thought in agent_thoughts: + tools = agent_thought.tool + if tools: + tools = tools.split(';') + tool_calls: list[AssistantPromptMessage.ToolCall] = [] + tool_call_response: list[ToolPromptMessage] = [] + tool_inputs = json.loads(agent_thought.tool_input) + for tool in tools: + # generate a uuid for tool call + tool_call_id = str(uuid.uuid4()) + tool_calls.append(AssistantPromptMessage.ToolCall( + id=tool_call_id, + type='function', + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=tool, + arguments=json.dumps(tool_inputs.get(tool, {})), + ) + )) + tool_call_response.append(ToolPromptMessage( + content=agent_thought.observation, + name=tool, + tool_call_id=tool_call_id, + )) + + result.extend([ + AssistantPromptMessage( + content=agent_thought.thought, + tool_calls=tool_calls, + ), + *tool_call_response + ]) + + return result \ No newline at end of file diff --git a/api/core/features/assistant_cot_runner.py b/api/core/features/assistant_cot_runner.py index b8d08bb5d3..c8477fb5d9 100644 --- a/api/core/features/assistant_cot_runner.py +++ b/api/core/features/assistant_cot_runner.py @@ -12,6 +12,7 @@ from core.model_runtime.entities.message_entities import ( PromptMessage, PromptMessageTool, SystemPromptMessage, + ToolPromptMessage, UserPromptMessage, ) from core.model_runtime.utils.encoders import jsonable_encoder @@ -39,6 +40,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): self._repack_app_orchestration_config(app_orchestration_config) agent_scratchpad: list[AgentScratchpadUnit] = [] + self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages) # check model mode if self.app_orchestration_config.model_config.mode == "completion": @@ -327,6 +329,39 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): continue return instruction + + def _init_agent_scratchpad(self, + agent_scratchpad: list[AgentScratchpadUnit], + messages: list[PromptMessage] + ) -> list[AgentScratchpadUnit]: + """ + init agent scratchpad + """ + current_scratchpad: AgentScratchpadUnit = None + for message in messages: + if isinstance(message, AssistantPromptMessage): + current_scratchpad = AgentScratchpadUnit( + agent_response=message.content, + thought=message.content, + action_str='', + action=None, + observation=None + ) + if message.tool_calls: + try: + current_scratchpad.action = AgentScratchpadUnit.Action( + action_name=message.tool_calls[0].function.name, + action_input=json.loads(message.tool_calls[0].function.arguments) + ) + except: + pass + + agent_scratchpad.append(current_scratchpad) + elif isinstance(message, ToolPromptMessage): + if current_scratchpad: + current_scratchpad.observation = message.content + + return agent_scratchpad def _extract_response_scratchpad(self, content: str) -> AgentScratchpadUnit: """