Feat/Agent-Image-Processing (#3293)

Co-authored-by: Joel <iamjoel007@gmail.com>
This commit is contained in:
Yeuoly 2024-04-10 14:48:40 +08:00 committed by GitHub
parent 240c793e7a
commit 14bb0b02ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 148 additions and 40 deletions

View File

@ -5,6 +5,7 @@ from datetime import datetime
from typing import Optional, Union, cast from typing import Optional, Union, cast
from core.agent.entities import AgentEntity, AgentToolEntity from core.agent.entities import AgentEntity, AgentToolEntity
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.base_app_runner import AppRunner from core.app.apps.base_app_runner import AppRunner
@ -14,6 +15,7 @@ from core.app.entities.app_invoke_entities import (
) )
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.file.message_file_parser import MessageFileParser
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
@ -22,6 +24,7 @@ from core.model_runtime.entities.message_entities import (
PromptMessage, PromptMessage,
PromptMessageTool, PromptMessageTool,
SystemPromptMessage, SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage, ToolPromptMessage,
UserPromptMessage, UserPromptMessage,
) )
@ -37,7 +40,7 @@ from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.tools.tool.tool import Tool from core.tools.tool.tool import Tool
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from extensions.ext_database import db from extensions.ext_database import db
from models.model import Message, MessageAgentThought from models.model import Conversation, Message, MessageAgentThought
from models.tools import ToolConversationVariables from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -45,6 +48,7 @@ logger = logging.getLogger(__name__)
class BaseAgentRunner(AppRunner): class BaseAgentRunner(AppRunner):
def __init__(self, tenant_id: str, def __init__(self, tenant_id: str,
application_generate_entity: AgentChatAppGenerateEntity, application_generate_entity: AgentChatAppGenerateEntity,
conversation: Conversation,
app_config: AgentChatAppConfig, app_config: AgentChatAppConfig,
model_config: ModelConfigWithCredentialsEntity, model_config: ModelConfigWithCredentialsEntity,
config: AgentEntity, config: AgentEntity,
@ -72,6 +76,7 @@ class BaseAgentRunner(AppRunner):
""" """
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.application_generate_entity = application_generate_entity self.application_generate_entity = application_generate_entity
self.conversation = conversation
self.app_config = app_config self.app_config = app_config
self.model_config = model_config self.model_config = model_config
self.config = config self.config = config
@ -118,6 +123,12 @@ class BaseAgentRunner(AppRunner):
else: else:
self.stream_tool_call = False self.stream_tool_call = False
# check if model supports vision
if model_schema and ModelFeature.VISION in (model_schema.features or []):
self.files = application_generate_entity.files
else:
self.files = []
def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \ def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \
-> AgentChatAppGenerateEntity: -> AgentChatAppGenerateEntity:
""" """
@ -412,15 +423,19 @@ class BaseAgentRunner(AppRunner):
""" """
result = [] result = []
# check if there is a system message in the beginning of the conversation # check if there is a system message in the beginning of the conversation
if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage): for prompt_message in prompt_messages:
result.append(prompt_messages[0]) if isinstance(prompt_message, SystemPromptMessage):
result.append(prompt_message)
messages: list[Message] = db.session.query(Message).filter( messages: list[Message] = db.session.query(Message).filter(
Message.conversation_id == self.message.conversation_id, Message.conversation_id == self.message.conversation_id,
).order_by(Message.created_at.asc()).all() ).order_by(Message.created_at.asc()).all()
for message in messages: for message in messages:
result.append(UserPromptMessage(content=message.query)) if message.id == self.message.id:
continue
result.append(self.organize_agent_user_prompt(message))
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
if agent_thoughts: if agent_thoughts:
for agent_thought in agent_thoughts: for agent_thought in agent_thoughts:
@ -471,3 +486,32 @@ class BaseAgentRunner(AppRunner):
db.session.close() db.session.close()
return result return result
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
message_file_parser = MessageFileParser(
tenant_id=self.tenant_id,
app_id=self.app_config.app_id,
)
files = message.message_files
if files:
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
if file_extra_config:
file_objs = message_file_parser.transform_message_files(
files,
file_extra_config
)
else:
file_objs = []
if not file_objs:
return UserPromptMessage(content=message.query)
else:
prompt_message_contents = [TextPromptMessageContent(data=message.query)]
for file_obj in file_objs:
prompt_message_contents.append(file_obj.prompt_message_content)
return UserPromptMessage(content=prompt_message_contents)
else:
return UserPromptMessage(content=message.query)

View File

@ -19,15 +19,14 @@ from core.model_runtime.entities.message_entities import (
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine from core.tools.tool_engine import ToolEngine
from models.model import Conversation, Message from models.model import Message
class CotAgentRunner(BaseAgentRunner): class CotAgentRunner(BaseAgentRunner):
_is_first_iteration = True _is_first_iteration = True
_ignore_observation_providers = ['wenxin'] _ignore_observation_providers = ['wenxin']
def run(self, conversation: Conversation, def run(self, message: Message,
message: Message,
query: str, query: str,
inputs: dict[str, str], inputs: dict[str, str],
) -> Union[Generator, LLMResult]: ) -> Union[Generator, LLMResult]:

View File

@ -1,6 +1,7 @@
import json import json
import logging import logging
from collections.abc import Generator from collections.abc import Generator
from copy import deepcopy
from typing import Any, Union from typing import Any, Union
from core.agent.base_agent_runner import BaseAgentRunner from core.agent.base_agent_runner import BaseAgentRunner
@ -10,20 +11,21 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
PromptMessage, PromptMessage,
PromptMessageContentType,
PromptMessageTool, PromptMessageTool,
SystemPromptMessage, SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage, ToolPromptMessage,
UserPromptMessage, UserPromptMessage,
) )
from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine from core.tools.tool_engine import ToolEngine
from models.model import Conversation, Message, MessageAgentThought from models.model import Message
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class FunctionCallAgentRunner(BaseAgentRunner): class FunctionCallAgentRunner(BaseAgentRunner):
def run(self, conversation: Conversation, def run(self, message: Message,
message: Message,
query: str, query: str,
) -> Generator[LLMResultChunk, None, None]: ) -> Generator[LLMResultChunk, None, None]:
""" """
@ -35,11 +37,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
prompt_template = app_config.prompt_template.simple_prompt_template or '' prompt_template = app_config.prompt_template.simple_prompt_template or ''
prompt_messages = self.history_prompt_messages prompt_messages = self.history_prompt_messages
prompt_messages = self.organize_prompt_messages( prompt_messages = self._init_system_message(prompt_template, prompt_messages)
prompt_template=prompt_template, prompt_messages = self._organize_user_query(query, prompt_messages)
query=query,
prompt_messages=prompt_messages
)
# convert tools into ModelRuntime Tool format # convert tools into ModelRuntime Tool format
prompt_messages_tools: list[PromptMessageTool] = [] prompt_messages_tools: list[PromptMessageTool] = []
@ -68,7 +67,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# continue to run until there is not any tool call # continue to run until there is not any tool call
function_call_state = True function_call_state = True
agent_thoughts: list[MessageAgentThought] = []
llm_usage = { llm_usage = {
'usage': None 'usage': None
} }
@ -287,9 +285,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
} }
tool_responses.append(tool_response) tool_responses.append(tool_response)
prompt_messages = self.organize_prompt_messages( prompt_messages = self._organize_assistant_message(
prompt_template=prompt_template,
query=None,
tool_call_id=tool_call_id, tool_call_id=tool_call_id,
tool_call_name=tool_call_name, tool_call_name=tool_call_name,
tool_response=tool_response['tool_response'], tool_response=tool_response['tool_response'],
@ -324,6 +320,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
iteration_step += 1 iteration_step += 1
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
self.update_db_variables(self.variables_pool, self.db_variables_pool) self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event # publish end event
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult( self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
@ -386,29 +384,68 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return tool_calls return tool_calls
def organize_prompt_messages(self, prompt_template: str, def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
query: str = None,
tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
prompt_messages: list[PromptMessage] = None
) -> list[PromptMessage]:
""" """
Organize prompt messages Initialize system message
""" """
if not prompt_messages and prompt_template:
if not prompt_messages: return [
prompt_messages = [
SystemPromptMessage(content=prompt_template), SystemPromptMessage(content=prompt_template),
UserPromptMessage(content=query),
] ]
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
return prompt_messages
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
"""
Organize user query
"""
if self.files:
prompt_message_contents = [TextPromptMessageContent(data=query)]
for file_obj in self.files:
prompt_message_contents.append(file_obj.prompt_message_content)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else: else:
if tool_response: prompt_messages.append(UserPromptMessage(content=query))
prompt_messages = prompt_messages.copy()
prompt_messages.append( return prompt_messages
ToolPromptMessage(
content=tool_response, def _organize_assistant_message(self, tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
tool_call_id=tool_call_id, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
name=tool_call_name, """
) Organize assistant message
"""
prompt_messages = deepcopy(prompt_messages)
if tool_response is not None:
prompt_messages.append(
ToolPromptMessage(
content=tool_response,
tool_call_id=tool_call_id,
name=tool_call_name,
) )
)
return prompt_messages
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
As for now, gpt supports both fc and vision at the first iteration.
We need to remove the image messages from the prompt messages at the first iteration.
"""
prompt_messages = deepcopy(prompt_messages)
for prompt_message in prompt_messages:
if isinstance(prompt_message, UserPromptMessage):
if isinstance(prompt_message.content, list):
prompt_message.content = '\n'.join([
content.data if content.type == PromptMessageContentType.TEXT else
'[image]' if content.type == PromptMessageContentType.IMAGE else
'[file]'
for content in prompt_message.content
])
return prompt_messages return prompt_messages

View File

@ -210,6 +210,7 @@ class AgentChatAppRunner(AppRunner):
assistant_cot_runner = CotAgentRunner( assistant_cot_runner = CotAgentRunner(
tenant_id=app_config.tenant_id, tenant_id=app_config.tenant_id,
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
conversation=conversation,
app_config=app_config, app_config=app_config,
model_config=application_generate_entity.model_config, model_config=application_generate_entity.model_config,
config=agent_entity, config=agent_entity,
@ -223,7 +224,6 @@ class AgentChatAppRunner(AppRunner):
model_instance=model_instance model_instance=model_instance
) )
invoke_result = assistant_cot_runner.run( invoke_result = assistant_cot_runner.run(
conversation=conversation,
message=message, message=message,
query=query, query=query,
inputs=inputs, inputs=inputs,
@ -232,6 +232,7 @@ class AgentChatAppRunner(AppRunner):
assistant_fc_runner = FunctionCallAgentRunner( assistant_fc_runner = FunctionCallAgentRunner(
tenant_id=app_config.tenant_id, tenant_id=app_config.tenant_id,
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
conversation=conversation,
app_config=app_config, app_config=app_config,
model_config=application_generate_entity.model_config, model_config=application_generate_entity.model_config,
config=agent_entity, config=agent_entity,
@ -245,7 +246,6 @@ class AgentChatAppRunner(AppRunner):
model_instance=model_instance model_instance=model_instance
) )
invoke_result = assistant_fc_runner.run( invoke_result = assistant_fc_runner.run(
conversation=conversation,
message=message, message=message,
query=query, query=query,
) )

View File

@ -547,6 +547,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
if user: if user:
extra_model_kwargs['user'] = user extra_model_kwargs['user'] = user
# clear illegal prompt messages
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
# chat model # chat model
response = client.chat.completions.create( response = client.chat.completions.create(
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
@ -757,6 +760,31 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
return tool_call return tool_call
def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Clear illegal prompt messages for OpenAI API
:param model: model name
:param prompt_messages: prompt messages
:return: cleaned prompt messages
"""
checklist = ['gpt-4-turbo', 'gpt-4-turbo-2024-04-09']
if model in checklist:
# count how many user messages are there
user_message_count = len([m for m in prompt_messages if isinstance(m, UserPromptMessage)])
if user_message_count > 1:
for prompt_message in prompt_messages:
if isinstance(prompt_message, UserPromptMessage):
if isinstance(prompt_message.content, list):
prompt_message.content = '\n'.join([
item.data if item.type == PromptMessageContentType.TEXT else
'[IMAGE]' if item.type == PromptMessageContentType.IMAGE else ''
for item in prompt_message.content
])
return prompt_messages
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
""" """
Convert PromptMessage to dict for OpenAI API Convert PromptMessage to dict for OpenAI API

View File

@ -229,7 +229,7 @@ export const useChat = (
// answer // answer
const responseItem: ChatItem = { const responseItem: ChatItem = {
id: `${Date.now()}`, id: placeholderAnswerId,
content: '', content: '',
agent_thoughts: [], agent_thoughts: [],
message_files: [], message_files: [],