fix(api/core/model_runtime/model_providers/baichuan,localai): Parse ToolPromptMessage. #4943 (#5138)

Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
yanghx 2024-06-13 05:08:30 +00:00 committed by GitHub
parent 742b08e1d5
commit adc948e87c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 88 additions and 54 deletions

View File

@ -7,6 +7,7 @@ from core.model_runtime.entities.message_entities import (
PromptMessage, PromptMessage,
PromptMessageTool, PromptMessageTool,
SystemPromptMessage, SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.errors.invoke import ( from core.model_runtime.errors.invoke import (
@ -46,6 +47,7 @@ class BaichuanLarguageModel(LargeLanguageModel):
def _num_tokens_from_messages(self, messages: list[PromptMessage], ) -> int: def _num_tokens_from_messages(self, messages: list[PromptMessage], ) -> int:
"""Calculate num tokens for baichuan model""" """Calculate num tokens for baichuan model"""
def tokens(text: str): def tokens(text: str):
return BaichuanTokenizer._get_num_tokens(text) return BaichuanTokenizer._get_num_tokens(text)
@ -85,6 +87,17 @@ class BaichuanLarguageModel(LargeLanguageModel):
elif isinstance(message, SystemPromptMessage): elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message) message = cast(SystemPromptMessage, message)
message_dict = {"role": "user", "content": message.content} message_dict = {"role": "user", "content": message.content}
elif isinstance(message, ToolPromptMessage):
# copy from core/model_runtime/model_providers/anthropic/llm/llm.py
message = cast(ToolPromptMessage, message)
message_dict = {
"role": "user",
"content": [{
"type": "tool_result",
"tool_use_id": message.tool_call_id,
"content": message.content
}]
}
else: else:
raise ValueError(f"Unknown message type {type(message)}") raise ValueError(f"Unknown message type {type(message)}")
@ -129,7 +142,8 @@ class BaichuanLarguageModel(LargeLanguageModel):
] ]
# invoke model # invoke model
response = instance.generate(model=model, stream=stream, messages=messages, parameters=model_parameters, timeout=60) response = instance.generate(model=model, stream=stream, messages=messages, parameters=model_parameters,
timeout=60)
if stream: if stream:
return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response) return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response)
@ -141,7 +155,9 @@ class BaichuanLarguageModel(LargeLanguageModel):
credentials: dict, credentials: dict,
response: BaichuanMessage) -> LLMResult: response: BaichuanMessage) -> LLMResult:
# convert baichuan message to llm result # convert baichuan message to llm result
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=response.usage['prompt_tokens'], completion_tokens=response.usage['completion_tokens']) usage = self._calc_response_usage(model=model, credentials=credentials,
prompt_tokens=response.usage['prompt_tokens'],
completion_tokens=response.usage['completion_tokens'])
return LLMResult( return LLMResult(
model=model, model=model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
@ -158,7 +174,9 @@ class BaichuanLarguageModel(LargeLanguageModel):
response: Generator[BaichuanMessage, None, None]) -> Generator: response: Generator[BaichuanMessage, None, None]) -> Generator:
for message in response: for message in response:
if message.usage: if message.usage:
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=message.usage['prompt_tokens'], completion_tokens=message.usage['completion_tokens']) usage = self._calc_response_usage(model=model, credentials=credentials,
prompt_tokens=message.usage['prompt_tokens'],
completion_tokens=message.usage['completion_tokens'])
yield LLMResultChunk( yield LLMResultChunk(
model=model, model=model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,

View File

@ -27,6 +27,7 @@ from core.model_runtime.entities.message_entities import (
PromptMessage, PromptMessage,
PromptMessageTool, PromptMessageTool,
SystemPromptMessage, SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.entities.model_entities import ( from core.model_runtime.entities.model_entities import (
@ -69,6 +70,7 @@ class LocalAILanguageModel(LargeLanguageModel):
Calculate num tokens for baichuan model Calculate num tokens for baichuan model
LocalAI does not supports LocalAI does not supports
""" """
def tokens(text: str): def tokens(text: str):
""" """
We cloud not determine which tokenizer to use, cause the model is customized. We cloud not determine which tokenizer to use, cause the model is customized.
@ -133,6 +135,7 @@ class LocalAILanguageModel(LargeLanguageModel):
:param tools: tools for tool calling :param tools: tools for tool calling
:return: number of tokens :return: number of tokens
""" """
def tokens(text: str): def tokens(text: str):
return self._get_num_tokens_by_gpt2(text) return self._get_num_tokens_by_gpt2(text)
@ -351,6 +354,17 @@ class LocalAILanguageModel(LargeLanguageModel):
elif isinstance(message, SystemPromptMessage): elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message) message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content} message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
# copy from core/model_runtime/model_providers/anthropic/llm/llm.py
message = cast(ToolPromptMessage, message)
message_dict = {
"role": "user",
"content": [{
"type": "tool_result",
"tool_use_id": message.tool_call_id,
"content": message.content
}]
}
else: else:
raise ValueError(f"Unknown message type {type(message)}") raise ValueError(f"Unknown message type {type(message)}")
@ -407,7 +421,8 @@ class LocalAILanguageModel(LargeLanguageModel):
) )
completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=[]) completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=[])
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens)
response = LLMResult( response = LLMResult(
model=model, model=model,
@ -452,7 +467,8 @@ class LocalAILanguageModel(LargeLanguageModel):
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools)
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens)
response = LLMResult( response = LLMResult(
model=model, model=model,