From 5b7b328193e65eacd33d97e652acac2567bc0375 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 26 Nov 2024 13:45:49 +0800 Subject: [PATCH] feat: Allow to contains files in the system prompt even model not support. (#11111) --- .../model_providers/anthropic/llm/llm.py | 2 +- .../model_providers/openai/llm/llm.py | 3 +++ api/core/workflow/nodes/llm/node.py | 23 ++++++++++++------- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index b5de02193b..b24324708b 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -453,7 +453,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): return credentials_kwargs - def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]: + def _convert_prompt_messages(self, prompt_messages: Sequence[PromptMessage]) -> tuple[str, list[dict]]: """ Convert prompt messages to dict list and system """ diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index aea884e002..07cb1e2d10 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -943,6 +943,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): } elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) + if isinstance(message.content, list): + text_contents = filter(lambda c: isinstance(c, TextPromptMessageContent), message.content) + message.content = "".join(c.data for c in text_contents) message_dict = {"role": "system", "content": message.content} elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 8653f539a0..2380829f7d 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -20,6 +20,7 @@ from core.model_runtime.entities import ( from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, + PromptMessageContent, PromptMessageRole, SystemPromptMessage, UserPromptMessage, @@ -828,14 +829,14 @@ class LLMNode(BaseNode[LLMNodeData]): } -def _combine_text_message_with_role(*, text: str, role: PromptMessageRole): +def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole): match role: case PromptMessageRole.USER: - return UserPromptMessage(content=[TextPromptMessageContent(data=text)]) + return UserPromptMessage(content=contents) case PromptMessageRole.ASSISTANT: - return AssistantPromptMessage(content=[TextPromptMessageContent(data=text)]) + return AssistantPromptMessage(content=contents) case PromptMessageRole.SYSTEM: - return SystemPromptMessage(content=[TextPromptMessageContent(data=text)]) + return SystemPromptMessage(content=contents) raise NotImplementedError(f"Role {role} is not supported") @@ -877,7 +878,9 @@ def _handle_list_messages( jinjia2_variables=jinja2_variables, variable_pool=variable_pool, ) - prompt_message = _combine_text_message_with_role(text=result_text, role=message.role) + prompt_message = _combine_message_content_with_role( + contents=[TextPromptMessageContent(data=result_text)], role=message.role + ) prompt_messages.append(prompt_message) else: # Get segment group from basic message @@ -908,12 +911,14 @@ def _handle_list_messages( # Create message with text from all segments plain_text = segment_group.text if plain_text: - prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role) + prompt_message = _combine_message_content_with_role( + contents=[TextPromptMessageContent(data=plain_text)], role=message.role + ) prompt_messages.append(prompt_message) if file_contents: # Create message with image contents - prompt_message = UserPromptMessage(content=file_contents) + prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role) prompt_messages.append(prompt_message) return prompt_messages @@ -1018,6 +1023,8 @@ def _handle_completion_template( else: template_text = template.text result_text = variable_pool.convert_template(template_text).text - prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER) + prompt_message = _combine_message_content_with_role( + contents=[TextPromptMessageContent(data=result_text)], role=PromptMessageRole.USER + ) prompt_messages.append(prompt_message) return prompt_messages