From bafc8a0bde29c88539809c734548e8ec5fbd877d Mon Sep 17 00:00:00 2001 From: sunxichen Date: Thu, 27 Jun 2024 12:35:27 +0800 Subject: [PATCH] fix: tool call message role according to credentials (#5625) Co-authored-by: sunxichen --- .../model_providers/openai_api_compatible/llm/llm.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index f8726c853a..36eae2042d 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -88,7 +88,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): :param tools: tools for tool calling :return: """ - return self._num_tokens_from_messages(model, prompt_messages, tools) + return self._num_tokens_from_messages(model, prompt_messages, tools, credentials) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -305,7 +305,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): if completion_type is LLMMode.CHAT: endpoint_url = urljoin(endpoint_url, 'chat/completions') - data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] + data['messages'] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages] elif completion_type is LLMMode.COMPLETION: endpoint_url = urljoin(endpoint_url, 'completions') data['prompt'] = prompt_messages[0].content @@ -582,7 +582,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): return result - def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: + def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: dict = None) -> dict: """ Convert PromptMessage to dict for OpenAI API format """ @@ -636,7 +636,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): # "tool_call_id": message.tool_call_id # } message_dict = { - "role": "function", + "role": "tool" if credentials and credentials.get('function_calling_type', 'no_call') == 'tool_call' else "function", "content": message.content, "name": message.tool_call_id } @@ -675,7 +675,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): return num_tokens def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + tools: Optional[list[PromptMessageTool]] = None, credentials: dict = None) -> int: """ Approximate num tokens with GPT2 tokenizer. """ @@ -684,7 +684,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): tokens_per_name = 1 num_tokens = 0 - messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages] + messages_dict = [self._convert_prompt_message_to_dict(m, credentials) for m in messages] for message in messages_dict: num_tokens += tokens_per_message for key, value in message.items():