diff --git a/api/core/model_runtime/model_providers/xinference/tts/__init__.py b/api/core/model_runtime/model_providers/xinference/tts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/xinference/tts/tts.py b/api/core/model_runtime/model_providers/xinference/tts/tts.py new file mode 100644 index 0000000000..c106e38781 --- /dev/null +++ b/api/core/model_runtime/model_providers/xinference/tts/tts.py @@ -0,0 +1,240 @@ +import concurrent.futures +from functools import reduce +from io import BytesIO +from typing import Optional + +from flask import Response +from pydub import AudioSegment +from xinference_client.client.restful.restful_client import Client, RESTfulAudioModelHandle + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.tts_model import TTSModel + + +class XinferenceText2SpeechModel(TTSModel): + + def __init__(self): + # preset voices, need support custom voice + self.model_voices = { + 'chattts': { + 'all': [ + {'name': 'Alloy', 'value': 'alloy'}, + {'name': 'Echo', 'value': 'echo'}, + {'name': 'Fable', 'value': 'fable'}, + {'name': 'Onyx', 'value': 'onyx'}, + {'name': 'Nova', 'value': 'nova'}, + {'name': 'Shimmer', 'value': 'shimmer'}, + ] + }, + 'cosyvoice': { + 'zh-Hans': [ + {'name': '中文男', 'value': '中文男'}, + {'name': '中文女', 'value': '中文女'}, + {'name': '粤语女', 'value': '粤语女'}, + ], + 'zh-Hant': [ + {'name': '中文男', 'value': '中文男'}, + {'name': '中文女', 'value': '中文女'}, + {'name': '粤语女', 'value': '粤语女'}, + ], + 'en-US': [ + {'name': '英文男', 'value': '英文男'}, + {'name': '英文女', 'value': '英文女'}, + ], + 'ja-JP': [ + {'name': '日语男', 'value': '日语男'}, + ], + 'ko-KR': [ + {'name': '韩语女', 'value': '韩语女'}, + ] + } + } + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + if ("/" in credentials['model_uid'] or + "?" in credentials['model_uid'] or + "#" in credentials['model_uid']): + raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") + + if credentials['server_url'].endswith('/'): + credentials['server_url'] = credentials['server_url'][:-1] + + # initialize client + client = Client( + base_url=credentials['server_url'] + ) + + xinference_client = client.get_model(model_uid=credentials['model_uid']) + + if not isinstance(xinference_client, RESTfulAudioModelHandle): + raise InvokeBadRequestError( + 'please check model type, the model you want to invoke is not a audio model') + + self._tts_invoke( + model=model, + credentials=credentials, + content_text='Hello Dify!', + voice=self._get_model_default_voice(model, credentials), + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, + user: Optional[str] = None): + """ + _invoke text2speech model + + :param model: model name + :param tenant_id: user tenant id + :param credentials: model credentials + :param voice: model timbre + :param content_text: text content to be translated + :param user: unique user id + :return: text translated to audio file + """ + return self._tts_invoke(model, credentials, content_text, voice) + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + + entity = AIModelEntity( + model=model, + label=I18nObject( + en_US=model + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.TTS, + model_properties={}, + parameter_rules=[] + ) + + return entity + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [ + InvokeConnectionError + ], + InvokeServerUnavailableError: [ + InvokeServerUnavailableError + ], + InvokeRateLimitError: [ + InvokeRateLimitError + ], + InvokeAuthorizationError: [ + InvokeAuthorizationError + ], + InvokeBadRequestError: [ + InvokeBadRequestError, + KeyError, + ValueError + ] + } + + def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list: + for key, voices in self.model_voices.items(): + if key in model.lower(): + if language in voices: + return voices[language] + elif 'all' in voices: + return voices['all'] + return [] + + def _get_model_default_voice(self, model: str, credentials: dict) -> any: + return "" + + def _get_model_word_limit(self, model: str, credentials: dict) -> int: + return 3500 + + def _get_model_audio_type(self, model: str, credentials: dict) -> str: + return "mp3" + + def _get_model_workers_limit(self, model: str, credentials: dict) -> int: + return 5 + + def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> any: + """ + _tts_invoke text2speech model + + :param model: model name + :param credentials: model credentials + :param voice: model timbre + :param content_text: text content to be translated + :return: text translated to audio file + """ + if credentials['server_url'].endswith('/'): + credentials['server_url'] = credentials['server_url'][:-1] + + word_limit = self._get_model_word_limit(model, credentials) + audio_type = self._get_model_audio_type(model, credentials) + handle = RESTfulAudioModelHandle(credentials['model_uid'], credentials['server_url'], auth_headers={}) + + try: + sentences = list(self._split_text_into_sentences(org_text=content_text, max_length=word_limit)) + audio_bytes_list = [] + + with concurrent.futures.ThreadPoolExecutor(max_workers=min((3, len(sentences)))) as executor: + futures = [executor.submit( + handle.speech, input=sentence, voice=voice, response_format="mp3", speed=1.0, stream=False) + for sentence in sentences] + for future in futures: + try: + if future.result(): + audio_bytes_list.append(future.result()) + except Exception as ex: + raise InvokeBadRequestError(str(ex)) + + if len(audio_bytes_list) > 0: + audio_segments = [AudioSegment.from_file( + BytesIO(audio_bytes), format=audio_type) for audio_bytes in + audio_bytes_list if audio_bytes] + combined_segment = reduce(lambda x, y: x + y, audio_segments) + buffer: BytesIO = BytesIO() + combined_segment.export(buffer, format=audio_type) + buffer.seek(0) + return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}") + except Exception as ex: + raise InvokeBadRequestError(str(ex)) + + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: + """ + _tts_invoke_streaming text2speech model + + Attention: stream api may return error [Parallel generation is not supported by ggml] + + :param model: model name + :param credentials: model credentials + :param voice: model timbre + :param content_text: text content to be translated + :return: text translated to audio file + """ + pass diff --git a/api/core/model_runtime/model_providers/xinference/xinference.yaml b/api/core/model_runtime/model_providers/xinference/xinference.yaml index 9496c66fdd..aca076b6e1 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference.yaml +++ b/api/core/model_runtime/model_providers/xinference/xinference.yaml @@ -17,6 +17,7 @@ supported_model_types: - text-embedding - rerank - speech2text + - tts configurate_methods: - customizable-model model_credential_schema: diff --git a/api/poetry.lock b/api/poetry.lock index abde108a7a..3356d11f78 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -9098,13 +9098,13 @@ h11 = ">=0.9.0,<1" [[package]] name = "xinference-client" -version = "0.9.4" +version = "0.13.3" description = "Client for Xinference" optional = false python-versions = "*" files = [ - {file = "xinference-client-0.9.4.tar.gz", hash = "sha256:21934bc9f3142ade66aaed33c2b6cf244c274d5b4b3163f9981bebdddacf205f"}, - {file = "xinference_client-0.9.4-py3-none-any.whl", hash = "sha256:6d3f1df3537a011f0afee5f9c9ca4f3ff564ca32cc999cf7038b324c0b907d0c"}, + {file = "xinference-client-0.13.3.tar.gz", hash = "sha256:822b722100affdff049c27760be7d62ac92de58c87a40d3361066df446ba648f"}, + {file = "xinference_client-0.13.3-py3-none-any.whl", hash = "sha256:f0eff3858b1ebcef2129726f82b09259c177e11db466a7ca23def3d4849c419f"}, ] [package.dependencies] @@ -9502,4 +9502,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "a8b61d74d9322302b7447b6f8728ad606abc160202a8a122a05a8ef3cec7055b" +content-hash = "ca55e4a4bb354fe969cc73c823557525c7598b0375e8791fcd77febc59e03b96" diff --git a/api/pyproject.toml b/api/pyproject.toml index 25778f323d..112ea22da8 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -173,7 +173,7 @@ transformers = "~4.35.0" unstructured = { version = "~0.10.27", extras = ["docx", "epub", "md", "msg", "ppt", "pptx"] } websocket-client = "~1.7.0" werkzeug = "~3.0.1" -xinference-client = "0.9.4" +xinference-client = "0.13.3" yarl = "~1.9.4" zhipuai = "1.0.7" rank-bm25 = "~0.2.2"