mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-15 04:45:52 +08:00
feat: add OpenAI o1 series models support (#8328)
This commit is contained in:
parent
153807f243
commit
e90d3c29ab
@ -5,6 +5,10 @@
|
|||||||
- chatgpt-4o-latest
|
- chatgpt-4o-latest
|
||||||
- gpt-4o-mini
|
- gpt-4o-mini
|
||||||
- gpt-4o-mini-2024-07-18
|
- gpt-4o-mini-2024-07-18
|
||||||
|
- o1-preview
|
||||||
|
- o1-preview-2024-09-12
|
||||||
|
- o1-mini
|
||||||
|
- o1-mini-2024-09-12
|
||||||
- gpt-4-turbo
|
- gpt-4-turbo
|
||||||
- gpt-4-turbo-2024-04-09
|
- gpt-4-turbo-2024-04-09
|
||||||
- gpt-4-turbo-preview
|
- gpt-4-turbo-preview
|
||||||
|
@ -613,6 +613,13 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||||||
# clear illegal prompt messages
|
# clear illegal prompt messages
|
||||||
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
|
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
|
||||||
|
|
||||||
|
block_as_stream = False
|
||||||
|
if model.startswith("o1"):
|
||||||
|
block_as_stream = True
|
||||||
|
stream = False
|
||||||
|
if "stream_options" in extra_model_kwargs:
|
||||||
|
del extra_model_kwargs["stream_options"]
|
||||||
|
|
||||||
# chat model
|
# chat model
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
|
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
|
||||||
@ -625,7 +632,39 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||||||
if stream:
|
if stream:
|
||||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
|
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
|
||||||
|
|
||||||
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
||||||
|
|
||||||
|
if block_as_stream:
|
||||||
|
return self._handle_chat_block_as_stream_response(block_result, prompt_messages)
|
||||||
|
|
||||||
|
return block_result
|
||||||
|
|
||||||
|
def _handle_chat_block_as_stream_response(
|
||||||
|
self,
|
||||||
|
block_result: LLMResult,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
) -> Generator[LLMResultChunk, None, None]:
|
||||||
|
"""
|
||||||
|
Handle llm chat response
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: credentials
|
||||||
|
:param response: response
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:return: llm response chunk generator
|
||||||
|
"""
|
||||||
|
yield LLMResultChunk(
|
||||||
|
model=block_result.model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
system_fingerprint=block_result.system_fingerprint,
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=0,
|
||||||
|
message=block_result.message,
|
||||||
|
finish_reason="stop",
|
||||||
|
usage=block_result.usage,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def _handle_chat_generate_response(
|
def _handle_chat_generate_response(
|
||||||
self,
|
self,
|
||||||
@ -960,7 +999,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||||||
model = model.split(":")[1]
|
model = model.split(":")[1]
|
||||||
|
|
||||||
# Currently, we can use gpt4o to calculate chatgpt-4o-latest's token.
|
# Currently, we can use gpt4o to calculate chatgpt-4o-latest's token.
|
||||||
if model == "chatgpt-4o-latest":
|
if model == "chatgpt-4o-latest" or model.startswith("o1"):
|
||||||
model = "gpt-4o"
|
model = "gpt-4o"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -975,7 +1014,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||||||
tokens_per_message = 4
|
tokens_per_message = 4
|
||||||
# if there's a name, the role is omitted
|
# if there's a name, the role is omitted
|
||||||
tokens_per_name = -1
|
tokens_per_name = -1
|
||||||
elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"):
|
elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4") or model.startswith("o1"):
|
||||||
tokens_per_message = 3
|
tokens_per_message = 3
|
||||||
tokens_per_name = 1
|
tokens_per_name = 1
|
||||||
else:
|
else:
|
||||||
|
@ -0,0 +1,33 @@
|
|||||||
|
model: o1-mini-2024-09-12
|
||||||
|
label:
|
||||||
|
zh_Hans: o1-mini-2024-09-12
|
||||||
|
en_US: o1-mini-2024-09-12
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 128000
|
||||||
|
parameter_rules:
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
default: 65563
|
||||||
|
min: 1
|
||||||
|
max: 65563
|
||||||
|
- name: response_format
|
||||||
|
label:
|
||||||
|
zh_Hans: 回复格式
|
||||||
|
en_US: response_format
|
||||||
|
type: string
|
||||||
|
help:
|
||||||
|
zh_Hans: 指定模型必须输出的格式
|
||||||
|
en_US: specifying the format that the model must output
|
||||||
|
required: false
|
||||||
|
options:
|
||||||
|
- text
|
||||||
|
- json_object
|
||||||
|
pricing:
|
||||||
|
input: '3.00'
|
||||||
|
output: '12.00'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
@ -0,0 +1,33 @@
|
|||||||
|
model: o1-mini
|
||||||
|
label:
|
||||||
|
zh_Hans: o1-mini
|
||||||
|
en_US: o1-mini
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 128000
|
||||||
|
parameter_rules:
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
default: 65563
|
||||||
|
min: 1
|
||||||
|
max: 65563
|
||||||
|
- name: response_format
|
||||||
|
label:
|
||||||
|
zh_Hans: 回复格式
|
||||||
|
en_US: response_format
|
||||||
|
type: string
|
||||||
|
help:
|
||||||
|
zh_Hans: 指定模型必须输出的格式
|
||||||
|
en_US: specifying the format that the model must output
|
||||||
|
required: false
|
||||||
|
options:
|
||||||
|
- text
|
||||||
|
- json_object
|
||||||
|
pricing:
|
||||||
|
input: '3.00'
|
||||||
|
output: '12.00'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
@ -0,0 +1,33 @@
|
|||||||
|
model: o1-preview-2024-09-12
|
||||||
|
label:
|
||||||
|
zh_Hans: o1-preview-2024-09-12
|
||||||
|
en_US: o1-preview-2024-09-12
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 128000
|
||||||
|
parameter_rules:
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
default: 32768
|
||||||
|
min: 1
|
||||||
|
max: 32768
|
||||||
|
- name: response_format
|
||||||
|
label:
|
||||||
|
zh_Hans: 回复格式
|
||||||
|
en_US: response_format
|
||||||
|
type: string
|
||||||
|
help:
|
||||||
|
zh_Hans: 指定模型必须输出的格式
|
||||||
|
en_US: specifying the format that the model must output
|
||||||
|
required: false
|
||||||
|
options:
|
||||||
|
- text
|
||||||
|
- json_object
|
||||||
|
pricing:
|
||||||
|
input: '15.00'
|
||||||
|
output: '60.00'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
@ -0,0 +1,33 @@
|
|||||||
|
model: o1-preview
|
||||||
|
label:
|
||||||
|
zh_Hans: o1-preview
|
||||||
|
en_US: o1-preview
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 128000
|
||||||
|
parameter_rules:
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
default: 32768
|
||||||
|
min: 1
|
||||||
|
max: 32768
|
||||||
|
- name: response_format
|
||||||
|
label:
|
||||||
|
zh_Hans: 回复格式
|
||||||
|
en_US: response_format
|
||||||
|
type: string
|
||||||
|
help:
|
||||||
|
zh_Hans: 指定模型必须输出的格式
|
||||||
|
en_US: specifying the format that the model must output
|
||||||
|
required: false
|
||||||
|
options:
|
||||||
|
- text
|
||||||
|
- json_object
|
||||||
|
pricing:
|
||||||
|
input: '15.00'
|
||||||
|
output: '60.00'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
@ -60,7 +60,8 @@ ignore = [
|
|||||||
"SIM113", # eumerate-for-loop
|
"SIM113", # eumerate-for-loop
|
||||||
"SIM117", # multiple-with-statements
|
"SIM117", # multiple-with-statements
|
||||||
"SIM210", # if-expr-with-true-false
|
"SIM210", # if-expr-with-true-false
|
||||||
"SIM300", # yoda-conditions
|
"SIM300", # yoda-conditions,
|
||||||
|
"PT004", # pytest-no-assert
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
[tool.ruff.lint.per-file-ignores]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user