feat: re-add prompt messages to result and chunks in llm (#17883)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2025-04-11 18:04:49 +09:00 committed by GitHub
parent 5f8d20b5b2
commit 8e6f6d64a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 24 additions and 15 deletions

View File

@ -177,7 +177,7 @@ class ModelInstance:
) )
def get_llm_num_tokens( def get_llm_num_tokens(
self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None self, prompt_messages: Sequence[PromptMessage], tools: Optional[Sequence[PromptMessageTool]] = None
) -> int: ) -> int:
""" """
Get number of tokens for llm Get number of tokens for llm

View File

@ -58,7 +58,7 @@ class Callback(ABC):
chunk: LLMResultChunk, chunk: LLMResultChunk,
model: str, model: str,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
@ -88,7 +88,7 @@ class Callback(ABC):
result: LLMResult, result: LLMResult,
model: str, model: str,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,

View File

@ -74,7 +74,7 @@ class LoggingCallback(Callback):
chunk: LLMResultChunk, chunk: LLMResultChunk,
model: str, model: str,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
@ -104,7 +104,7 @@ class LoggingCallback(Callback):
result: LLMResult, result: LLMResult,
model: str, model: str,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,

View File

@ -1,8 +1,9 @@
from collections.abc import Sequence
from decimal import Decimal from decimal import Decimal
from enum import StrEnum from enum import StrEnum
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel, Field
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage
from core.model_runtime.entities.model_entities import ModelUsage, PriceInfo from core.model_runtime.entities.model_entities import ModelUsage, PriceInfo
@ -107,7 +108,7 @@ class LLMResult(BaseModel):
id: Optional[str] = None id: Optional[str] = None
model: str model: str
prompt_messages: list[PromptMessage] prompt_messages: Sequence[PromptMessage] = Field(default_factory=list)
message: AssistantPromptMessage message: AssistantPromptMessage
usage: LLMUsage usage: LLMUsage
system_fingerprint: Optional[str] = None system_fingerprint: Optional[str] = None
@ -130,7 +131,7 @@ class LLMResultChunk(BaseModel):
""" """
model: str model: str
prompt_messages: list[PromptMessage] prompt_messages: Sequence[PromptMessage] = Field(default_factory=list)
system_fingerprint: Optional[str] = None system_fingerprint: Optional[str] = None
delta: LLMResultChunkDelta delta: LLMResultChunkDelta

View File

@ -45,7 +45,7 @@ class LargeLanguageModel(AIModel):
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
) -> Union[LLMResult, Generator]: ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
""" """
Invoke large language model Invoke large language model
@ -205,22 +205,26 @@ class LargeLanguageModel(AIModel):
user=user, user=user,
callbacks=callbacks, callbacks=callbacks,
) )
# Following https://github.com/langgenius/dify/issues/17799,
return result # we removed the prompt_messages from the chunk on the plugin daemon side.
# To ensure compatibility, we add the prompt_messages back here.
result.prompt_messages = prompt_messages
return result
raise NotImplementedError("unsupported invoke result type", type(result))
def _invoke_result_generator( def _invoke_result_generator(
self, self,
model: str, model: str,
result: Generator, result: Generator,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
) -> Generator: ) -> Generator[LLMResultChunk, None, None]:
""" """
Invoke result generator Invoke result generator
@ -235,6 +239,10 @@ class LargeLanguageModel(AIModel):
try: try:
for chunk in result: for chunk in result:
# Following https://github.com/langgenius/dify/issues/17799,
# we removed the prompt_messages from the chunk on the plugin daemon side.
# To ensure compatibility, we add the prompt_messages back here.
chunk.prompt_messages = prompt_messages
yield chunk yield chunk
self._trigger_new_chunk_callbacks( self._trigger_new_chunk_callbacks(
@ -403,7 +411,7 @@ class LargeLanguageModel(AIModel):
chunk: LLMResultChunk, chunk: LLMResultChunk,
model: str, model: str,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
@ -450,7 +458,7 @@ class LargeLanguageModel(AIModel):
model: str, model: str,
result: LLMResult, result: LLMResult,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,