From 4ad3f2cdc2b8fda1e1c126cf49885a27a28015a4 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 9 Apr 2024 15:20:45 +0800 Subject: [PATCH] fix: image text when retrieve chat histories (#3220) --- api/core/memory/token_buffer_memory.py | 15 +++++++- api/core/workflow/workflow_engine_manager.py | 40 ++++++++++++++++++-- api/models/model.py | 2 +- api/models/workflow.py | 4 ++ 4 files changed, 54 insertions(+), 7 deletions(-) diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 252b5f1cba..cd0b2508d4 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -3,6 +3,7 @@ from core.file.message_file_parser import MessageFileParser from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, + ImagePromptMessageContent, PromptMessage, PromptMessageRole, TextPromptMessageContent, @@ -124,7 +125,17 @@ class TokenBufferMemory: else: continue - message = f"{role}: {m.content}" - string_messages.append(message) + if isinstance(m.content, list): + inner_msg = "" + for content in m.content: + if isinstance(content, TextPromptMessageContent): + inner_msg += f"{content.data}\n" + elif isinstance(content, ImagePromptMessageContent): + inner_msg += "[image]\n" + + string_messages.append(f"{role}: {inner_msg.strip()}") + else: + message = f"{role}: {m.content}" + string_messages.append(message) return "\n".join(string_messages) \ No newline at end of file diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 532d474258..9390ffa2a4 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -1,9 +1,10 @@ import logging import time -from typing import Optional +from typing import Optional, cast +from core.app.app_config.entities import FileExtraConfig from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException -from core.file.file_obj import FileVar +from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool, VariableValue @@ -16,6 +17,7 @@ from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.http_request.http_request_node import HttpRequestNode from core.workflow.nodes.if_else.if_else_node import IfElseNode from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode +from core.workflow.nodes.llm.entities import LLMNodeData from core.workflow.nodes.llm.llm_node import LLMNode from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode from core.workflow.nodes.start.start_node import StartNode @@ -219,7 +221,8 @@ class WorkflowEngineManager: raise ValueError('node id not found in workflow graph') # Get node class - node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type'))) + node_type = NodeType.value_of(node_config.get('data', {}).get('type')) + node_cls = node_classes.get(node_type) # init workflow run state node_instance = node_cls( @@ -252,11 +255,40 @@ class WorkflowEngineManager: variable_node_id = variable_selector[0] variable_key_list = variable_selector[1:] + # get value + value = user_inputs.get(variable_key) + + # temp fix for image type + if node_type == NodeType.LLM: + new_value = [] + if isinstance(value, list): + node_data = node_instance.node_data + node_data = cast(LLMNodeData, node_data) + + detail = node_data.vision.configs.detail if node_data.vision.configs else None + + for item in value: + if isinstance(item, dict) and 'type' in item and item['type'] == 'image': + transfer_method = FileTransferMethod.value_of(item.get('transfer_method')) + file = FileVar( + tenant_id=workflow.tenant_id, + type=FileType.IMAGE, + transfer_method=transfer_method, + url=item.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, + related_id=item.get( + 'upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, + extra_config=FileExtraConfig(image_config={'detail': detail} if detail else None), + ) + new_value.append(file) + + if new_value: + value = new_value + # append variable and value to variable pool variable_pool.append_variable( node_id=variable_node_id, variable_key_list=variable_key_list, - value=user_inputs.get(variable_key) + value=value ) # run node node_run_result = node_instance.run( diff --git a/api/models/model.py b/api/models/model.py index d34c577b5d..df858e5219 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -815,7 +815,7 @@ class Message(db.Model): @property def workflow_run(self): if self.workflow_run_id: - from api.models.workflow import WorkflowRun + from .workflow import WorkflowRun return db.session.query(WorkflowRun).filter(WorkflowRun.id == self.workflow_run_id).first() return None diff --git a/api/models/workflow.py b/api/models/workflow.py index 8db874b471..f65eba3637 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -299,6 +299,10 @@ class WorkflowRun(db.Model): Message.workflow_run_id == self.id ).first() + @property + def workflow(self): + return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first() + class WorkflowNodeExecutionTriggeredFrom(Enum): """