mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-20 08:09:24 +08:00
azure add o1-mini、o1-preview models (#9088)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
parent
c0b71f8286
commit
55679b4389
@ -1081,8 +1081,81 @@ LLM_BASE_MODELS = [
|
|||||||
),
|
),
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
AzureBaseModel(
|
||||||
|
base_model_name="o1-preview",
|
||||||
|
entity=AIModelEntity(
|
||||||
|
model="fake-deployment-name",
|
||||||
|
label=I18nObject(
|
||||||
|
en_US="fake-deployment-name-label",
|
||||||
|
),
|
||||||
|
model_type=ModelType.LLM,
|
||||||
|
features=[
|
||||||
|
ModelFeature.AGENT_THOUGHT,
|
||||||
|
],
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_properties={
|
||||||
|
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
||||||
|
ModelPropertyKey.CONTEXT_SIZE: 128000,
|
||||||
|
},
|
||||||
|
parameter_rules=[
|
||||||
|
ParameterRule(
|
||||||
|
name="response_format",
|
||||||
|
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
|
||||||
|
type="string",
|
||||||
|
help=I18nObject(
|
||||||
|
zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output"
|
||||||
|
),
|
||||||
|
required=False,
|
||||||
|
options=["text", "json_object"],
|
||||||
|
),
|
||||||
|
_get_max_tokens(default=512, min_val=1, max_val=32768),
|
||||||
|
],
|
||||||
|
pricing=PriceConfig(
|
||||||
|
input=15.00,
|
||||||
|
output=60.00,
|
||||||
|
unit=0.000001,
|
||||||
|
currency="USD",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
AzureBaseModel(
|
||||||
|
base_model_name="o1-mini",
|
||||||
|
entity=AIModelEntity(
|
||||||
|
model="fake-deployment-name",
|
||||||
|
label=I18nObject(
|
||||||
|
en_US="fake-deployment-name-label",
|
||||||
|
),
|
||||||
|
model_type=ModelType.LLM,
|
||||||
|
features=[
|
||||||
|
ModelFeature.AGENT_THOUGHT,
|
||||||
|
],
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_properties={
|
||||||
|
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
||||||
|
ModelPropertyKey.CONTEXT_SIZE: 128000,
|
||||||
|
},
|
||||||
|
parameter_rules=[
|
||||||
|
ParameterRule(
|
||||||
|
name="response_format",
|
||||||
|
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
|
||||||
|
type="string",
|
||||||
|
help=I18nObject(
|
||||||
|
zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output"
|
||||||
|
),
|
||||||
|
required=False,
|
||||||
|
options=["text", "json_object"],
|
||||||
|
),
|
||||||
|
_get_max_tokens(default=512, min_val=1, max_val=65536),
|
||||||
|
],
|
||||||
|
pricing=PriceConfig(
|
||||||
|
input=3.00,
|
||||||
|
output=12.00,
|
||||||
|
unit=0.000001,
|
||||||
|
currency="USD",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
EMBEDDING_BASE_MODELS = [
|
EMBEDDING_BASE_MODELS = [
|
||||||
AzureBaseModel(
|
AzureBaseModel(
|
||||||
base_model_name="text-embedding-ada-002",
|
base_model_name="text-embedding-ada-002",
|
||||||
|
@ -120,6 +120,18 @@ model_credential_schema:
|
|||||||
show_on:
|
show_on:
|
||||||
- variable: __model_type
|
- variable: __model_type
|
||||||
value: llm
|
value: llm
|
||||||
|
- label:
|
||||||
|
en_US: o1-mini
|
||||||
|
value: o1-mini
|
||||||
|
show_on:
|
||||||
|
- variable: __model_type
|
||||||
|
value: llm
|
||||||
|
- label:
|
||||||
|
en_US: o1-preview
|
||||||
|
value: o1-preview
|
||||||
|
show_on:
|
||||||
|
- variable: __model_type
|
||||||
|
value: llm
|
||||||
- label:
|
- label:
|
||||||
en_US: gpt-4o-mini
|
en_US: gpt-4o-mini
|
||||||
value: gpt-4o-mini
|
value: gpt-4o-mini
|
||||||
|
@ -312,10 +312,24 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
if user:
|
if user:
|
||||||
extra_model_kwargs["user"] = user
|
extra_model_kwargs["user"] = user
|
||||||
|
|
||||||
|
# clear illegal prompt messages
|
||||||
|
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
|
||||||
|
|
||||||
|
block_as_stream = False
|
||||||
|
if model.startswith("o1"):
|
||||||
|
if stream:
|
||||||
|
block_as_stream = True
|
||||||
|
stream = False
|
||||||
|
|
||||||
|
if "stream_options" in extra_model_kwargs:
|
||||||
|
del extra_model_kwargs["stream_options"]
|
||||||
|
|
||||||
|
if "stop" in extra_model_kwargs:
|
||||||
|
del extra_model_kwargs["stop"]
|
||||||
|
|
||||||
# chat model
|
# chat model
|
||||||
messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
messages=messages,
|
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
|
||||||
model=model,
|
model=model,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
**model_parameters,
|
**model_parameters,
|
||||||
@ -325,7 +339,91 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
if stream:
|
if stream:
|
||||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
|
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
|
||||||
|
|
||||||
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
||||||
|
|
||||||
|
if block_as_stream:
|
||||||
|
return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop)
|
||||||
|
|
||||||
|
return block_result
|
||||||
|
|
||||||
|
def _handle_chat_block_as_stream_response(
|
||||||
|
self,
|
||||||
|
block_result: LLMResult,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
) -> Generator[LLMResultChunk, None, None]:
|
||||||
|
"""
|
||||||
|
Handle llm chat response
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: credentials
|
||||||
|
:param response: response
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:param stop: stop words
|
||||||
|
:return: llm response chunk generator
|
||||||
|
"""
|
||||||
|
text = block_result.message.content
|
||||||
|
text = cast(str, text)
|
||||||
|
|
||||||
|
if stop:
|
||||||
|
text = self.enforce_stop_tokens(text, stop)
|
||||||
|
|
||||||
|
yield LLMResultChunk(
|
||||||
|
model=block_result.model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
system_fingerprint=block_result.system_fingerprint,
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=0,
|
||||||
|
message=AssistantPromptMessage(content=text),
|
||||||
|
finish_reason="stop",
|
||||||
|
usage=block_result.usage,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||||
|
"""
|
||||||
|
Clear illegal prompt messages for OpenAI API
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:return: cleaned prompt messages
|
||||||
|
"""
|
||||||
|
checklist = ["gpt-4-turbo", "gpt-4-turbo-2024-04-09"]
|
||||||
|
|
||||||
|
if model in checklist:
|
||||||
|
# count how many user messages are there
|
||||||
|
user_message_count = len([m for m in prompt_messages if isinstance(m, UserPromptMessage)])
|
||||||
|
if user_message_count > 1:
|
||||||
|
for prompt_message in prompt_messages:
|
||||||
|
if isinstance(prompt_message, UserPromptMessage):
|
||||||
|
if isinstance(prompt_message.content, list):
|
||||||
|
prompt_message.content = "\n".join(
|
||||||
|
[
|
||||||
|
item.data
|
||||||
|
if item.type == PromptMessageContentType.TEXT
|
||||||
|
else "[IMAGE]"
|
||||||
|
if item.type == PromptMessageContentType.IMAGE
|
||||||
|
else ""
|
||||||
|
for item in prompt_message.content
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if model.startswith("o1"):
|
||||||
|
system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)])
|
||||||
|
if system_message_count > 0:
|
||||||
|
new_prompt_messages = []
|
||||||
|
for prompt_message in prompt_messages:
|
||||||
|
if isinstance(prompt_message, SystemPromptMessage):
|
||||||
|
prompt_message = UserPromptMessage(
|
||||||
|
content=prompt_message.content,
|
||||||
|
name=prompt_message.name,
|
||||||
|
)
|
||||||
|
|
||||||
|
new_prompt_messages.append(prompt_message)
|
||||||
|
prompt_messages = new_prompt_messages
|
||||||
|
|
||||||
|
return prompt_messages
|
||||||
|
|
||||||
def _handle_chat_generate_response(
|
def _handle_chat_generate_response(
|
||||||
self,
|
self,
|
||||||
@ -560,7 +658,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
tokens_per_message = 4
|
tokens_per_message = 4
|
||||||
# if there's a name, the role is omitted
|
# if there's a name, the role is omitted
|
||||||
tokens_per_name = -1
|
tokens_per_name = -1
|
||||||
elif model.startswith("gpt-35-turbo") or model.startswith("gpt-4"):
|
elif model.startswith("gpt-35-turbo") or model.startswith("gpt-4") or model.startswith("o1"):
|
||||||
tokens_per_message = 3
|
tokens_per_message = 3
|
||||||
tokens_per_name = 1
|
tokens_per_name = 1
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user