From f44d1e62d24f910ac1a55e055a437e6b48d7b317 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 5 Jun 2024 01:53:05 +0800 Subject: [PATCH] fix: bedrock get_num_tokens prompt_messages parameter name err (#4932) --- .../model_runtime/model_providers/bedrock/llm/llm.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index 81a9ce2f00..1386d680a4 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -358,26 +358,25 @@ class BedrockLargeLanguageModel(LargeLanguageModel): return message_dict - def get_num_tokens(self, model: str, credentials: dict, messages: list[PromptMessage] | str, + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage] | str, tools: Optional[list[PromptMessageTool]] = None) -> int: """ Get number of tokens for given prompt messages :param model: model name :param credentials: model credentials - :param messages: prompt messages or message string + :param prompt_messages: prompt messages or message string :param tools: tools for tool calling :return:md = genai.GenerativeModel(model) """ prefix = model.split('.')[0] model_name = model.split('.')[1] - if isinstance(messages, str): - prompt = messages + if isinstance(prompt_messages, str): + prompt = prompt_messages else: - prompt = self._convert_messages_to_prompt(messages, prefix, model_name) + prompt = self._convert_messages_to_prompt(prompt_messages, prefix, model_name) return self._get_num_tokens_by_gpt2(prompt) - def validate_credentials(self, model: str, credentials: dict) -> None: """