mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 22:35:54 +08:00
feat: support doubao llm and embeding models (#4431)
This commit is contained in:
parent
dd94931116
commit
6e9066ebf4
@ -1,4 +1,64 @@
|
|||||||
ModelConfigs = {
|
ModelConfigs = {
|
||||||
|
'Doubao-pro-4k': {
|
||||||
|
'req_params': {
|
||||||
|
'max_prompt_tokens': 4096,
|
||||||
|
'max_new_tokens': 4096,
|
||||||
|
},
|
||||||
|
'model_properties': {
|
||||||
|
'context_size': 4096,
|
||||||
|
'mode': 'chat',
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'Doubao-lite-4k': {
|
||||||
|
'req_params': {
|
||||||
|
'max_prompt_tokens': 4096,
|
||||||
|
'max_new_tokens': 4096,
|
||||||
|
},
|
||||||
|
'model_properties': {
|
||||||
|
'context_size': 4096,
|
||||||
|
'mode': 'chat',
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'Doubao-pro-32k': {
|
||||||
|
'req_params': {
|
||||||
|
'max_prompt_tokens': 32768,
|
||||||
|
'max_new_tokens': 32768,
|
||||||
|
},
|
||||||
|
'model_properties': {
|
||||||
|
'context_size': 32768,
|
||||||
|
'mode': 'chat',
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'Doubao-lite-32k': {
|
||||||
|
'req_params': {
|
||||||
|
'max_prompt_tokens': 32768,
|
||||||
|
'max_new_tokens': 32768,
|
||||||
|
},
|
||||||
|
'model_properties': {
|
||||||
|
'context_size': 32768,
|
||||||
|
'mode': 'chat',
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'Doubao-pro-128k': {
|
||||||
|
'req_params': {
|
||||||
|
'max_prompt_tokens': 131072,
|
||||||
|
'max_new_tokens': 131072,
|
||||||
|
},
|
||||||
|
'model_properties': {
|
||||||
|
'context_size': 131072,
|
||||||
|
'mode': 'chat',
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'Doubao-lite-128k': {
|
||||||
|
'req_params': {
|
||||||
|
'max_prompt_tokens': 131072,
|
||||||
|
'max_new_tokens': 131072,
|
||||||
|
},
|
||||||
|
'model_properties': {
|
||||||
|
'context_size': 131072,
|
||||||
|
'mode': 'chat',
|
||||||
|
}
|
||||||
|
},
|
||||||
'Skylark2-pro-4k': {
|
'Skylark2-pro-4k': {
|
||||||
'req_params': {
|
'req_params': {
|
||||||
'max_prompt_tokens': 4096,
|
'max_prompt_tokens': 4096,
|
||||||
@ -8,5 +68,5 @@ ModelConfigs = {
|
|||||||
'context_size': 4096,
|
'context_size': 4096,
|
||||||
'mode': 'chat',
|
'mode': 'chat',
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,9 @@
|
|||||||
|
ModelConfigs = {
|
||||||
|
'Doubao-embedding': {
|
||||||
|
'req_params': {},
|
||||||
|
'model_properties': {
|
||||||
|
'context_size': 4096,
|
||||||
|
'max_chunks': 1,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
@ -1,7 +1,16 @@
|
|||||||
import time
|
import time
|
||||||
|
from decimal import Decimal
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from core.model_runtime.entities.model_entities import PriceType
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
from core.model_runtime.entities.model_entities import (
|
||||||
|
AIModelEntity,
|
||||||
|
FetchFrom,
|
||||||
|
ModelPropertyKey,
|
||||||
|
ModelType,
|
||||||
|
PriceConfig,
|
||||||
|
PriceType,
|
||||||
|
)
|
||||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
from core.model_runtime.errors.invoke import (
|
from core.model_runtime.errors.invoke import (
|
||||||
InvokeAuthorizationError,
|
InvokeAuthorizationError,
|
||||||
@ -21,6 +30,7 @@ from core.model_runtime.model_providers.volcengine_maas.errors import (
|
|||||||
RateLimitErrors,
|
RateLimitErrors,
|
||||||
ServerUnavailableErrors,
|
ServerUnavailableErrors,
|
||||||
)
|
)
|
||||||
|
from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import ModelConfigs
|
||||||
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
|
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
|
||||||
|
|
||||||
|
|
||||||
@ -45,7 +55,7 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
resp = MaaSClient.wrap_exception(lambda: client.embeddings(texts))
|
resp = MaaSClient.wrap_exception(lambda: client.embeddings(texts))
|
||||||
|
|
||||||
usage = self._calc_response_usage(
|
usage = self._calc_response_usage(
|
||||||
model=model, credentials=credentials, tokens=resp['total_tokens'])
|
model=model, credentials=credentials, tokens=resp['usage']['total_tokens'])
|
||||||
|
|
||||||
result = TextEmbeddingResult(
|
result = TextEmbeddingResult(
|
||||||
model=model,
|
model=model,
|
||||||
@ -101,6 +111,34 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
InvokeBadRequestError: BadRequestErrors.values(),
|
InvokeBadRequestError: BadRequestErrors.values(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||||
|
"""
|
||||||
|
generate custom model entities from credentials
|
||||||
|
"""
|
||||||
|
model_properties = ModelConfigs.get(
|
||||||
|
credentials['base_model_name'], {}).get('model_properties', {}).copy()
|
||||||
|
if credentials.get('context_size'):
|
||||||
|
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
|
||||||
|
credentials.get('context_size', 4096))
|
||||||
|
if credentials.get('max_chunks'):
|
||||||
|
model_properties[ModelPropertyKey.MAX_CHUNKS] = int(
|
||||||
|
credentials.get('max_chunks', 4096))
|
||||||
|
entity = AIModelEntity(
|
||||||
|
model=model,
|
||||||
|
label=I18nObject(en_US=model),
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_properties=model_properties,
|
||||||
|
parameter_rules=[],
|
||||||
|
pricing=PriceConfig(
|
||||||
|
input=Decimal(credentials.get('input_price', 0)),
|
||||||
|
unit=Decimal(credentials.get('unit', 0)),
|
||||||
|
currency=credentials.get('currency', "USD")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return entity
|
||||||
|
|
||||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||||
"""
|
"""
|
||||||
Calculate response usage
|
Calculate response usage
|
||||||
|
@ -76,21 +76,60 @@ model_credential_schema:
|
|||||||
en_US: Enter your Endpoint ID
|
en_US: Enter your Endpoint ID
|
||||||
zh_Hans: 输入您的 Endpoint ID
|
zh_Hans: 输入您的 Endpoint ID
|
||||||
- variable: base_model_name
|
- variable: base_model_name
|
||||||
show_on:
|
|
||||||
- variable: __model_type
|
|
||||||
value: llm
|
|
||||||
label:
|
label:
|
||||||
en_US: Base Model
|
en_US: Base Model
|
||||||
zh_Hans: 基础模型
|
zh_Hans: 基础模型
|
||||||
type: select
|
type: select
|
||||||
required: true
|
required: true
|
||||||
options:
|
options:
|
||||||
|
- label:
|
||||||
|
en_US: Doubao-pro-4k
|
||||||
|
value: Doubao-pro-4k
|
||||||
|
show_on:
|
||||||
|
- variable: __model_type
|
||||||
|
value: llm
|
||||||
|
- label:
|
||||||
|
en_US: Doubao-lite-4k
|
||||||
|
value: Doubao-lite-4k
|
||||||
|
show_on:
|
||||||
|
- variable: __model_type
|
||||||
|
value: llm
|
||||||
|
- label:
|
||||||
|
en_US: Doubao-pro-32k
|
||||||
|
value: Doubao-pro-32k
|
||||||
|
show_on:
|
||||||
|
- variable: __model_type
|
||||||
|
value: llm
|
||||||
|
- label:
|
||||||
|
en_US: Doubao-lite-32k
|
||||||
|
value: Doubao-lite-32k
|
||||||
|
show_on:
|
||||||
|
- variable: __model_type
|
||||||
|
value: llm
|
||||||
|
- label:
|
||||||
|
en_US: Doubao-pro-128k
|
||||||
|
value: Doubao-pro-128k
|
||||||
|
show_on:
|
||||||
|
- variable: __model_type
|
||||||
|
value: llm
|
||||||
|
- label:
|
||||||
|
en_US: Doubao-lite-128k
|
||||||
|
value: Doubao-lite-128k
|
||||||
|
show_on:
|
||||||
|
- variable: __model_type
|
||||||
|
value: llm
|
||||||
- label:
|
- label:
|
||||||
en_US: Skylark2-pro-4k
|
en_US: Skylark2-pro-4k
|
||||||
value: Skylark2-pro-4k
|
value: Skylark2-pro-4k
|
||||||
show_on:
|
show_on:
|
||||||
- variable: __model_type
|
- variable: __model_type
|
||||||
value: llm
|
value: llm
|
||||||
|
- label:
|
||||||
|
en_US: Doubao-embedding
|
||||||
|
value: Doubao-embedding
|
||||||
|
show_on:
|
||||||
|
- variable: __model_type
|
||||||
|
value: text-embedding
|
||||||
- label:
|
- label:
|
||||||
en_US: Custom
|
en_US: Custom
|
||||||
zh_Hans: 自定义
|
zh_Hans: 自定义
|
||||||
@ -122,8 +161,6 @@ model_credential_schema:
|
|||||||
- variable: context_size
|
- variable: context_size
|
||||||
required: true
|
required: true
|
||||||
show_on:
|
show_on:
|
||||||
- variable: __model_type
|
|
||||||
value: llm
|
|
||||||
- variable: base_model_name
|
- variable: base_model_name
|
||||||
value: Custom
|
value: Custom
|
||||||
label:
|
label:
|
||||||
|
@ -21,6 +21,7 @@ def test_validate_credentials():
|
|||||||
'volc_access_key_id': 'INVALID',
|
'volc_access_key_id': 'INVALID',
|
||||||
'volc_secret_access_key': 'INVALID',
|
'volc_secret_access_key': 'INVALID',
|
||||||
'endpoint_id': 'INVALID',
|
'endpoint_id': 'INVALID',
|
||||||
|
'base_model_name': 'Doubao-embedding',
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -32,6 +33,7 @@ def test_validate_credentials():
|
|||||||
'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
|
'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
|
||||||
'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
|
'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
|
||||||
'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
|
'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
|
||||||
|
'base_model_name': 'Doubao-embedding',
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -47,6 +49,7 @@ def test_invoke_model():
|
|||||||
'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
|
'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
|
||||||
'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
|
'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
|
||||||
'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
|
'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
|
||||||
|
'base_model_name': 'Doubao-embedding',
|
||||||
},
|
},
|
||||||
texts=[
|
texts=[
|
||||||
"hello",
|
"hello",
|
||||||
@ -71,6 +74,7 @@ def test_get_num_tokens():
|
|||||||
'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
|
'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
|
||||||
'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
|
'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
|
||||||
'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
|
'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
|
||||||
|
'base_model_name': 'Doubao-embedding',
|
||||||
},
|
},
|
||||||
texts=[
|
texts=[
|
||||||
"hello",
|
"hello",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user