From c98d91e44d75cf03f395eee521e5af9a36a45ad8 Mon Sep 17 00:00:00 2001 From: jiangbo721 <365065261@qq.com> Date: Wed, 25 Dec 2024 13:29:43 +0800 Subject: [PATCH] fix: o1 model error, use max_completion_tokens instead of max_tokens. (#12037) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 刘江波 --- .../model_providers/azure_openai/llm/llm.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py index c5d7a83a4e..03818741f6 100644 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -113,7 +113,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): try: client = AzureOpenAI(**self._to_credential_kwargs(credentials)) - if "o1" in model: + if model.startswith("o1"): client.chat.completions.create( messages=[{"role": "user", "content": "ping"}], model=model, @@ -311,7 +311,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) block_as_stream = False - if "o1" in model: + if model.startswith("o1"): + if "max_tokens" in model_parameters: + model_parameters["max_completion_tokens"] = model_parameters["max_tokens"] + del model_parameters["max_tokens"] if stream: block_as_stream = True stream = False @@ -404,7 +407,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): ] ) - if "o1" in model: + if model.startswith("o1"): system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)]) if system_message_count > 0: new_prompt_messages = []