mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 06:19:04 +08:00
fix: Update prompt message content types to use Literal and add union type for content (#17136)
Co-authored-by: 朱庆超 <zhuqingchao@xiaomi.com> Co-authored-by: crazywoola <427733928@qq.com>
This commit is contained in:
parent
404f8a790c
commit
a1158cc946
@ -21,14 +21,13 @@ from core.model_runtime.entities import (
|
|||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
LLMUsage,
|
LLMUsage,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
PromptMessageContent,
|
|
||||||
PromptMessageTool,
|
PromptMessageTool,
|
||||||
SystemPromptMessage,
|
SystemPromptMessage,
|
||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
ToolPromptMessage,
|
ToolPromptMessage,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||||
from core.model_runtime.entities.model_entities import ModelFeature
|
from core.model_runtime.entities.model_entities import ModelFeature
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||||
@ -501,7 +500,7 @@ class BaseAgentRunner(AppRunner):
|
|||||||
)
|
)
|
||||||
if not file_objs:
|
if not file_objs:
|
||||||
return UserPromptMessage(content=message.query)
|
return UserPromptMessage(content=message.query)
|
||||||
prompt_message_contents: list[PromptMessageContent] = []
|
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||||
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
||||||
for file in file_objs:
|
for file in file_objs:
|
||||||
prompt_message_contents.append(
|
prompt_message_contents.append(
|
||||||
|
@ -5,12 +5,11 @@ from core.file import file_manager
|
|||||||
from core.model_runtime.entities import (
|
from core.model_runtime.entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
PromptMessageContent,
|
|
||||||
SystemPromptMessage,
|
SystemPromptMessage,
|
||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
|
|
||||||
|
|
||||||
@ -40,7 +39,7 @@ class CotChatAgentRunner(CotAgentRunner):
|
|||||||
Organize user query
|
Organize user query
|
||||||
"""
|
"""
|
||||||
if self.files:
|
if self.files:
|
||||||
prompt_message_contents: list[PromptMessageContent] = []
|
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||||
|
|
||||||
# get image detail config
|
# get image detail config
|
||||||
|
@ -15,14 +15,13 @@ from core.model_runtime.entities import (
|
|||||||
LLMResultChunkDelta,
|
LLMResultChunkDelta,
|
||||||
LLMUsage,
|
LLMUsage,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
PromptMessageContent,
|
|
||||||
PromptMessageContentType,
|
PromptMessageContentType,
|
||||||
SystemPromptMessage,
|
SystemPromptMessage,
|
||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
ToolPromptMessage,
|
ToolPromptMessage,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||||
from core.tools.tool_engine import ToolEngine
|
from core.tools.tool_engine import ToolEngine
|
||||||
@ -395,7 +394,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||||||
Organize user query
|
Organize user query
|
||||||
"""
|
"""
|
||||||
if self.files:
|
if self.files:
|
||||||
prompt_message_contents: list[PromptMessageContent] = []
|
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||||
|
|
||||||
# get image detail config
|
# get image detail config
|
||||||
|
@ -7,9 +7,9 @@ from core.model_runtime.entities import (
|
|||||||
AudioPromptMessageContent,
|
AudioPromptMessageContent,
|
||||||
DocumentPromptMessageContent,
|
DocumentPromptMessageContent,
|
||||||
ImagePromptMessageContent,
|
ImagePromptMessageContent,
|
||||||
MultiModalPromptMessageContent,
|
|
||||||
VideoPromptMessageContent,
|
VideoPromptMessageContent,
|
||||||
)
|
)
|
||||||
|
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
||||||
from extensions.ext_storage import storage
|
from extensions.ext_storage import storage
|
||||||
|
|
||||||
from . import helpers
|
from . import helpers
|
||||||
@ -43,7 +43,7 @@ def to_prompt_message_content(
|
|||||||
/,
|
/,
|
||||||
*,
|
*,
|
||||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||||
) -> MultiModalPromptMessageContent:
|
) -> PromptMessageContentUnionTypes:
|
||||||
if f.extension is None:
|
if f.extension is None:
|
||||||
raise ValueError("Missing file extension")
|
raise ValueError("Missing file extension")
|
||||||
if f.mime_type is None:
|
if f.mime_type is None:
|
||||||
@ -58,7 +58,7 @@ def to_prompt_message_content(
|
|||||||
if f.type == FileType.IMAGE:
|
if f.type == FileType.IMAGE:
|
||||||
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||||
|
|
||||||
prompt_class_map: Mapping[FileType, type[MultiModalPromptMessageContent]] = {
|
prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = {
|
||||||
FileType.IMAGE: ImagePromptMessageContent,
|
FileType.IMAGE: ImagePromptMessageContent,
|
||||||
FileType.AUDIO: AudioPromptMessageContent,
|
FileType.AUDIO: AudioPromptMessageContent,
|
||||||
FileType.VIDEO: VideoPromptMessageContent,
|
FileType.VIDEO: VideoPromptMessageContent,
|
||||||
|
@ -8,11 +8,11 @@ from core.model_runtime.entities import (
|
|||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
ImagePromptMessageContent,
|
ImagePromptMessageContent,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
PromptMessageContent,
|
|
||||||
PromptMessageRole,
|
PromptMessageRole,
|
||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
|
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
||||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
@ -100,7 +100,7 @@ class TokenBufferMemory:
|
|||||||
if not file_objs:
|
if not file_objs:
|
||||||
prompt_messages.append(UserPromptMessage(content=message.query))
|
prompt_messages.append(UserPromptMessage(content=message.query))
|
||||||
else:
|
else:
|
||||||
prompt_message_contents: list[PromptMessageContent] = []
|
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||||
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
||||||
for file in file_objs:
|
for file in file_objs:
|
||||||
prompt_message = file_manager.to_prompt_message_content(
|
prompt_message = file_manager.to_prompt_message_content(
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from enum import Enum, StrEnum
|
from enum import Enum, StrEnum
|
||||||
from typing import Any, Optional, Union
|
from typing import Annotated, Any, Literal, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_serializer, field_validator
|
from pydantic import BaseModel, Field, field_serializer, field_validator
|
||||||
|
|
||||||
@ -61,11 +61,7 @@ class PromptMessageContentType(StrEnum):
|
|||||||
|
|
||||||
|
|
||||||
class PromptMessageContent(BaseModel):
|
class PromptMessageContent(BaseModel):
|
||||||
"""
|
pass
|
||||||
Model class for prompt message content.
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: PromptMessageContentType
|
|
||||||
|
|
||||||
|
|
||||||
class TextPromptMessageContent(PromptMessageContent):
|
class TextPromptMessageContent(PromptMessageContent):
|
||||||
@ -73,7 +69,7 @@ class TextPromptMessageContent(PromptMessageContent):
|
|||||||
Model class for text prompt message content.
|
Model class for text prompt message content.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: PromptMessageContentType = PromptMessageContentType.TEXT
|
type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT
|
||||||
data: str
|
data: str
|
||||||
|
|
||||||
|
|
||||||
@ -82,7 +78,6 @@ class MultiModalPromptMessageContent(PromptMessageContent):
|
|||||||
Model class for multi-modal prompt message content.
|
Model class for multi-modal prompt message content.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: PromptMessageContentType
|
|
||||||
format: str = Field(default=..., description="the format of multi-modal file")
|
format: str = Field(default=..., description="the format of multi-modal file")
|
||||||
base64_data: str = Field(default="", description="the base64 data of multi-modal file")
|
base64_data: str = Field(default="", description="the base64 data of multi-modal file")
|
||||||
url: str = Field(default="", description="the url of multi-modal file")
|
url: str = Field(default="", description="the url of multi-modal file")
|
||||||
@ -94,11 +89,11 @@ class MultiModalPromptMessageContent(PromptMessageContent):
|
|||||||
|
|
||||||
|
|
||||||
class VideoPromptMessageContent(MultiModalPromptMessageContent):
|
class VideoPromptMessageContent(MultiModalPromptMessageContent):
|
||||||
type: PromptMessageContentType = PromptMessageContentType.VIDEO
|
type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO
|
||||||
|
|
||||||
|
|
||||||
class AudioPromptMessageContent(MultiModalPromptMessageContent):
|
class AudioPromptMessageContent(MultiModalPromptMessageContent):
|
||||||
type: PromptMessageContentType = PromptMessageContentType.AUDIO
|
type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO
|
||||||
|
|
||||||
|
|
||||||
class ImagePromptMessageContent(MultiModalPromptMessageContent):
|
class ImagePromptMessageContent(MultiModalPromptMessageContent):
|
||||||
@ -110,12 +105,24 @@ class ImagePromptMessageContent(MultiModalPromptMessageContent):
|
|||||||
LOW = "low"
|
LOW = "low"
|
||||||
HIGH = "high"
|
HIGH = "high"
|
||||||
|
|
||||||
type: PromptMessageContentType = PromptMessageContentType.IMAGE
|
type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE
|
||||||
detail: DETAIL = DETAIL.LOW
|
detail: DETAIL = DETAIL.LOW
|
||||||
|
|
||||||
|
|
||||||
class DocumentPromptMessageContent(MultiModalPromptMessageContent):
|
class DocumentPromptMessageContent(MultiModalPromptMessageContent):
|
||||||
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
|
type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT
|
||||||
|
|
||||||
|
|
||||||
|
PromptMessageContentUnionTypes = Annotated[
|
||||||
|
Union[
|
||||||
|
TextPromptMessageContent,
|
||||||
|
ImagePromptMessageContent,
|
||||||
|
DocumentPromptMessageContent,
|
||||||
|
AudioPromptMessageContent,
|
||||||
|
VideoPromptMessageContent,
|
||||||
|
],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class PromptMessage(BaseModel):
|
class PromptMessage(BaseModel):
|
||||||
@ -124,7 +131,7 @@ class PromptMessage(BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
role: PromptMessageRole
|
role: PromptMessageRole
|
||||||
content: Optional[str | Sequence[PromptMessageContent]] = None
|
content: Optional[str | list[PromptMessageContentUnionTypes]] = None
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
|
|
||||||
def is_empty(self) -> bool:
|
def is_empty(self) -> bool:
|
||||||
|
@ -9,13 +9,12 @@ from core.memory.token_buffer_memory import TokenBufferMemory
|
|||||||
from core.model_runtime.entities import (
|
from core.model_runtime.entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
PromptMessageContent,
|
|
||||||
PromptMessageRole,
|
PromptMessageRole,
|
||||||
SystemPromptMessage,
|
SystemPromptMessage,
|
||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||||
from core.prompt.prompt_transform import PromptTransform
|
from core.prompt.prompt_transform import PromptTransform
|
||||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||||
@ -125,7 +124,7 @@ class AdvancedPromptTransform(PromptTransform):
|
|||||||
prompt = Jinja2Formatter.format(prompt, prompt_inputs)
|
prompt = Jinja2Formatter.format(prompt, prompt_inputs)
|
||||||
|
|
||||||
if files:
|
if files:
|
||||||
prompt_message_contents: list[PromptMessageContent] = []
|
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||||
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
|
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
|
||||||
for file in files:
|
for file in files:
|
||||||
prompt_message_contents.append(
|
prompt_message_contents.append(
|
||||||
@ -201,7 +200,7 @@ class AdvancedPromptTransform(PromptTransform):
|
|||||||
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
|
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
|
||||||
|
|
||||||
if files and query is not None:
|
if files and query is not None:
|
||||||
prompt_message_contents: list[PromptMessageContent] = []
|
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||||
for file in files:
|
for file in files:
|
||||||
prompt_message_contents.append(
|
prompt_message_contents.append(
|
||||||
|
@ -11,7 +11,7 @@ from core.memory.token_buffer_memory import TokenBufferMemory
|
|||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
ImagePromptMessageContent,
|
ImagePromptMessageContent,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
PromptMessageContent,
|
PromptMessageContentUnionTypes,
|
||||||
SystemPromptMessage,
|
SystemPromptMessage,
|
||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
@ -277,7 +277,7 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
|
image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
|
||||||
) -> UserPromptMessage:
|
) -> UserPromptMessage:
|
||||||
if files:
|
if files:
|
||||||
prompt_message_contents: list[PromptMessageContent] = []
|
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||||
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
|
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
|
||||||
for file in files:
|
for file in files:
|
||||||
prompt_message_contents.append(
|
prompt_message_contents.append(
|
||||||
|
@ -24,7 +24,7 @@ from core.model_runtime.entities import (
|
|||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
PromptMessageContent,
|
PromptMessageContentUnionTypes,
|
||||||
PromptMessageRole,
|
PromptMessageRole,
|
||||||
SystemPromptMessage,
|
SystemPromptMessage,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
@ -594,8 +594,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
variable_pool: VariablePool,
|
variable_pool: VariablePool,
|
||||||
jinja2_variables: Sequence[VariableSelector],
|
jinja2_variables: Sequence[VariableSelector],
|
||||||
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
|
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
|
||||||
# FIXME: fix the type error cause prompt_messages is type quick a few times
|
prompt_messages: list[PromptMessage] = []
|
||||||
prompt_messages: list[Any] = []
|
|
||||||
|
|
||||||
if isinstance(prompt_template, list):
|
if isinstance(prompt_template, list):
|
||||||
# For chat model
|
# For chat model
|
||||||
@ -657,12 +656,14 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
# For issue #11247 - Check if prompt content is a string or a list
|
# For issue #11247 - Check if prompt content is a string or a list
|
||||||
prompt_content_type = type(prompt_content)
|
prompt_content_type = type(prompt_content)
|
||||||
if prompt_content_type == str:
|
if prompt_content_type == str:
|
||||||
|
prompt_content = str(prompt_content)
|
||||||
if "#histories#" in prompt_content:
|
if "#histories#" in prompt_content:
|
||||||
prompt_content = prompt_content.replace("#histories#", memory_text)
|
prompt_content = prompt_content.replace("#histories#", memory_text)
|
||||||
else:
|
else:
|
||||||
prompt_content = memory_text + "\n" + prompt_content
|
prompt_content = memory_text + "\n" + prompt_content
|
||||||
prompt_messages[0].content = prompt_content
|
prompt_messages[0].content = prompt_content
|
||||||
elif prompt_content_type == list:
|
elif prompt_content_type == list:
|
||||||
|
prompt_content = prompt_content if isinstance(prompt_content, list) else []
|
||||||
for content_item in prompt_content:
|
for content_item in prompt_content:
|
||||||
if content_item.type == PromptMessageContentType.TEXT:
|
if content_item.type == PromptMessageContentType.TEXT:
|
||||||
if "#histories#" in content_item.data:
|
if "#histories#" in content_item.data:
|
||||||
@ -675,9 +676,10 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
# Add current query to the prompt message
|
# Add current query to the prompt message
|
||||||
if sys_query:
|
if sys_query:
|
||||||
if prompt_content_type == str:
|
if prompt_content_type == str:
|
||||||
prompt_content = prompt_messages[0].content.replace("#sys.query#", sys_query)
|
prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
|
||||||
prompt_messages[0].content = prompt_content
|
prompt_messages[0].content = prompt_content
|
||||||
elif prompt_content_type == list:
|
elif prompt_content_type == list:
|
||||||
|
prompt_content = prompt_content if isinstance(prompt_content, list) else []
|
||||||
for content_item in prompt_content:
|
for content_item in prompt_content:
|
||||||
if content_item.type == PromptMessageContentType.TEXT:
|
if content_item.type == PromptMessageContentType.TEXT:
|
||||||
content_item.data = sys_query + "\n" + content_item.data
|
content_item.data = sys_query + "\n" + content_item.data
|
||||||
@ -707,7 +709,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
filtered_prompt_messages = []
|
filtered_prompt_messages = []
|
||||||
for prompt_message in prompt_messages:
|
for prompt_message in prompt_messages:
|
||||||
if isinstance(prompt_message.content, list):
|
if isinstance(prompt_message.content, list):
|
||||||
prompt_message_content = []
|
prompt_message_content: list[PromptMessageContentUnionTypes] = []
|
||||||
for content_item in prompt_message.content:
|
for content_item in prompt_message.content:
|
||||||
# Skip content if features are not defined
|
# Skip content if features are not defined
|
||||||
if not model_config.model_schema.features:
|
if not model_config.model_schema.features:
|
||||||
@ -1132,7 +1134,9 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
|
def _combine_message_content_with_role(
|
||||||
|
*, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole
|
||||||
|
):
|
||||||
match role:
|
match role:
|
||||||
case PromptMessageRole.USER:
|
case PromptMessageRole.USER:
|
||||||
return UserPromptMessage(content=contents)
|
return UserPromptMessage(content=contents)
|
||||||
|
27
api/tests/unit_tests/core/prompt/test_prompt_message.py
Normal file
27
api/tests/unit_tests/core/prompt/test_prompt_message.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
ImagePromptMessageContent,
|
||||||
|
TextPromptMessageContent,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_prompt_message_with_prompt_message_contents():
|
||||||
|
prompt = UserPromptMessage(content=[TextPromptMessageContent(data="Hello, World!")])
|
||||||
|
assert isinstance(prompt.content, list)
|
||||||
|
assert isinstance(prompt.content[0], TextPromptMessageContent)
|
||||||
|
assert prompt.content[0].data == "Hello, World!"
|
||||||
|
|
||||||
|
|
||||||
|
def test_dump_prompt_message():
|
||||||
|
example_url = "https://example.com/image.jpg"
|
||||||
|
prompt = UserPromptMessage(
|
||||||
|
content=[
|
||||||
|
ImagePromptMessageContent(
|
||||||
|
url=example_url,
|
||||||
|
format="jpeg",
|
||||||
|
mime_type="image/jpeg",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
data = prompt.model_dump()
|
||||||
|
assert data["content"][0].get("url") == example_url
|
Loading…
x
Reference in New Issue
Block a user