feat: support doubao llm and embeding models (#4431)

This commit is contained in:
sino 2024-05-16 11:41:24 +08:00 committed by GitHub
parent dd94931116
commit 6e9066ebf4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 156 additions and 8 deletions

View File

@ -1,4 +1,64 @@
ModelConfigs = {
'Doubao-pro-4k': {
'req_params': {
'max_prompt_tokens': 4096,
'max_new_tokens': 4096,
},
'model_properties': {
'context_size': 4096,
'mode': 'chat',
}
},
'Doubao-lite-4k': {
'req_params': {
'max_prompt_tokens': 4096,
'max_new_tokens': 4096,
},
'model_properties': {
'context_size': 4096,
'mode': 'chat',
}
},
'Doubao-pro-32k': {
'req_params': {
'max_prompt_tokens': 32768,
'max_new_tokens': 32768,
},
'model_properties': {
'context_size': 32768,
'mode': 'chat',
}
},
'Doubao-lite-32k': {
'req_params': {
'max_prompt_tokens': 32768,
'max_new_tokens': 32768,
},
'model_properties': {
'context_size': 32768,
'mode': 'chat',
}
},
'Doubao-pro-128k': {
'req_params': {
'max_prompt_tokens': 131072,
'max_new_tokens': 131072,
},
'model_properties': {
'context_size': 131072,
'mode': 'chat',
}
},
'Doubao-lite-128k': {
'req_params': {
'max_prompt_tokens': 131072,
'max_new_tokens': 131072,
},
'model_properties': {
'context_size': 131072,
'mode': 'chat',
}
},
'Skylark2-pro-4k': {
'req_params': {
'max_prompt_tokens': 4096,
@ -8,5 +68,5 @@ ModelConfigs = {
'context_size': 4096,
'mode': 'chat',
}
}
},
}

View File

@ -0,0 +1,9 @@
ModelConfigs = {
'Doubao-embedding': {
'req_params': {},
'model_properties': {
'context_size': 4096,
'max_chunks': 1,
}
},
}

View File

@ -1,7 +1,16 @@
import time
from decimal import Decimal
from typing import Optional
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
ModelPropertyKey,
ModelType,
PriceConfig,
PriceType,
)
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
@ -21,6 +30,7 @@ from core.model_runtime.model_providers.volcengine_maas.errors import (
RateLimitErrors,
ServerUnavailableErrors,
)
from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import ModelConfigs
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
@ -45,7 +55,7 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
resp = MaaSClient.wrap_exception(lambda: client.embeddings(texts))
usage = self._calc_response_usage(
model=model, credentials=credentials, tokens=resp['total_tokens'])
model=model, credentials=credentials, tokens=resp['usage']['total_tokens'])
result = TextEmbeddingResult(
model=model,
@ -101,6 +111,34 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
InvokeBadRequestError: BadRequestErrors.values(),
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
generate custom model entities from credentials
"""
model_properties = ModelConfigs.get(
credentials['base_model_name'], {}).get('model_properties', {}).copy()
if credentials.get('context_size'):
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
credentials.get('context_size', 4096))
if credentials.get('max_chunks'):
model_properties[ModelPropertyKey.MAX_CHUNKS] = int(
credentials.get('max_chunks', 4096))
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.TEXT_EMBEDDING,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties=model_properties,
parameter_rules=[],
pricing=PriceConfig(
input=Decimal(credentials.get('input_price', 0)),
unit=Decimal(credentials.get('unit', 0)),
currency=credentials.get('currency', "USD")
)
)
return entity
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
"""
Calculate response usage

View File

@ -76,21 +76,60 @@ model_credential_schema:
en_US: Enter your Endpoint ID
zh_Hans: 输入您的 Endpoint ID
- variable: base_model_name
show_on:
- variable: __model_type
value: llm
label:
en_US: Base Model
zh_Hans: 基础模型
type: select
required: true
options:
- label:
en_US: Doubao-pro-4k
value: Doubao-pro-4k
show_on:
- variable: __model_type
value: llm
- label:
en_US: Doubao-lite-4k
value: Doubao-lite-4k
show_on:
- variable: __model_type
value: llm
- label:
en_US: Doubao-pro-32k
value: Doubao-pro-32k
show_on:
- variable: __model_type
value: llm
- label:
en_US: Doubao-lite-32k
value: Doubao-lite-32k
show_on:
- variable: __model_type
value: llm
- label:
en_US: Doubao-pro-128k
value: Doubao-pro-128k
show_on:
- variable: __model_type
value: llm
- label:
en_US: Doubao-lite-128k
value: Doubao-lite-128k
show_on:
- variable: __model_type
value: llm
- label:
en_US: Skylark2-pro-4k
value: Skylark2-pro-4k
show_on:
- variable: __model_type
value: llm
- label:
en_US: Doubao-embedding
value: Doubao-embedding
show_on:
- variable: __model_type
value: text-embedding
- label:
en_US: Custom
zh_Hans: 自定义
@ -122,8 +161,6 @@ model_credential_schema:
- variable: context_size
required: true
show_on:
- variable: __model_type
value: llm
- variable: base_model_name
value: Custom
label:

View File

@ -21,6 +21,7 @@ def test_validate_credentials():
'volc_access_key_id': 'INVALID',
'volc_secret_access_key': 'INVALID',
'endpoint_id': 'INVALID',
'base_model_name': 'Doubao-embedding',
}
)
@ -32,6 +33,7 @@ def test_validate_credentials():
'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
'base_model_name': 'Doubao-embedding',
},
)
@ -47,6 +49,7 @@ def test_invoke_model():
'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
'base_model_name': 'Doubao-embedding',
},
texts=[
"hello",
@ -71,6 +74,7 @@ def test_get_num_tokens():
'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
'base_model_name': 'Doubao-embedding',
},
texts=[
"hello",