Add support for local ai speech to text (#3921)

Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
Tomy 2024-05-07 17:14:24 +08:00 committed by GitHub
parent d51f52a649
commit bb7c62777d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 159 additions and 0 deletions

View File

@ -15,6 +15,7 @@ help:
supported_model_types:
- llm
- text-embedding
- speech2text
configurate_methods:
- customizable-model
model_credential_schema:
@ -57,6 +58,9 @@ model_credential_schema:
zh_Hans: 在此输入LocalAI的服务器地址如 http://192.168.1.100:8080
en_US: Enter the url of your LocalAI, e.g. http://192.168.1.100:8080
- variable: context_size
show_on:
- variable: __model_type
value: llm
label:
zh_Hans: 上下文大小
en_US: Context size

View File

@ -0,0 +1,101 @@
from typing import IO, Optional
from requests import Request, Session
from yarl import URL
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
class LocalAISpeech2text(Speech2TextModel):
"""
Model class for Local AI Text to speech model.
"""
def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
url = str(URL(credentials['server_url']) / "v1/audio/transcriptions")
data = {"model": model}
files = {"file": file}
session = Session()
request = Request("POST", url, data=data, files=files)
prepared_request = session.prepare_request(request)
response = session.send(prepared_request)
if 'error' in response.json():
raise InvokeServerUnavailableError("Empty response")
return response.json()["text"]
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
audio_file_path = self._get_demo_file_path()
with open(audio_file_path, 'rb') as audio_file:
self._invoke(model, credentials, audio_file)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
InvokeBadRequestError
],
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.SPEECH2TEXT,
model_properties={},
parameter_rules=[]
)
return entity

View File

@ -0,0 +1,54 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.localai.speech2text.speech2text import LocalAISpeech2text
def test_validate_credentials():
model = LocalAISpeech2text()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='whisper-1',
credentials={
'server_url': 'invalid_url'
}
)
model.validate_credentials(
model='whisper-1',
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL')
}
)
def test_invoke_model():
model = LocalAISpeech2text()
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
# Construct the path to the audio file
audio_file_path = os.path.join(assets_dir, 'audio.mp3')
# Open the file and get the file object
with open(audio_file_path, 'rb') as audio_file:
file = audio_file
result = model.invoke(
model='whisper-1',
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL')
},
file=file,
user="abc-123"
)
assert isinstance(result, str)
assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10'