diff --git a/api/core/model_runtime/model_providers/azure_openai/_constant.py b/api/core/model_runtime/model_providers/azure_openai/_constant.py index 90dd2e7a6b..8d6cac3ec3 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_constant.py +++ b/api/core/model_runtime/model_providers/azure_openai/_constant.py @@ -526,3 +526,20 @@ EMBEDDING_BASE_MODELS = [ ) ) ] +SPEECH2TEXT_BASE_MODELS = [ + AzureBaseModel( + base_model_name='whisper-1', + entity=AIModelEntity( + model='fake-deployment-name', + label=I18nObject( + en_US='fake-deployment-name-label' + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.SPEECH2TEXT, + model_properties={ + ModelPropertyKey.FILE_UPLOAD_LIMIT: 25, + ModelPropertyKey.SUPPORTED_FILE_EXTENSIONS: 'flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm' + } + ) + ) +] diff --git a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml index c081808639..fe4f3538ed 100644 --- a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml +++ b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml @@ -15,6 +15,7 @@ help: supported_model_types: - llm - text-embedding + - speech2text configurate_methods: - customizable-model model_credential_schema: @@ -99,6 +100,12 @@ model_credential_schema: show_on: - variable: __model_type value: text-embedding + - label: + en_US: whisper-1 + value: whisper-1 + show_on: + - variable: __model_type + value: speech2text placeholder: zh_Hans: 在此输入您的模型版本 en_US: Enter your model version diff --git a/api/core/model_runtime/model_providers/azure_openai/speech2text/__init__.py b/api/core/model_runtime/model_providers/azure_openai/speech2text/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py new file mode 100644 index 0000000000..763a6d90e3 --- /dev/null +++ b/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py @@ -0,0 +1,81 @@ +from typing import IO, Optional + +from openai import AzureOpenAI + +from core.model_runtime.entities.model_entities import AIModelEntity +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel +from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI +from core.model_runtime.model_providers.azure_openai._constant import SPEECH2TEXT_BASE_MODELS, AzureBaseModel + + +class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel): + """ + Model class for OpenAI Speech to text model. + """ + + def _invoke(self, model: str, credentials: dict, + file: IO[bytes], user: Optional[str] = None) \ + -> str: + """ + Invoke speech2text model + + :param model: model name + :param credentials: model credentials + :param file: audio file + :param user: unique user id + :return: text for given audio file + """ + return self._speech2text_invoke(model, credentials, file) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + audio_file_path = self._get_demo_file_path() + + with open(audio_file_path, 'rb') as audio_file: + self._speech2text_invoke(model, credentials, audio_file) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def _speech2text_invoke(self, model: str, credentials: dict, file: IO[bytes]) -> str: + """ + Invoke speech2text model + + :param model: model name + :param credentials: model credentials + :param file: audio file + :return: text for given audio file + """ + # transform credentials to kwargs for model instance + credentials_kwargs = self._to_credential_kwargs(credentials) + + # init model client + client = AzureOpenAI(**credentials_kwargs) + + response = client.audio.transcriptions.create(model=model, file=file) + + return response.text + + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) + return ai_model_entity.entity + + + @staticmethod + def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel: + for ai_model_entity in SPEECH2TEXT_BASE_MODELS: + if ai_model_entity.base_model_name == base_model_name: + ai_model_entity_copy = copy.deepcopy(ai_model_entity) + ai_model_entity_copy.entity.model = model + ai_model_entity_copy.entity.label.en_US = model + ai_model_entity_copy.entity.label.zh_Hans = model + return ai_model_entity_copy + + return None diff --git a/api/core/model_runtime/model_providers/openai/speech2text/whisper-1.yaml b/api/core/model_runtime/model_providers/openai/speech2text/whisper-1.yaml index 3cfe6f1a3a..6c14c76619 100644 --- a/api/core/model_runtime/model_providers/openai/speech2text/whisper-1.yaml +++ b/api/core/model_runtime/model_providers/openai/speech2text/whisper-1.yaml @@ -2,4 +2,4 @@ model: whisper-1 model_type: speech2text model_properties: file_upload_limit: 25 - supported_file_extensions: mp3,mp4,mpeg,mpga,m4a,wav,webm + supported_file_extensions: flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm diff --git a/api/core/model_runtime/model_providers/openai/text_embedding/text-embedidng-ada-002.yaml b/api/core/model_runtime/model_providers/openai/text_embedding/text-embedding-ada-002.yaml similarity index 100% rename from api/core/model_runtime/model_providers/openai/text_embedding/text-embedidng-ada-002.yaml rename to api/core/model_runtime/model_providers/openai/text_embedding/text-embedding-ada-002.yaml