mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 19:08:58 +08:00
feat:support azure whisper model and fix:rename text-embedidng-ada-002.yaml to text-embedding-ada-002.yaml (#2732)
This commit is contained in:
parent
8fe83750b7
commit
9819ad347f
@ -526,3 +526,20 @@ EMBEDDING_BASE_MODELS = [
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
SPEECH2TEXT_BASE_MODELS = [
|
||||||
|
AzureBaseModel(
|
||||||
|
base_model_name='whisper-1',
|
||||||
|
entity=AIModelEntity(
|
||||||
|
model='fake-deployment-name',
|
||||||
|
label=I18nObject(
|
||||||
|
en_US='fake-deployment-name-label'
|
||||||
|
),
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_type=ModelType.SPEECH2TEXT,
|
||||||
|
model_properties={
|
||||||
|
ModelPropertyKey.FILE_UPLOAD_LIMIT: 25,
|
||||||
|
ModelPropertyKey.SUPPORTED_FILE_EXTENSIONS: 'flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm'
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
@ -15,6 +15,7 @@ help:
|
|||||||
supported_model_types:
|
supported_model_types:
|
||||||
- llm
|
- llm
|
||||||
- text-embedding
|
- text-embedding
|
||||||
|
- speech2text
|
||||||
configurate_methods:
|
configurate_methods:
|
||||||
- customizable-model
|
- customizable-model
|
||||||
model_credential_schema:
|
model_credential_schema:
|
||||||
@ -99,6 +100,12 @@ model_credential_schema:
|
|||||||
show_on:
|
show_on:
|
||||||
- variable: __model_type
|
- variable: __model_type
|
||||||
value: text-embedding
|
value: text-embedding
|
||||||
|
- label:
|
||||||
|
en_US: whisper-1
|
||||||
|
value: whisper-1
|
||||||
|
show_on:
|
||||||
|
- variable: __model_type
|
||||||
|
value: speech2text
|
||||||
placeholder:
|
placeholder:
|
||||||
zh_Hans: 在此输入您的模型版本
|
zh_Hans: 在此输入您的模型版本
|
||||||
en_US: Enter your model version
|
en_US: Enter your model version
|
||||||
|
@ -0,0 +1,81 @@
|
|||||||
|
from typing import IO, Optional
|
||||||
|
|
||||||
|
from openai import AzureOpenAI
|
||||||
|
|
||||||
|
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||||
|
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
|
||||||
|
from core.model_runtime.model_providers.azure_openai._constant import SPEECH2TEXT_BASE_MODELS, AzureBaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel):
|
||||||
|
"""
|
||||||
|
Model class for OpenAI Speech to text model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _invoke(self, model: str, credentials: dict,
|
||||||
|
file: IO[bytes], user: Optional[str] = None) \
|
||||||
|
-> str:
|
||||||
|
"""
|
||||||
|
Invoke speech2text model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param file: audio file
|
||||||
|
:param user: unique user id
|
||||||
|
:return: text for given audio file
|
||||||
|
"""
|
||||||
|
return self._speech2text_invoke(model, credentials, file)
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
audio_file_path = self._get_demo_file_path()
|
||||||
|
|
||||||
|
with open(audio_file_path, 'rb') as audio_file:
|
||||||
|
self._speech2text_invoke(model, credentials, audio_file)
|
||||||
|
except Exception as ex:
|
||||||
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
|
def _speech2text_invoke(self, model: str, credentials: dict, file: IO[bytes]) -> str:
|
||||||
|
"""
|
||||||
|
Invoke speech2text model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param file: audio file
|
||||||
|
:return: text for given audio file
|
||||||
|
"""
|
||||||
|
# transform credentials to kwargs for model instance
|
||||||
|
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||||
|
|
||||||
|
# init model client
|
||||||
|
client = AzureOpenAI(**credentials_kwargs)
|
||||||
|
|
||||||
|
response = client.audio.transcriptions.create(model=model, file=file)
|
||||||
|
|
||||||
|
return response.text
|
||||||
|
|
||||||
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||||
|
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
|
||||||
|
return ai_model_entity.entity
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
|
||||||
|
for ai_model_entity in SPEECH2TEXT_BASE_MODELS:
|
||||||
|
if ai_model_entity.base_model_name == base_model_name:
|
||||||
|
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
|
||||||
|
ai_model_entity_copy.entity.model = model
|
||||||
|
ai_model_entity_copy.entity.label.en_US = model
|
||||||
|
ai_model_entity_copy.entity.label.zh_Hans = model
|
||||||
|
return ai_model_entity_copy
|
||||||
|
|
||||||
|
return None
|
@ -2,4 +2,4 @@ model: whisper-1
|
|||||||
model_type: speech2text
|
model_type: speech2text
|
||||||
model_properties:
|
model_properties:
|
||||||
file_upload_limit: 25
|
file_upload_limit: 25
|
||||||
supported_file_extensions: mp3,mp4,mpeg,mpga,m4a,wav,webm
|
supported_file_extensions: flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm
|
||||||
|
Loading…
x
Reference in New Issue
Block a user