From 6e9066ebf4edd278c8a4c080f0dbebc9630d5a41 Mon Sep 17 00:00:00 2001 From: sino Date: Thu, 16 May 2024 11:41:24 +0800 Subject: [PATCH] feat: support doubao llm and embeding models (#4431) --- .../volcengine_maas/llm/models.py | 62 ++++++++++++++++++- .../volcengine_maas/text_embedding/models.py | 9 +++ .../text_embedding/text_embedding.py | 42 ++++++++++++- .../volcengine_maas/volcengine_maas.yaml | 47 ++++++++++++-- .../volcengine_maas/test_embedding.py | 4 ++ 5 files changed, 156 insertions(+), 8 deletions(-) create mode 100644 api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py index d022f0069b..2e8ff314fc 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py @@ -1,4 +1,64 @@ ModelConfigs = { + 'Doubao-pro-4k': { + 'req_params': { + 'max_prompt_tokens': 4096, + 'max_new_tokens': 4096, + }, + 'model_properties': { + 'context_size': 4096, + 'mode': 'chat', + } + }, + 'Doubao-lite-4k': { + 'req_params': { + 'max_prompt_tokens': 4096, + 'max_new_tokens': 4096, + }, + 'model_properties': { + 'context_size': 4096, + 'mode': 'chat', + } + }, + 'Doubao-pro-32k': { + 'req_params': { + 'max_prompt_tokens': 32768, + 'max_new_tokens': 32768, + }, + 'model_properties': { + 'context_size': 32768, + 'mode': 'chat', + } + }, + 'Doubao-lite-32k': { + 'req_params': { + 'max_prompt_tokens': 32768, + 'max_new_tokens': 32768, + }, + 'model_properties': { + 'context_size': 32768, + 'mode': 'chat', + } + }, + 'Doubao-pro-128k': { + 'req_params': { + 'max_prompt_tokens': 131072, + 'max_new_tokens': 131072, + }, + 'model_properties': { + 'context_size': 131072, + 'mode': 'chat', + } + }, + 'Doubao-lite-128k': { + 'req_params': { + 'max_prompt_tokens': 131072, + 'max_new_tokens': 131072, + }, + 'model_properties': { + 'context_size': 131072, + 'mode': 'chat', + } + }, 'Skylark2-pro-4k': { 'req_params': { 'max_prompt_tokens': 4096, @@ -8,5 +68,5 @@ ModelConfigs = { 'context_size': 4096, 'mode': 'chat', } - } + }, } diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py new file mode 100644 index 0000000000..569f89e975 --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py @@ -0,0 +1,9 @@ +ModelConfigs = { + 'Doubao-embedding': { + 'req_params': {}, + 'model_properties': { + 'context_size': 4096, + 'max_chunks': 1, + } + }, +} diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py index d63399aec2..10b01c0d0d 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py @@ -1,7 +1,16 @@ import time +from decimal import Decimal from typing import Optional -from core.model_runtime.entities.model_entities import PriceType +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelPropertyKey, + ModelType, + PriceConfig, + PriceType, +) from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( InvokeAuthorizationError, @@ -21,6 +30,7 @@ from core.model_runtime.model_providers.volcengine_maas.errors import ( RateLimitErrors, ServerUnavailableErrors, ) +from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import ModelConfigs from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException @@ -45,7 +55,7 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): resp = MaaSClient.wrap_exception(lambda: client.embeddings(texts)) usage = self._calc_response_usage( - model=model, credentials=credentials, tokens=resp['total_tokens']) + model=model, credentials=credentials, tokens=resp['usage']['total_tokens']) result = TextEmbeddingResult( model=model, @@ -101,6 +111,34 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): InvokeBadRequestError: BadRequestErrors.values(), } + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + generate custom model entities from credentials + """ + model_properties = ModelConfigs.get( + credentials['base_model_name'], {}).get('model_properties', {}).copy() + if credentials.get('context_size'): + model_properties[ModelPropertyKey.CONTEXT_SIZE] = int( + credentials.get('context_size', 4096)) + if credentials.get('max_chunks'): + model_properties[ModelPropertyKey.MAX_CHUNKS] = int( + credentials.get('max_chunks', 4096)) + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.TEXT_EMBEDDING, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties=model_properties, + parameter_rules=[], + pricing=PriceConfig( + input=Decimal(credentials.get('input_price', 0)), + unit=Decimal(credentials.get('unit', 0)), + currency=credentials.get('currency', "USD") + ) + ) + + return entity + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml b/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml index 4f299ecae0..d7bcbd43f8 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml +++ b/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml @@ -76,21 +76,60 @@ model_credential_schema: en_US: Enter your Endpoint ID zh_Hans: 输入您的 Endpoint ID - variable: base_model_name - show_on: - - variable: __model_type - value: llm label: en_US: Base Model zh_Hans: 基础模型 type: select required: true options: + - label: + en_US: Doubao-pro-4k + value: Doubao-pro-4k + show_on: + - variable: __model_type + value: llm + - label: + en_US: Doubao-lite-4k + value: Doubao-lite-4k + show_on: + - variable: __model_type + value: llm + - label: + en_US: Doubao-pro-32k + value: Doubao-pro-32k + show_on: + - variable: __model_type + value: llm + - label: + en_US: Doubao-lite-32k + value: Doubao-lite-32k + show_on: + - variable: __model_type + value: llm + - label: + en_US: Doubao-pro-128k + value: Doubao-pro-128k + show_on: + - variable: __model_type + value: llm + - label: + en_US: Doubao-lite-128k + value: Doubao-lite-128k + show_on: + - variable: __model_type + value: llm - label: en_US: Skylark2-pro-4k value: Skylark2-pro-4k show_on: - variable: __model_type value: llm + - label: + en_US: Doubao-embedding + value: Doubao-embedding + show_on: + - variable: __model_type + value: text-embedding - label: en_US: Custom zh_Hans: 自定义 @@ -122,8 +161,6 @@ model_credential_schema: - variable: context_size required: true show_on: - - variable: __model_type - value: llm - variable: base_model_name value: Custom label: diff --git a/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py b/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py index 61e9f704af..3b399d604e 100644 --- a/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py +++ b/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py @@ -21,6 +21,7 @@ def test_validate_credentials(): 'volc_access_key_id': 'INVALID', 'volc_secret_access_key': 'INVALID', 'endpoint_id': 'INVALID', + 'base_model_name': 'Doubao-embedding', } ) @@ -32,6 +33,7 @@ def test_validate_credentials(): 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'), + 'base_model_name': 'Doubao-embedding', }, ) @@ -47,6 +49,7 @@ def test_invoke_model(): 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'), + 'base_model_name': 'Doubao-embedding', }, texts=[ "hello", @@ -71,6 +74,7 @@ def test_get_num_tokens(): 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'), + 'base_model_name': 'Doubao-embedding', }, texts=[ "hello",