mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 04:18:58 +08:00
feat: support doubao llm and embeding models (#4431)
This commit is contained in:
parent
dd94931116
commit
6e9066ebf4
@ -1,4 +1,64 @@
|
||||
ModelConfigs = {
|
||||
'Doubao-pro-4k': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 4096,
|
||||
'max_new_tokens': 4096,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 4096,
|
||||
'mode': 'chat',
|
||||
}
|
||||
},
|
||||
'Doubao-lite-4k': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 4096,
|
||||
'max_new_tokens': 4096,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 4096,
|
||||
'mode': 'chat',
|
||||
}
|
||||
},
|
||||
'Doubao-pro-32k': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 32768,
|
||||
'max_new_tokens': 32768,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 32768,
|
||||
'mode': 'chat',
|
||||
}
|
||||
},
|
||||
'Doubao-lite-32k': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 32768,
|
||||
'max_new_tokens': 32768,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 32768,
|
||||
'mode': 'chat',
|
||||
}
|
||||
},
|
||||
'Doubao-pro-128k': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 131072,
|
||||
'max_new_tokens': 131072,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 131072,
|
||||
'mode': 'chat',
|
||||
}
|
||||
},
|
||||
'Doubao-lite-128k': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 131072,
|
||||
'max_new_tokens': 131072,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 131072,
|
||||
'mode': 'chat',
|
||||
}
|
||||
},
|
||||
'Skylark2-pro-4k': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 4096,
|
||||
@ -8,5 +68,5 @@ ModelConfigs = {
|
||||
'context_size': 4096,
|
||||
'mode': 'chat',
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
@ -0,0 +1,9 @@
|
||||
ModelConfigs = {
|
||||
'Doubao-embedding': {
|
||||
'req_params': {},
|
||||
'model_properties': {
|
||||
'context_size': 4096,
|
||||
'max_chunks': 1,
|
||||
}
|
||||
},
|
||||
}
|
@ -1,7 +1,16 @@
|
||||
import time
|
||||
from decimal import Decimal
|
||||
from typing import Optional
|
||||
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
PriceConfig,
|
||||
PriceType,
|
||||
)
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
@ -21,6 +30,7 @@ from core.model_runtime.model_providers.volcengine_maas.errors import (
|
||||
RateLimitErrors,
|
||||
ServerUnavailableErrors,
|
||||
)
|
||||
from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import ModelConfigs
|
||||
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
|
||||
|
||||
|
||||
@ -45,7 +55,7 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
|
||||
resp = MaaSClient.wrap_exception(lambda: client.embeddings(texts))
|
||||
|
||||
usage = self._calc_response_usage(
|
||||
model=model, credentials=credentials, tokens=resp['total_tokens'])
|
||||
model=model, credentials=credentials, tokens=resp['usage']['total_tokens'])
|
||||
|
||||
result = TextEmbeddingResult(
|
||||
model=model,
|
||||
@ -101,6 +111,34 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
|
||||
InvokeBadRequestError: BadRequestErrors.values(),
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
model_properties = ModelConfigs.get(
|
||||
credentials['base_model_name'], {}).get('model_properties', {}).copy()
|
||||
if credentials.get('context_size'):
|
||||
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
|
||||
credentials.get('context_size', 4096))
|
||||
if credentials.get('max_chunks'):
|
||||
model_properties[ModelPropertyKey.MAX_CHUNKS] = int(
|
||||
credentials.get('max_chunks', 4096))
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties=model_properties,
|
||||
parameter_rules=[],
|
||||
pricing=PriceConfig(
|
||||
input=Decimal(credentials.get('input_price', 0)),
|
||||
unit=Decimal(credentials.get('unit', 0)),
|
||||
currency=credentials.get('currency', "USD")
|
||||
)
|
||||
)
|
||||
|
||||
return entity
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
@ -76,21 +76,60 @@ model_credential_schema:
|
||||
en_US: Enter your Endpoint ID
|
||||
zh_Hans: 输入您的 Endpoint ID
|
||||
- variable: base_model_name
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
label:
|
||||
en_US: Base Model
|
||||
zh_Hans: 基础模型
|
||||
type: select
|
||||
required: true
|
||||
options:
|
||||
- label:
|
||||
en_US: Doubao-pro-4k
|
||||
value: Doubao-pro-4k
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: Doubao-lite-4k
|
||||
value: Doubao-lite-4k
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: Doubao-pro-32k
|
||||
value: Doubao-pro-32k
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: Doubao-lite-32k
|
||||
value: Doubao-lite-32k
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: Doubao-pro-128k
|
||||
value: Doubao-pro-128k
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: Doubao-lite-128k
|
||||
value: Doubao-lite-128k
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: Skylark2-pro-4k
|
||||
value: Skylark2-pro-4k
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: Doubao-embedding
|
||||
value: Doubao-embedding
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: text-embedding
|
||||
- label:
|
||||
en_US: Custom
|
||||
zh_Hans: 自定义
|
||||
@ -122,8 +161,6 @@ model_credential_schema:
|
||||
- variable: context_size
|
||||
required: true
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- variable: base_model_name
|
||||
value: Custom
|
||||
label:
|
||||
|
@ -21,6 +21,7 @@ def test_validate_credentials():
|
||||
'volc_access_key_id': 'INVALID',
|
||||
'volc_secret_access_key': 'INVALID',
|
||||
'endpoint_id': 'INVALID',
|
||||
'base_model_name': 'Doubao-embedding',
|
||||
}
|
||||
)
|
||||
|
||||
@ -32,6 +33,7 @@ def test_validate_credentials():
|
||||
'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
|
||||
'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
|
||||
'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
|
||||
'base_model_name': 'Doubao-embedding',
|
||||
},
|
||||
)
|
||||
|
||||
@ -47,6 +49,7 @@ def test_invoke_model():
|
||||
'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
|
||||
'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
|
||||
'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
|
||||
'base_model_name': 'Doubao-embedding',
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
@ -71,6 +74,7 @@ def test_get_num_tokens():
|
||||
'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
|
||||
'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
|
||||
'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
|
||||
'base_model_name': 'Doubao-embedding',
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
|
Loading…
x
Reference in New Issue
Block a user