From 55679b43898d2eb3f71a2fd5fea08c3c0a3e5926 Mon Sep 17 00:00:00 2001 From: "Charlie.Wei" Date: Wed, 9 Oct 2024 16:15:03 +0800 Subject: [PATCH] =?UTF-8?q?azure=20add=20o1-mini=E3=80=81o1-preview=20mode?= =?UTF-8?q?ls=20(#9088)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> --- .../model_providers/azure_openai/_constant.py | 75 ++++++++++- .../azure_openai/azure_openai.yaml | 12 ++ .../model_providers/azure_openai/llm/llm.py | 122 ++++++++++++++++-- 3 files changed, 196 insertions(+), 13 deletions(-) diff --git a/api/core/model_runtime/model_providers/azure_openai/_constant.py b/api/core/model_runtime/model_providers/azure_openai/_constant.py index 0dada70cc5..baa5421396 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_constant.py +++ b/api/core/model_runtime/model_providers/azure_openai/_constant.py @@ -1081,8 +1081,81 @@ LLM_BASE_MODELS = [ ), ), ), + AzureBaseModel( + base_model_name="o1-preview", + entity=AIModelEntity( + model="fake-deployment-name", + label=I18nObject( + en_US="fake-deployment-name-label", + ), + model_type=ModelType.LLM, + features=[ + ModelFeature.AGENT_THOUGHT, + ], + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.MODE: LLMMode.CHAT.value, + ModelPropertyKey.CONTEXT_SIZE: 128000, + }, + parameter_rules=[ + ParameterRule( + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", + help=I18nObject( + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" + ), + required=False, + options=["text", "json_object"], + ), + _get_max_tokens(default=512, min_val=1, max_val=32768), + ], + pricing=PriceConfig( + input=15.00, + output=60.00, + unit=0.000001, + currency="USD", + ), + ), + ), + AzureBaseModel( + base_model_name="o1-mini", + entity=AIModelEntity( + model="fake-deployment-name", + label=I18nObject( + en_US="fake-deployment-name-label", + ), + model_type=ModelType.LLM, + features=[ + ModelFeature.AGENT_THOUGHT, + ], + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.MODE: LLMMode.CHAT.value, + ModelPropertyKey.CONTEXT_SIZE: 128000, + }, + parameter_rules=[ + ParameterRule( + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", + help=I18nObject( + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" + ), + required=False, + options=["text", "json_object"], + ), + _get_max_tokens(default=512, min_val=1, max_val=65536), + ], + pricing=PriceConfig( + input=3.00, + output=12.00, + unit=0.000001, + currency="USD", + ), + ), + ), ] - EMBEDDING_BASE_MODELS = [ AzureBaseModel( base_model_name="text-embedding-ada-002", diff --git a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml index 867f9fec42..cb87bfdf6b 100644 --- a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml +++ b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml @@ -120,6 +120,18 @@ model_credential_schema: show_on: - variable: __model_type value: llm + - label: + en_US: o1-mini + value: o1-mini + show_on: + - variable: __model_type + value: llm + - label: + en_US: o1-preview + value: o1-preview + show_on: + - variable: __model_type + value: llm - label: en_US: gpt-4o-mini value: gpt-4o-mini diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py index f0033ea051..f63c80e58e 100644 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -312,20 +312,118 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): if user: extra_model_kwargs["user"] = user - # chat model - messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] - response = client.chat.completions.create( - messages=messages, - model=model, - stream=stream, - **model_parameters, - **extra_model_kwargs, + # clear illegal prompt messages + prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) + + block_as_stream = False + if model.startswith("o1"): + if stream: + block_as_stream = True + stream = False + + if "stream_options" in extra_model_kwargs: + del extra_model_kwargs["stream_options"] + + if "stop" in extra_model_kwargs: + del extra_model_kwargs["stop"] + + # chat model + response = client.chat.completions.create( + messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], + model=model, + stream=stream, + **model_parameters, + **extra_model_kwargs, + ) + + if stream: + return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools) + + block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) + + if block_as_stream: + return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop) + + return block_result + + def _handle_chat_block_as_stream_response( + self, + block_result: LLMResult, + prompt_messages: list[PromptMessage], + stop: Optional[list[str]] = None, + ) -> Generator[LLMResultChunk, None, None]: + """ + Handle llm chat response + + :param model: model name + :param credentials: credentials + :param response: response + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :param stop: stop words + :return: llm response chunk generator + """ + text = block_result.message.content + text = cast(str, text) + + if stop: + text = self.enforce_stop_tokens(text, stop) + + yield LLMResultChunk( + model=block_result.model, + prompt_messages=prompt_messages, + system_fingerprint=block_result.system_fingerprint, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=text), + finish_reason="stop", + usage=block_result.usage, + ), ) - if stream: - return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools) + def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: + """ + Clear illegal prompt messages for OpenAI API - return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) + :param model: model name + :param prompt_messages: prompt messages + :return: cleaned prompt messages + """ + checklist = ["gpt-4-turbo", "gpt-4-turbo-2024-04-09"] + + if model in checklist: + # count how many user messages are there + user_message_count = len([m for m in prompt_messages if isinstance(m, UserPromptMessage)]) + if user_message_count > 1: + for prompt_message in prompt_messages: + if isinstance(prompt_message, UserPromptMessage): + if isinstance(prompt_message.content, list): + prompt_message.content = "\n".join( + [ + item.data + if item.type == PromptMessageContentType.TEXT + else "[IMAGE]" + if item.type == PromptMessageContentType.IMAGE + else "" + for item in prompt_message.content + ] + ) + + if model.startswith("o1"): + system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)]) + if system_message_count > 0: + new_prompt_messages = [] + for prompt_message in prompt_messages: + if isinstance(prompt_message, SystemPromptMessage): + prompt_message = UserPromptMessage( + content=prompt_message.content, + name=prompt_message.name, + ) + + new_prompt_messages.append(prompt_message) + prompt_messages = new_prompt_messages + + return prompt_messages def _handle_chat_generate_response( self, @@ -560,7 +658,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): tokens_per_message = 4 # if there's a name, the role is omitted tokens_per_name = -1 - elif model.startswith("gpt-35-turbo") or model.startswith("gpt-4"): + elif model.startswith("gpt-35-turbo") or model.startswith("gpt-4") or model.startswith("o1"): tokens_per_message = 3 tokens_per_name = 1 else: