diff --git a/api/core/model_runtime/model_providers/localai/localai.yaml b/api/core/model_runtime/model_providers/localai/localai.yaml index a870914632..151f02ee6f 100644 --- a/api/core/model_runtime/model_providers/localai/localai.yaml +++ b/api/core/model_runtime/model_providers/localai/localai.yaml @@ -15,6 +15,7 @@ help: supported_model_types: - llm - text-embedding + - speech2text configurate_methods: - customizable-model model_credential_schema: @@ -57,6 +58,9 @@ model_credential_schema: zh_Hans: 在此输入LocalAI的服务器地址,如 http://192.168.1.100:8080 en_US: Enter the url of your LocalAI, e.g. http://192.168.1.100:8080 - variable: context_size + show_on: + - variable: __model_type + value: llm label: zh_Hans: 上下文大小 en_US: Context size diff --git a/api/core/model_runtime/model_providers/localai/speech2text/__init__.py b/api/core/model_runtime/model_providers/localai/speech2text/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/localai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/localai/speech2text/speech2text.py new file mode 100644 index 0000000000..d7403aff4f --- /dev/null +++ b/api/core/model_runtime/model_providers/localai/speech2text/speech2text.py @@ -0,0 +1,101 @@ +from typing import IO, Optional + +from requests import Request, Session +from yarl import URL + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel + + +class LocalAISpeech2text(Speech2TextModel): + """ + Model class for Local AI Text to speech model. + """ + + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: + """ + Invoke large language model + + :param model: model name + :param credentials: model credentials + :param file: audio file + :param user: unique user id + :return: text for given audio file + """ + + url = str(URL(credentials['server_url']) / "v1/audio/transcriptions") + data = {"model": model} + files = {"file": file} + + session = Session() + request = Request("POST", url, data=data, files=files) + prepared_request = session.prepare_request(request) + response = session.send(prepared_request) + + if 'error' in response.json(): + raise InvokeServerUnavailableError("Empty response") + + return response.json()["text"] + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + audio_file_path = self._get_demo_file_path() + + with open(audio_file_path, 'rb') as audio_file: + self._invoke(model, credentials, audio_file) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + return { + InvokeConnectionError: [ + InvokeConnectionError + ], + InvokeServerUnavailableError: [ + InvokeServerUnavailableError + ], + InvokeRateLimitError: [ + InvokeRateLimitError + ], + InvokeAuthorizationError: [ + InvokeAuthorizationError + ], + InvokeBadRequestError: [ + InvokeBadRequestError + ], + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + entity = AIModelEntity( + model=model, + label=I18nObject( + en_US=model + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.SPEECH2TEXT, + model_properties={}, + parameter_rules=[] + ) + + return entity \ No newline at end of file diff --git a/api/tests/integration_tests/model_runtime/localai/test_speech2text.py b/api/tests/integration_tests/model_runtime/localai/test_speech2text.py new file mode 100644 index 0000000000..3fd2ebed4f --- /dev/null +++ b/api/tests/integration_tests/model_runtime/localai/test_speech2text.py @@ -0,0 +1,54 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.localai.speech2text.speech2text import LocalAISpeech2text + + +def test_validate_credentials(): + model = LocalAISpeech2text() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='whisper-1', + credentials={ + 'server_url': 'invalid_url' + } + ) + + model.validate_credentials( + model='whisper-1', + credentials={ + 'server_url': os.environ.get('LOCALAI_SERVER_URL') + } + ) + + +def test_invoke_model(): + model = LocalAISpeech2text() + + # Get the directory of the current file + current_dir = os.path.dirname(os.path.abspath(__file__)) + + # Get assets directory + assets_dir = os.path.join(os.path.dirname(current_dir), 'assets') + + # Construct the path to the audio file + audio_file_path = os.path.join(assets_dir, 'audio.mp3') + + # Open the file and get the file object + with open(audio_file_path, 'rb') as audio_file: + file = audio_file + + result = model.invoke( + model='whisper-1', + credentials={ + 'server_url': os.environ.get('LOCALAI_SERVER_URL') + }, + file=file, + user="abc-123" + ) + + assert isinstance(result, str) + assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10' \ No newline at end of file