diff --git a/api/controllers/console/workspace/providers.py b/api/controllers/console/workspace/providers.py index 87dad0d93a..f2baec29c1 100644 --- a/api/controllers/console/workspace/providers.py +++ b/api/controllers/console/workspace/providers.py @@ -157,7 +157,7 @@ class ProviderTokenValidateApi(Resource): args = parser.parse_args() # todo: remove this when the provider is supported - if provider in [ProviderName.ANTHROPIC.value, ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value, + if provider in [ProviderName.ANTHROPIC.value, ProviderName.COHERE.value, ProviderName.HUGGINGFACEHUB.value]: return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'} diff --git a/api/core/llm/provider/azure_provider.py b/api/core/llm/provider/azure_provider.py index bd44a0cc4b..649c64cf73 100644 --- a/api/core/llm/provider/azure_provider.py +++ b/api/core/llm/provider/azure_provider.py @@ -78,7 +78,7 @@ class AzureProvider(BaseProvider): def get_token_type(self): # TODO: change to dict when implemented - return lambda value: value + return dict def config_validate(self, config: Union[dict | str]): """ @@ -91,16 +91,34 @@ class AzureProvider(BaseProvider): if 'openai_api_version' not in config: config['openai_api_version'] = '2023-03-15-preview' - self.get_models(credentials=config) + models = self.get_models(credentials=config) + + if not models: + raise ValidateFailedError("Please add deployments for 'text-davinci-003', " + "'gpt-3.5-turbo', 'text-embedding-ada-002'.") + + fixed_model_ids = [ + 'text-davinci-003', + 'gpt-35-turbo', + 'text-embedding-ada-002' + ] + + current_model_ids = [model['id'] for model in models] + + missing_model_ids = [fixed_model_id for fixed_model_id in fixed_model_ids if + fixed_model_id not in current_model_ids] + + if missing_model_ids: + raise ValidateFailedError("Please add deployments for '{}'.".format(", ".join(missing_model_ids))) except AzureAuthenticationError: - raise ValidateFailedError('Azure OpenAI Credentials validation failed, please check your API Key.') - except requests.ConnectionError: - raise ValidateFailedError('Azure OpenAI Credentials validation failed, please check your API Base Endpoint.') + raise ValidateFailedError('Validation failed, please check your API Key.') + except (requests.ConnectionError, requests.RequestException): + raise ValidateFailedError('Validation failed, please check your API Base Endpoint.') except AzureRequestFailedError as ex: - raise ValidateFailedError('Azure OpenAI Credentials validation failed, error: {}.'.format(str(ex))) + raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex))) except Exception as ex: logging.exception('Azure OpenAI Credentials validation failed') - raise ex + raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex))) def get_encrypted_token(self, config: Union[dict | str]): """