mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 00:35:56 +08:00
feat: add OpenAI o1 series models support (#8328)
This commit is contained in:
parent
153807f243
commit
e90d3c29ab
@ -5,6 +5,10 @@
|
||||
- chatgpt-4o-latest
|
||||
- gpt-4o-mini
|
||||
- gpt-4o-mini-2024-07-18
|
||||
- o1-preview
|
||||
- o1-preview-2024-09-12
|
||||
- o1-mini
|
||||
- o1-mini-2024-09-12
|
||||
- gpt-4-turbo
|
||||
- gpt-4-turbo-2024-04-09
|
||||
- gpt-4-turbo-preview
|
||||
|
@ -613,6 +613,13 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
# clear illegal prompt messages
|
||||
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
|
||||
|
||||
block_as_stream = False
|
||||
if model.startswith("o1"):
|
||||
block_as_stream = True
|
||||
stream = False
|
||||
if "stream_options" in extra_model_kwargs:
|
||||
del extra_model_kwargs["stream_options"]
|
||||
|
||||
# chat model
|
||||
response = client.chat.completions.create(
|
||||
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
|
||||
@ -625,7 +632,39 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
if stream:
|
||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
|
||||
|
||||
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
||||
block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
||||
|
||||
if block_as_stream:
|
||||
return self._handle_chat_block_as_stream_response(block_result, prompt_messages)
|
||||
|
||||
return block_result
|
||||
|
||||
def _handle_chat_block_as_stream_response(
|
||||
self,
|
||||
block_result: LLMResult,
|
||||
prompt_messages: list[PromptMessage],
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Handle llm chat response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return: llm response chunk generator
|
||||
"""
|
||||
yield LLMResultChunk(
|
||||
model=block_result.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=block_result.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=block_result.message,
|
||||
finish_reason="stop",
|
||||
usage=block_result.usage,
|
||||
),
|
||||
)
|
||||
|
||||
def _handle_chat_generate_response(
|
||||
self,
|
||||
@ -960,7 +999,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
model = model.split(":")[1]
|
||||
|
||||
# Currently, we can use gpt4o to calculate chatgpt-4o-latest's token.
|
||||
if model == "chatgpt-4o-latest":
|
||||
if model == "chatgpt-4o-latest" or model.startswith("o1"):
|
||||
model = "gpt-4o"
|
||||
|
||||
try:
|
||||
@ -975,7 +1014,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
tokens_per_message = 4
|
||||
# if there's a name, the role is omitted
|
||||
tokens_per_name = -1
|
||||
elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"):
|
||||
elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4") or model.startswith("o1"):
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
|
@ -0,0 +1,33 @@
|
||||
model: o1-mini-2024-09-12
|
||||
label:
|
||||
zh_Hans: o1-mini-2024-09-12
|
||||
en_US: o1-mini-2024-09-12
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 65563
|
||||
min: 1
|
||||
max: 65563
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '3.00'
|
||||
output: '12.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
@ -0,0 +1,33 @@
|
||||
model: o1-mini
|
||||
label:
|
||||
zh_Hans: o1-mini
|
||||
en_US: o1-mini
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 65563
|
||||
min: 1
|
||||
max: 65563
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '3.00'
|
||||
output: '12.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
@ -0,0 +1,33 @@
|
||||
model: o1-preview-2024-09-12
|
||||
label:
|
||||
zh_Hans: o1-preview-2024-09-12
|
||||
en_US: o1-preview-2024-09-12
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 32768
|
||||
min: 1
|
||||
max: 32768
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '15.00'
|
||||
output: '60.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
@ -0,0 +1,33 @@
|
||||
model: o1-preview
|
||||
label:
|
||||
zh_Hans: o1-preview
|
||||
en_US: o1-preview
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 32768
|
||||
min: 1
|
||||
max: 32768
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '15.00'
|
||||
output: '60.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
@ -60,7 +60,8 @@ ignore = [
|
||||
"SIM113", # eumerate-for-loop
|
||||
"SIM117", # multiple-with-statements
|
||||
"SIM210", # if-expr-with-true-false
|
||||
"SIM300", # yoda-conditions
|
||||
"SIM300", # yoda-conditions,
|
||||
"PT004", # pytest-no-assert
|
||||
]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
|
Loading…
x
Reference in New Issue
Block a user