From a1158cc946c92ca83700411e2d995173eabcefc1 Mon Sep 17 00:00:00 2001 From: ZalterCitty Date: Tue, 22 Apr 2025 16:17:55 +0800 Subject: [PATCH] fix: Update prompt message content types to use Literal and add union type for content (#17136) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 朱庆超 Co-authored-by: crazywoola <427733928@qq.com> --- api/core/agent/base_agent_runner.py | 5 ++- api/core/agent/cot_chat_agent_runner.py | 5 ++- api/core/agent/fc_agent_runner.py | 5 ++- api/core/file/file_manager.py | 6 ++-- api/core/memory/token_buffer_memory.py | 4 +-- .../entities/message_entities.py | 33 +++++++++++-------- api/core/prompt/advanced_prompt_transform.py | 7 ++-- api/core/prompt/simple_prompt_transform.py | 4 +-- api/core/workflow/nodes/llm/node.py | 16 +++++---- .../core/prompt/test_prompt_message.py | 27 +++++++++++++++ 10 files changed, 73 insertions(+), 39 deletions(-) create mode 100644 api/tests/unit_tests/core/prompt/test_prompt_message.py diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 48c92ea2db..e648613605 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -21,14 +21,13 @@ from core.model_runtime.entities import ( AssistantPromptMessage, LLMUsage, PromptMessage, - PromptMessageContent, PromptMessageTool, SystemPromptMessage, TextPromptMessageContent, ToolPromptMessage, UserPromptMessage, ) -from core.model_runtime.entities.message_entities import ImagePromptMessageContent +from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.utils.extract_thread_messages import extract_thread_messages @@ -501,7 +500,7 @@ class BaseAgentRunner(AppRunner): ) if not file_objs: return UserPromptMessage(content=message.query) - prompt_message_contents: list[PromptMessageContent] = [] + prompt_message_contents: list[PromptMessageContentUnionTypes] = [] prompt_message_contents.append(TextPromptMessageContent(data=message.query)) for file in file_objs: prompt_message_contents.append( diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index 7d407a4976..5ff89bdacb 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -5,12 +5,11 @@ from core.file import file_manager from core.model_runtime.entities import ( AssistantPromptMessage, PromptMessage, - PromptMessageContent, SystemPromptMessage, TextPromptMessageContent, UserPromptMessage, ) -from core.model_runtime.entities.message_entities import ImagePromptMessageContent +from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes from core.model_runtime.utils.encoders import jsonable_encoder @@ -40,7 +39,7 @@ class CotChatAgentRunner(CotAgentRunner): Organize user query """ if self.files: - prompt_message_contents: list[PromptMessageContent] = [] + prompt_message_contents: list[PromptMessageContentUnionTypes] = [] prompt_message_contents.append(TextPromptMessageContent(data=query)) # get image detail config diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index f45fa5c66e..a1110e7709 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -15,14 +15,13 @@ from core.model_runtime.entities import ( LLMResultChunkDelta, LLMUsage, PromptMessage, - PromptMessageContent, PromptMessageContentType, SystemPromptMessage, TextPromptMessageContent, ToolPromptMessage, UserPromptMessage, ) -from core.model_runtime.entities.message_entities import ImagePromptMessageContent +from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool_engine import ToolEngine @@ -395,7 +394,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): Organize user query """ if self.files: - prompt_message_contents: list[PromptMessageContent] = [] + prompt_message_contents: list[PromptMessageContentUnionTypes] = [] prompt_message_contents.append(TextPromptMessageContent(data=query)) # get image detail config diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index 4ebe997ac5..9a204e9ff6 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -7,9 +7,9 @@ from core.model_runtime.entities import ( AudioPromptMessageContent, DocumentPromptMessageContent, ImagePromptMessageContent, - MultiModalPromptMessageContent, VideoPromptMessageContent, ) +from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes from extensions.ext_storage import storage from . import helpers @@ -43,7 +43,7 @@ def to_prompt_message_content( /, *, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, -) -> MultiModalPromptMessageContent: +) -> PromptMessageContentUnionTypes: if f.extension is None: raise ValueError("Missing file extension") if f.mime_type is None: @@ -58,7 +58,7 @@ def to_prompt_message_content( if f.type == FileType.IMAGE: params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW - prompt_class_map: Mapping[FileType, type[MultiModalPromptMessageContent]] = { + prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = { FileType.IMAGE: ImagePromptMessageContent, FileType.AUDIO: AudioPromptMessageContent, FileType.VIDEO: VideoPromptMessageContent, diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 3c90dd22a2..2254b3d4d5 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -8,11 +8,11 @@ from core.model_runtime.entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessage, - PromptMessageContent, PromptMessageRole, TextPromptMessageContent, UserPromptMessage, ) +from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes from core.prompt.utils.extract_thread_messages import extract_thread_messages from extensions.ext_database import db from factories import file_factory @@ -100,7 +100,7 @@ class TokenBufferMemory: if not file_objs: prompt_messages.append(UserPromptMessage(content=message.query)) else: - prompt_message_contents: list[PromptMessageContent] = [] + prompt_message_contents: list[PromptMessageContentUnionTypes] = [] prompt_message_contents.append(TextPromptMessageContent(data=message.query)) for file in file_objs: prompt_message = file_manager.to_prompt_message_content( diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index 3bed2460dd..b1c43d1455 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -1,6 +1,6 @@ from collections.abc import Sequence from enum import Enum, StrEnum -from typing import Any, Optional, Union +from typing import Annotated, Any, Literal, Optional, Union from pydantic import BaseModel, Field, field_serializer, field_validator @@ -61,11 +61,7 @@ class PromptMessageContentType(StrEnum): class PromptMessageContent(BaseModel): - """ - Model class for prompt message content. - """ - - type: PromptMessageContentType + pass class TextPromptMessageContent(PromptMessageContent): @@ -73,7 +69,7 @@ class TextPromptMessageContent(PromptMessageContent): Model class for text prompt message content. """ - type: PromptMessageContentType = PromptMessageContentType.TEXT + type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT data: str @@ -82,7 +78,6 @@ class MultiModalPromptMessageContent(PromptMessageContent): Model class for multi-modal prompt message content. """ - type: PromptMessageContentType format: str = Field(default=..., description="the format of multi-modal file") base64_data: str = Field(default="", description="the base64 data of multi-modal file") url: str = Field(default="", description="the url of multi-modal file") @@ -94,11 +89,11 @@ class MultiModalPromptMessageContent(PromptMessageContent): class VideoPromptMessageContent(MultiModalPromptMessageContent): - type: PromptMessageContentType = PromptMessageContentType.VIDEO + type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO class AudioPromptMessageContent(MultiModalPromptMessageContent): - type: PromptMessageContentType = PromptMessageContentType.AUDIO + type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO class ImagePromptMessageContent(MultiModalPromptMessageContent): @@ -110,12 +105,24 @@ class ImagePromptMessageContent(MultiModalPromptMessageContent): LOW = "low" HIGH = "high" - type: PromptMessageContentType = PromptMessageContentType.IMAGE + type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE detail: DETAIL = DETAIL.LOW class DocumentPromptMessageContent(MultiModalPromptMessageContent): - type: PromptMessageContentType = PromptMessageContentType.DOCUMENT + type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT + + +PromptMessageContentUnionTypes = Annotated[ + Union[ + TextPromptMessageContent, + ImagePromptMessageContent, + DocumentPromptMessageContent, + AudioPromptMessageContent, + VideoPromptMessageContent, + ], + Field(discriminator="type"), +] class PromptMessage(BaseModel): @@ -124,7 +131,7 @@ class PromptMessage(BaseModel): """ role: PromptMessageRole - content: Optional[str | Sequence[PromptMessageContent]] = None + content: Optional[str | list[PromptMessageContentUnionTypes]] = None name: Optional[str] = None def is_empty(self) -> bool: diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index c7427f797e..25964ae063 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -9,13 +9,12 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities import ( AssistantPromptMessage, PromptMessage, - PromptMessageContent, PromptMessageRole, SystemPromptMessage, TextPromptMessageContent, UserPromptMessage, ) -from core.model_runtime.entities.message_entities import ImagePromptMessageContent +from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.prompt_transform import PromptTransform from core.prompt.utils.prompt_template_parser import PromptTemplateParser @@ -125,7 +124,7 @@ class AdvancedPromptTransform(PromptTransform): prompt = Jinja2Formatter.format(prompt, prompt_inputs) if files: - prompt_message_contents: list[PromptMessageContent] = [] + prompt_message_contents: list[PromptMessageContentUnionTypes] = [] prompt_message_contents.append(TextPromptMessageContent(data=prompt)) for file in files: prompt_message_contents.append( @@ -201,7 +200,7 @@ class AdvancedPromptTransform(PromptTransform): prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config) if files and query is not None: - prompt_message_contents: list[PromptMessageContent] = [] + prompt_message_contents: list[PromptMessageContentUnionTypes] = [] prompt_message_contents.append(TextPromptMessageContent(data=query)) for file in files: prompt_message_contents.append( diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index ad56d84cb6..47808928f7 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -11,7 +11,7 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import ( ImagePromptMessageContent, PromptMessage, - PromptMessageContent, + PromptMessageContentUnionTypes, SystemPromptMessage, TextPromptMessageContent, UserPromptMessage, @@ -277,7 +277,7 @@ class SimplePromptTransform(PromptTransform): image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, ) -> UserPromptMessage: if files: - prompt_message_contents: list[PromptMessageContent] = [] + prompt_message_contents: list[PromptMessageContentUnionTypes] = [] prompt_message_contents.append(TextPromptMessageContent(data=prompt)) for file in files: prompt_message_contents.append( diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 8db7394e54..1089e7168e 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -24,7 +24,7 @@ from core.model_runtime.entities import ( from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, - PromptMessageContent, + PromptMessageContentUnionTypes, PromptMessageRole, SystemPromptMessage, UserPromptMessage, @@ -594,8 +594,7 @@ class LLMNode(BaseNode[LLMNodeData]): variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]: - # FIXME: fix the type error cause prompt_messages is type quick a few times - prompt_messages: list[Any] = [] + prompt_messages: list[PromptMessage] = [] if isinstance(prompt_template, list): # For chat model @@ -657,12 +656,14 @@ class LLMNode(BaseNode[LLMNodeData]): # For issue #11247 - Check if prompt content is a string or a list prompt_content_type = type(prompt_content) if prompt_content_type == str: + prompt_content = str(prompt_content) if "#histories#" in prompt_content: prompt_content = prompt_content.replace("#histories#", memory_text) else: prompt_content = memory_text + "\n" + prompt_content prompt_messages[0].content = prompt_content elif prompt_content_type == list: + prompt_content = prompt_content if isinstance(prompt_content, list) else [] for content_item in prompt_content: if content_item.type == PromptMessageContentType.TEXT: if "#histories#" in content_item.data: @@ -675,9 +676,10 @@ class LLMNode(BaseNode[LLMNodeData]): # Add current query to the prompt message if sys_query: if prompt_content_type == str: - prompt_content = prompt_messages[0].content.replace("#sys.query#", sys_query) + prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query) prompt_messages[0].content = prompt_content elif prompt_content_type == list: + prompt_content = prompt_content if isinstance(prompt_content, list) else [] for content_item in prompt_content: if content_item.type == PromptMessageContentType.TEXT: content_item.data = sys_query + "\n" + content_item.data @@ -707,7 +709,7 @@ class LLMNode(BaseNode[LLMNodeData]): filtered_prompt_messages = [] for prompt_message in prompt_messages: if isinstance(prompt_message.content, list): - prompt_message_content = [] + prompt_message_content: list[PromptMessageContentUnionTypes] = [] for content_item in prompt_message.content: # Skip content if features are not defined if not model_config.model_schema.features: @@ -1132,7 +1134,9 @@ class LLMNode(BaseNode[LLMNodeData]): ) -def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole): +def _combine_message_content_with_role( + *, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole +): match role: case PromptMessageRole.USER: return UserPromptMessage(content=contents) diff --git a/api/tests/unit_tests/core/prompt/test_prompt_message.py b/api/tests/unit_tests/core/prompt/test_prompt_message.py new file mode 100644 index 0000000000..e5da51d733 --- /dev/null +++ b/api/tests/unit_tests/core/prompt/test_prompt_message.py @@ -0,0 +1,27 @@ +from core.model_runtime.entities.message_entities import ( + ImagePromptMessageContent, + TextPromptMessageContent, + UserPromptMessage, +) + + +def test_build_prompt_message_with_prompt_message_contents(): + prompt = UserPromptMessage(content=[TextPromptMessageContent(data="Hello, World!")]) + assert isinstance(prompt.content, list) + assert isinstance(prompt.content[0], TextPromptMessageContent) + assert prompt.content[0].data == "Hello, World!" + + +def test_dump_prompt_message(): + example_url = "https://example.com/image.jpg" + prompt = UserPromptMessage( + content=[ + ImagePromptMessageContent( + url=example_url, + format="jpeg", + mime_type="image/jpeg", + ) + ] + ) + data = prompt.model_dump() + assert data["content"][0].get("url") == example_url