mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 10:48:59 +08:00
Refactor agent history organization and initialization of agent scrat… (#2495)
This commit is contained in:
parent
e6cd7b0467
commit
ae3ad59b16
@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from mimetypes import guess_extension
|
from mimetypes import guess_extension
|
||||||
from typing import Optional, Union, cast
|
from typing import Optional, Union, cast
|
||||||
@ -20,7 +21,14 @@ from core.file.message_file_parser import FileTransferMethod
|
|||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
PromptMessage,
|
||||||
|
PromptMessageTool,
|
||||||
|
SystemPromptMessage,
|
||||||
|
ToolPromptMessage,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
from core.model_runtime.entities.model_entities import ModelFeature
|
from core.model_runtime.entities.model_entities import ModelFeature
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
@ -77,7 +85,9 @@ class BaseAssistantApplicationRunner(AppRunner):
|
|||||||
self.message = message
|
self.message = message
|
||||||
self.user_id = user_id
|
self.user_id = user_id
|
||||||
self.memory = memory
|
self.memory = memory
|
||||||
self.history_prompt_messages = prompt_messages
|
self.history_prompt_messages = self.organize_agent_history(
|
||||||
|
prompt_messages=prompt_messages or []
|
||||||
|
)
|
||||||
self.variables_pool = variables_pool
|
self.variables_pool = variables_pool
|
||||||
self.db_variables_pool = db_variables
|
self.db_variables_pool = db_variables
|
||||||
self.model_instance = model_instance
|
self.model_instance = model_instance
|
||||||
@ -505,17 +515,6 @@ class BaseAssistantApplicationRunner(AppRunner):
|
|||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
def get_history_prompt_messages(self) -> list[PromptMessage]:
|
|
||||||
"""
|
|
||||||
Get history prompt messages
|
|
||||||
"""
|
|
||||||
if self.history_prompt_messages is None:
|
|
||||||
self.history_prompt_messages = db.session.query(PromptMessage).filter(
|
|
||||||
PromptMessage.message_id == self.message.id,
|
|
||||||
).order_by(PromptMessage.position.asc()).all()
|
|
||||||
|
|
||||||
return self.history_prompt_messages
|
|
||||||
|
|
||||||
def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]:
|
def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]:
|
||||||
"""
|
"""
|
||||||
Transform tool message into agent thought
|
Transform tool message into agent thought
|
||||||
@ -590,3 +589,53 @@ class BaseAssistantApplicationRunner(AppRunner):
|
|||||||
db_variables.updated_at = datetime.utcnow()
|
db_variables.updated_at = datetime.utcnow()
|
||||||
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
|
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
|
def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||||
|
"""
|
||||||
|
Organize agent history
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
# check if there is a system message in the beginning of the conversation
|
||||||
|
if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||||
|
result.append(prompt_messages[0])
|
||||||
|
|
||||||
|
messages: list[Message] = db.session.query(Message).filter(
|
||||||
|
Message.conversation_id == self.message.conversation_id,
|
||||||
|
).order_by(Message.created_at.asc()).all()
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
result.append(UserPromptMessage(content=message.query))
|
||||||
|
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
|
||||||
|
for agent_thought in agent_thoughts:
|
||||||
|
tools = agent_thought.tool
|
||||||
|
if tools:
|
||||||
|
tools = tools.split(';')
|
||||||
|
tool_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||||
|
tool_call_response: list[ToolPromptMessage] = []
|
||||||
|
tool_inputs = json.loads(agent_thought.tool_input)
|
||||||
|
for tool in tools:
|
||||||
|
# generate a uuid for tool call
|
||||||
|
tool_call_id = str(uuid.uuid4())
|
||||||
|
tool_calls.append(AssistantPromptMessage.ToolCall(
|
||||||
|
id=tool_call_id,
|
||||||
|
type='function',
|
||||||
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
|
name=tool,
|
||||||
|
arguments=json.dumps(tool_inputs.get(tool, {})),
|
||||||
|
)
|
||||||
|
))
|
||||||
|
tool_call_response.append(ToolPromptMessage(
|
||||||
|
content=agent_thought.observation,
|
||||||
|
name=tool,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
))
|
||||||
|
|
||||||
|
result.extend([
|
||||||
|
AssistantPromptMessage(
|
||||||
|
content=agent_thought.thought,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
),
|
||||||
|
*tool_call_response
|
||||||
|
])
|
||||||
|
|
||||||
|
return result
|
@ -12,6 +12,7 @@ from core.model_runtime.entities.message_entities import (
|
|||||||
PromptMessage,
|
PromptMessage,
|
||||||
PromptMessageTool,
|
PromptMessageTool,
|
||||||
SystemPromptMessage,
|
SystemPromptMessage,
|
||||||
|
ToolPromptMessage,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
@ -39,6 +40,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
|||||||
self._repack_app_orchestration_config(app_orchestration_config)
|
self._repack_app_orchestration_config(app_orchestration_config)
|
||||||
|
|
||||||
agent_scratchpad: list[AgentScratchpadUnit] = []
|
agent_scratchpad: list[AgentScratchpadUnit] = []
|
||||||
|
self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages)
|
||||||
|
|
||||||
# check model mode
|
# check model mode
|
||||||
if self.app_orchestration_config.model_config.mode == "completion":
|
if self.app_orchestration_config.model_config.mode == "completion":
|
||||||
@ -328,6 +330,39 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
|||||||
|
|
||||||
return instruction
|
return instruction
|
||||||
|
|
||||||
|
def _init_agent_scratchpad(self,
|
||||||
|
agent_scratchpad: list[AgentScratchpadUnit],
|
||||||
|
messages: list[PromptMessage]
|
||||||
|
) -> list[AgentScratchpadUnit]:
|
||||||
|
"""
|
||||||
|
init agent scratchpad
|
||||||
|
"""
|
||||||
|
current_scratchpad: AgentScratchpadUnit = None
|
||||||
|
for message in messages:
|
||||||
|
if isinstance(message, AssistantPromptMessage):
|
||||||
|
current_scratchpad = AgentScratchpadUnit(
|
||||||
|
agent_response=message.content,
|
||||||
|
thought=message.content,
|
||||||
|
action_str='',
|
||||||
|
action=None,
|
||||||
|
observation=None
|
||||||
|
)
|
||||||
|
if message.tool_calls:
|
||||||
|
try:
|
||||||
|
current_scratchpad.action = AgentScratchpadUnit.Action(
|
||||||
|
action_name=message.tool_calls[0].function.name,
|
||||||
|
action_input=json.loads(message.tool_calls[0].function.arguments)
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
agent_scratchpad.append(current_scratchpad)
|
||||||
|
elif isinstance(message, ToolPromptMessage):
|
||||||
|
if current_scratchpad:
|
||||||
|
current_scratchpad.observation = message.content
|
||||||
|
|
||||||
|
return agent_scratchpad
|
||||||
|
|
||||||
def _extract_response_scratchpad(self, content: str) -> AgentScratchpadUnit:
|
def _extract_response_scratchpad(self, content: str) -> AgentScratchpadUnit:
|
||||||
"""
|
"""
|
||||||
extract response from llm response
|
extract response from llm response
|
||||||
|
Loading…
x
Reference in New Issue
Block a user