fix: o1 model error, use max_completion_tokens instead of max_tokens. (#12037)

Co-authored-by: 刘江波 <jiangbo721@163.com>
This commit is contained in:
jiangbo721 2024-12-25 13:29:43 +08:00 committed by GitHub
parent 3ea54e9d25
commit c98d91e44d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -113,7 +113,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
try: try:
client = AzureOpenAI(**self._to_credential_kwargs(credentials)) client = AzureOpenAI(**self._to_credential_kwargs(credentials))
if "o1" in model: if model.startswith("o1"):
client.chat.completions.create( client.chat.completions.create(
messages=[{"role": "user", "content": "ping"}], messages=[{"role": "user", "content": "ping"}],
model=model, model=model,
@ -311,7 +311,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
block_as_stream = False block_as_stream = False
if "o1" in model: if model.startswith("o1"):
if "max_tokens" in model_parameters:
model_parameters["max_completion_tokens"] = model_parameters["max_tokens"]
del model_parameters["max_tokens"]
if stream: if stream:
block_as_stream = True block_as_stream = True
stream = False stream = False
@ -404,7 +407,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
] ]
) )
if "o1" in model: if model.startswith("o1"):
system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)]) system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)])
if system_message_count > 0: if system_message_count > 0:
new_prompt_messages = [] new_prompt_messages = []