diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 0d5e8a3e4b..0845ef206e 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -177,7 +177,7 @@ class ModelInstance: ) def get_llm_num_tokens( - self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + self, prompt_messages: Sequence[PromptMessage], tools: Optional[Sequence[PromptMessageTool]] = None ) -> int: """ Get number of tokens for llm diff --git a/api/core/model_runtime/callbacks/base_callback.py b/api/core/model_runtime/callbacks/base_callback.py index 8870b34435..57cad17285 100644 --- a/api/core/model_runtime/callbacks/base_callback.py +++ b/api/core/model_runtime/callbacks/base_callback.py @@ -58,7 +58,7 @@ class Callback(ABC): chunk: LLMResultChunk, model: str, credentials: dict, - prompt_messages: list[PromptMessage], + prompt_messages: Sequence[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[Sequence[str]] = None, @@ -88,7 +88,7 @@ class Callback(ABC): result: LLMResult, model: str, credentials: dict, - prompt_messages: list[PromptMessage], + prompt_messages: Sequence[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[Sequence[str]] = None, diff --git a/api/core/model_runtime/callbacks/logging_callback.py b/api/core/model_runtime/callbacks/logging_callback.py index 1f21a2d376..899f08195d 100644 --- a/api/core/model_runtime/callbacks/logging_callback.py +++ b/api/core/model_runtime/callbacks/logging_callback.py @@ -74,7 +74,7 @@ class LoggingCallback(Callback): chunk: LLMResultChunk, model: str, credentials: dict, - prompt_messages: list[PromptMessage], + prompt_messages: Sequence[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[Sequence[str]] = None, @@ -104,7 +104,7 @@ class LoggingCallback(Callback): result: LLMResult, model: str, credentials: dict, - prompt_messages: list[PromptMessage], + prompt_messages: Sequence[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[Sequence[str]] = None, diff --git a/api/core/model_runtime/entities/llm_entities.py b/api/core/model_runtime/entities/llm_entities.py index 4523da4388..9bb118622b 100644 --- a/api/core/model_runtime/entities/llm_entities.py +++ b/api/core/model_runtime/entities/llm_entities.py @@ -1,8 +1,9 @@ +from collections.abc import Sequence from decimal import Decimal from enum import StrEnum from typing import Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage from core.model_runtime.entities.model_entities import ModelUsage, PriceInfo @@ -107,7 +108,7 @@ class LLMResult(BaseModel): id: Optional[str] = None model: str - prompt_messages: list[PromptMessage] + prompt_messages: Sequence[PromptMessage] = Field(default_factory=list) message: AssistantPromptMessage usage: LLMUsage system_fingerprint: Optional[str] = None @@ -130,7 +131,7 @@ class LLMResultChunk(BaseModel): """ model: str - prompt_messages: list[PromptMessage] + prompt_messages: Sequence[PromptMessage] = Field(default_factory=list) system_fingerprint: Optional[str] = None delta: LLMResultChunkDelta diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index b81ccafc1e..53de16d621 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -45,7 +45,7 @@ class LargeLanguageModel(AIModel): stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, - ) -> Union[LLMResult, Generator]: + ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: """ Invoke large language model @@ -205,22 +205,26 @@ class LargeLanguageModel(AIModel): user=user, callbacks=callbacks, ) - - return result + # Following https://github.com/langgenius/dify/issues/17799, + # we removed the prompt_messages from the chunk on the plugin daemon side. + # To ensure compatibility, we add the prompt_messages back here. + result.prompt_messages = prompt_messages + return result + raise NotImplementedError("unsupported invoke result type", type(result)) def _invoke_result_generator( self, model: str, result: Generator, credentials: dict, - prompt_messages: list[PromptMessage], + prompt_messages: Sequence[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, - ) -> Generator: + ) -> Generator[LLMResultChunk, None, None]: """ Invoke result generator @@ -235,6 +239,10 @@ class LargeLanguageModel(AIModel): try: for chunk in result: + # Following https://github.com/langgenius/dify/issues/17799, + # we removed the prompt_messages from the chunk on the plugin daemon side. + # To ensure compatibility, we add the prompt_messages back here. + chunk.prompt_messages = prompt_messages yield chunk self._trigger_new_chunk_callbacks( @@ -403,7 +411,7 @@ class LargeLanguageModel(AIModel): chunk: LLMResultChunk, model: str, credentials: dict, - prompt_messages: list[PromptMessage], + prompt_messages: Sequence[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[Sequence[str]] = None, @@ -450,7 +458,7 @@ class LargeLanguageModel(AIModel): model: str, result: LLMResult, credentials: dict, - prompt_messages: list[PromptMessage], + prompt_messages: Sequence[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[Sequence[str]] = None,