From c1559a7c8ea1bff27c82f65b83cdab0e96427ec6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Sun, 27 Apr 2025 11:32:14 +0800 Subject: [PATCH] fix: LLMResultChunk cause concatenate str and list exception (#18852) --- .../__base/large_language_model.py | 7 +++++-- api/core/model_runtime/utils/helper.py | 17 +++++++++++++++++ api/core/workflow/nodes/llm/node.py | 16 +++------------- 3 files changed, 25 insertions(+), 15 deletions(-) diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 1b799131e7..1a4e4a1537 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -2,7 +2,7 @@ import logging import time import uuid from collections.abc import Generator, Sequence -from typing import Optional, Union +from typing import Optional, Union, cast from pydantic import ConfigDict @@ -20,6 +20,7 @@ from core.model_runtime.entities.model_entities import ( PriceType, ) from core.model_runtime.model_providers.__base.ai_model import AIModel +from core.model_runtime.utils.helper import convert_llm_result_chunk_to_str from core.plugin.manager.model import PluginModelManager logger = logging.getLogger(__name__) @@ -280,7 +281,9 @@ class LargeLanguageModel(AIModel): callbacks=callbacks, ) - assistant_message.content += chunk.delta.message.content + text = convert_llm_result_chunk_to_str(chunk.delta.message.content) + current_content = cast(str, assistant_message.content) + assistant_message.content = current_content + text real_model = chunk.model if chunk.delta.usage: usage = chunk.delta.usage diff --git a/api/core/model_runtime/utils/helper.py b/api/core/model_runtime/utils/helper.py index 5e8a723ec7..53789a8e91 100644 --- a/api/core/model_runtime/utils/helper.py +++ b/api/core/model_runtime/utils/helper.py @@ -1,6 +1,8 @@ import pydantic from pydantic import BaseModel +from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes + def dump_model(model: BaseModel) -> dict: if hasattr(pydantic, "model_dump"): @@ -8,3 +10,18 @@ def dump_model(model: BaseModel) -> dict: return pydantic.model_dump(model) # type: ignore else: return model.model_dump() + + +def convert_llm_result_chunk_to_str(content: None | str | list[PromptMessageContentUnionTypes]) -> str: + if content is None: + message_text = "" + elif isinstance(content, str): + message_text = content + elif isinstance(content, list): + # Assuming the list contains PromptMessageContent objects with a "data" attribute + message_text = "".join( + item.data if hasattr(item, "data") and isinstance(item.data, str) else str(item) for item in content + ) + else: + message_text = str(content) + return message_text diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 1089e7168e..35b146e5d9 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -38,6 +38,7 @@ from core.model_runtime.entities.model_entities import ( ) from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder +from core.model_runtime.utils.helper import convert_llm_result_chunk_to_str from core.plugin.entities.plugin import ModelProviderID from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil @@ -269,18 +270,7 @@ class LLMNode(BaseNode[LLMNodeData]): def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]: if isinstance(invoke_result, LLMResult): - content = invoke_result.message.content - if content is None: - message_text = "" - elif isinstance(content, str): - message_text = content - elif isinstance(content, list): - # Assuming the list contains PromptMessageContent objects with a "data" attribute - message_text = "".join( - item.data if hasattr(item, "data") and isinstance(item.data, str) else str(item) for item in content - ) - else: - message_text = str(content) + message_text = convert_llm_result_chunk_to_str(invoke_result.message.content) yield ModelInvokeCompletedEvent( text=message_text, @@ -295,7 +285,7 @@ class LLMNode(BaseNode[LLMNodeData]): usage = None finish_reason = None for result in invoke_result: - text = result.delta.message.content + text = convert_llm_result_chunk_to_str(result.delta.message.content) full_text += text yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"])