From ba67206bb9f463777822c8c3cb6995bc9149f7bc Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 24 Jun 2024 15:35:21 +0800 Subject: [PATCH] fix(api/model_runtime/azure/llm): Switch to tool_call. (#5541) --- .../model_providers/azure_openai/llm/llm.py | 269 +++++++++--------- .../model_runtime/__mock/openai_chat.py | 10 +- 2 files changed, 137 insertions(+), 142 deletions(-) diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py index eb6d985f23..25bc94cde6 100644 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -1,14 +1,13 @@ import copy import logging -from collections.abc import Generator +from collections.abc import Generator, Sequence from typing import Optional, Union, cast import tiktoken from openai import AzureOpenAI, Stream from openai.types import Completion from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall -from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall -from openai.types.chat.chat_completion_message import FunctionCall +from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -16,6 +15,7 @@ from core.model_runtime.entities.message_entities import ( ImagePromptMessageContent, PromptMessage, PromptMessageContentType, + PromptMessageFunction, PromptMessageTool, SystemPromptMessage, TextPromptMessageContent, @@ -26,7 +26,8 @@ from core.model_runtime.entities.model_entities import AIModelEntity, ModelPrope from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI -from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS, AzureBaseModel +from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS +from core.model_runtime.utils import helper logger = logging.getLogger(__name__) @@ -39,9 +40,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: - ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model) + base_model_name = credentials.get('base_model_name') + if not base_model_name: + raise ValueError('Base Model Name is required') + ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) - if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: + if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: # chat model return self._chat_generate( model=model, @@ -65,18 +69,29 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): user=user ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: - - model_mode = self._get_ai_model_entity(credentials.get('base_model_name'), model).entity.model_properties.get( - ModelPropertyKey.MODE) + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None + ) -> int: + base_model_name = credentials.get('base_model_name') + if not base_model_name: + raise ValueError('Base Model Name is required') + model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) + if not model_entity: + raise ValueError(f'Base Model Name {base_model_name} is invalid') + model_mode = model_entity.entity.model_properties.get(ModelPropertyKey.MODE) if model_mode == LLMMode.CHAT.value: # chat model return self._num_tokens_from_messages(credentials, prompt_messages, tools) else: # text completion model, do not support tool calling - return self._num_tokens_from_string(credentials, prompt_messages[0].content) + content = prompt_messages[0].content + assert isinstance(content, str) + return self._num_tokens_from_string(credentials,content) def validate_credentials(self, model: str, credentials: dict) -> None: if 'openai_api_base' not in credentials: @@ -88,7 +103,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): if 'base_model_name' not in credentials: raise CredentialsValidateFailedError('Base Model Name is required') - ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model) + base_model_name = credentials.get('base_model_name') + if not base_model_name: + raise CredentialsValidateFailedError('Base Model Name is required') + ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) if not ai_model_entity: raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid') @@ -118,7 +136,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): raise CredentialsValidateFailedError(str(ex)) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model) + base_model_name = credentials.get('base_model_name') + if not base_model_name: + raise ValueError('Base Model Name is required') + ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) return ai_model_entity.entity if ai_model_entity else None def _generate(self, model: str, credentials: dict, @@ -149,8 +170,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: Completion, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: Completion, + prompt_messages: list[PromptMessage] + ): assistant_text = response.choices[0].text # transform assistant message to prompt message @@ -165,7 +188,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): completion_tokens = response.usage.completion_tokens else: # calculate num tokens - prompt_tokens = self._num_tokens_from_string(credentials, prompt_messages[0].content) + content = prompt_messages[0].content + assert isinstance(content, str) + prompt_tokens = self._num_tokens_from_string(credentials, content) completion_tokens = self._num_tokens_from_string(credentials, assistant_text) # transform usage @@ -182,8 +207,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): return result - def _handle_generate_stream_response(self, model: str, credentials: dict, response: Stream[Completion], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: Stream[Completion], + prompt_messages: list[PromptMessage] + ) -> Generator: full_text = '' for chunk in response: if len(chunk.choices) == 0: @@ -210,7 +237,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): completion_tokens = chunk.usage.completion_tokens else: # calculate num tokens - prompt_tokens = self._num_tokens_from_string(credentials, prompt_messages[0].content) + content = prompt_messages[0].content + assert isinstance(content, str) + prompt_tokens = self._num_tokens_from_string(credentials, content) completion_tokens = self._num_tokens_from_string(credentials, full_text) # transform usage @@ -257,12 +286,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): extra_model_kwargs = {} if tools: - # extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools] - extra_model_kwargs['functions'] = [{ - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters - } for tool in tools] + extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools] + # extra_model_kwargs['functions'] = [{ + # "name": tool.name, + # "description": tool.description, + # "parameters": tool.parameters + # } for tool in tools] if stop: extra_model_kwargs['stop'] = stop @@ -271,8 +300,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): extra_model_kwargs['user'] = user # chat model + messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] response = client.chat.completions.create( - messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], + messages=messages, model=model, stream=stream, **model_parameters, @@ -284,18 +314,17 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) - def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion, - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> LLMResult: - + def _handle_chat_generate_response( + self, model: str, credentials: dict, response: ChatCompletion, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None + ): assistant_message = response.choices[0].message - # assistant_message_tool_calls = assistant_message.tool_calls - assistant_message_function_call = assistant_message.function_call + assistant_message_tool_calls = assistant_message.tool_calls # extract tool calls from response - # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) - function_call = self._extract_response_function_call(assistant_message_function_call) - tool_calls = [function_call] if function_call else [] + tool_calls = [] + self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=assistant_message_tool_calls) # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( @@ -317,7 +346,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) # transform response - response = LLMResult( + result = LLMResult( model=response.model or model, prompt_messages=prompt_messages, message=assistant_prompt_message, @@ -325,58 +354,34 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): system_fingerprint=response.system_fingerprint, ) - return response + return result - def _handle_chat_generate_stream_response(self, model: str, credentials: dict, - response: Stream[ChatCompletionChunk], - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> Generator: + def _handle_chat_generate_stream_response( + self, + model: str, + credentials: dict, + response: Stream[ChatCompletionChunk], + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None + ): index = 0 full_assistant_content = '' - delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None real_model = model system_fingerprint = None completion = '' + tool_calls = [] for chunk in response: if len(chunk.choices) == 0: continue delta = chunk.choices[0] + # extract tool calls from response + self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=delta.delta.tool_calls) + # Handling exceptions when content filters' streaming mode is set to asynchronous modified filter - if delta.delta is None or ( - delta.finish_reason is None - and (delta.delta.content is None or delta.delta.content == '') - and delta.delta.function_call is None - ): + if delta.finish_reason is None and not delta.delta.content: continue - - # assistant_message_tool_calls = delta.delta.tool_calls - assistant_message_function_call = delta.delta.function_call - - # extract tool calls from response - if delta_assistant_message_function_call_storage is not None: - # handle process of stream function call - if assistant_message_function_call: - # message has not ended ever - delta_assistant_message_function_call_storage.arguments += assistant_message_function_call.arguments - continue - else: - # message has ended - assistant_message_function_call = delta_assistant_message_function_call_storage - delta_assistant_message_function_call_storage = None - else: - if assistant_message_function_call: - # start of stream function call - delta_assistant_message_function_call_storage = assistant_message_function_call - if delta_assistant_message_function_call_storage.arguments is None: - delta_assistant_message_function_call_storage.arguments = '' - continue - - # extract tool calls from response - # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) - function_call = self._extract_response_function_call(assistant_message_function_call) - tool_calls = [function_call] if function_call else [] # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( @@ -426,54 +431,56 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): ) @staticmethod - def _extract_response_tool_calls(response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \ - -> list[AssistantPromptMessage.ToolCall]: + def _update_tool_calls(tool_calls: list[AssistantPromptMessage.ToolCall], tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]]) -> None: + if tool_calls_response: + for response_tool_call in tool_calls_response: + if isinstance(response_tool_call, ChatCompletionMessageToolCall): + function = AssistantPromptMessage.ToolCall.ToolCallFunction( + name=response_tool_call.function.name, + arguments=response_tool_call.function.arguments + ) - tool_calls = [] - if response_tool_calls: - for response_tool_call in response_tool_calls: - function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.function.name, - arguments=response_tool_call.function.arguments - ) + tool_call = AssistantPromptMessage.ToolCall( + id=response_tool_call.id, + type=response_tool_call.type, + function=function + ) + tool_calls.append(tool_call) + elif isinstance(response_tool_call, ChoiceDeltaToolCall): + index = response_tool_call.index + if index < len(tool_calls): + tool_calls[index].id = response_tool_call.id or tool_calls[index].id + tool_calls[index].type = response_tool_call.type or tool_calls[index].type + if response_tool_call.function: + tool_calls[index].function.name = response_tool_call.function.name or tool_calls[index].function.name + tool_calls[index].function.arguments += response_tool_call.function.arguments or '' + else: + assert response_tool_call.id is not None + assert response_tool_call.type is not None + assert response_tool_call.function is not None + assert response_tool_call.function.name is not None + assert response_tool_call.function.arguments is not None - tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.id, - type=response_tool_call.type, - function=function - ) - tool_calls.append(tool_call) - - return tool_calls + function = AssistantPromptMessage.ToolCall.ToolCallFunction( + name=response_tool_call.function.name, + arguments=response_tool_call.function.arguments + ) + tool_call = AssistantPromptMessage.ToolCall( + id=response_tool_call.id, + type=response_tool_call.type, + function=function + ) + tool_calls.append(tool_call) @staticmethod - def _extract_response_function_call(response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \ - -> AssistantPromptMessage.ToolCall: - - tool_call = None - if response_function_call: - function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_function_call.name, - arguments=response_function_call.arguments - ) - - tool_call = AssistantPromptMessage.ToolCall( - id=response_function_call.name, - type="function", - function=function - ) - - return tool_call - - @staticmethod - def _convert_prompt_message_to_dict(message: PromptMessage) -> dict: - + def _convert_prompt_message_to_dict(message: PromptMessage): if isinstance(message, UserPromptMessage): message = cast(UserPromptMessage, message) if isinstance(message.content, str): message_dict = {"role": "user", "content": message.content} else: sub_messages = [] + assert message.content is not None for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) @@ -492,33 +499,22 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): } } sub_messages.append(sub_message_dict) - message_dict = {"role": "user", "content": sub_messages} elif isinstance(message, AssistantPromptMessage): message = cast(AssistantPromptMessage, message) message_dict = {"role": "assistant", "content": message.content} if message.tool_calls: - # message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in - # message.tool_calls] - function_call = message.tool_calls[0] - message_dict["function_call"] = { - "name": function_call.function.name, - "arguments": function_call.function.arguments, - } + message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in message.tool_calls] elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) message_dict = {"role": "system", "content": message.content} elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) - # message_dict = { - # "role": "tool", - # "content": message.content, - # "tool_call_id": message.tool_call_id - # } message_dict = { - "role": "function", + "role": "tool", + "name": message.name, "content": message.content, - "name": message.tool_call_id + "tool_call_id": message.tool_call_id } else: raise ValueError(f"Got unknown type {message}") @@ -542,8 +538,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): return num_tokens - def _num_tokens_from_messages(self, credentials: dict, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, credentials: dict, messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None + ) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. Official documentation: https://github.com/openai/openai-cookbook/blob/ @@ -591,6 +589,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): if key == "tool_calls": for tool_call in value: + assert isinstance(tool_call, dict) for t_key, t_value in tool_call.items(): num_tokens += len(encoding.encode(t_key)) if t_key == "function": @@ -631,12 +630,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): num_tokens += len(encoding.encode('parameters')) if 'title' in parameters: num_tokens += len(encoding.encode('title')) - num_tokens += len(encoding.encode(parameters.get("title"))) + num_tokens += len(encoding.encode(parameters['title'])) num_tokens += len(encoding.encode('type')) - num_tokens += len(encoding.encode(parameters.get("type"))) + num_tokens += len(encoding.encode(parameters['type'])) if 'properties' in parameters: num_tokens += len(encoding.encode('properties')) - for key, value in parameters.get('properties').items(): + for key, value in parameters['properties'].items(): num_tokens += len(encoding.encode(key)) for field_key, field_value in value.items(): num_tokens += len(encoding.encode(field_key)) @@ -656,7 +655,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): return num_tokens @staticmethod - def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel: + def _get_ai_model_entity(base_model_name: str, model: str): for ai_model_entity in LLM_BASE_MODELS: if ai_model_entity.base_model_name == base_model_name: ai_model_entity_copy = copy.deepcopy(ai_model_entity) @@ -664,5 +663,3 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): ai_model_entity_copy.entity.label.en_US = model ai_model_entity_copy.entity.label.zh_Hans = model return ai_model_entity_copy - - return None diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_chat.py b/api/tests/integration_tests/model_runtime/__mock/openai_chat.py index c5af941d1e..ba902e32ea 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_chat.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_chat.py @@ -73,17 +73,15 @@ class MockChatClass: return FunctionCall(name=function_name, arguments=dumps(parameters)) @staticmethod - def generate_tool_calls( - tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, - ) -> Optional[list[ChatCompletionMessageToolCall]]: + def generate_tool_calls(tools = NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]: list_tool_calls = [] if not tools or len(tools) == 0: return None - tool: ChatCompletionToolParam = tools[0] + tool = tools[0] - if tools['type'] != 'function': + if 'type' in tools and tools['type'] != 'function': return None - + function = tool['function'] function_call = MockChatClass.generate_function_call(functions=[function])