mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 12:56:01 +08:00
chore: optimize streaming tts of xinference (#6966)
This commit is contained in:
parent
dd676866aa
commit
ea30174057
@ -1,11 +1,7 @@
|
|||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
from functools import reduce
|
|
||||||
from io import BytesIO
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from flask import Response
|
from xinference_client.client.restful.restful_client import RESTfulAudioModelHandle
|
||||||
from pydub import AudioSegment
|
|
||||||
from xinference_client.client.restful.restful_client import Client, RESTfulAudioModelHandle
|
|
||||||
|
|
||||||
from core.model_runtime.entities.common_entities import I18nObject
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||||
@ -19,6 +15,7 @@ from core.model_runtime.errors.invoke import (
|
|||||||
)
|
)
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||||
|
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper
|
||||||
|
|
||||||
|
|
||||||
class XinferenceText2SpeechModel(TTSModel):
|
class XinferenceText2SpeechModel(TTSModel):
|
||||||
@ -26,7 +23,12 @@ class XinferenceText2SpeechModel(TTSModel):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
# preset voices, need support custom voice
|
# preset voices, need support custom voice
|
||||||
self.model_voices = {
|
self.model_voices = {
|
||||||
'chattts': {
|
'__default': {
|
||||||
|
'all': [
|
||||||
|
{'name': 'Default', 'value': 'default'},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
'ChatTTS': {
|
||||||
'all': [
|
'all': [
|
||||||
{'name': 'Alloy', 'value': 'alloy'},
|
{'name': 'Alloy', 'value': 'alloy'},
|
||||||
{'name': 'Echo', 'value': 'echo'},
|
{'name': 'Echo', 'value': 'echo'},
|
||||||
@ -36,7 +38,7 @@ class XinferenceText2SpeechModel(TTSModel):
|
|||||||
{'name': 'Shimmer', 'value': 'shimmer'},
|
{'name': 'Shimmer', 'value': 'shimmer'},
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
'cosyvoice': {
|
'CosyVoice': {
|
||||||
'zh-Hans': [
|
'zh-Hans': [
|
||||||
{'name': '中文男', 'value': '中文男'},
|
{'name': '中文男', 'value': '中文男'},
|
||||||
{'name': '中文女', 'value': '中文女'},
|
{'name': '中文女', 'value': '中文女'},
|
||||||
@ -77,18 +79,21 @@ class XinferenceText2SpeechModel(TTSModel):
|
|||||||
if credentials['server_url'].endswith('/'):
|
if credentials['server_url'].endswith('/'):
|
||||||
credentials['server_url'] = credentials['server_url'][:-1]
|
credentials['server_url'] = credentials['server_url'][:-1]
|
||||||
|
|
||||||
# initialize client
|
extra_param = XinferenceHelper.get_xinference_extra_parameter(
|
||||||
client = Client(
|
server_url=credentials['server_url'],
|
||||||
base_url=credentials['server_url']
|
model_uid=credentials['model_uid']
|
||||||
)
|
)
|
||||||
|
|
||||||
xinference_client = client.get_model(model_uid=credentials['model_uid'])
|
if 'text-to-audio' not in extra_param.model_ability:
|
||||||
|
|
||||||
if not isinstance(xinference_client, RESTfulAudioModelHandle):
|
|
||||||
raise InvokeBadRequestError(
|
raise InvokeBadRequestError(
|
||||||
'please check model type, the model you want to invoke is not a audio model')
|
'please check model type, the model you want to invoke is not a text-to-audio model')
|
||||||
|
|
||||||
self._tts_invoke(
|
if extra_param.model_family and extra_param.model_family in self.model_voices:
|
||||||
|
credentials['audio_model_name'] = extra_param.model_family
|
||||||
|
else:
|
||||||
|
credentials['audio_model_name'] = '__default'
|
||||||
|
|
||||||
|
self._tts_invoke_streaming(
|
||||||
model=model,
|
model=model,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
content_text='Hello Dify!',
|
content_text='Hello Dify!',
|
||||||
@ -110,7 +115,7 @@ class XinferenceText2SpeechModel(TTSModel):
|
|||||||
:param user: unique user id
|
:param user: unique user id
|
||||||
:return: text translated to audio file
|
:return: text translated to audio file
|
||||||
"""
|
"""
|
||||||
return self._tts_invoke(model, credentials, content_text, voice)
|
return self._tts_invoke_streaming(model, credentials, content_text, voice)
|
||||||
|
|
||||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||||
"""
|
"""
|
||||||
@ -161,13 +166,15 @@ class XinferenceText2SpeechModel(TTSModel):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
|
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
|
||||||
|
audio_model_name = credentials.get('audio_model_name', '__default')
|
||||||
for key, voices in self.model_voices.items():
|
for key, voices in self.model_voices.items():
|
||||||
if key in model.lower():
|
if key in audio_model_name:
|
||||||
if language in voices:
|
if language and language in voices:
|
||||||
return voices[language]
|
return voices[language]
|
||||||
elif 'all' in voices:
|
elif 'all' in voices:
|
||||||
return voices['all']
|
return voices['all']
|
||||||
return []
|
|
||||||
|
return self.model_voices['__default']['all']
|
||||||
|
|
||||||
def _get_model_default_voice(self, model: str, credentials: dict) -> any:
|
def _get_model_default_voice(self, model: str, credentials: dict) -> any:
|
||||||
return ""
|
return ""
|
||||||
@ -181,60 +188,55 @@ class XinferenceText2SpeechModel(TTSModel):
|
|||||||
def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
|
def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
|
||||||
return 5
|
return 5
|
||||||
|
|
||||||
def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> any:
|
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str,
|
||||||
|
voice: str) -> any:
|
||||||
"""
|
"""
|
||||||
_tts_invoke text2speech model
|
_tts_invoke_streaming text2speech model
|
||||||
|
|
||||||
:param model: model name
|
:param model: model name
|
||||||
:param credentials: model credentials
|
:param credentials: model credentials
|
||||||
:param voice: model timbre
|
|
||||||
:param content_text: text content to be translated
|
:param content_text: text content to be translated
|
||||||
|
:param voice: model timbre
|
||||||
:return: text translated to audio file
|
:return: text translated to audio file
|
||||||
"""
|
"""
|
||||||
if credentials['server_url'].endswith('/'):
|
if credentials['server_url'].endswith('/'):
|
||||||
credentials['server_url'] = credentials['server_url'][:-1]
|
credentials['server_url'] = credentials['server_url'][:-1]
|
||||||
|
|
||||||
word_limit = self._get_model_word_limit(model, credentials)
|
try:
|
||||||
audio_type = self._get_model_audio_type(model, credentials)
|
|
||||||
handle = RESTfulAudioModelHandle(credentials['model_uid'], credentials['server_url'], auth_headers={})
|
handle = RESTfulAudioModelHandle(credentials['model_uid'], credentials['server_url'], auth_headers={})
|
||||||
|
|
||||||
try:
|
model_support_voice = [x.get("value") for x in
|
||||||
sentences = list(self._split_text_into_sentences(org_text=content_text, max_length=word_limit))
|
self.get_tts_model_voices(model=model, credentials=credentials)]
|
||||||
audio_bytes_list = []
|
if not voice or voice not in model_support_voice:
|
||||||
|
voice = self._get_model_default_voice(model, credentials)
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=min((3, len(sentences)))) as executor:
|
word_limit = self._get_model_word_limit(model, credentials)
|
||||||
|
if len(content_text) > word_limit:
|
||||||
|
sentences = self._split_text_into_sentences(content_text, max_length=word_limit)
|
||||||
|
executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences)))
|
||||||
futures = [executor.submit(
|
futures = [executor.submit(
|
||||||
handle.speech, input=sentence, voice=voice, response_format="mp3", speed=1.0, stream=False)
|
handle.speech,
|
||||||
for sentence in sentences]
|
input=sentences[i],
|
||||||
for future in futures:
|
voice=voice,
|
||||||
try:
|
response_format="mp3",
|
||||||
if future.result():
|
speed=1.0,
|
||||||
audio_bytes_list.append(future.result())
|
stream=False
|
||||||
|
)
|
||||||
|
for i in range(len(sentences))]
|
||||||
|
|
||||||
|
for index, future in enumerate(futures):
|
||||||
|
response = future.result()
|
||||||
|
for i in range(0, len(response), 1024):
|
||||||
|
yield response[i:i + 1024]
|
||||||
|
else:
|
||||||
|
response = handle.speech(
|
||||||
|
input=content_text.strip(),
|
||||||
|
voice=voice,
|
||||||
|
response_format="mp3",
|
||||||
|
speed=1.0,
|
||||||
|
stream=False
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(0, len(response), 1024):
|
||||||
|
yield response[i:i + 1024]
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise InvokeBadRequestError(str(ex))
|
raise InvokeBadRequestError(str(ex))
|
||||||
|
|
||||||
if len(audio_bytes_list) > 0:
|
|
||||||
audio_segments = [AudioSegment.from_file(
|
|
||||||
BytesIO(audio_bytes), format=audio_type) for audio_bytes in
|
|
||||||
audio_bytes_list if audio_bytes]
|
|
||||||
combined_segment = reduce(lambda x, y: x + y, audio_segments)
|
|
||||||
buffer: BytesIO = BytesIO()
|
|
||||||
combined_segment.export(buffer, format=audio_type)
|
|
||||||
buffer.seek(0)
|
|
||||||
return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}")
|
|
||||||
except Exception as ex:
|
|
||||||
raise InvokeBadRequestError(str(ex))
|
|
||||||
|
|
||||||
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any:
|
|
||||||
"""
|
|
||||||
_tts_invoke_streaming text2speech model
|
|
||||||
|
|
||||||
Attention: stream api may return error [Parallel generation is not supported by ggml]
|
|
||||||
|
|
||||||
:param model: model name
|
|
||||||
:param credentials: model credentials
|
|
||||||
:param voice: model timbre
|
|
||||||
:param content_text: text content to be translated
|
|
||||||
:return: text translated to audio file
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from threading import Lock
|
from threading import Lock
|
||||||
from time import time
|
from time import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from requests.adapters import HTTPAdapter
|
from requests.adapters import HTTPAdapter
|
||||||
from requests.exceptions import ConnectionError, MissingSchema, Timeout
|
from requests.exceptions import ConnectionError, MissingSchema, Timeout
|
||||||
@ -15,9 +16,11 @@ class XinferenceModelExtraParameter:
|
|||||||
context_length: int = 2048
|
context_length: int = 2048
|
||||||
support_function_call: bool = False
|
support_function_call: bool = False
|
||||||
support_vision: bool = False
|
support_vision: bool = False
|
||||||
|
model_family: Optional[str]
|
||||||
|
|
||||||
def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str],
|
def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str],
|
||||||
support_function_call: bool, support_vision: bool, max_tokens: int, context_length: int) -> None:
|
support_function_call: bool, support_vision: bool, max_tokens: int, context_length: int,
|
||||||
|
model_family: Optional[str]) -> None:
|
||||||
self.model_format = model_format
|
self.model_format = model_format
|
||||||
self.model_handle_type = model_handle_type
|
self.model_handle_type = model_handle_type
|
||||||
self.model_ability = model_ability
|
self.model_ability = model_ability
|
||||||
@ -25,6 +28,7 @@ class XinferenceModelExtraParameter:
|
|||||||
self.support_vision = support_vision
|
self.support_vision = support_vision
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
self.context_length = context_length
|
self.context_length = context_length
|
||||||
|
self.model_family = model_family
|
||||||
|
|
||||||
cache = {}
|
cache = {}
|
||||||
cache_lock = Lock()
|
cache_lock = Lock()
|
||||||
@ -78,9 +82,16 @@ class XinferenceHelper:
|
|||||||
|
|
||||||
model_format = response_json.get('model_format', 'ggmlv3')
|
model_format = response_json.get('model_format', 'ggmlv3')
|
||||||
model_ability = response_json.get('model_ability', [])
|
model_ability = response_json.get('model_ability', [])
|
||||||
|
model_family = response_json.get('model_family', None)
|
||||||
|
|
||||||
if response_json.get('model_type') == 'embedding':
|
if response_json.get('model_type') == 'embedding':
|
||||||
model_handle_type = 'embedding'
|
model_handle_type = 'embedding'
|
||||||
|
elif response_json.get('model_type') == 'audio':
|
||||||
|
model_handle_type = 'audio'
|
||||||
|
if model_family and model_family in ['ChatTTS', 'CosyVoice']:
|
||||||
|
model_ability.append('text-to-audio')
|
||||||
|
else:
|
||||||
|
model_ability.append('audio-to-text')
|
||||||
elif model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']:
|
elif model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']:
|
||||||
model_handle_type = 'chatglm'
|
model_handle_type = 'chatglm'
|
||||||
elif 'generate' in model_ability:
|
elif 'generate' in model_ability:
|
||||||
@ -88,7 +99,7 @@ class XinferenceHelper:
|
|||||||
elif 'chat' in model_ability:
|
elif 'chat' in model_ability:
|
||||||
model_handle_type = 'chat'
|
model_handle_type = 'chat'
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported')
|
raise NotImplementedError('xinference model handle type is not supported')
|
||||||
|
|
||||||
support_function_call = 'tools' in model_ability
|
support_function_call = 'tools' in model_ability
|
||||||
support_vision = 'vision' in model_ability
|
support_vision = 'vision' in model_ability
|
||||||
@ -103,5 +114,6 @@ class XinferenceHelper:
|
|||||||
support_function_call=support_function_call,
|
support_function_call=support_function_call,
|
||||||
support_vision=support_vision,
|
support_vision=support_vision,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
context_length=context_length
|
context_length=context_length,
|
||||||
|
model_family=model_family
|
||||||
)
|
)
|
Loading…
x
Reference in New Issue
Block a user