mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 20:35:55 +08:00
fix: LLMResultChunk cause concatenate str and list exception (#18852)
This commit is contained in:
parent
993ef87dca
commit
c1559a7c8e
@ -2,7 +2,7 @@ import logging
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Generator, Sequence
|
from collections.abc import Generator, Sequence
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union, cast
|
||||||
|
|
||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
|
|
||||||
@ -20,6 +20,7 @@ from core.model_runtime.entities.model_entities import (
|
|||||||
PriceType,
|
PriceType,
|
||||||
)
|
)
|
||||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||||
|
from core.model_runtime.utils.helper import convert_llm_result_chunk_to_str
|
||||||
from core.plugin.manager.model import PluginModelManager
|
from core.plugin.manager.model import PluginModelManager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -280,7 +281,9 @@ class LargeLanguageModel(AIModel):
|
|||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
assistant_message.content += chunk.delta.message.content
|
text = convert_llm_result_chunk_to_str(chunk.delta.message.content)
|
||||||
|
current_content = cast(str, assistant_message.content)
|
||||||
|
assistant_message.content = current_content + text
|
||||||
real_model = chunk.model
|
real_model = chunk.model
|
||||||
if chunk.delta.usage:
|
if chunk.delta.usage:
|
||||||
usage = chunk.delta.usage
|
usage = chunk.delta.usage
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import pydantic
|
import pydantic
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
||||||
|
|
||||||
|
|
||||||
def dump_model(model: BaseModel) -> dict:
|
def dump_model(model: BaseModel) -> dict:
|
||||||
if hasattr(pydantic, "model_dump"):
|
if hasattr(pydantic, "model_dump"):
|
||||||
@ -8,3 +10,18 @@ def dump_model(model: BaseModel) -> dict:
|
|||||||
return pydantic.model_dump(model) # type: ignore
|
return pydantic.model_dump(model) # type: ignore
|
||||||
else:
|
else:
|
||||||
return model.model_dump()
|
return model.model_dump()
|
||||||
|
|
||||||
|
|
||||||
|
def convert_llm_result_chunk_to_str(content: None | str | list[PromptMessageContentUnionTypes]) -> str:
|
||||||
|
if content is None:
|
||||||
|
message_text = ""
|
||||||
|
elif isinstance(content, str):
|
||||||
|
message_text = content
|
||||||
|
elif isinstance(content, list):
|
||||||
|
# Assuming the list contains PromptMessageContent objects with a "data" attribute
|
||||||
|
message_text = "".join(
|
||||||
|
item.data if hasattr(item, "data") and isinstance(item.data, str) else str(item) for item in content
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
message_text = str(content)
|
||||||
|
return message_text
|
||||||
|
@ -38,6 +38,7 @@ from core.model_runtime.entities.model_entities import (
|
|||||||
)
|
)
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
|
from core.model_runtime.utils.helper import convert_llm_result_chunk_to_str
|
||||||
from core.plugin.entities.plugin import ModelProviderID
|
from core.plugin.entities.plugin import ModelProviderID
|
||||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||||
@ -269,18 +270,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
|
|
||||||
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]:
|
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]:
|
||||||
if isinstance(invoke_result, LLMResult):
|
if isinstance(invoke_result, LLMResult):
|
||||||
content = invoke_result.message.content
|
message_text = convert_llm_result_chunk_to_str(invoke_result.message.content)
|
||||||
if content is None:
|
|
||||||
message_text = ""
|
|
||||||
elif isinstance(content, str):
|
|
||||||
message_text = content
|
|
||||||
elif isinstance(content, list):
|
|
||||||
# Assuming the list contains PromptMessageContent objects with a "data" attribute
|
|
||||||
message_text = "".join(
|
|
||||||
item.data if hasattr(item, "data") and isinstance(item.data, str) else str(item) for item in content
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
message_text = str(content)
|
|
||||||
|
|
||||||
yield ModelInvokeCompletedEvent(
|
yield ModelInvokeCompletedEvent(
|
||||||
text=message_text,
|
text=message_text,
|
||||||
@ -295,7 +285,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
usage = None
|
usage = None
|
||||||
finish_reason = None
|
finish_reason = None
|
||||||
for result in invoke_result:
|
for result in invoke_result:
|
||||||
text = result.delta.message.content
|
text = convert_llm_result_chunk_to_str(result.delta.message.content)
|
||||||
full_text += text
|
full_text += text
|
||||||
|
|
||||||
yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"])
|
yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user