From f76ac8bdee022093d847094818e05f94ec24b309 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=86=E8=90=8C=E9=97=B7=E6=B2=B9=E7=93=B6?= <253605712@qq.com> Date: Tue, 23 Apr 2024 17:09:30 +0800 Subject: [PATCH] enhance:speedup xinference audio transcription (#3636) --- .../xinference/speech2text/speech2text.py | 27 +++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py index 35269fceca..f60d8d3443 100644 --- a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py @@ -47,6 +47,20 @@ class XinferenceSpeech2TextModel(Speech2TextModel): if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") + if credentials['server_url'].endswith('/'): + credentials['server_url'] = credentials['server_url'][:-1] + + # initialize client + client = Client( + base_url=credentials['server_url'] + ) + + xinference_client = client.get_model(model_uid=credentials['model_uid']) + + if not isinstance(xinference_client, RESTfulAudioModelHandle): + raise InvokeBadRequestError( + 'please check model type, the model you want to invoke is not a audio model') + audio_file_path = self._get_demo_file_path() with open(audio_file_path, 'rb') as audio_file: @@ -110,17 +124,8 @@ class XinferenceSpeech2TextModel(Speech2TextModel): if credentials['server_url'].endswith('/'): credentials['server_url'] = credentials['server_url'][:-1] - # initialize client - client = Client( - base_url=credentials['server_url'] - ) - - xinference_client = client.get_model(model_uid=credentials['model_uid']) - - if not isinstance(xinference_client, RESTfulAudioModelHandle): - raise InvokeBadRequestError('please check model type, the model you want to invoke is not a audio model') - - response = xinference_client.transcriptions( + handle = RESTfulAudioModelHandle(credentials['model_uid'],credentials['server_url'],auth_headers={}) + response = handle.transcriptions( audio=file, language = language, prompt = prompt,