From d018b32d0b75cfaa6f6f38fb7e5e15984f0172ff Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 24 Oct 2024 17:52:11 +0800 Subject: [PATCH] fix(workflow): enhance prompt handling with vision support (#9790) --- api/core/workflow/nodes/llm/node.py | 8 +++++++- .../nodes/question_classifier/question_classifier_node.py | 1 + 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 94aa8c5eab..abf77f3339 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -127,9 +127,10 @@ class LLMNode(BaseNode[LLMNodeData]): context=context, memory=memory, model_config=model_config, - vision_detail=self.node_data.vision.configs.detail, prompt_template=self.node_data.prompt_template, memory_config=self.node_data.memory, + vision_enabled=self.node_data.vision.enabled, + vision_detail=self.node_data.vision.configs.detail, ) process_data = { @@ -518,6 +519,7 @@ class LLMNode(BaseNode[LLMNodeData]): model_config: ModelConfigWithCredentialsEntity, prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, memory_config: MemoryConfig | None = None, + vision_enabled: bool = False, vision_detail: ImagePromptMessageContent.DETAIL, ) -> tuple[list[PromptMessage], Optional[list[str]]]: inputs = inputs or {} @@ -542,6 +544,10 @@ class LLMNode(BaseNode[LLMNodeData]): if not isinstance(prompt_message.content, str): prompt_message_content = [] for content_item in prompt_message.content or []: + # Skip image if vision is disabled + if not vision_enabled and content_item.type == PromptMessageContentType.IMAGE: + continue + if isinstance(content_item, ImagePromptMessageContent): # Override vision config if LLM node has vision config, # cuz vision detail is related to the configuration from FileUpload feature. diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index e6af453dcf..ee160e7c69 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -88,6 +88,7 @@ class QuestionClassifierNode(LLMNode): memory=memory, model_config=model_config, files=files, + vision_enabled=node_data.vision.enabled, vision_detail=node_data.vision.configs.detail, )