mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 20:48:58 +08:00
Feat/open azure validate (#163)
This commit is contained in:
parent
1c5f63de7e
commit
6da5e54180
@ -157,7 +157,7 @@ class ProviderTokenValidateApi(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
# todo: remove this when the provider is supported
|
||||
if provider in [ProviderName.ANTHROPIC.value, ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
|
||||
if provider in [ProviderName.ANTHROPIC.value, ProviderName.COHERE.value,
|
||||
ProviderName.HUGGINGFACEHUB.value]:
|
||||
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}
|
||||
|
||||
|
@ -78,7 +78,7 @@ class AzureProvider(BaseProvider):
|
||||
|
||||
def get_token_type(self):
|
||||
# TODO: change to dict when implemented
|
||||
return lambda value: value
|
||||
return dict
|
||||
|
||||
def config_validate(self, config: Union[dict | str]):
|
||||
"""
|
||||
@ -91,16 +91,34 @@ class AzureProvider(BaseProvider):
|
||||
if 'openai_api_version' not in config:
|
||||
config['openai_api_version'] = '2023-03-15-preview'
|
||||
|
||||
self.get_models(credentials=config)
|
||||
models = self.get_models(credentials=config)
|
||||
|
||||
if not models:
|
||||
raise ValidateFailedError("Please add deployments for 'text-davinci-003', "
|
||||
"'gpt-3.5-turbo', 'text-embedding-ada-002'.")
|
||||
|
||||
fixed_model_ids = [
|
||||
'text-davinci-003',
|
||||
'gpt-35-turbo',
|
||||
'text-embedding-ada-002'
|
||||
]
|
||||
|
||||
current_model_ids = [model['id'] for model in models]
|
||||
|
||||
missing_model_ids = [fixed_model_id for fixed_model_id in fixed_model_ids if
|
||||
fixed_model_id not in current_model_ids]
|
||||
|
||||
if missing_model_ids:
|
||||
raise ValidateFailedError("Please add deployments for '{}'.".format(", ".join(missing_model_ids)))
|
||||
except AzureAuthenticationError:
|
||||
raise ValidateFailedError('Azure OpenAI Credentials validation failed, please check your API Key.')
|
||||
except requests.ConnectionError:
|
||||
raise ValidateFailedError('Azure OpenAI Credentials validation failed, please check your API Base Endpoint.')
|
||||
raise ValidateFailedError('Validation failed, please check your API Key.')
|
||||
except (requests.ConnectionError, requests.RequestException):
|
||||
raise ValidateFailedError('Validation failed, please check your API Base Endpoint.')
|
||||
except AzureRequestFailedError as ex:
|
||||
raise ValidateFailedError('Azure OpenAI Credentials validation failed, error: {}.'.format(str(ex)))
|
||||
raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
|
||||
except Exception as ex:
|
||||
logging.exception('Azure OpenAI Credentials validation failed')
|
||||
raise ex
|
||||
raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
|
||||
|
||||
def get_encrypted_token(self, config: Union[dict | str]):
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user