mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-16 06:45:53 +08:00
chore: Extract common functions of the base model in Azure OpenAI Provider (#9907)
This commit is contained in:
parent
216442ddc1
commit
22776f24ab
@ -53,6 +53,9 @@ model_credential_schema:
|
|||||||
type: select
|
type: select
|
||||||
required: true
|
required: true
|
||||||
options:
|
options:
|
||||||
|
- label:
|
||||||
|
en_US: 2024-10-01-preview
|
||||||
|
value: 2024-10-01-preview
|
||||||
- label:
|
- label:
|
||||||
en_US: 2024-09-01-preview
|
en_US: 2024-09-01-preview
|
||||||
value: 2024-09-01-preview
|
value: 2024-09-01-preview
|
||||||
|
@ -45,9 +45,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
) -> Union[LLMResult, Generator]:
|
) -> Union[LLMResult, Generator]:
|
||||||
base_model_name = credentials.get("base_model_name")
|
base_model_name = self._get_base_model_name(credentials)
|
||||||
if not base_model_name:
|
|
||||||
raise ValueError("Base Model Name is required")
|
|
||||||
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
||||||
|
|
||||||
if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
||||||
@ -81,9 +79,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
base_model_name = credentials.get("base_model_name")
|
base_model_name = self._get_base_model_name(credentials)
|
||||||
if not base_model_name:
|
|
||||||
raise ValueError("Base Model Name is required")
|
|
||||||
model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
||||||
if not model_entity:
|
if not model_entity:
|
||||||
raise ValueError(f"Base Model Name {base_model_name} is invalid")
|
raise ValueError(f"Base Model Name {base_model_name} is invalid")
|
||||||
@ -108,9 +104,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
if "base_model_name" not in credentials:
|
if "base_model_name" not in credentials:
|
||||||
raise CredentialsValidateFailedError("Base Model Name is required")
|
raise CredentialsValidateFailedError("Base Model Name is required")
|
||||||
|
|
||||||
base_model_name = credentials.get("base_model_name")
|
base_model_name = self._get_base_model_name(credentials)
|
||||||
if not base_model_name:
|
|
||||||
raise CredentialsValidateFailedError("Base Model Name is required")
|
|
||||||
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
||||||
|
|
||||||
if not ai_model_entity:
|
if not ai_model_entity:
|
||||||
@ -149,9 +143,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
raise CredentialsValidateFailedError(str(ex))
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||||
base_model_name = credentials.get("base_model_name")
|
base_model_name = self._get_base_model_name(credentials)
|
||||||
if not base_model_name:
|
|
||||||
raise ValueError("Base Model Name is required")
|
|
||||||
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
||||||
return ai_model_entity.entity if ai_model_entity else None
|
return ai_model_entity.entity if ai_model_entity else None
|
||||||
|
|
||||||
@ -308,11 +300,6 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
extra_model_kwargs["tools"] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
|
extra_model_kwargs["tools"] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
|
||||||
# extra_model_kwargs['functions'] = [{
|
|
||||||
# "name": tool.name,
|
|
||||||
# "description": tool.description,
|
|
||||||
# "parameters": tool.parameters
|
|
||||||
# } for tool in tools]
|
|
||||||
|
|
||||||
if stop:
|
if stop:
|
||||||
extra_model_kwargs["stop"] = stop
|
extra_model_kwargs["stop"] = stop
|
||||||
@ -769,3 +756,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
ai_model_entity_copy.entity.label.en_US = model
|
ai_model_entity_copy.entity.label.en_US = model
|
||||||
ai_model_entity_copy.entity.label.zh_Hans = model
|
ai_model_entity_copy.entity.label.zh_Hans = model
|
||||||
return ai_model_entity_copy
|
return ai_model_entity_copy
|
||||||
|
|
||||||
|
def _get_base_model_name(self, credentials: dict) -> str:
|
||||||
|
base_model_name = credentials.get("base_model_name")
|
||||||
|
if not base_model_name:
|
||||||
|
raise ValueError("Base Model Name is required")
|
||||||
|
return base_model_name
|
||||||
|
Loading…
x
Reference in New Issue
Block a user