diff --git a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py index 35269fceca..f60d8d3443 100644 --- a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py @@ -47,6 +47,20 @@ class XinferenceSpeech2TextModel(Speech2TextModel): if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") + if credentials['server_url'].endswith('/'): + credentials['server_url'] = credentials['server_url'][:-1] + + # initialize client + client = Client( + base_url=credentials['server_url'] + ) + + xinference_client = client.get_model(model_uid=credentials['model_uid']) + + if not isinstance(xinference_client, RESTfulAudioModelHandle): + raise InvokeBadRequestError( + 'please check model type, the model you want to invoke is not a audio model') + audio_file_path = self._get_demo_file_path() with open(audio_file_path, 'rb') as audio_file: @@ -110,17 +124,8 @@ class XinferenceSpeech2TextModel(Speech2TextModel): if credentials['server_url'].endswith('/'): credentials['server_url'] = credentials['server_url'][:-1] - # initialize client - client = Client( - base_url=credentials['server_url'] - ) - - xinference_client = client.get_model(model_uid=credentials['model_uid']) - - if not isinstance(xinference_client, RESTfulAudioModelHandle): - raise InvokeBadRequestError('please check model type, the model you want to invoke is not a audio model') - - response = xinference_client.transcriptions( + handle = RESTfulAudioModelHandle(credentials['model_uid'],credentials['server_url'],auth_headers={}) + response = handle.transcriptions( audio=file, language = language, prompt = prompt,