mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 19:39:02 +08:00
Feat/Agent-Image-Processing (#3293)
Co-authored-by: Joel <iamjoel007@gmail.com>
This commit is contained in:
parent
240c793e7a
commit
14bb0b02ac
@ -5,6 +5,7 @@ from datetime import datetime
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
@ -14,6 +15,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
@ -22,6 +24,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
@ -37,7 +40,7 @@ from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message, MessageAgentThought
|
||||
from models.model import Conversation, Message, MessageAgentThought
|
||||
from models.tools import ToolConversationVariables
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -45,6 +48,7 @@ logger = logging.getLogger(__name__)
|
||||
class BaseAgentRunner(AppRunner):
|
||||
def __init__(self, tenant_id: str,
|
||||
application_generate_entity: AgentChatAppGenerateEntity,
|
||||
conversation: Conversation,
|
||||
app_config: AgentChatAppConfig,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
config: AgentEntity,
|
||||
@ -72,6 +76,7 @@ class BaseAgentRunner(AppRunner):
|
||||
"""
|
||||
self.tenant_id = tenant_id
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.conversation = conversation
|
||||
self.app_config = app_config
|
||||
self.model_config = model_config
|
||||
self.config = config
|
||||
@ -118,6 +123,12 @@ class BaseAgentRunner(AppRunner):
|
||||
else:
|
||||
self.stream_tool_call = False
|
||||
|
||||
# check if model supports vision
|
||||
if model_schema and ModelFeature.VISION in (model_schema.features or []):
|
||||
self.files = application_generate_entity.files
|
||||
else:
|
||||
self.files = []
|
||||
|
||||
def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \
|
||||
-> AgentChatAppGenerateEntity:
|
||||
"""
|
||||
@ -412,15 +423,19 @@ class BaseAgentRunner(AppRunner):
|
||||
"""
|
||||
result = []
|
||||
# check if there is a system message in the beginning of the conversation
|
||||
if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
result.append(prompt_messages[0])
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, SystemPromptMessage):
|
||||
result.append(prompt_message)
|
||||
|
||||
messages: list[Message] = db.session.query(Message).filter(
|
||||
Message.conversation_id == self.message.conversation_id,
|
||||
).order_by(Message.created_at.asc()).all()
|
||||
|
||||
for message in messages:
|
||||
result.append(UserPromptMessage(content=message.query))
|
||||
if message.id == self.message.id:
|
||||
continue
|
||||
|
||||
result.append(self.organize_agent_user_prompt(message))
|
||||
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
|
||||
if agent_thoughts:
|
||||
for agent_thought in agent_thoughts:
|
||||
@ -471,3 +486,32 @@ class BaseAgentRunner(AppRunner):
|
||||
db.session.close()
|
||||
|
||||
return result
|
||||
|
||||
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
|
||||
message_file_parser = MessageFileParser(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_config.app_id,
|
||||
)
|
||||
|
||||
files = message.message_files
|
||||
if files:
|
||||
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
|
||||
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.transform_message_files(
|
||||
files,
|
||||
file_extra_config
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
if not file_objs:
|
||||
return UserPromptMessage(content=message.query)
|
||||
else:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=message.query)]
|
||||
for file_obj in file_objs:
|
||||
prompt_message_contents.append(file_obj.prompt_message_content)
|
||||
|
||||
return UserPromptMessage(content=prompt_message_contents)
|
||||
else:
|
||||
return UserPromptMessage(content=message.query)
|
||||
|
@ -19,15 +19,14 @@ from core.model_runtime.entities.message_entities import (
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from models.model import Conversation, Message
|
||||
from models.model import Message
|
||||
|
||||
|
||||
class CotAgentRunner(BaseAgentRunner):
|
||||
_is_first_iteration = True
|
||||
_ignore_observation_providers = ['wenxin']
|
||||
|
||||
def run(self, conversation: Conversation,
|
||||
message: Message,
|
||||
def run(self, message: Message,
|
||||
query: str,
|
||||
inputs: dict[str, str],
|
||||
) -> Union[Generator, LLMResult]:
|
||||
|
@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from typing import Any, Union
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
@ -10,20 +11,21 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from models.model import Conversation, Message, MessageAgentThought
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
def run(self, conversation: Conversation,
|
||||
message: Message,
|
||||
def run(self, message: Message,
|
||||
query: str,
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
@ -35,11 +37,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
prompt_template = app_config.prompt_template.simple_prompt_template or ''
|
||||
prompt_messages = self.history_prompt_messages
|
||||
prompt_messages = self.organize_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
query=query,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
prompt_messages = self._init_system_message(prompt_template, prompt_messages)
|
||||
prompt_messages = self._organize_user_query(query, prompt_messages)
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
prompt_messages_tools: list[PromptMessageTool] = []
|
||||
@ -68,7 +67,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = True
|
||||
agent_thoughts: list[MessageAgentThought] = []
|
||||
llm_usage = {
|
||||
'usage': None
|
||||
}
|
||||
@ -287,9 +285,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
}
|
||||
|
||||
tool_responses.append(tool_response)
|
||||
prompt_messages = self.organize_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
query=None,
|
||||
prompt_messages = self._organize_assistant_message(
|
||||
tool_call_id=tool_call_id,
|
||||
tool_call_name=tool_call_name,
|
||||
tool_response=tool_response['tool_response'],
|
||||
@ -324,6 +320,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
iteration_step += 1
|
||||
|
||||
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
|
||||
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
|
||||
@ -386,23 +384,43 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
return tool_calls
|
||||
|
||||
def organize_prompt_messages(self, prompt_template: str,
|
||||
query: str = None,
|
||||
tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
|
||||
prompt_messages: list[PromptMessage] = None
|
||||
) -> list[PromptMessage]:
|
||||
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize prompt messages
|
||||
Initialize system message
|
||||
"""
|
||||
|
||||
if not prompt_messages:
|
||||
prompt_messages = [
|
||||
if not prompt_messages and prompt_template:
|
||||
return [
|
||||
SystemPromptMessage(content=prompt_template),
|
||||
UserPromptMessage(content=query),
|
||||
]
|
||||
|
||||
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
|
||||
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
if self.files:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=query)]
|
||||
for file_obj in self.files:
|
||||
prompt_message_contents.append(file_obj.prompt_message_content)
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
if tool_response:
|
||||
prompt_messages = prompt_messages.copy()
|
||||
prompt_messages.append(UserPromptMessage(content=query))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_assistant_message(self, tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
|
||||
prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize assistant message
|
||||
"""
|
||||
prompt_messages = deepcopy(prompt_messages)
|
||||
|
||||
if tool_response is not None:
|
||||
prompt_messages.append(
|
||||
ToolPromptMessage(
|
||||
content=tool_response,
|
||||
@ -412,3 +430,22 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
As for now, gpt supports both fc and vision at the first iteration.
|
||||
We need to remove the image messages from the prompt messages at the first iteration.
|
||||
"""
|
||||
prompt_messages = deepcopy(prompt_messages)
|
||||
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message.content = '\n'.join([
|
||||
content.data if content.type == PromptMessageContentType.TEXT else
|
||||
'[image]' if content.type == PromptMessageContentType.IMAGE else
|
||||
'[file]'
|
||||
for content in prompt_message.content
|
||||
])
|
||||
|
||||
return prompt_messages
|
@ -210,6 +210,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
assistant_cot_runner = CotAgentRunner(
|
||||
tenant_id=app_config.tenant_id,
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation,
|
||||
app_config=app_config,
|
||||
model_config=application_generate_entity.model_config,
|
||||
config=agent_entity,
|
||||
@ -223,7 +224,6 @@ class AgentChatAppRunner(AppRunner):
|
||||
model_instance=model_instance
|
||||
)
|
||||
invoke_result = assistant_cot_runner.run(
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
@ -232,6 +232,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
assistant_fc_runner = FunctionCallAgentRunner(
|
||||
tenant_id=app_config.tenant_id,
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation,
|
||||
app_config=app_config,
|
||||
model_config=application_generate_entity.model_config,
|
||||
config=agent_entity,
|
||||
@ -245,7 +246,6 @@ class AgentChatAppRunner(AppRunner):
|
||||
model_instance=model_instance
|
||||
)
|
||||
invoke_result = assistant_fc_runner.run(
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
query=query,
|
||||
)
|
||||
|
@ -547,6 +547,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
if user:
|
||||
extra_model_kwargs['user'] = user
|
||||
|
||||
# clear illegal prompt messages
|
||||
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
|
||||
|
||||
# chat model
|
||||
response = client.chat.completions.create(
|
||||
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
|
||||
@ -757,6 +760,31 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
|
||||
return tool_call
|
||||
|
||||
def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Clear illegal prompt messages for OpenAI API
|
||||
|
||||
:param model: model name
|
||||
:param prompt_messages: prompt messages
|
||||
:return: cleaned prompt messages
|
||||
"""
|
||||
checklist = ['gpt-4-turbo', 'gpt-4-turbo-2024-04-09']
|
||||
|
||||
if model in checklist:
|
||||
# count how many user messages are there
|
||||
user_message_count = len([m for m in prompt_messages if isinstance(m, UserPromptMessage)])
|
||||
if user_message_count > 1:
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message.content = '\n'.join([
|
||||
item.data if item.type == PromptMessageContentType.TEXT else
|
||||
'[IMAGE]' if item.type == PromptMessageContentType.IMAGE else ''
|
||||
for item in prompt_message.content
|
||||
])
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
||||
"""
|
||||
Convert PromptMessage to dict for OpenAI API
|
||||
|
@ -229,7 +229,7 @@ export const useChat = (
|
||||
|
||||
// answer
|
||||
const responseItem: ChatItem = {
|
||||
id: `${Date.now()}`,
|
||||
id: placeholderAnswerId,
|
||||
content: '',
|
||||
agent_thoughts: [],
|
||||
message_files: [],
|
||||
|
Loading…
x
Reference in New Issue
Block a user