diff --git a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml index 093f57c51e..1ef5e83abc 100644 --- a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml +++ b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml @@ -53,6 +53,9 @@ model_credential_schema: type: select required: true options: + - label: + en_US: 2024-10-01-preview + value: 2024-10-01-preview - label: en_US: 2024-09-01-preview value: 2024-09-01-preview diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py index 4f46fb81f8..1cd4823e13 100644 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -45,9 +45,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): stream: bool = True, user: Optional[str] = None, ) -> Union[LLMResult, Generator]: - base_model_name = credentials.get("base_model_name") - if not base_model_name: - raise ValueError("Base Model Name is required") + base_model_name = self._get_base_model_name(credentials) ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: @@ -81,9 +79,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None, ) -> int: - base_model_name = credentials.get("base_model_name") - if not base_model_name: - raise ValueError("Base Model Name is required") + base_model_name = self._get_base_model_name(credentials) model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) if not model_entity: raise ValueError(f"Base Model Name {base_model_name} is invalid") @@ -108,9 +104,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): if "base_model_name" not in credentials: raise CredentialsValidateFailedError("Base Model Name is required") - base_model_name = credentials.get("base_model_name") - if not base_model_name: - raise CredentialsValidateFailedError("Base Model Name is required") + base_model_name = self._get_base_model_name(credentials) ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) if not ai_model_entity: @@ -149,9 +143,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): raise CredentialsValidateFailedError(str(ex)) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - base_model_name = credentials.get("base_model_name") - if not base_model_name: - raise ValueError("Base Model Name is required") + base_model_name = self._get_base_model_name(credentials) ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) return ai_model_entity.entity if ai_model_entity else None @@ -308,11 +300,6 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): if tools: extra_model_kwargs["tools"] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools] - # extra_model_kwargs['functions'] = [{ - # "name": tool.name, - # "description": tool.description, - # "parameters": tool.parameters - # } for tool in tools] if stop: extra_model_kwargs["stop"] = stop @@ -769,3 +756,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): ai_model_entity_copy.entity.label.en_US = model ai_model_entity_copy.entity.label.zh_Hans = model return ai_model_entity_copy + + def _get_base_model_name(self, credentials: dict) -> str: + base_model_name = credentials.get("base_model_name") + if not base_model_name: + raise ValueError("Base Model Name is required") + return base_model_name