fix: tool call message role according to credentials (#5625)

Co-authored-by: sunxichen <sun.xc@digitalcnzz.com>
This commit is contained in:
sunxichen 2024-06-27 12:35:27 +08:00 committed by GitHub
parent 92c56fdf2b
commit bafc8a0bde
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -88,7 +88,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
:param tools: tools for tool calling :param tools: tools for tool calling
:return: :return:
""" """
return self._num_tokens_from_messages(model, prompt_messages, tools) return self._num_tokens_from_messages(model, prompt_messages, tools, credentials)
def validate_credentials(self, model: str, credentials: dict) -> None: def validate_credentials(self, model: str, credentials: dict) -> None:
""" """
@ -305,7 +305,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
if completion_type is LLMMode.CHAT: if completion_type is LLMMode.CHAT:
endpoint_url = urljoin(endpoint_url, 'chat/completions') endpoint_url = urljoin(endpoint_url, 'chat/completions')
data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] data['messages'] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages]
elif completion_type is LLMMode.COMPLETION: elif completion_type is LLMMode.COMPLETION:
endpoint_url = urljoin(endpoint_url, 'completions') endpoint_url = urljoin(endpoint_url, 'completions')
data['prompt'] = prompt_messages[0].content data['prompt'] = prompt_messages[0].content
@ -582,7 +582,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
return result return result
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: dict = None) -> dict:
""" """
Convert PromptMessage to dict for OpenAI API format Convert PromptMessage to dict for OpenAI API format
""" """
@ -636,7 +636,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
# "tool_call_id": message.tool_call_id # "tool_call_id": message.tool_call_id
# } # }
message_dict = { message_dict = {
"role": "function", "role": "tool" if credentials and credentials.get('function_calling_type', 'no_call') == 'tool_call' else "function",
"content": message.content, "content": message.content,
"name": message.tool_call_id "name": message.tool_call_id
} }
@ -675,7 +675,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
return num_tokens return num_tokens
def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int: tools: Optional[list[PromptMessageTool]] = None, credentials: dict = None) -> int:
""" """
Approximate num tokens with GPT2 tokenizer. Approximate num tokens with GPT2 tokenizer.
""" """
@ -684,7 +684,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
tokens_per_name = 1 tokens_per_name = 1
num_tokens = 0 num_tokens = 0
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages] messages_dict = [self._convert_prompt_message_to_dict(m, credentials) for m in messages]
for message in messages_dict: for message in messages_dict:
num_tokens += tokens_per_message num_tokens += tokens_per_message
for key, value in message.items(): for key, value in message.items():