fix: bedrock get_num_tokens prompt_messages parameter name err (#4932)

This commit is contained in:
takatost 2024-06-05 01:53:05 +08:00 committed by GitHub
parent 21ac2afb3a
commit f44d1e62d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -358,26 +358,25 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
return message_dict return message_dict
def get_num_tokens(self, model: str, credentials: dict, messages: list[PromptMessage] | str, def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage] | str,
tools: Optional[list[PromptMessageTool]] = None) -> int: tools: Optional[list[PromptMessageTool]] = None) -> int:
""" """
Get number of tokens for given prompt messages Get number of tokens for given prompt messages
:param model: model name :param model: model name
:param credentials: model credentials :param credentials: model credentials
:param messages: prompt messages or message string :param prompt_messages: prompt messages or message string
:param tools: tools for tool calling :param tools: tools for tool calling
:return:md = genai.GenerativeModel(model) :return:md = genai.GenerativeModel(model)
""" """
prefix = model.split('.')[0] prefix = model.split('.')[0]
model_name = model.split('.')[1] model_name = model.split('.')[1]
if isinstance(messages, str): if isinstance(prompt_messages, str):
prompt = messages prompt = prompt_messages
else: else:
prompt = self._convert_messages_to_prompt(messages, prefix, model_name) prompt = self._convert_messages_to_prompt(prompt_messages, prefix, model_name)
return self._get_num_tokens_by_gpt2(prompt) return self._get_num_tokens_by_gpt2(prompt)
def validate_credentials(self, model: str, credentials: dict) -> None: def validate_credentials(self, model: str, credentials: dict) -> None:
""" """