fix: Update prompt message content types to use Literal and add union type for content (#17136)

Co-authored-by: 朱庆超 <zhuqingchao@xiaomi.com>
Co-authored-by: crazywoola <427733928@qq.com>
This commit is contained in:
ZalterCitty 2025-04-22 16:17:55 +08:00 committed by GitHub
parent 404f8a790c
commit a1158cc946
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 73 additions and 39 deletions

View File

@ -21,14 +21,13 @@ from core.model_runtime.entities import (
AssistantPromptMessage,
LLMUsage,
PromptMessage,
PromptMessageContent,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.utils.extract_thread_messages import extract_thread_messages
@ -501,7 +500,7 @@ class BaseAgentRunner(AppRunner):
)
if not file_objs:
return UserPromptMessage(content=message.query)
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
for file in file_objs:
prompt_message_contents.append(

View File

@ -5,12 +5,11 @@ from core.file import file_manager
from core.model_runtime.entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageContent,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.model_runtime.utils.encoders import jsonable_encoder
@ -40,7 +39,7 @@ class CotChatAgentRunner(CotAgentRunner):
Organize user query
"""
if self.files:
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
prompt_message_contents.append(TextPromptMessageContent(data=query))
# get image detail config

View File

@ -15,14 +15,13 @@ from core.model_runtime.entities import (
LLMResultChunkDelta,
LLMUsage,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
@ -395,7 +394,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
Organize user query
"""
if self.files:
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
prompt_message_contents.append(TextPromptMessageContent(data=query))
# get image detail config

View File

@ -7,9 +7,9 @@ from core.model_runtime.entities import (
AudioPromptMessageContent,
DocumentPromptMessageContent,
ImagePromptMessageContent,
MultiModalPromptMessageContent,
VideoPromptMessageContent,
)
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
from extensions.ext_storage import storage
from . import helpers
@ -43,7 +43,7 @@ def to_prompt_message_content(
/,
*,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> MultiModalPromptMessageContent:
) -> PromptMessageContentUnionTypes:
if f.extension is None:
raise ValueError("Missing file extension")
if f.mime_type is None:
@ -58,7 +58,7 @@ def to_prompt_message_content(
if f.type == FileType.IMAGE:
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
prompt_class_map: Mapping[FileType, type[MultiModalPromptMessageContent]] = {
prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = {
FileType.IMAGE: ImagePromptMessageContent,
FileType.AUDIO: AudioPromptMessageContent,
FileType.VIDEO: VideoPromptMessageContent,

View File

@ -8,11 +8,11 @@ from core.model_runtime.entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageRole,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from extensions.ext_database import db
from factories import file_factory
@ -100,7 +100,7 @@ class TokenBufferMemory:
if not file_objs:
prompt_messages.append(UserPromptMessage(content=message.query))
else:
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
for file in file_objs:
prompt_message = file_manager.to_prompt_message_content(

View File

@ -1,6 +1,6 @@
from collections.abc import Sequence
from enum import Enum, StrEnum
from typing import Any, Optional, Union
from typing import Annotated, Any, Literal, Optional, Union
from pydantic import BaseModel, Field, field_serializer, field_validator
@ -61,11 +61,7 @@ class PromptMessageContentType(StrEnum):
class PromptMessageContent(BaseModel):
"""
Model class for prompt message content.
"""
type: PromptMessageContentType
pass
class TextPromptMessageContent(PromptMessageContent):
@ -73,7 +69,7 @@ class TextPromptMessageContent(PromptMessageContent):
Model class for text prompt message content.
"""
type: PromptMessageContentType = PromptMessageContentType.TEXT
type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT
data: str
@ -82,7 +78,6 @@ class MultiModalPromptMessageContent(PromptMessageContent):
Model class for multi-modal prompt message content.
"""
type: PromptMessageContentType
format: str = Field(default=..., description="the format of multi-modal file")
base64_data: str = Field(default="", description="the base64 data of multi-modal file")
url: str = Field(default="", description="the url of multi-modal file")
@ -94,11 +89,11 @@ class MultiModalPromptMessageContent(PromptMessageContent):
class VideoPromptMessageContent(MultiModalPromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.VIDEO
type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO
class AudioPromptMessageContent(MultiModalPromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.AUDIO
type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO
class ImagePromptMessageContent(MultiModalPromptMessageContent):
@ -110,12 +105,24 @@ class ImagePromptMessageContent(MultiModalPromptMessageContent):
LOW = "low"
HIGH = "high"
type: PromptMessageContentType = PromptMessageContentType.IMAGE
type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE
detail: DETAIL = DETAIL.LOW
class DocumentPromptMessageContent(MultiModalPromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT
PromptMessageContentUnionTypes = Annotated[
Union[
TextPromptMessageContent,
ImagePromptMessageContent,
DocumentPromptMessageContent,
AudioPromptMessageContent,
VideoPromptMessageContent,
],
Field(discriminator="type"),
]
class PromptMessage(BaseModel):
@ -124,7 +131,7 @@ class PromptMessage(BaseModel):
"""
role: PromptMessageRole
content: Optional[str | Sequence[PromptMessageContent]] = None
content: Optional[str | list[PromptMessageContentUnionTypes]] = None
name: Optional[str] = None
def is_empty(self) -> bool:

View File

@ -9,13 +9,12 @@ from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageContent,
PromptMessageRole,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.prompt.prompt_transform import PromptTransform
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
@ -125,7 +124,7 @@ class AdvancedPromptTransform(PromptTransform):
prompt = Jinja2Formatter.format(prompt, prompt_inputs)
if files:
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
for file in files:
prompt_message_contents.append(
@ -201,7 +200,7 @@ class AdvancedPromptTransform(PromptTransform):
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
if files and query is not None:
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
prompt_message_contents.append(TextPromptMessageContent(data=query))
for file in files:
prompt_message_contents.append(

View File

@ -11,7 +11,7 @@ from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageContentUnionTypes,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
@ -277,7 +277,7 @@ class SimplePromptTransform(PromptTransform):
image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
) -> UserPromptMessage:
if files:
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
for file in files:
prompt_message_contents.append(

View File

@ -24,7 +24,7 @@ from core.model_runtime.entities import (
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageContent,
PromptMessageContentUnionTypes,
PromptMessageRole,
SystemPromptMessage,
UserPromptMessage,
@ -594,8 +594,7 @@ class LLMNode(BaseNode[LLMNodeData]):
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
# FIXME: fix the type error cause prompt_messages is type quick a few times
prompt_messages: list[Any] = []
prompt_messages: list[PromptMessage] = []
if isinstance(prompt_template, list):
# For chat model
@ -657,12 +656,14 @@ class LLMNode(BaseNode[LLMNodeData]):
# For issue #11247 - Check if prompt content is a string or a list
prompt_content_type = type(prompt_content)
if prompt_content_type == str:
prompt_content = str(prompt_content)
if "#histories#" in prompt_content:
prompt_content = prompt_content.replace("#histories#", memory_text)
else:
prompt_content = memory_text + "\n" + prompt_content
prompt_messages[0].content = prompt_content
elif prompt_content_type == list:
prompt_content = prompt_content if isinstance(prompt_content, list) else []
for content_item in prompt_content:
if content_item.type == PromptMessageContentType.TEXT:
if "#histories#" in content_item.data:
@ -675,9 +676,10 @@ class LLMNode(BaseNode[LLMNodeData]):
# Add current query to the prompt message
if sys_query:
if prompt_content_type == str:
prompt_content = prompt_messages[0].content.replace("#sys.query#", sys_query)
prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
prompt_messages[0].content = prompt_content
elif prompt_content_type == list:
prompt_content = prompt_content if isinstance(prompt_content, list) else []
for content_item in prompt_content:
if content_item.type == PromptMessageContentType.TEXT:
content_item.data = sys_query + "\n" + content_item.data
@ -707,7 +709,7 @@ class LLMNode(BaseNode[LLMNodeData]):
filtered_prompt_messages = []
for prompt_message in prompt_messages:
if isinstance(prompt_message.content, list):
prompt_message_content = []
prompt_message_content: list[PromptMessageContentUnionTypes] = []
for content_item in prompt_message.content:
# Skip content if features are not defined
if not model_config.model_schema.features:
@ -1132,7 +1134,9 @@ class LLMNode(BaseNode[LLMNodeData]):
)
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
def _combine_message_content_with_role(
*, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole
):
match role:
case PromptMessageRole.USER:
return UserPromptMessage(content=contents)

View File

@ -0,0 +1,27 @@
from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
TextPromptMessageContent,
UserPromptMessage,
)
def test_build_prompt_message_with_prompt_message_contents():
prompt = UserPromptMessage(content=[TextPromptMessageContent(data="Hello, World!")])
assert isinstance(prompt.content, list)
assert isinstance(prompt.content[0], TextPromptMessageContent)
assert prompt.content[0].data == "Hello, World!"
def test_dump_prompt_message():
example_url = "https://example.com/image.jpg"
prompt = UserPromptMessage(
content=[
ImagePromptMessageContent(
url=example_url,
format="jpeg",
mime_type="image/jpeg",
)
]
)
data = prompt.model_dump()
assert data["content"][0].get("url") == example_url