mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 05:48:58 +08:00
fix(api/core/model_runtime/model_providers/baichuan,localai): Parse ToolPromptMessage. #4943 (#5138)
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
742b08e1d5
commit
adc948e87c
@ -7,6 +7,7 @@ from core.model_runtime.entities.message_entities import (
|
|||||||
PromptMessage,
|
PromptMessage,
|
||||||
PromptMessageTool,
|
PromptMessageTool,
|
||||||
SystemPromptMessage,
|
SystemPromptMessage,
|
||||||
|
ToolPromptMessage,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
from core.model_runtime.errors.invoke import (
|
from core.model_runtime.errors.invoke import (
|
||||||
@ -38,14 +39,15 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
|||||||
stream: bool = True, user: str | None = None) \
|
stream: bool = True, user: str | None = None) \
|
||||||
-> LLMResult | Generator:
|
-> LLMResult | Generator:
|
||||||
return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
||||||
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
|
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
tools: list[PromptMessageTool] | None = None) -> int:
|
tools: list[PromptMessageTool] | None = None) -> int:
|
||||||
return self._num_tokens_from_messages(prompt_messages)
|
return self._num_tokens_from_messages(prompt_messages)
|
||||||
|
|
||||||
def _num_tokens_from_messages(self, messages: list[PromptMessage],) -> int:
|
def _num_tokens_from_messages(self, messages: list[PromptMessage], ) -> int:
|
||||||
"""Calculate num tokens for baichuan model"""
|
"""Calculate num tokens for baichuan model"""
|
||||||
|
|
||||||
def tokens(text: str):
|
def tokens(text: str):
|
||||||
return BaichuanTokenizer._get_num_tokens(text)
|
return BaichuanTokenizer._get_num_tokens(text)
|
||||||
|
|
||||||
@ -85,6 +87,17 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
|||||||
elif isinstance(message, SystemPromptMessage):
|
elif isinstance(message, SystemPromptMessage):
|
||||||
message = cast(SystemPromptMessage, message)
|
message = cast(SystemPromptMessage, message)
|
||||||
message_dict = {"role": "user", "content": message.content}
|
message_dict = {"role": "user", "content": message.content}
|
||||||
|
elif isinstance(message, ToolPromptMessage):
|
||||||
|
# copy from core/model_runtime/model_providers/anthropic/llm/llm.py
|
||||||
|
message = cast(ToolPromptMessage, message)
|
||||||
|
message_dict = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": message.tool_call_id,
|
||||||
|
"content": message.content
|
||||||
|
}]
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown message type {type(message)}")
|
raise ValueError(f"Unknown message type {type(message)}")
|
||||||
|
|
||||||
@ -107,8 +120,8 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
|||||||
raise CredentialsValidateFailedError(f"Invalid API key: {e}")
|
raise CredentialsValidateFailedError(f"Invalid API key: {e}")
|
||||||
|
|
||||||
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
||||||
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
||||||
-> LLMResult | Generator:
|
-> LLMResult | Generator:
|
||||||
if tools is not None and len(tools) > 0:
|
if tools is not None and len(tools) > 0:
|
||||||
raise InvokeBadRequestError("Baichuan model doesn't support tools")
|
raise InvokeBadRequestError("Baichuan model doesn't support tools")
|
||||||
@ -129,7 +142,8 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
|||||||
]
|
]
|
||||||
|
|
||||||
# invoke model
|
# invoke model
|
||||||
response = instance.generate(model=model, stream=stream, messages=messages, parameters=model_parameters, timeout=60)
|
response = instance.generate(model=model, stream=stream, messages=messages, parameters=model_parameters,
|
||||||
|
timeout=60)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response)
|
return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response)
|
||||||
@ -141,7 +155,9 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
|||||||
credentials: dict,
|
credentials: dict,
|
||||||
response: BaichuanMessage) -> LLMResult:
|
response: BaichuanMessage) -> LLMResult:
|
||||||
# convert baichuan message to llm result
|
# convert baichuan message to llm result
|
||||||
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=response.usage['prompt_tokens'], completion_tokens=response.usage['completion_tokens'])
|
usage = self._calc_response_usage(model=model, credentials=credentials,
|
||||||
|
prompt_tokens=response.usage['prompt_tokens'],
|
||||||
|
completion_tokens=response.usage['completion_tokens'])
|
||||||
return LLMResult(
|
return LLMResult(
|
||||||
model=model,
|
model=model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
@ -158,7 +174,9 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
|||||||
response: Generator[BaichuanMessage, None, None]) -> Generator:
|
response: Generator[BaichuanMessage, None, None]) -> Generator:
|
||||||
for message in response:
|
for message in response:
|
||||||
if message.usage:
|
if message.usage:
|
||||||
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=message.usage['prompt_tokens'], completion_tokens=message.usage['completion_tokens'])
|
usage = self._calc_response_usage(model=model, credentials=credentials,
|
||||||
|
prompt_tokens=message.usage['prompt_tokens'],
|
||||||
|
completion_tokens=message.usage['completion_tokens'])
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=model,
|
model=model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
|
@ -27,6 +27,7 @@ from core.model_runtime.entities.message_entities import (
|
|||||||
PromptMessage,
|
PromptMessage,
|
||||||
PromptMessageTool,
|
PromptMessageTool,
|
||||||
SystemPromptMessage,
|
SystemPromptMessage,
|
||||||
|
ToolPromptMessage,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.model_entities import (
|
from core.model_runtime.entities.model_entities import (
|
||||||
@ -57,7 +58,7 @@ class LocalAILanguageModel(LargeLanguageModel):
|
|||||||
stream: bool = True, user: str | None = None) \
|
stream: bool = True, user: str | None = None) \
|
||||||
-> LLMResult | Generator:
|
-> LLMResult | Generator:
|
||||||
return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
||||||
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
|
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
tools: list[PromptMessageTool] | None = None) -> int:
|
tools: list[PromptMessageTool] | None = None) -> int:
|
||||||
@ -69,6 +70,7 @@ class LocalAILanguageModel(LargeLanguageModel):
|
|||||||
Calculate num tokens for baichuan model
|
Calculate num tokens for baichuan model
|
||||||
LocalAI does not supports
|
LocalAI does not supports
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def tokens(text: str):
|
def tokens(text: str):
|
||||||
"""
|
"""
|
||||||
We cloud not determine which tokenizer to use, cause the model is customized.
|
We cloud not determine which tokenizer to use, cause the model is customized.
|
||||||
@ -133,6 +135,7 @@ class LocalAILanguageModel(LargeLanguageModel):
|
|||||||
:param tools: tools for tool calling
|
:param tools: tools for tool calling
|
||||||
:return: number of tokens
|
:return: number of tokens
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def tokens(text: str):
|
def tokens(text: str):
|
||||||
return self._get_num_tokens_by_gpt2(text)
|
return self._get_num_tokens_by_gpt2(text)
|
||||||
|
|
||||||
@ -247,8 +250,8 @@ class LocalAILanguageModel(LargeLanguageModel):
|
|||||||
return entity
|
return entity
|
||||||
|
|
||||||
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
||||||
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
||||||
-> LLMResult | Generator:
|
-> LLMResult | Generator:
|
||||||
|
|
||||||
kwargs = self._to_client_kwargs(credentials)
|
kwargs = self._to_client_kwargs(credentials)
|
||||||
@ -351,6 +354,17 @@ class LocalAILanguageModel(LargeLanguageModel):
|
|||||||
elif isinstance(message, SystemPromptMessage):
|
elif isinstance(message, SystemPromptMessage):
|
||||||
message = cast(SystemPromptMessage, message)
|
message = cast(SystemPromptMessage, message)
|
||||||
message_dict = {"role": "system", "content": message.content}
|
message_dict = {"role": "system", "content": message.content}
|
||||||
|
elif isinstance(message, ToolPromptMessage):
|
||||||
|
# copy from core/model_runtime/model_providers/anthropic/llm/llm.py
|
||||||
|
message = cast(ToolPromptMessage, message)
|
||||||
|
message_dict = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": message.tool_call_id,
|
||||||
|
"content": message.content
|
||||||
|
}]
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown message type {type(message)}")
|
raise ValueError(f"Unknown message type {type(message)}")
|
||||||
|
|
||||||
@ -377,10 +391,10 @@ class LocalAILanguageModel(LargeLanguageModel):
|
|||||||
return prompts
|
return prompts
|
||||||
|
|
||||||
def _handle_completion_generate_response(self, model: str,
|
def _handle_completion_generate_response(self, model: str,
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
credentials: dict,
|
credentials: dict,
|
||||||
response: Completion,
|
response: Completion,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""
|
"""
|
||||||
Handle llm chat response
|
Handle llm chat response
|
||||||
|
|
||||||
@ -407,7 +421,8 @@ class LocalAILanguageModel(LargeLanguageModel):
|
|||||||
)
|
)
|
||||||
completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=[])
|
completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=[])
|
||||||
|
|
||||||
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
|
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens)
|
||||||
|
|
||||||
response = LLMResult(
|
response = LLMResult(
|
||||||
model=model,
|
model=model,
|
||||||
@ -452,7 +467,8 @@ class LocalAILanguageModel(LargeLanguageModel):
|
|||||||
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
|
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
|
||||||
completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools)
|
completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools)
|
||||||
|
|
||||||
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
|
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens)
|
||||||
|
|
||||||
response = LLMResult(
|
response = LLMResult(
|
||||||
model=model,
|
model=model,
|
||||||
@ -465,10 +481,10 @@ class LocalAILanguageModel(LargeLanguageModel):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
def _handle_completion_generate_stream_response(self, model: str,
|
def _handle_completion_generate_stream_response(self, model: str,
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
credentials: dict,
|
credentials: dict,
|
||||||
response: Stream[Completion],
|
response: Stream[Completion],
|
||||||
tools: list[PromptMessageTool]) -> Generator:
|
tools: list[PromptMessageTool]) -> Generator:
|
||||||
full_response = ''
|
full_response = ''
|
||||||
|
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user