feat: Allow to contains files in the system prompt even model not support. (#11111)

This commit is contained in:
-LAN- 2024-11-26 13:45:49 +08:00 committed by GitHub
parent 8d5a1be227
commit 5b7b328193
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 19 additions and 9 deletions

View File

@ -453,7 +453,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
return credentials_kwargs return credentials_kwargs
def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]: def _convert_prompt_messages(self, prompt_messages: Sequence[PromptMessage]) -> tuple[str, list[dict]]:
""" """
Convert prompt messages to dict list and system Convert prompt messages to dict list and system
""" """

View File

@ -943,6 +943,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
} }
elif isinstance(message, SystemPromptMessage): elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message) message = cast(SystemPromptMessage, message)
if isinstance(message.content, list):
text_contents = filter(lambda c: isinstance(c, TextPromptMessageContent), message.content)
message.content = "".join(c.data for c in text_contents)
message_dict = {"role": "system", "content": message.content} message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage): elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message) message = cast(ToolPromptMessage, message)

View File

@ -20,6 +20,7 @@ from core.model_runtime.entities import (
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
PromptMessageContent,
PromptMessageRole, PromptMessageRole,
SystemPromptMessage, SystemPromptMessage,
UserPromptMessage, UserPromptMessage,
@ -828,14 +829,14 @@ class LLMNode(BaseNode[LLMNodeData]):
} }
def _combine_text_message_with_role(*, text: str, role: PromptMessageRole): def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
match role: match role:
case PromptMessageRole.USER: case PromptMessageRole.USER:
return UserPromptMessage(content=[TextPromptMessageContent(data=text)]) return UserPromptMessage(content=contents)
case PromptMessageRole.ASSISTANT: case PromptMessageRole.ASSISTANT:
return AssistantPromptMessage(content=[TextPromptMessageContent(data=text)]) return AssistantPromptMessage(content=contents)
case PromptMessageRole.SYSTEM: case PromptMessageRole.SYSTEM:
return SystemPromptMessage(content=[TextPromptMessageContent(data=text)]) return SystemPromptMessage(content=contents)
raise NotImplementedError(f"Role {role} is not supported") raise NotImplementedError(f"Role {role} is not supported")
@ -877,7 +878,9 @@ def _handle_list_messages(
jinjia2_variables=jinja2_variables, jinjia2_variables=jinja2_variables,
variable_pool=variable_pool, variable_pool=variable_pool,
) )
prompt_message = _combine_text_message_with_role(text=result_text, role=message.role) prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=result_text)], role=message.role
)
prompt_messages.append(prompt_message) prompt_messages.append(prompt_message)
else: else:
# Get segment group from basic message # Get segment group from basic message
@ -908,12 +911,14 @@ def _handle_list_messages(
# Create message with text from all segments # Create message with text from all segments
plain_text = segment_group.text plain_text = segment_group.text
if plain_text: if plain_text:
prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role) prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=plain_text)], role=message.role
)
prompt_messages.append(prompt_message) prompt_messages.append(prompt_message)
if file_contents: if file_contents:
# Create message with image contents # Create message with image contents
prompt_message = UserPromptMessage(content=file_contents) prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role)
prompt_messages.append(prompt_message) prompt_messages.append(prompt_message)
return prompt_messages return prompt_messages
@ -1018,6 +1023,8 @@ def _handle_completion_template(
else: else:
template_text = template.text template_text = template.text
result_text = variable_pool.convert_template(template_text).text result_text = variable_pool.convert_template(template_text).text
prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER) prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=result_text)], role=PromptMessageRole.USER
)
prompt_messages.append(prompt_message) prompt_messages.append(prompt_message)
return prompt_messages return prompt_messages