fix: azure openai deployment list was deprecated suddenly (#611)

This commit is contained in:
John Wang 2023-07-20 13:46:39 +08:00 committed by GitHub
parent 52c84da051
commit cae15013e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,6 +2,7 @@ import json
import logging import logging
from typing import Optional, Union from typing import Optional, Union
import openai
import requests import requests
from core.llm.provider.base import BaseProvider from core.llm.provider.base import BaseProvider
@ -11,30 +12,37 @@ from models.provider import ProviderName
class AzureProvider(BaseProvider): class AzureProvider(BaseProvider):
def get_models(self, model_id: Optional[str] = None, credentials: Optional[dict] = None) -> list[dict]: def get_models(self, model_id: Optional[str] = None, credentials: Optional[dict] = None) -> list[dict]:
credentials = self.get_credentials(model_id) if not credentials else credentials return []
url = "{}/openai/deployments?api-version={}".format(
str(credentials.get('openai_api_base')),
str(credentials.get('openai_api_version'))
)
headers = { def check_embedding_model(self, credentials: Optional[dict] = None):
"api-key": str(credentials.get('openai_api_key')), credentials = self.get_credentials('text-embedding-ada-002') if not credentials else credentials
"content-type": "application/json; charset=utf-8" try:
} result = openai.Embedding.create(input=['test'],
engine='text-embedding-ada-0021',
response = requests.get(url, headers=headers) timeout=60,
api_key=str(credentials.get('openai_api_key')),
if response.status_code == 200: api_base=str(credentials.get('openai_api_base')),
result = response.json() api_type='azure',
return [{ api_version=str(credentials.get('openai_api_version')))["data"][0][
'id': deployment['id'], "embedding"]
'name': '{} ({})'.format(deployment['id'], deployment['model']) except openai.error.AuthenticationError as e:
} for deployment in result['data'] if deployment['status'] == 'succeeded'] raise AzureAuthenticationError(str(e))
else: except openai.error.APIConnectionError as e:
if response.status_code == 401: raise AzureRequestFailedError(
raise AzureAuthenticationError() 'Failed to request Azure OpenAI, please check your API Base Endpoint, The format is `https://xxx.openai.azure.com/`')
except openai.error.InvalidRequestError as e:
if e.http_status == 404:
raise AzureRequestFailedError("Please check your 'gpt-3.5-turbo' or 'text-embedding-ada-002' "
"deployment name is exists in Azure AI")
else: else:
raise AzureRequestFailedError('Failed to request Azure OpenAI. Status code: {}'.format(response.status_code)) raise AzureRequestFailedError(
'Failed to request Azure OpenAI. cause: {}'.format(str(e)))
except openai.error.OpenAIError as e:
raise AzureRequestFailedError(
'Failed to request Azure OpenAI. cause: {}'.format(str(e)))
if not isinstance(result, list):
raise AzureRequestFailedError('Failed to request Azure OpenAI.')
def get_credentials(self, model_id: Optional[str] = None) -> dict: def get_credentials(self, model_id: Optional[str] = None) -> dict:
""" """
@ -94,31 +102,11 @@ class AzureProvider(BaseProvider):
if 'openai_api_version' not in config: if 'openai_api_version' not in config:
config['openai_api_version'] = '2023-03-15-preview' config['openai_api_version'] = '2023-03-15-preview'
models = self.get_models(credentials=config) self.check_embedding_model(credentials=config)
if not models:
raise ValidateFailedError("Please add deployments for "
"'gpt-3.5-turbo', 'text-embedding-ada-002' (required) "
"and 'gpt-4', 'gpt-35-turbo-16k', 'text-davinci-003' (optional).")
fixed_model_ids = [
'gpt-35-turbo',
'text-embedding-ada-002'
]
current_model_ids = [model['id'] for model in models]
missing_model_ids = [fixed_model_id for fixed_model_id in fixed_model_ids if
fixed_model_id not in current_model_ids]
if missing_model_ids:
raise ValidateFailedError("Please add deployments for '{}'.".format(", ".join(missing_model_ids)))
except ValidateFailedError as e: except ValidateFailedError as e:
raise e raise e
except AzureAuthenticationError: except AzureAuthenticationError:
raise ValidateFailedError('Validation failed, please check your API Key.') raise ValidateFailedError('Validation failed, please check your API Key.')
except (requests.ConnectionError, requests.RequestException):
raise ValidateFailedError('Validation failed, please check your API Base Endpoint.')
except AzureRequestFailedError as ex: except AzureRequestFailedError as ex:
raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex))) raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
except Exception as ex: except Exception as ex: