From ea30174057e384c9ddbba5d851c8056cb88c9504 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 5 Aug 2024 18:23:23 +0800 Subject: [PATCH] chore: optimize streaming tts of xinference (#6966) --- .../model_providers/xinference/tts/tts.py | 122 +++++++++--------- .../xinference/xinference_helper.py | 20 ++- 2 files changed, 78 insertions(+), 64 deletions(-) diff --git a/api/core/model_runtime/model_providers/xinference/tts/tts.py b/api/core/model_runtime/model_providers/xinference/tts/tts.py index c106e38781..a564a021b1 100644 --- a/api/core/model_runtime/model_providers/xinference/tts/tts.py +++ b/api/core/model_runtime/model_providers/xinference/tts/tts.py @@ -1,11 +1,7 @@ import concurrent.futures -from functools import reduce -from io import BytesIO from typing import Optional -from flask import Response -from pydub import AudioSegment -from xinference_client.client.restful.restful_client import Client, RESTfulAudioModelHandle +from xinference_client.client.restful.restful_client import RESTfulAudioModelHandle from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType @@ -19,6 +15,7 @@ from core.model_runtime.errors.invoke import ( ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.tts_model import TTSModel +from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper class XinferenceText2SpeechModel(TTSModel): @@ -26,7 +23,12 @@ class XinferenceText2SpeechModel(TTSModel): def __init__(self): # preset voices, need support custom voice self.model_voices = { - 'chattts': { + '__default': { + 'all': [ + {'name': 'Default', 'value': 'default'}, + ] + }, + 'ChatTTS': { 'all': [ {'name': 'Alloy', 'value': 'alloy'}, {'name': 'Echo', 'value': 'echo'}, @@ -36,7 +38,7 @@ class XinferenceText2SpeechModel(TTSModel): {'name': 'Shimmer', 'value': 'shimmer'}, ] }, - 'cosyvoice': { + 'CosyVoice': { 'zh-Hans': [ {'name': '中文男', 'value': '中文男'}, {'name': '中文女', 'value': '中文女'}, @@ -77,18 +79,21 @@ class XinferenceText2SpeechModel(TTSModel): if credentials['server_url'].endswith('/'): credentials['server_url'] = credentials['server_url'][:-1] - # initialize client - client = Client( - base_url=credentials['server_url'] + extra_param = XinferenceHelper.get_xinference_extra_parameter( + server_url=credentials['server_url'], + model_uid=credentials['model_uid'] ) - xinference_client = client.get_model(model_uid=credentials['model_uid']) - - if not isinstance(xinference_client, RESTfulAudioModelHandle): + if 'text-to-audio' not in extra_param.model_ability: raise InvokeBadRequestError( - 'please check model type, the model you want to invoke is not a audio model') + 'please check model type, the model you want to invoke is not a text-to-audio model') - self._tts_invoke( + if extra_param.model_family and extra_param.model_family in self.model_voices: + credentials['audio_model_name'] = extra_param.model_family + else: + credentials['audio_model_name'] = '__default' + + self._tts_invoke_streaming( model=model, credentials=credentials, content_text='Hello Dify!', @@ -110,7 +115,7 @@ class XinferenceText2SpeechModel(TTSModel): :param user: unique user id :return: text translated to audio file """ - return self._tts_invoke(model, credentials, content_text, voice) + return self._tts_invoke_streaming(model, credentials, content_text, voice) def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ @@ -161,13 +166,15 @@ class XinferenceText2SpeechModel(TTSModel): } def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list: + audio_model_name = credentials.get('audio_model_name', '__default') for key, voices in self.model_voices.items(): - if key in model.lower(): - if language in voices: + if key in audio_model_name: + if language and language in voices: return voices[language] elif 'all' in voices: return voices['all'] - return [] + + return self.model_voices['__default']['all'] def _get_model_default_voice(self, model: str, credentials: dict) -> any: return "" @@ -181,60 +188,55 @@ class XinferenceText2SpeechModel(TTSModel): def _get_model_workers_limit(self, model: str, credentials: dict) -> int: return 5 - def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> any: + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, + voice: str) -> any: """ - _tts_invoke text2speech model + _tts_invoke_streaming text2speech model :param model: model name :param credentials: model credentials - :param voice: model timbre :param content_text: text content to be translated + :param voice: model timbre :return: text translated to audio file """ if credentials['server_url'].endswith('/'): credentials['server_url'] = credentials['server_url'][:-1] - word_limit = self._get_model_word_limit(model, credentials) - audio_type = self._get_model_audio_type(model, credentials) - handle = RESTfulAudioModelHandle(credentials['model_uid'], credentials['server_url'], auth_headers={}) - try: - sentences = list(self._split_text_into_sentences(org_text=content_text, max_length=word_limit)) - audio_bytes_list = [] + handle = RESTfulAudioModelHandle(credentials['model_uid'], credentials['server_url'], auth_headers={}) - with concurrent.futures.ThreadPoolExecutor(max_workers=min((3, len(sentences)))) as executor: + model_support_voice = [x.get("value") for x in + self.get_tts_model_voices(model=model, credentials=credentials)] + if not voice or voice not in model_support_voice: + voice = self._get_model_default_voice(model, credentials) + word_limit = self._get_model_word_limit(model, credentials) + if len(content_text) > word_limit: + sentences = self._split_text_into_sentences(content_text, max_length=word_limit) + executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences))) futures = [executor.submit( - handle.speech, input=sentence, voice=voice, response_format="mp3", speed=1.0, stream=False) - for sentence in sentences] - for future in futures: - try: - if future.result(): - audio_bytes_list.append(future.result()) - except Exception as ex: - raise InvokeBadRequestError(str(ex)) + handle.speech, + input=sentences[i], + voice=voice, + response_format="mp3", + speed=1.0, + stream=False + ) + for i in range(len(sentences))] - if len(audio_bytes_list) > 0: - audio_segments = [AudioSegment.from_file( - BytesIO(audio_bytes), format=audio_type) for audio_bytes in - audio_bytes_list if audio_bytes] - combined_segment = reduce(lambda x, y: x + y, audio_segments) - buffer: BytesIO = BytesIO() - combined_segment.export(buffer, format=audio_type) - buffer.seek(0) - return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}") + for index, future in enumerate(futures): + response = future.result() + for i in range(0, len(response), 1024): + yield response[i:i + 1024] + else: + response = handle.speech( + input=content_text.strip(), + voice=voice, + response_format="mp3", + speed=1.0, + stream=False + ) + + for i in range(0, len(response), 1024): + yield response[i:i + 1024] except Exception as ex: raise InvokeBadRequestError(str(ex)) - - def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: - """ - _tts_invoke_streaming text2speech model - - Attention: stream api may return error [Parallel generation is not supported by ggml] - - :param model: model name - :param credentials: model credentials - :param voice: model timbre - :param content_text: text content to be translated - :return: text translated to audio file - """ - pass diff --git a/api/core/model_runtime/model_providers/xinference/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/xinference_helper.py index 9a3fc9b193..7db483a485 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference_helper.py +++ b/api/core/model_runtime/model_providers/xinference/xinference_helper.py @@ -1,5 +1,6 @@ from threading import Lock from time import time +from typing import Optional from requests.adapters import HTTPAdapter from requests.exceptions import ConnectionError, MissingSchema, Timeout @@ -15,9 +16,11 @@ class XinferenceModelExtraParameter: context_length: int = 2048 support_function_call: bool = False support_vision: bool = False + model_family: Optional[str] def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str], - support_function_call: bool, support_vision: bool, max_tokens: int, context_length: int) -> None: + support_function_call: bool, support_vision: bool, max_tokens: int, context_length: int, + model_family: Optional[str]) -> None: self.model_format = model_format self.model_handle_type = model_handle_type self.model_ability = model_ability @@ -25,6 +28,7 @@ class XinferenceModelExtraParameter: self.support_vision = support_vision self.max_tokens = max_tokens self.context_length = context_length + self.model_family = model_family cache = {} cache_lock = Lock() @@ -78,9 +82,16 @@ class XinferenceHelper: model_format = response_json.get('model_format', 'ggmlv3') model_ability = response_json.get('model_ability', []) + model_family = response_json.get('model_family', None) if response_json.get('model_type') == 'embedding': model_handle_type = 'embedding' + elif response_json.get('model_type') == 'audio': + model_handle_type = 'audio' + if model_family and model_family in ['ChatTTS', 'CosyVoice']: + model_ability.append('text-to-audio') + else: + model_ability.append('audio-to-text') elif model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']: model_handle_type = 'chatglm' elif 'generate' in model_ability: @@ -88,7 +99,7 @@ class XinferenceHelper: elif 'chat' in model_ability: model_handle_type = 'chat' else: - raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported') + raise NotImplementedError('xinference model handle type is not supported') support_function_call = 'tools' in model_ability support_vision = 'vision' in model_ability @@ -103,5 +114,6 @@ class XinferenceHelper: support_function_call=support_function_call, support_vision=support_vision, max_tokens=max_tokens, - context_length=context_length - ) \ No newline at end of file + context_length=context_length, + model_family=model_family + )