feat: add OpenAI o1 series models support (#8328)

This commit is contained in:
takatost 2024-09-13 02:15:19 +08:00 committed by GitHub
parent 153807f243
commit e90d3c29ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 180 additions and 4 deletions

View File

@ -5,6 +5,10 @@
- chatgpt-4o-latest
- gpt-4o-mini
- gpt-4o-mini-2024-07-18
- o1-preview
- o1-preview-2024-09-12
- o1-mini
- o1-mini-2024-09-12
- gpt-4-turbo
- gpt-4-turbo-2024-04-09
- gpt-4-turbo-preview

View File

@ -613,6 +613,13 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
# clear illegal prompt messages
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
block_as_stream = False
if model.startswith("o1"):
block_as_stream = True
stream = False
if "stream_options" in extra_model_kwargs:
del extra_model_kwargs["stream_options"]
# chat model
response = client.chat.completions.create(
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
@ -625,7 +632,39 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
if stream:
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
if block_as_stream:
return self._handle_chat_block_as_stream_response(block_result, prompt_messages)
return block_result
def _handle_chat_block_as_stream_response(
self,
block_result: LLMResult,
prompt_messages: list[PromptMessage],
) -> Generator[LLMResultChunk, None, None]:
"""
Handle llm chat response
:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return: llm response chunk generator
"""
yield LLMResultChunk(
model=block_result.model,
prompt_messages=prompt_messages,
system_fingerprint=block_result.system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=block_result.message,
finish_reason="stop",
usage=block_result.usage,
),
)
def _handle_chat_generate_response(
self,
@ -960,7 +999,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
model = model.split(":")[1]
# Currently, we can use gpt4o to calculate chatgpt-4o-latest's token.
if model == "chatgpt-4o-latest":
if model == "chatgpt-4o-latest" or model.startswith("o1"):
model = "gpt-4o"
try:
@ -975,7 +1014,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
tokens_per_message = 4
# if there's a name, the role is omitted
tokens_per_name = -1
elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"):
elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4") or model.startswith("o1"):
tokens_per_message = 3
tokens_per_name = 1
else:

View File

@ -0,0 +1,33 @@
model: o1-mini-2024-09-12
label:
zh_Hans: o1-mini-2024-09-12
en_US: o1-mini-2024-09-12
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: max_tokens
use_template: max_tokens
default: 65563
min: 1
max: 65563
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '3.00'
output: '12.00'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,33 @@
model: o1-mini
label:
zh_Hans: o1-mini
en_US: o1-mini
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: max_tokens
use_template: max_tokens
default: 65563
min: 1
max: 65563
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '3.00'
output: '12.00'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,33 @@
model: o1-preview-2024-09-12
label:
zh_Hans: o1-preview-2024-09-12
en_US: o1-preview-2024-09-12
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: max_tokens
use_template: max_tokens
default: 32768
min: 1
max: 32768
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '15.00'
output: '60.00'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,33 @@
model: o1-preview
label:
zh_Hans: o1-preview
en_US: o1-preview
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: max_tokens
use_template: max_tokens
default: 32768
min: 1
max: 32768
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '15.00'
output: '60.00'
unit: '0.000001'
currency: USD

View File

@ -60,7 +60,8 @@ ignore = [
"SIM113", # eumerate-for-loop
"SIM117", # multiple-with-statements
"SIM210", # if-expr-with-true-false
"SIM300", # yoda-conditions
"SIM300", # yoda-conditions,
"PT004", # pytest-no-assert
]
[tool.ruff.lint.per-file-ignores]