enhance:speedup xinference audio transcription (#3636)

This commit is contained in:
呆萌闷油瓶 2024-04-23 17:09:30 +08:00 committed by GitHub
parent 83caffe000
commit f76ac8bdee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -47,6 +47,20 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
if credentials['server_url'].endswith('/'):
credentials['server_url'] = credentials['server_url'][:-1]
# initialize client
client = Client(
base_url=credentials['server_url']
)
xinference_client = client.get_model(model_uid=credentials['model_uid'])
if not isinstance(xinference_client, RESTfulAudioModelHandle):
raise InvokeBadRequestError(
'please check model type, the model you want to invoke is not a audio model')
audio_file_path = self._get_demo_file_path()
with open(audio_file_path, 'rb') as audio_file:
@ -110,17 +124,8 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
if credentials['server_url'].endswith('/'):
credentials['server_url'] = credentials['server_url'][:-1]
# initialize client
client = Client(
base_url=credentials['server_url']
)
xinference_client = client.get_model(model_uid=credentials['model_uid'])
if not isinstance(xinference_client, RESTfulAudioModelHandle):
raise InvokeBadRequestError('please check model type, the model you want to invoke is not a audio model')
response = xinference_client.transcriptions(
handle = RESTfulAudioModelHandle(credentials['model_uid'],credentials['server_url'],auth_headers={})
response = handle.transcriptions(
audio=file,
language = language,
prompt = prompt,