diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 8955ad2d1a..e9d208330b 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -5,6 +5,7 @@ from datetime import datetime from typing import Optional, Union, cast from core.agent.entities import AgentEntity, AgentToolEntity +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_runner import AppRunner @@ -14,6 +15,7 @@ from core.app.entities.app_invoke_entities import ( ) from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.file.message_file_parser import MessageFileParser from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMUsage @@ -22,6 +24,7 @@ from core.model_runtime.entities.message_entities import ( PromptMessage, PromptMessageTool, SystemPromptMessage, + TextPromptMessageContent, ToolPromptMessage, UserPromptMessage, ) @@ -37,7 +40,7 @@ from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool from core.tools.tool.tool import Tool from core.tools.tool_manager import ToolManager from extensions.ext_database import db -from models.model import Message, MessageAgentThought +from models.model import Conversation, Message, MessageAgentThought from models.tools import ToolConversationVariables logger = logging.getLogger(__name__) @@ -45,6 +48,7 @@ logger = logging.getLogger(__name__) class BaseAgentRunner(AppRunner): def __init__(self, tenant_id: str, application_generate_entity: AgentChatAppGenerateEntity, + conversation: Conversation, app_config: AgentChatAppConfig, model_config: ModelConfigWithCredentialsEntity, config: AgentEntity, @@ -72,6 +76,7 @@ class BaseAgentRunner(AppRunner): """ self.tenant_id = tenant_id self.application_generate_entity = application_generate_entity + self.conversation = conversation self.app_config = app_config self.model_config = model_config self.config = config @@ -118,6 +123,12 @@ class BaseAgentRunner(AppRunner): else: self.stream_tool_call = False + # check if model supports vision + if model_schema and ModelFeature.VISION in (model_schema.features or []): + self.files = application_generate_entity.files + else: + self.files = [] + def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \ -> AgentChatAppGenerateEntity: """ @@ -412,15 +423,19 @@ class BaseAgentRunner(AppRunner): """ result = [] # check if there is a system message in the beginning of the conversation - if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage): - result.append(prompt_messages[0]) + for prompt_message in prompt_messages: + if isinstance(prompt_message, SystemPromptMessage): + result.append(prompt_message) messages: list[Message] = db.session.query(Message).filter( Message.conversation_id == self.message.conversation_id, ).order_by(Message.created_at.asc()).all() for message in messages: - result.append(UserPromptMessage(content=message.query)) + if message.id == self.message.id: + continue + + result.append(self.organize_agent_user_prompt(message)) agent_thoughts: list[MessageAgentThought] = message.agent_thoughts if agent_thoughts: for agent_thought in agent_thoughts: @@ -471,3 +486,32 @@ class BaseAgentRunner(AppRunner): db.session.close() return result + + def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: + message_file_parser = MessageFileParser( + tenant_id=self.tenant_id, + app_id=self.app_config.app_id, + ) + + files = message.message_files + if files: + file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) + + if file_extra_config: + file_objs = message_file_parser.transform_message_files( + files, + file_extra_config + ) + else: + file_objs = [] + + if not file_objs: + return UserPromptMessage(content=message.query) + else: + prompt_message_contents = [TextPromptMessageContent(data=message.query)] + for file_obj in file_objs: + prompt_message_contents.append(file_obj.prompt_message_content) + + return UserPromptMessage(content=prompt_message_contents) + else: + return UserPromptMessage(content=message.query) diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 4ad5df5cfc..3b39bb1951 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -19,15 +19,14 @@ from core.model_runtime.entities.message_entities import ( from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool_engine import ToolEngine -from models.model import Conversation, Message +from models.model import Message class CotAgentRunner(BaseAgentRunner): _is_first_iteration = True _ignore_observation_providers = ['wenxin'] - def run(self, conversation: Conversation, - message: Message, + def run(self, message: Message, query: str, inputs: dict[str, str], ) -> Union[Generator, LLMResult]: diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 732a6ee750..ea5d31293d 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -1,6 +1,7 @@ import json import logging from collections.abc import Generator +from copy import deepcopy from typing import Any, Union from core.agent.base_agent_runner import BaseAgentRunner @@ -10,20 +11,21 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, + PromptMessageContentType, PromptMessageTool, SystemPromptMessage, + TextPromptMessageContent, ToolPromptMessage, UserPromptMessage, ) from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool_engine import ToolEngine -from models.model import Conversation, Message, MessageAgentThought +from models.model import Message logger = logging.getLogger(__name__) class FunctionCallAgentRunner(BaseAgentRunner): - def run(self, conversation: Conversation, - message: Message, + def run(self, message: Message, query: str, ) -> Generator[LLMResultChunk, None, None]: """ @@ -35,11 +37,8 @@ class FunctionCallAgentRunner(BaseAgentRunner): prompt_template = app_config.prompt_template.simple_prompt_template or '' prompt_messages = self.history_prompt_messages - prompt_messages = self.organize_prompt_messages( - prompt_template=prompt_template, - query=query, - prompt_messages=prompt_messages - ) + prompt_messages = self._init_system_message(prompt_template, prompt_messages) + prompt_messages = self._organize_user_query(query, prompt_messages) # convert tools into ModelRuntime Tool format prompt_messages_tools: list[PromptMessageTool] = [] @@ -68,7 +67,6 @@ class FunctionCallAgentRunner(BaseAgentRunner): # continue to run until there is not any tool call function_call_state = True - agent_thoughts: list[MessageAgentThought] = [] llm_usage = { 'usage': None } @@ -287,9 +285,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): } tool_responses.append(tool_response) - prompt_messages = self.organize_prompt_messages( - prompt_template=prompt_template, - query=None, + prompt_messages = self._organize_assistant_message( tool_call_id=tool_call_id, tool_call_name=tool_call_name, tool_response=tool_response['tool_response'], @@ -324,6 +320,8 @@ class FunctionCallAgentRunner(BaseAgentRunner): iteration_step += 1 + prompt_messages = self._clear_user_prompt_image_messages(prompt_messages) + self.update_db_variables(self.variables_pool, self.db_variables_pool) # publish end event self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult( @@ -386,29 +384,68 @@ class FunctionCallAgentRunner(BaseAgentRunner): return tool_calls - def organize_prompt_messages(self, prompt_template: str, - query: str = None, - tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None, - prompt_messages: list[PromptMessage] = None - ) -> list[PromptMessage]: + def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: """ - Organize prompt messages + Initialize system message """ - - if not prompt_messages: - prompt_messages = [ + if not prompt_messages and prompt_template: + return [ SystemPromptMessage(content=prompt_template), - UserPromptMessage(content=query), ] + + if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template: + prompt_messages.insert(0, SystemPromptMessage(content=prompt_template)) + + return prompt_messages + + def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: + """ + Organize user query + """ + if self.files: + prompt_message_contents = [TextPromptMessageContent(data=query)] + for file_obj in self.files: + prompt_message_contents.append(file_obj.prompt_message_content) + + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: - if tool_response: - prompt_messages = prompt_messages.copy() - prompt_messages.append( - ToolPromptMessage( - content=tool_response, - tool_call_id=tool_call_id, - name=tool_call_name, - ) + prompt_messages.append(UserPromptMessage(content=query)) + + return prompt_messages + + def _organize_assistant_message(self, tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None, + prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: + """ + Organize assistant message + """ + prompt_messages = deepcopy(prompt_messages) + + if tool_response is not None: + prompt_messages.append( + ToolPromptMessage( + content=tool_response, + tool_call_id=tool_call_id, + name=tool_call_name, ) + ) + + return prompt_messages + + def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: + """ + As for now, gpt supports both fc and vision at the first iteration. + We need to remove the image messages from the prompt messages at the first iteration. + """ + prompt_messages = deepcopy(prompt_messages) + + for prompt_message in prompt_messages: + if isinstance(prompt_message, UserPromptMessage): + if isinstance(prompt_message.content, list): + prompt_message.content = '\n'.join([ + content.data if content.type == PromptMessageContentType.TEXT else + '[image]' if content.type == PromptMessageContentType.IMAGE else + '[file]' + for content in prompt_message.content + ]) return prompt_messages \ No newline at end of file diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 0dc8a1e218..f42b146e51 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -210,6 +210,7 @@ class AgentChatAppRunner(AppRunner): assistant_cot_runner = CotAgentRunner( tenant_id=app_config.tenant_id, application_generate_entity=application_generate_entity, + conversation=conversation, app_config=app_config, model_config=application_generate_entity.model_config, config=agent_entity, @@ -223,7 +224,6 @@ class AgentChatAppRunner(AppRunner): model_instance=model_instance ) invoke_result = assistant_cot_runner.run( - conversation=conversation, message=message, query=query, inputs=inputs, @@ -232,6 +232,7 @@ class AgentChatAppRunner(AppRunner): assistant_fc_runner = FunctionCallAgentRunner( tenant_id=app_config.tenant_id, application_generate_entity=application_generate_entity, + conversation=conversation, app_config=app_config, model_config=application_generate_entity.model_config, config=agent_entity, @@ -245,7 +246,6 @@ class AgentChatAppRunner(AppRunner): model_instance=model_instance ) invoke_result = assistant_fc_runner.run( - conversation=conversation, message=message, query=query, ) diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index 46f17fe19b..b7db39376c 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -547,6 +547,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): if user: extra_model_kwargs['user'] = user + # clear illegal prompt messages + prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) + # chat model response = client.chat.completions.create( messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], @@ -757,6 +760,31 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): return tool_call + def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: + """ + Clear illegal prompt messages for OpenAI API + + :param model: model name + :param prompt_messages: prompt messages + :return: cleaned prompt messages + """ + checklist = ['gpt-4-turbo', 'gpt-4-turbo-2024-04-09'] + + if model in checklist: + # count how many user messages are there + user_message_count = len([m for m in prompt_messages if isinstance(m, UserPromptMessage)]) + if user_message_count > 1: + for prompt_message in prompt_messages: + if isinstance(prompt_message, UserPromptMessage): + if isinstance(prompt_message.content, list): + prompt_message.content = '\n'.join([ + item.data if item.type == PromptMessageContentType.TEXT else + '[IMAGE]' if item.type == PromptMessageContentType.IMAGE else '' + for item in prompt_message.content + ]) + + return prompt_messages + def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: """ Convert PromptMessage to dict for OpenAI API diff --git a/web/app/components/base/chat/chat/hooks.ts b/web/app/components/base/chat/chat/hooks.ts index 8aeaf1463a..4cdb6e8e38 100644 --- a/web/app/components/base/chat/chat/hooks.ts +++ b/web/app/components/base/chat/chat/hooks.ts @@ -229,7 +229,7 @@ export const useChat = ( // answer const responseItem: ChatItem = { - id: `${Date.now()}`, + id: placeholderAnswerId, content: '', agent_thoughts: [], message_files: [],