feat: support doubao llm and embeding models (#4431)

This commit is contained in:
sino 2024-05-16 11:41:24 +08:00 committed by GitHub
parent dd94931116
commit 6e9066ebf4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 156 additions and 8 deletions

View File

@ -1,4 +1,64 @@
ModelConfigs = { ModelConfigs = {
'Doubao-pro-4k': {
'req_params': {
'max_prompt_tokens': 4096,
'max_new_tokens': 4096,
},
'model_properties': {
'context_size': 4096,
'mode': 'chat',
}
},
'Doubao-lite-4k': {
'req_params': {
'max_prompt_tokens': 4096,
'max_new_tokens': 4096,
},
'model_properties': {
'context_size': 4096,
'mode': 'chat',
}
},
'Doubao-pro-32k': {
'req_params': {
'max_prompt_tokens': 32768,
'max_new_tokens': 32768,
},
'model_properties': {
'context_size': 32768,
'mode': 'chat',
}
},
'Doubao-lite-32k': {
'req_params': {
'max_prompt_tokens': 32768,
'max_new_tokens': 32768,
},
'model_properties': {
'context_size': 32768,
'mode': 'chat',
}
},
'Doubao-pro-128k': {
'req_params': {
'max_prompt_tokens': 131072,
'max_new_tokens': 131072,
},
'model_properties': {
'context_size': 131072,
'mode': 'chat',
}
},
'Doubao-lite-128k': {
'req_params': {
'max_prompt_tokens': 131072,
'max_new_tokens': 131072,
},
'model_properties': {
'context_size': 131072,
'mode': 'chat',
}
},
'Skylark2-pro-4k': { 'Skylark2-pro-4k': {
'req_params': { 'req_params': {
'max_prompt_tokens': 4096, 'max_prompt_tokens': 4096,
@ -8,5 +68,5 @@ ModelConfigs = {
'context_size': 4096, 'context_size': 4096,
'mode': 'chat', 'mode': 'chat',
} }
} },
} }

View File

@ -0,0 +1,9 @@
ModelConfigs = {
'Doubao-embedding': {
'req_params': {},
'model_properties': {
'context_size': 4096,
'max_chunks': 1,
}
},
}

View File

@ -1,7 +1,16 @@
import time import time
from decimal import Decimal
from typing import Optional from typing import Optional
from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
ModelPropertyKey,
ModelType,
PriceConfig,
PriceType,
)
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import ( from core.model_runtime.errors.invoke import (
InvokeAuthorizationError, InvokeAuthorizationError,
@ -21,6 +30,7 @@ from core.model_runtime.model_providers.volcengine_maas.errors import (
RateLimitErrors, RateLimitErrors,
ServerUnavailableErrors, ServerUnavailableErrors,
) )
from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import ModelConfigs
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
@ -45,7 +55,7 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
resp = MaaSClient.wrap_exception(lambda: client.embeddings(texts)) resp = MaaSClient.wrap_exception(lambda: client.embeddings(texts))
usage = self._calc_response_usage( usage = self._calc_response_usage(
model=model, credentials=credentials, tokens=resp['total_tokens']) model=model, credentials=credentials, tokens=resp['usage']['total_tokens'])
result = TextEmbeddingResult( result = TextEmbeddingResult(
model=model, model=model,
@ -101,6 +111,34 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
InvokeBadRequestError: BadRequestErrors.values(), InvokeBadRequestError: BadRequestErrors.values(),
} }
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
generate custom model entities from credentials
"""
model_properties = ModelConfigs.get(
credentials['base_model_name'], {}).get('model_properties', {}).copy()
if credentials.get('context_size'):
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
credentials.get('context_size', 4096))
if credentials.get('max_chunks'):
model_properties[ModelPropertyKey.MAX_CHUNKS] = int(
credentials.get('max_chunks', 4096))
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.TEXT_EMBEDDING,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties=model_properties,
parameter_rules=[],
pricing=PriceConfig(
input=Decimal(credentials.get('input_price', 0)),
unit=Decimal(credentials.get('unit', 0)),
currency=credentials.get('currency', "USD")
)
)
return entity
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
""" """
Calculate response usage Calculate response usage

View File

@ -76,21 +76,60 @@ model_credential_schema:
en_US: Enter your Endpoint ID en_US: Enter your Endpoint ID
zh_Hans: 输入您的 Endpoint ID zh_Hans: 输入您的 Endpoint ID
- variable: base_model_name - variable: base_model_name
show_on:
- variable: __model_type
value: llm
label: label:
en_US: Base Model en_US: Base Model
zh_Hans: 基础模型 zh_Hans: 基础模型
type: select type: select
required: true required: true
options: options:
- label:
en_US: Doubao-pro-4k
value: Doubao-pro-4k
show_on:
- variable: __model_type
value: llm
- label:
en_US: Doubao-lite-4k
value: Doubao-lite-4k
show_on:
- variable: __model_type
value: llm
- label:
en_US: Doubao-pro-32k
value: Doubao-pro-32k
show_on:
- variable: __model_type
value: llm
- label:
en_US: Doubao-lite-32k
value: Doubao-lite-32k
show_on:
- variable: __model_type
value: llm
- label:
en_US: Doubao-pro-128k
value: Doubao-pro-128k
show_on:
- variable: __model_type
value: llm
- label:
en_US: Doubao-lite-128k
value: Doubao-lite-128k
show_on:
- variable: __model_type
value: llm
- label: - label:
en_US: Skylark2-pro-4k en_US: Skylark2-pro-4k
value: Skylark2-pro-4k value: Skylark2-pro-4k
show_on: show_on:
- variable: __model_type - variable: __model_type
value: llm value: llm
- label:
en_US: Doubao-embedding
value: Doubao-embedding
show_on:
- variable: __model_type
value: text-embedding
- label: - label:
en_US: Custom en_US: Custom
zh_Hans: 自定义 zh_Hans: 自定义
@ -122,8 +161,6 @@ model_credential_schema:
- variable: context_size - variable: context_size
required: true required: true
show_on: show_on:
- variable: __model_type
value: llm
- variable: base_model_name - variable: base_model_name
value: Custom value: Custom
label: label:

View File

@ -21,6 +21,7 @@ def test_validate_credentials():
'volc_access_key_id': 'INVALID', 'volc_access_key_id': 'INVALID',
'volc_secret_access_key': 'INVALID', 'volc_secret_access_key': 'INVALID',
'endpoint_id': 'INVALID', 'endpoint_id': 'INVALID',
'base_model_name': 'Doubao-embedding',
} }
) )
@ -32,6 +33,7 @@ def test_validate_credentials():
'volc_access_key_id': os.environ.get('VOLC_API_KEY'), 'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'), 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
'base_model_name': 'Doubao-embedding',
}, },
) )
@ -47,6 +49,7 @@ def test_invoke_model():
'volc_access_key_id': os.environ.get('VOLC_API_KEY'), 'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'), 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
'base_model_name': 'Doubao-embedding',
}, },
texts=[ texts=[
"hello", "hello",
@ -71,6 +74,7 @@ def test_get_num_tokens():
'volc_access_key_id': os.environ.get('VOLC_API_KEY'), 'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'), 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
'base_model_name': 'Doubao-embedding',
}, },
texts=[ texts=[
"hello", "hello",