chore: optimize streaming tts of xinference (#6966)

This commit is contained in:
takatost 2024-08-05 18:23:23 +08:00 committed by GitHub
parent dd676866aa
commit ea30174057
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 78 additions and 64 deletions

View File

@ -1,11 +1,7 @@
import concurrent.futures import concurrent.futures
from functools import reduce
from io import BytesIO
from typing import Optional from typing import Optional
from flask import Response from xinference_client.client.restful.restful_client import RESTfulAudioModelHandle
from pydub import AudioSegment
from xinference_client.client.restful.restful_client import Client, RESTfulAudioModelHandle
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
@ -19,6 +15,7 @@ from core.model_runtime.errors.invoke import (
) )
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.tts_model import TTSModel from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper
class XinferenceText2SpeechModel(TTSModel): class XinferenceText2SpeechModel(TTSModel):
@ -26,7 +23,12 @@ class XinferenceText2SpeechModel(TTSModel):
def __init__(self): def __init__(self):
# preset voices, need support custom voice # preset voices, need support custom voice
self.model_voices = { self.model_voices = {
'chattts': { '__default': {
'all': [
{'name': 'Default', 'value': 'default'},
]
},
'ChatTTS': {
'all': [ 'all': [
{'name': 'Alloy', 'value': 'alloy'}, {'name': 'Alloy', 'value': 'alloy'},
{'name': 'Echo', 'value': 'echo'}, {'name': 'Echo', 'value': 'echo'},
@ -36,7 +38,7 @@ class XinferenceText2SpeechModel(TTSModel):
{'name': 'Shimmer', 'value': 'shimmer'}, {'name': 'Shimmer', 'value': 'shimmer'},
] ]
}, },
'cosyvoice': { 'CosyVoice': {
'zh-Hans': [ 'zh-Hans': [
{'name': '中文男', 'value': '中文男'}, {'name': '中文男', 'value': '中文男'},
{'name': '中文女', 'value': '中文女'}, {'name': '中文女', 'value': '中文女'},
@ -77,18 +79,21 @@ class XinferenceText2SpeechModel(TTSModel):
if credentials['server_url'].endswith('/'): if credentials['server_url'].endswith('/'):
credentials['server_url'] = credentials['server_url'][:-1] credentials['server_url'] = credentials['server_url'][:-1]
# initialize client extra_param = XinferenceHelper.get_xinference_extra_parameter(
client = Client( server_url=credentials['server_url'],
base_url=credentials['server_url'] model_uid=credentials['model_uid']
) )
xinference_client = client.get_model(model_uid=credentials['model_uid']) if 'text-to-audio' not in extra_param.model_ability:
if not isinstance(xinference_client, RESTfulAudioModelHandle):
raise InvokeBadRequestError( raise InvokeBadRequestError(
'please check model type, the model you want to invoke is not a audio model') 'please check model type, the model you want to invoke is not a text-to-audio model')
self._tts_invoke( if extra_param.model_family and extra_param.model_family in self.model_voices:
credentials['audio_model_name'] = extra_param.model_family
else:
credentials['audio_model_name'] = '__default'
self._tts_invoke_streaming(
model=model, model=model,
credentials=credentials, credentials=credentials,
content_text='Hello Dify!', content_text='Hello Dify!',
@ -110,7 +115,7 @@ class XinferenceText2SpeechModel(TTSModel):
:param user: unique user id :param user: unique user id
:return: text translated to audio file :return: text translated to audio file
""" """
return self._tts_invoke(model, credentials, content_text, voice) return self._tts_invoke_streaming(model, credentials, content_text, voice)
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
""" """
@ -161,13 +166,15 @@ class XinferenceText2SpeechModel(TTSModel):
} }
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list: def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
audio_model_name = credentials.get('audio_model_name', '__default')
for key, voices in self.model_voices.items(): for key, voices in self.model_voices.items():
if key in model.lower(): if key in audio_model_name:
if language in voices: if language and language in voices:
return voices[language] return voices[language]
elif 'all' in voices: elif 'all' in voices:
return voices['all'] return voices['all']
return []
return self.model_voices['__default']['all']
def _get_model_default_voice(self, model: str, credentials: dict) -> any: def _get_model_default_voice(self, model: str, credentials: dict) -> any:
return "" return ""
@ -181,60 +188,55 @@ class XinferenceText2SpeechModel(TTSModel):
def _get_model_workers_limit(self, model: str, credentials: dict) -> int: def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
return 5 return 5
def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> any: def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str,
voice: str) -> any:
""" """
_tts_invoke text2speech model _tts_invoke_streaming text2speech model
:param model: model name :param model: model name
:param credentials: model credentials :param credentials: model credentials
:param voice: model timbre
:param content_text: text content to be translated :param content_text: text content to be translated
:param voice: model timbre
:return: text translated to audio file :return: text translated to audio file
""" """
if credentials['server_url'].endswith('/'): if credentials['server_url'].endswith('/'):
credentials['server_url'] = credentials['server_url'][:-1] credentials['server_url'] = credentials['server_url'][:-1]
word_limit = self._get_model_word_limit(model, credentials) try:
audio_type = self._get_model_audio_type(model, credentials)
handle = RESTfulAudioModelHandle(credentials['model_uid'], credentials['server_url'], auth_headers={}) handle = RESTfulAudioModelHandle(credentials['model_uid'], credentials['server_url'], auth_headers={})
try: model_support_voice = [x.get("value") for x in
sentences = list(self._split_text_into_sentences(org_text=content_text, max_length=word_limit)) self.get_tts_model_voices(model=model, credentials=credentials)]
audio_bytes_list = [] if not voice or voice not in model_support_voice:
voice = self._get_model_default_voice(model, credentials)
with concurrent.futures.ThreadPoolExecutor(max_workers=min((3, len(sentences)))) as executor: word_limit = self._get_model_word_limit(model, credentials)
if len(content_text) > word_limit:
sentences = self._split_text_into_sentences(content_text, max_length=word_limit)
executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences)))
futures = [executor.submit( futures = [executor.submit(
handle.speech, input=sentence, voice=voice, response_format="mp3", speed=1.0, stream=False) handle.speech,
for sentence in sentences] input=sentences[i],
for future in futures: voice=voice,
try: response_format="mp3",
if future.result(): speed=1.0,
audio_bytes_list.append(future.result()) stream=False
)
for i in range(len(sentences))]
for index, future in enumerate(futures):
response = future.result()
for i in range(0, len(response), 1024):
yield response[i:i + 1024]
else:
response = handle.speech(
input=content_text.strip(),
voice=voice,
response_format="mp3",
speed=1.0,
stream=False
)
for i in range(0, len(response), 1024):
yield response[i:i + 1024]
except Exception as ex: except Exception as ex:
raise InvokeBadRequestError(str(ex)) raise InvokeBadRequestError(str(ex))
if len(audio_bytes_list) > 0:
audio_segments = [AudioSegment.from_file(
BytesIO(audio_bytes), format=audio_type) for audio_bytes in
audio_bytes_list if audio_bytes]
combined_segment = reduce(lambda x, y: x + y, audio_segments)
buffer: BytesIO = BytesIO()
combined_segment.export(buffer, format=audio_type)
buffer.seek(0)
return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}")
except Exception as ex:
raise InvokeBadRequestError(str(ex))
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any:
"""
_tts_invoke_streaming text2speech model
Attention: stream api may return error [Parallel generation is not supported by ggml]
:param model: model name
:param credentials: model credentials
:param voice: model timbre
:param content_text: text content to be translated
:return: text translated to audio file
"""
pass

View File

@ -1,5 +1,6 @@
from threading import Lock from threading import Lock
from time import time from time import time
from typing import Optional
from requests.adapters import HTTPAdapter from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError, MissingSchema, Timeout from requests.exceptions import ConnectionError, MissingSchema, Timeout
@ -15,9 +16,11 @@ class XinferenceModelExtraParameter:
context_length: int = 2048 context_length: int = 2048
support_function_call: bool = False support_function_call: bool = False
support_vision: bool = False support_vision: bool = False
model_family: Optional[str]
def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str], def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str],
support_function_call: bool, support_vision: bool, max_tokens: int, context_length: int) -> None: support_function_call: bool, support_vision: bool, max_tokens: int, context_length: int,
model_family: Optional[str]) -> None:
self.model_format = model_format self.model_format = model_format
self.model_handle_type = model_handle_type self.model_handle_type = model_handle_type
self.model_ability = model_ability self.model_ability = model_ability
@ -25,6 +28,7 @@ class XinferenceModelExtraParameter:
self.support_vision = support_vision self.support_vision = support_vision
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.context_length = context_length self.context_length = context_length
self.model_family = model_family
cache = {} cache = {}
cache_lock = Lock() cache_lock = Lock()
@ -78,9 +82,16 @@ class XinferenceHelper:
model_format = response_json.get('model_format', 'ggmlv3') model_format = response_json.get('model_format', 'ggmlv3')
model_ability = response_json.get('model_ability', []) model_ability = response_json.get('model_ability', [])
model_family = response_json.get('model_family', None)
if response_json.get('model_type') == 'embedding': if response_json.get('model_type') == 'embedding':
model_handle_type = 'embedding' model_handle_type = 'embedding'
elif response_json.get('model_type') == 'audio':
model_handle_type = 'audio'
if model_family and model_family in ['ChatTTS', 'CosyVoice']:
model_ability.append('text-to-audio')
else:
model_ability.append('audio-to-text')
elif model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']: elif model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']:
model_handle_type = 'chatglm' model_handle_type = 'chatglm'
elif 'generate' in model_ability: elif 'generate' in model_ability:
@ -88,7 +99,7 @@ class XinferenceHelper:
elif 'chat' in model_ability: elif 'chat' in model_ability:
model_handle_type = 'chat' model_handle_type = 'chat'
else: else:
raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported') raise NotImplementedError('xinference model handle type is not supported')
support_function_call = 'tools' in model_ability support_function_call = 'tools' in model_ability
support_vision = 'vision' in model_ability support_vision = 'vision' in model_ability
@ -103,5 +114,6 @@ class XinferenceHelper:
support_function_call=support_function_call, support_function_call=support_function_call,
support_vision=support_vision, support_vision=support_vision,
max_tokens=max_tokens, max_tokens=max_tokens,
context_length=context_length context_length=context_length,
model_family=model_family
) )