diff --git a/api/core/model_runtime/model_providers/_position.yaml b/api/core/model_runtime/model_providers/_position.yaml
index bda53e4394..1c27b2b4aa 100644
--- a/api/core/model_runtime/model_providers/_position.yaml
+++ b/api/core/model_runtime/model_providers/_position.yaml
@@ -26,5 +26,6 @@
- yi
- openllm
- localai
+- volcengine_maas
- openai_api_compatible
- deepseek
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_l_en.svg b/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_l_en.svg
new file mode 100644
index 0000000000..616e90916b
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_l_en.svg
@@ -0,0 +1,23 @@
+
+
\ No newline at end of file
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_l_zh.svg b/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_l_zh.svg
new file mode 100644
index 0000000000..24b92195bd
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_l_zh.svg
@@ -0,0 +1,39 @@
+
+
\ No newline at end of file
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_s_en.svg
new file mode 100644
index 0000000000..e6454a89b7
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_s_en.svg
@@ -0,0 +1,8 @@
+
+
\ No newline at end of file
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/client.py b/api/core/model_runtime/model_providers/volcengine_maas/client.py
new file mode 100644
index 0000000000..c7bf4fde8c
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/client.py
@@ -0,0 +1,108 @@
+import re
+from collections.abc import Callable, Generator
+from typing import cast
+
+from core.model_runtime.entities.message_entities import (
+ AssistantPromptMessage,
+ ImagePromptMessageContent,
+ PromptMessage,
+ PromptMessageContentType,
+ SystemPromptMessage,
+ UserPromptMessage,
+)
+from core.model_runtime.model_providers.volcengine_maas.errors import wrap_error
+from core.model_runtime.model_providers.volcengine_maas.volc_sdk import ChatRole, MaasException, MaasService
+
+
+class MaaSClient(MaasService):
+ def __init__(self, host: str, region: str):
+ self.endpoint_id = None
+ super().__init__(host, region)
+
+ def set_endpoint_id(self, endpoint_id: str):
+ self.endpoint_id = endpoint_id
+
+ @classmethod
+ def from_credential(cls, credentials: dict) -> 'MaaSClient':
+ host = credentials['api_endpoint_host']
+ region = credentials['volc_region']
+ ak = credentials['volc_access_key_id']
+ sk = credentials['volc_secret_access_key']
+ endpoint_id = credentials['endpoint_id']
+
+ client = cls(host, region)
+ client.set_endpoint_id(endpoint_id)
+ client.set_ak(ak)
+ client.set_sk(sk)
+ return client
+
+ def chat(self, params: dict, messages: list[PromptMessage], stream=False) -> Generator | dict:
+ req = {
+ 'parameters': params,
+ 'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages]
+ }
+ if not stream:
+ return super().chat(
+ self.endpoint_id,
+ req,
+ )
+ return super().stream_chat(
+ self.endpoint_id,
+ req,
+ )
+
+ def embeddings(self, texts: list[str]) -> dict:
+ req = {
+ 'input': texts
+ }
+ return super().embeddings(self.endpoint_id, req)
+
+ @staticmethod
+ def convert_prompt_message_to_maas_message(message: PromptMessage) -> dict:
+ if isinstance(message, UserPromptMessage):
+ message = cast(UserPromptMessage, message)
+ if isinstance(message.content, str):
+ message_dict = {"role": ChatRole.USER,
+ "content": message.content}
+ else:
+ content = []
+ for message_content in message.content:
+ if message_content.type == PromptMessageContentType.TEXT:
+ raise ValueError(
+ 'Content object type only support image_url')
+ elif message_content.type == PromptMessageContentType.IMAGE:
+ message_content = cast(
+ ImagePromptMessageContent, message_content)
+ image_data = re.sub(
+ r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
+ content.append({
+ 'type': 'image_url',
+ 'image_url': {
+ 'url': '',
+ 'image_bytes': image_data,
+ 'detail': message_content.detail,
+ }
+ })
+
+ message_dict = {'role': ChatRole.USER, 'content': content}
+ elif isinstance(message, AssistantPromptMessage):
+ message = cast(AssistantPromptMessage, message)
+ message_dict = {'role': ChatRole.ASSISTANT,
+ 'content': message.content}
+ elif isinstance(message, SystemPromptMessage):
+ message = cast(SystemPromptMessage, message)
+ message_dict = {'role': ChatRole.SYSTEM,
+ 'content': message.content}
+ else:
+ raise ValueError(f"Got unknown PromptMessage type {message}")
+
+ return message_dict
+
+ @staticmethod
+ def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator:
+ try:
+ resp = fn()
+ except MaasException as e:
+ raise wrap_error(e)
+
+ return resp
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/errors.py b/api/core/model_runtime/model_providers/volcengine_maas/errors.py
new file mode 100644
index 0000000000..63397a456e
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/errors.py
@@ -0,0 +1,156 @@
+from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
+
+
+class ClientSDKRequestError(MaasException):
+ pass
+
+
+class SignatureDoesNotMatch(MaasException):
+ pass
+
+
+class RequestTimeout(MaasException):
+ pass
+
+
+class ServiceConnectionTimeout(MaasException):
+ pass
+
+
+class MissingAuthenticationHeader(MaasException):
+ pass
+
+
+class AuthenticationHeaderIsInvalid(MaasException):
+ pass
+
+
+class InternalServiceError(MaasException):
+ pass
+
+
+class MissingParameter(MaasException):
+ pass
+
+
+class InvalidParameter(MaasException):
+ pass
+
+
+class AuthenticationExpire(MaasException):
+ pass
+
+
+class EndpointIsInvalid(MaasException):
+ pass
+
+
+class EndpointIsNotEnable(MaasException):
+ pass
+
+
+class ModelNotSupportStreamMode(MaasException):
+ pass
+
+
+class ReqTextExistRisk(MaasException):
+ pass
+
+
+class RespTextExistRisk(MaasException):
+ pass
+
+
+class EndpointRateLimitExceeded(MaasException):
+ pass
+
+
+class ServiceConnectionRefused(MaasException):
+ pass
+
+
+class ServiceConnectionClosed(MaasException):
+ pass
+
+
+class UnauthorizedUserForEndpoint(MaasException):
+ pass
+
+
+class InvalidEndpointWithNoURL(MaasException):
+ pass
+
+
+class EndpointAccountRpmRateLimitExceeded(MaasException):
+ pass
+
+
+class EndpointAccountTpmRateLimitExceeded(MaasException):
+ pass
+
+
+class ServiceResourceWaitQueueFull(MaasException):
+ pass
+
+
+class EndpointIsPending(MaasException):
+ pass
+
+
+class ServiceNotOpen(MaasException):
+ pass
+
+
+AuthErrors = {
+ 'SignatureDoesNotMatch': SignatureDoesNotMatch,
+ 'MissingAuthenticationHeader': MissingAuthenticationHeader,
+ 'AuthenticationHeaderIsInvalid': AuthenticationHeaderIsInvalid,
+ 'AuthenticationExpire': AuthenticationExpire,
+ 'UnauthorizedUserForEndpoint': UnauthorizedUserForEndpoint,
+}
+
+BadRequestErrors = {
+ 'MissingParameter': MissingParameter,
+ 'InvalidParameter': InvalidParameter,
+ 'EndpointIsInvalid': EndpointIsInvalid,
+ 'EndpointIsNotEnable': EndpointIsNotEnable,
+ 'ModelNotSupportStreamMode': ModelNotSupportStreamMode,
+ 'ReqTextExistRisk': ReqTextExistRisk,
+ 'RespTextExistRisk': RespTextExistRisk,
+ 'InvalidEndpointWithNoURL': InvalidEndpointWithNoURL,
+ 'ServiceNotOpen': ServiceNotOpen,
+}
+
+RateLimitErrors = {
+ 'EndpointRateLimitExceeded': EndpointRateLimitExceeded,
+ 'EndpointAccountRpmRateLimitExceeded': EndpointAccountRpmRateLimitExceeded,
+ 'EndpointAccountTpmRateLimitExceeded': EndpointAccountTpmRateLimitExceeded,
+}
+
+ServerUnavailableErrors = {
+ 'InternalServiceError': InternalServiceError,
+ 'EndpointIsPending': EndpointIsPending,
+ 'ServiceResourceWaitQueueFull': ServiceResourceWaitQueueFull,
+}
+
+ConnectionErrors = {
+ 'ClientSDKRequestError': ClientSDKRequestError,
+ 'RequestTimeout': RequestTimeout,
+ 'ServiceConnectionTimeout': ServiceConnectionTimeout,
+ 'ServiceConnectionRefused': ServiceConnectionRefused,
+ 'ServiceConnectionClosed': ServiceConnectionClosed,
+}
+
+ErrorCodeMap = {
+ **AuthErrors,
+ **BadRequestErrors,
+ **RateLimitErrors,
+ **ServerUnavailableErrors,
+ **ConnectionErrors,
+}
+
+
+def wrap_error(e: MaasException) -> Exception:
+ if ErrorCodeMap.get(e.code):
+ return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id)
+ return e
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py
new file mode 100644
index 0000000000..7a36d019e2
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py
@@ -0,0 +1,284 @@
+import logging
+from collections.abc import Generator
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
+from core.model_runtime.entities.message_entities import (
+ AssistantPromptMessage,
+ PromptMessage,
+ PromptMessageTool,
+ UserPromptMessage,
+)
+from core.model_runtime.entities.model_entities import (
+ AIModelEntity,
+ FetchFrom,
+ ModelPropertyKey,
+ ModelType,
+ ParameterRule,
+ ParameterType,
+)
+from core.model_runtime.errors.invoke import (
+ InvokeAuthorizationError,
+ InvokeBadRequestError,
+ InvokeConnectionError,
+ InvokeError,
+ InvokeRateLimitError,
+ InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient
+from core.model_runtime.model_providers.volcengine_maas.errors import (
+ AuthErrors,
+ BadRequestErrors,
+ ConnectionErrors,
+ RateLimitErrors,
+ ServerUnavailableErrors,
+)
+from core.model_runtime.model_providers.volcengine_maas.llm.models import ModelConfigs
+from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
+
+logger = logging.getLogger(__name__)
+
+
+class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
+ def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+ model_parameters: dict, tools: list[PromptMessageTool] | None = None,
+ stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
+ -> LLMResult | Generator:
+ return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
+
+ def validate_credentials(self, model: str, credentials: dict) -> None:
+ """
+ Validate credentials
+ """
+ # ping
+ client = MaaSClient.from_credential(credentials)
+ try:
+ client.chat(
+ {
+ 'max_new_tokens': 16,
+ 'temperature': 0.7,
+ 'top_p': 0.9,
+ 'top_k': 15,
+ },
+ [UserPromptMessage(content='ping\nAnswer: ')],
+ )
+ except MaasException as e:
+ raise CredentialsValidateFailedError(e.message)
+
+ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+ tools: list[PromptMessageTool] | None = None) -> int:
+ if len(prompt_messages) == 0:
+ return 0
+ return self._num_tokens_from_messages(prompt_messages)
+
+ def _num_tokens_from_messages(self, messages: list[PromptMessage]) -> int:
+ """
+ Calculate num tokens.
+
+ :param messages: messages
+ """
+ num_tokens = 0
+ messages_dict = [
+ MaaSClient.convert_prompt_message_to_maas_message(m) for m in messages]
+ for message in messages_dict:
+ for key, value in message.items():
+ num_tokens += self._get_num_tokens_by_gpt2(str(key))
+ num_tokens += self._get_num_tokens_by_gpt2(str(value))
+
+ return num_tokens
+
+ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+ model_parameters: dict, tools: list[PromptMessageTool] | None = None,
+ stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
+ -> LLMResult | Generator:
+
+ client = MaaSClient.from_credential(credentials)
+
+ req_params = ModelConfigs.get(
+ credentials['base_model_name'], {}).get('req_params', {}).copy()
+ if credentials.get('context_size'):
+ req_params['max_prompt_tokens'] = credentials.get('context_size')
+ if credentials.get('max_tokens'):
+ req_params['max_new_tokens'] = credentials.get('max_tokens')
+ if model_parameters.get('max_tokens'):
+ req_params['max_new_tokens'] = model_parameters.get('max_tokens')
+ if model_parameters.get('temperature'):
+ req_params['temperature'] = model_parameters.get('temperature')
+ if model_parameters.get('top_p'):
+ req_params['top_p'] = model_parameters.get('top_p')
+ if model_parameters.get('top_k'):
+ req_params['top_k'] = model_parameters.get('top_k')
+ if model_parameters.get('presence_penalty'):
+ req_params['presence_penalty'] = model_parameters.get(
+ 'presence_penalty')
+ if model_parameters.get('frequency_penalty'):
+ req_params['frequency_penalty'] = model_parameters.get(
+ 'frequency_penalty')
+ if stop:
+ req_params['stop'] = stop
+
+ resp = MaaSClient.wrap_exception(
+ lambda: client.chat(req_params, prompt_messages, stream))
+ if not stream:
+ return self._handle_chat_response(model, credentials, prompt_messages, resp)
+ return self._handle_stream_chat_response(model, credentials, prompt_messages, resp)
+
+ def _handle_stream_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: Generator) -> Generator:
+ for index, r in enumerate(resp):
+ choices = r['choices']
+ if not choices:
+ continue
+ choice = choices[0]
+ message = choice['message']
+ usage = None
+ if r.get('usage'):
+ usage = self._calc_usage(model, credentials, r['usage'])
+ yield LLMResultChunk(
+ model=model,
+ prompt_messages=prompt_messages,
+ delta=LLMResultChunkDelta(
+ index=index,
+ message=AssistantPromptMessage(
+ content=message['content'] if message['content'] else '',
+ tool_calls=[]
+ ),
+ usage=usage,
+ finish_reason=choice.get('finish_reason'),
+ ),
+ )
+
+ def _handle_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: dict) -> LLMResult:
+ choices = resp['choices']
+ if not choices:
+ return
+ choice = choices[0]
+ message = choice['message']
+
+ return LLMResult(
+ model=model,
+ prompt_messages=prompt_messages,
+ message=AssistantPromptMessage(
+ content=message['content'] if message['content'] else '',
+ tool_calls=[],
+ ),
+ usage=self._calc_usage(model, credentials, resp['usage']),
+ )
+
+ def _calc_usage(self, model: str, credentials: dict, usage: dict) -> LLMUsage:
+ return self._calc_response_usage(model=model, credentials=credentials,
+ prompt_tokens=usage['prompt_tokens'],
+ completion_tokens=usage['completion_tokens']
+ )
+
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
+ """
+ used to define customizable model schema
+ """
+ max_tokens = ModelConfigs.get(
+ credentials['base_model_name'], {}).get('req_params', {}).get('max_new_tokens')
+ if credentials.get('max_tokens'):
+ max_tokens = int(credentials.get('max_tokens'))
+ rules = [
+ ParameterRule(
+ name='temperature',
+ type=ParameterType.FLOAT,
+ use_template='temperature',
+ label=I18nObject(
+ zh_Hans='温度',
+ en_US='Temperature'
+ )
+ ),
+ ParameterRule(
+ name='top_p',
+ type=ParameterType.FLOAT,
+ use_template='top_p',
+ label=I18nObject(
+ zh_Hans='Top P',
+ en_US='Top P'
+ )
+ ),
+ ParameterRule(
+ name='top_k',
+ type=ParameterType.INT,
+ min=1,
+ default=1,
+ label=I18nObject(
+ zh_Hans='Top K',
+ en_US='Top K'
+ )
+ ),
+ ParameterRule(
+ name='presence_penalty',
+ type=ParameterType.FLOAT,
+ use_template='presence_penalty',
+ label={
+ 'en_US': 'Presence Penalty',
+ 'zh_Hans': '存在惩罚',
+ },
+ min=-2.0,
+ max=2.0,
+ ),
+ ParameterRule(
+ name='frequency_penalty',
+ type=ParameterType.FLOAT,
+ use_template='frequency_penalty',
+ label={
+ 'en_US': 'Frequency Penalty',
+ 'zh_Hans': '频率惩罚',
+ },
+ min=-2.0,
+ max=2.0,
+ ),
+ ParameterRule(
+ name='max_tokens',
+ type=ParameterType.INT,
+ use_template='max_tokens',
+ min=1,
+ max=max_tokens,
+ default=512,
+ label=I18nObject(
+ zh_Hans='最大生成长度',
+ en_US='Max Tokens'
+ )
+ ),
+ ]
+
+ model_properties = ModelConfigs.get(
+ credentials['base_model_name'], {}).get('model_properties', {}).copy()
+ if credentials.get('mode'):
+ model_properties[ModelPropertyKey.MODE] = credentials.get('mode')
+ if credentials.get('context_size'):
+ model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
+ credentials.get('context_size', 4096))
+ entity = AIModelEntity(
+ model=model,
+ label=I18nObject(
+ en_US=model
+ ),
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+ model_type=ModelType.LLM,
+ model_properties=model_properties,
+ parameter_rules=rules
+ )
+
+ return entity
+
+ @property
+ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ """
+ Map model invoke error to unified error
+ The key is the error type thrown to the caller
+ The value is the error type thrown by the model,
+ which needs to be converted into a unified error type for the caller.
+
+ :return: Invoke error mapping
+ """
+ return {
+ InvokeConnectionError: ConnectionErrors.values(),
+ InvokeServerUnavailableError: ServerUnavailableErrors.values(),
+ InvokeRateLimitError: RateLimitErrors.values(),
+ InvokeAuthorizationError: AuthErrors.values(),
+ InvokeBadRequestError: BadRequestErrors.values(),
+ }
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py
new file mode 100644
index 0000000000..d022f0069b
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py
@@ -0,0 +1,12 @@
+ModelConfigs = {
+ 'Skylark2-pro-4k': {
+ 'req_params': {
+ 'max_prompt_tokens': 4096,
+ 'max_new_tokens': 4000,
+ },
+ 'model_properties': {
+ 'context_size': 4096,
+ 'mode': 'chat',
+ }
+ }
+}
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py
new file mode 100644
index 0000000000..d63399aec2
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py
@@ -0,0 +1,132 @@
+import time
+from typing import Optional
+
+from core.model_runtime.entities.model_entities import PriceType
+from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
+from core.model_runtime.errors.invoke import (
+ InvokeAuthorizationError,
+ InvokeBadRequestError,
+ InvokeConnectionError,
+ InvokeError,
+ InvokeRateLimitError,
+ InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient
+from core.model_runtime.model_providers.volcengine_maas.errors import (
+ AuthErrors,
+ BadRequestErrors,
+ ConnectionErrors,
+ RateLimitErrors,
+ ServerUnavailableErrors,
+)
+from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
+
+
+class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
+ """
+ Model class for VolcengineMaaS text embedding model.
+ """
+
+ def _invoke(self, model: str, credentials: dict,
+ texts: list[str], user: Optional[str] = None) \
+ -> TextEmbeddingResult:
+ """
+ Invoke text embedding model
+
+ :param model: model name
+ :param credentials: model credentials
+ :param texts: texts to embed
+ :param user: unique user id
+ :return: embeddings result
+ """
+ client = MaaSClient.from_credential(credentials)
+ resp = MaaSClient.wrap_exception(lambda: client.embeddings(texts))
+
+ usage = self._calc_response_usage(
+ model=model, credentials=credentials, tokens=resp['total_tokens'])
+
+ result = TextEmbeddingResult(
+ model=model,
+ embeddings=[v['embedding'] for v in resp['data']],
+ usage=usage
+ )
+
+ return result
+
+ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+ """
+ Get number of tokens for given prompt messages
+
+ :param model: model name
+ :param credentials: model credentials
+ :param texts: texts to embed
+ :return:
+ """
+ num_tokens = 0
+ for text in texts:
+ # use GPT2Tokenizer to get num tokens
+ num_tokens += self._get_num_tokens_by_gpt2(text)
+ return num_tokens
+
+ def validate_credentials(self, model: str, credentials: dict) -> None:
+ """
+ Validate model credentials
+
+ :param model: model name
+ :param credentials: model credentials
+ :return:
+ """
+ try:
+ self._invoke(model=model, credentials=credentials, texts=['ping'])
+ except MaasException as e:
+ raise CredentialsValidateFailedError(e.message)
+
+ @property
+ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ """
+ Map model invoke error to unified error
+ The key is the error type thrown to the caller
+ The value is the error type thrown by the model,
+ which needs to be converted into a unified error type for the caller.
+
+ :return: Invoke error mapping
+ """
+ return {
+ InvokeConnectionError: ConnectionErrors.values(),
+ InvokeServerUnavailableError: ServerUnavailableErrors.values(),
+ InvokeRateLimitError: RateLimitErrors.values(),
+ InvokeAuthorizationError: AuthErrors.values(),
+ InvokeBadRequestError: BadRequestErrors.values(),
+ }
+
+ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
+ """
+ Calculate response usage
+
+ :param model: model name
+ :param credentials: model credentials
+ :param tokens: input tokens
+ :return: usage
+ """
+ # get input price info
+ input_price_info = self.get_price(
+ model=model,
+ credentials=credentials,
+ price_type=PriceType.INPUT,
+ tokens=tokens
+ )
+
+ # transform usage
+ usage = EmbeddingUsage(
+ tokens=tokens,
+ total_tokens=tokens,
+ unit_price=input_price_info.unit_price,
+ price_unit=input_price_info.unit,
+ total_price=input_price_info.total_amount,
+ currency=input_price_info.currency,
+ latency=time.perf_counter() - self.started_at
+ )
+
+ return usage
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/__init__.py
new file mode 100644
index 0000000000..64f342f16e
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/__init__.py
@@ -0,0 +1,4 @@
+from .common import ChatRole
+from .maas import MaasException, MaasService
+
+__all__ = ['MaasService', 'ChatRole', 'MaasException']
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/__init__.py
new file mode 100644
index 0000000000..8b13789179
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/__init__.py
@@ -0,0 +1 @@
+
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py
new file mode 100644
index 0000000000..48110f16d7
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py
@@ -0,0 +1,144 @@
+# coding : utf-8
+import datetime
+
+import pytz
+
+from .util import Util
+
+
+class MetaData:
+ def __init__(self):
+ self.algorithm = ''
+ self.credential_scope = ''
+ self.signed_headers = ''
+ self.date = ''
+ self.region = ''
+ self.service = ''
+
+ def set_date(self, date):
+ self.date = date
+
+ def set_service(self, service):
+ self.service = service
+
+ def set_region(self, region):
+ self.region = region
+
+ def set_algorithm(self, algorithm):
+ self.algorithm = algorithm
+
+ def set_credential_scope(self, credential_scope):
+ self.credential_scope = credential_scope
+
+ def set_signed_headers(self, signed_headers):
+ self.signed_headers = signed_headers
+
+
+class SignResult:
+ def __init__(self):
+ self.xdate = ''
+ self.xCredential = ''
+ self.xAlgorithm = ''
+ self.xSignedHeaders = ''
+ self.xSignedQueries = ''
+ self.xSignature = ''
+ self.xContextSha256 = ''
+ self.xSecurityToken = ''
+
+ self.authorization = ''
+
+ def __str__(self):
+ return '\n'.join(['{}:{}'.format(*item) for item in self.__dict__.items()])
+
+
+class Credentials:
+ def __init__(self, ak, sk, service, region, session_token=''):
+ self.ak = ak
+ self.sk = sk
+ self.service = service
+ self.region = region
+ self.session_token = session_token
+
+ def set_ak(self, ak):
+ self.ak = ak
+
+ def set_sk(self, sk):
+ self.sk = sk
+
+ def set_session_token(self, session_token):
+ self.session_token = session_token
+
+
+class Signer:
+ @staticmethod
+ def sign(request, credentials):
+ if request.path == '':
+ request.path = '/'
+ if request.method != 'GET' and not ('Content-Type' in request.headers):
+ request.headers['Content-Type'] = 'application/x-www-form-urlencoded; charset=utf-8'
+
+ format_date = Signer.get_current_format_date()
+ request.headers['X-Date'] = format_date
+ if credentials.session_token != '':
+ request.headers['X-Security-Token'] = credentials.session_token
+
+ md = MetaData()
+ md.set_algorithm('HMAC-SHA256')
+ md.set_service(credentials.service)
+ md.set_region(credentials.region)
+ md.set_date(format_date[:8])
+
+ hashed_canon_req = Signer.hashed_canonical_request_v4(request, md)
+ md.set_credential_scope('/'.join([md.date, md.region, md.service, 'request']))
+
+ signing_str = '\n'.join([md.algorithm, format_date, md.credential_scope, hashed_canon_req])
+ signing_key = Signer.get_signing_secret_key_v4(credentials.sk, md.date, md.region, md.service)
+ sign = Util.to_hex(Util.hmac_sha256(signing_key, signing_str))
+ request.headers['Authorization'] = Signer.build_auth_header_v4(sign, md, credentials)
+ return
+
+ @staticmethod
+ def hashed_canonical_request_v4(request, meta):
+ body_hash = Util.sha256(request.body)
+ request.headers['X-Content-Sha256'] = body_hash
+
+ signed_headers = dict()
+ for key in request.headers:
+ if key in ['Content-Type', 'Content-Md5', 'Host'] or key.startswith('X-'):
+ signed_headers[key.lower()] = request.headers[key]
+
+ if 'host' in signed_headers:
+ v = signed_headers['host']
+ if v.find(':') != -1:
+ split = v.split(':')
+ port = split[1]
+ if str(port) == '80' or str(port) == '443':
+ signed_headers['host'] = split[0]
+
+ signed_str = ''
+ for key in sorted(signed_headers.keys()):
+ signed_str += key + ':' + signed_headers[key] + '\n'
+
+ meta.set_signed_headers(';'.join(sorted(signed_headers.keys())))
+
+ canonical_request = '\n'.join(
+ [request.method, Util.norm_uri(request.path), Util.norm_query(request.query), signed_str,
+ meta.signed_headers, body_hash])
+
+ return Util.sha256(canonical_request)
+
+ @staticmethod
+ def get_signing_secret_key_v4(sk, date, region, service):
+ date = Util.hmac_sha256(bytes(sk, encoding='utf-8'), date)
+ region = Util.hmac_sha256(date, region)
+ service = Util.hmac_sha256(region, service)
+ return Util.hmac_sha256(service, 'request')
+
+ @staticmethod
+ def build_auth_header_v4(signature, meta, credentials):
+ credential = credentials.ak + '/' + meta.credential_scope
+ return meta.algorithm + ' Credential=' + credential + ', SignedHeaders=' + meta.signed_headers + ', Signature=' + signature
+
+ @staticmethod
+ def get_current_format_date():
+ return datetime.datetime.now(tz=pytz.timezone('UTC')).strftime("%Y%m%dT%H%M%SZ")
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py
new file mode 100644
index 0000000000..03734ec54f
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py
@@ -0,0 +1,207 @@
+import json
+from collections import OrderedDict
+from urllib.parse import urlencode
+
+import requests
+
+from .auth import Signer
+
+VERSION = 'v1.0.137'
+
+
+class Service:
+ def __init__(self, service_info, api_info):
+ self.service_info = service_info
+ self.api_info = api_info
+ self.session = requests.session()
+
+ def set_ak(self, ak):
+ self.service_info.credentials.set_ak(ak)
+
+ def set_sk(self, sk):
+ self.service_info.credentials.set_sk(sk)
+
+ def set_session_token(self, session_token):
+ self.service_info.credentials.set_session_token(session_token)
+
+ def set_host(self, host):
+ self.service_info.host = host
+
+ def set_scheme(self, scheme):
+ self.service_info.scheme = scheme
+
+ def get(self, api, params, doseq=0):
+ if not (api in self.api_info):
+ raise Exception("no such api")
+ api_info = self.api_info[api]
+
+ r = self.prepare_request(api_info, params, doseq)
+
+ Signer.sign(r, self.service_info.credentials)
+
+ url = r.build(doseq)
+ resp = self.session.get(url, headers=r.headers,
+ timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout))
+ if resp.status_code == 200:
+ return resp.text
+ else:
+ raise Exception(resp.text)
+
+ def post(self, api, params, form):
+ if not (api in self.api_info):
+ raise Exception("no such api")
+ api_info = self.api_info[api]
+ r = self.prepare_request(api_info, params)
+ r.headers['Content-Type'] = 'application/x-www-form-urlencoded'
+ r.form = self.merge(api_info.form, form)
+ r.body = urlencode(r.form, True)
+ Signer.sign(r, self.service_info.credentials)
+
+ url = r.build()
+
+ resp = self.session.post(url, headers=r.headers, data=r.form,
+ timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout))
+ if resp.status_code == 200:
+ return resp.text
+ else:
+ raise Exception(resp.text)
+
+ def json(self, api, params, body):
+ if not (api in self.api_info):
+ raise Exception("no such api")
+ api_info = self.api_info[api]
+ r = self.prepare_request(api_info, params)
+ r.headers['Content-Type'] = 'application/json'
+ r.body = body
+
+ Signer.sign(r, self.service_info.credentials)
+
+ url = r.build()
+ resp = self.session.post(url, headers=r.headers, data=r.body,
+ timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout))
+ if resp.status_code == 200:
+ return json.dumps(resp.json())
+ else:
+ raise Exception(resp.text.encode("utf-8"))
+
+ def put(self, url, file_path, headers):
+ with open(file_path, 'rb') as f:
+ resp = self.session.put(url, headers=headers, data=f)
+ if resp.status_code == 200:
+ return True, resp.text.encode("utf-8")
+ else:
+ return False, resp.text.encode("utf-8")
+
+ def put_data(self, url, data, headers):
+ resp = self.session.put(url, headers=headers, data=data)
+ if resp.status_code == 200:
+ return True, resp.text.encode("utf-8")
+ else:
+ return False, resp.text.encode("utf-8")
+
+ def prepare_request(self, api_info, params, doseq=0):
+ for key in params:
+ if type(params[key]) == int or type(params[key]) == float or type(params[key]) == bool:
+ params[key] = str(params[key])
+ elif type(params[key]) == list:
+ if not doseq:
+ params[key] = ','.join(params[key])
+
+ connection_timeout = self.service_info.connection_timeout
+ socket_timeout = self.service_info.socket_timeout
+
+ r = Request()
+ r.set_schema(self.service_info.scheme)
+ r.set_method(api_info.method)
+ r.set_connection_timeout(connection_timeout)
+ r.set_socket_timeout(socket_timeout)
+
+ headers = self.merge(api_info.header, self.service_info.header)
+ headers['Host'] = self.service_info.host
+ headers['User-Agent'] = 'volc-sdk-python/' + VERSION
+ r.set_headers(headers)
+
+ query = self.merge(api_info.query, params)
+ r.set_query(query)
+
+ r.set_host(self.service_info.host)
+ r.set_path(api_info.path)
+
+ return r
+
+ @staticmethod
+ def merge(param1, param2):
+ od = OrderedDict()
+ for key in param1:
+ od[key] = param1[key]
+
+ for key in param2:
+ od[key] = param2[key]
+
+ return od
+
+
+class Request:
+ def __init__(self):
+ self.schema = ''
+ self.method = ''
+ self.host = ''
+ self.path = ''
+ self.headers = OrderedDict()
+ self.query = OrderedDict()
+ self.body = ''
+ self.form = dict()
+ self.connection_timeout = 0
+ self.socket_timeout = 0
+
+ def set_schema(self, schema):
+ self.schema = schema
+
+ def set_method(self, method):
+ self.method = method
+
+ def set_host(self, host):
+ self.host = host
+
+ def set_path(self, path):
+ self.path = path
+
+ def set_headers(self, headers):
+ self.headers = headers
+
+ def set_query(self, query):
+ self.query = query
+
+ def set_body(self, body):
+ self.body = body
+
+ def set_connection_timeout(self, connection_timeout):
+ self.connection_timeout = connection_timeout
+
+ def set_socket_timeout(self, socket_timeout):
+ self.socket_timeout = socket_timeout
+
+ def build(self, doseq=0):
+ return self.schema + '://' + self.host + self.path + '?' + urlencode(self.query, doseq)
+
+
+class ServiceInfo:
+ def __init__(self, host, header, credentials, connection_timeout, socket_timeout, scheme='http'):
+ self.host = host
+ self.header = header
+ self.credentials = credentials
+ self.connection_timeout = connection_timeout
+ self.socket_timeout = socket_timeout
+ self.scheme = scheme
+
+
+class ApiInfo:
+ def __init__(self, method, path, query, form, header):
+ self.method = method
+ self.path = path
+ self.query = query
+ self.form = form
+ self.header = header
+
+ def __str__(self):
+ return 'method: ' + self.method + ', path: ' + self.path
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/util.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/util.py
new file mode 100644
index 0000000000..7eb5fdfa91
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/util.py
@@ -0,0 +1,43 @@
+import hashlib
+import hmac
+from functools import reduce
+from urllib.parse import quote
+
+
+class Util:
+ @staticmethod
+ def norm_uri(path):
+ return quote(path).replace('%2F', '/').replace('+', '%20')
+
+ @staticmethod
+ def norm_query(params):
+ query = ''
+ for key in sorted(params.keys()):
+ if type(params[key]) == list:
+ for k in params[key]:
+ query = query + quote(key, safe='-_.~') + '=' + quote(k, safe='-_.~') + '&'
+ else:
+ query = query + quote(key, safe='-_.~') + '=' + quote(params[key], safe='-_.~') + '&'
+ query = query[:-1]
+ return query.replace('+', '%20')
+
+ @staticmethod
+ def hmac_sha256(key, content):
+ return hmac.new(key, bytes(content, encoding='utf-8'), hashlib.sha256).digest()
+
+ @staticmethod
+ def sha256(content):
+ if isinstance(content, str) is True:
+ return hashlib.sha256(content.encode('utf-8')).hexdigest()
+ else:
+ return hashlib.sha256(content).hexdigest()
+
+ @staticmethod
+ def to_hex(content):
+ lst = []
+ for ch in content:
+ hv = hex(ch).replace('0x', '')
+ if len(hv) == 1:
+ hv = '0' + hv
+ lst.append(hv)
+ return reduce(lambda x, y: x + y, lst)
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/common.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/common.py
new file mode 100644
index 0000000000..8b14d026d9
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/common.py
@@ -0,0 +1,79 @@
+import json
+import random
+from datetime import datetime
+
+
+class ChatRole:
+ USER = "user"
+ ASSISTANT = "assistant"
+ SYSTEM = "system"
+ FUNCTION = "function"
+
+
+class _Dict(dict):
+ __setattr__ = dict.__setitem__
+ __getattr__ = dict.__getitem__
+
+ def __missing__(self, key):
+ return None
+
+
+def dict_to_object(dict_obj):
+ # 支持嵌套类型
+ if isinstance(dict_obj, list):
+ insts = []
+ for i in dict_obj:
+ insts.append(dict_to_object(i))
+ return insts
+
+ if isinstance(dict_obj, dict):
+ inst = _Dict()
+ for k, v in dict_obj.items():
+ inst[k] = dict_to_object(v)
+ return inst
+
+ return dict_obj
+
+
+def json_to_object(json_str, req_id=None):
+ obj = dict_to_object(json.loads(json_str))
+ if obj and isinstance(obj, dict) and req_id:
+ obj["req_id"] = req_id
+ return obj
+
+
+def gen_req_id():
+ return datetime.now().strftime("%Y%m%d%H%M%S") + format(
+ random.randint(0, 2 ** 64 - 1), "020X"
+ )
+
+
+class SSEDecoder:
+ def __init__(self, source):
+ self.source = source
+
+ def _read(self):
+ data = b''
+ for chunk in self.source:
+ for line in chunk.splitlines(True):
+ data += line
+ if data.endswith((b'\r\r', b'\n\n', b'\r\n\r\n')):
+ yield data
+ data = b''
+ if data:
+ yield data
+
+ def next(self):
+ for chunk in self._read():
+ for line in chunk.splitlines():
+ # skip comment
+ if line.startswith(b':'):
+ continue
+
+ if b':' in line:
+ field, value = line.split(b':', 1)
+ else:
+ field, value = line, b''
+
+ if field == b'data' and len(value) > 0:
+ yield value
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/maas.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/maas.py
new file mode 100644
index 0000000000..3cbe9d9f09
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/maas.py
@@ -0,0 +1,213 @@
+import copy
+import json
+from collections.abc import Iterator
+
+from .base.auth import Credentials, Signer
+from .base.service import ApiInfo, Service, ServiceInfo
+from .common import SSEDecoder, dict_to_object, gen_req_id, json_to_object
+
+
+class MaasService(Service):
+ def __init__(self, host, region, connection_timeout=60, socket_timeout=60):
+ service_info = self.get_service_info(
+ host, region, connection_timeout, socket_timeout
+ )
+ self._apikey = None
+ api_info = self.get_api_info()
+ super().__init__(service_info, api_info)
+
+ def set_apikey(self, apikey):
+ self._apikey = apikey
+
+ @staticmethod
+ def get_service_info(host, region, connection_timeout, socket_timeout):
+ service_info = ServiceInfo(
+ host,
+ {"Accept": "application/json"},
+ Credentials("", "", "ml_maas", region),
+ connection_timeout,
+ socket_timeout,
+ "https",
+ )
+ return service_info
+
+ @staticmethod
+ def get_api_info():
+ api_info = {
+ "chat": ApiInfo("POST", "/api/v2/endpoint/{endpoint_id}/chat", {}, {}, {}),
+ "embeddings": ApiInfo(
+ "POST", "/api/v2/endpoint/{endpoint_id}/embeddings", {}, {}, {}
+ ),
+ }
+ return api_info
+
+ def chat(self, endpoint_id, req):
+ req["stream"] = False
+ return self._request(endpoint_id, "chat", req)
+
+ def stream_chat(self, endpoint_id, req):
+ req_id = gen_req_id()
+ self._validate("chat", req_id)
+ apikey = self._apikey
+
+ try:
+ req["stream"] = True
+ res = self._call(
+ endpoint_id, "chat", req_id, {}, json.dumps(req).encode("utf-8"), apikey, stream=True
+ )
+
+ decoder = SSEDecoder(res)
+
+ def iter_fn():
+ for data in decoder.next():
+ if data == b"[DONE]":
+ return
+
+ try:
+ res = json_to_object(
+ str(data, encoding="utf-8"), req_id=req_id)
+ except Exception:
+ raise
+
+ if res.error is not None and res.error.code_n != 0:
+ raise MaasException(
+ res.error.code_n,
+ res.error.code,
+ res.error.message,
+ req_id,
+ )
+ yield res
+
+ return iter_fn()
+ except MaasException:
+ raise
+ except Exception as e:
+ raise new_client_sdk_request_error(str(e))
+
+ def embeddings(self, endpoint_id, req):
+ return self._request(endpoint_id, "embeddings", req)
+
+ def _request(self, endpoint_id, api, req, params={}):
+ req_id = gen_req_id()
+
+ self._validate(api, req_id)
+
+ apikey = self._apikey
+
+ try:
+ res = self._call(endpoint_id, api, req_id, params,
+ json.dumps(req).encode("utf-8"), apikey)
+ resp = dict_to_object(res.json())
+ if resp and isinstance(resp, dict):
+ resp["req_id"] = req_id
+ return resp
+
+ except MaasException as e:
+ raise e
+ except Exception as e:
+ raise new_client_sdk_request_error(str(e), req_id)
+
+ def _validate(self, api, req_id):
+ credentials_exist = (
+ self.service_info.credentials is not None and
+ self.service_info.credentials.sk is not None and
+ self.service_info.credentials.ak is not None
+ )
+
+ if not self._apikey and not credentials_exist:
+ raise new_client_sdk_request_error("no valid credential", req_id)
+
+ if not (api in self.api_info):
+ raise new_client_sdk_request_error("no such api", req_id)
+
+ def _call(self, endpoint_id, api, req_id, params, body, apikey=None, stream=False):
+ api_info = copy.deepcopy(self.api_info[api])
+ api_info.path = api_info.path.format(endpoint_id=endpoint_id)
+
+ r = self.prepare_request(api_info, params)
+ r.headers["x-tt-logid"] = req_id
+ r.headers["Content-Type"] = "application/json"
+ r.body = body
+
+ if apikey is None:
+ Signer.sign(r, self.service_info.credentials)
+ elif apikey is not None:
+ r.headers["Authorization"] = "Bearer " + apikey
+
+ url = r.build()
+ res = self.session.post(
+ url,
+ headers=r.headers,
+ data=r.body,
+ timeout=(
+ self.service_info.connection_timeout,
+ self.service_info.socket_timeout,
+ ),
+ stream=stream,
+ )
+
+ if res.status_code != 200:
+ raw = res.text.encode()
+ res.close()
+ try:
+ resp = json_to_object(
+ str(raw, encoding="utf-8"), req_id=req_id)
+ except Exception:
+ raise new_client_sdk_request_error(raw, req_id)
+
+ if resp.error:
+ raise MaasException(
+ resp.error.code_n, resp.error.code, resp.error.message, req_id
+ )
+ else:
+ raise new_client_sdk_request_error(resp, req_id)
+
+ return res
+
+
+class MaasException(Exception):
+ def __init__(self, code_n, code, message, req_id):
+ self.code_n = code_n
+ self.code = code
+ self.message = message
+ self.req_id = req_id
+
+ def __str__(self):
+ return ("Detailed exception information is listed below.\n" +
+ "req_id: {}\n" +
+ "code_n: {}\n" +
+ "code: {}\n" +
+ "message: {}").format(self.req_id, self.code_n, self.code, self.message)
+
+
+def new_client_sdk_request_error(raw, req_id=""):
+ return MaasException(1709701, "ClientSDKRequestError", "MaaS SDK request error: {}".format(raw), req_id)
+
+
+class BinaryResponseContent:
+ def __init__(self, response, request_id) -> None:
+ self.response = response
+ self.request_id = request_id
+
+ def stream_to_file(
+ self,
+ file: str
+ ) -> None:
+ is_first = True
+ error_bytes = b''
+ with open(file, mode="wb") as f:
+ for data in self.response:
+ if len(error_bytes) > 0 or (is_first and "\"error\":" in str(data)):
+ error_bytes += data
+ else:
+ f.write(data)
+
+ if len(error_bytes) > 0:
+ resp = json_to_object(
+ str(error_bytes, encoding="utf-8"), req_id=self.request_id)
+ raise MaasException(
+ resp.error.code_n, resp.error.code, resp.error.message, self.request_id
+ )
+
+ def iter_bytes(self) -> Iterator[bytes]:
+ yield from self.response
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.py b/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.py
new file mode 100644
index 0000000000..10f9be2d08
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.py
@@ -0,0 +1,10 @@
+import logging
+
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+class VolcengineMaaSProvider(ModelProvider):
+ def validate_provider_credentials(self, credentials: dict) -> None:
+ pass
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml b/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml
new file mode 100644
index 0000000000..4f299ecae0
--- /dev/null
+++ b/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml
@@ -0,0 +1,151 @@
+provider: volcengine_maas
+label:
+ en_US: Volcengine
+description:
+ en_US: Volcengine MaaS models.
+icon_small:
+ en_US: icon_s_en.svg
+icon_large:
+ en_US: icon_l_en.svg
+ zh_Hans: icon_l_zh.svg
+background: "#F9FAFB"
+help:
+ title:
+ en_US: Get your Access Key and Secret Access Key from Volcengine Console
+ url:
+ en_US: https://console.volcengine.com/iam/keymanage/
+supported_model_types:
+ - llm
+ - text-embedding
+configurate_methods:
+ - customizable-model
+model_credential_schema:
+ model:
+ label:
+ en_US: Model Name
+ zh_Hans: 模型名称
+ placeholder:
+ en_US: Enter your Model Name
+ zh_Hans: 输入模型名称
+ credential_form_schemas:
+ - variable: volc_access_key_id
+ required: true
+ label:
+ en_US: Access Key
+ zh_Hans: Access Key
+ type: secret-input
+ placeholder:
+ en_US: Enter your Access Key
+ zh_Hans: 输入您的 Access Key
+ - variable: volc_secret_access_key
+ required: true
+ label:
+ en_US: Secret Access Key
+ zh_Hans: Secret Access Key
+ type: secret-input
+ placeholder:
+ en_US: Enter your Secret Access Key
+ zh_Hans: 输入您的 Secret Access Key
+ - variable: volc_region
+ required: true
+ label:
+ en_US: Volcengine Region
+ zh_Hans: 火山引擎地区
+ type: text-input
+ default: cn-beijing
+ placeholder:
+ en_US: Enter Volcengine Region
+ zh_Hans: 输入火山引擎地域
+ - variable: api_endpoint_host
+ required: true
+ label:
+ en_US: API Endpoint Host
+ zh_Hans: API Endpoint Host
+ type: text-input
+ default: maas-api.ml-platform-cn-beijing.volces.com
+ placeholder:
+ en_US: Enter your API Endpoint Host
+ zh_Hans: 输入 API Endpoint Host
+ - variable: endpoint_id
+ required: true
+ label:
+ en_US: Endpoint ID
+ zh_Hans: Endpoint ID
+ type: text-input
+ placeholder:
+ en_US: Enter your Endpoint ID
+ zh_Hans: 输入您的 Endpoint ID
+ - variable: base_model_name
+ show_on:
+ - variable: __model_type
+ value: llm
+ label:
+ en_US: Base Model
+ zh_Hans: 基础模型
+ type: select
+ required: true
+ options:
+ - label:
+ en_US: Skylark2-pro-4k
+ value: Skylark2-pro-4k
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: Custom
+ zh_Hans: 自定义
+ value: Custom
+ - variable: mode
+ required: true
+ show_on:
+ - variable: __model_type
+ value: llm
+ - variable: base_model_name
+ value: Custom
+ label:
+ zh_Hans: 模型类型
+ en_US: Completion Mode
+ type: select
+ default: chat
+ placeholder:
+ zh_Hans: 选择对话类型
+ en_US: Select Completion Mode
+ options:
+ - value: completion
+ label:
+ en_US: Completion
+ zh_Hans: 补全
+ - value: chat
+ label:
+ en_US: Chat
+ zh_Hans: 对话
+ - variable: context_size
+ required: true
+ show_on:
+ - variable: __model_type
+ value: llm
+ - variable: base_model_name
+ value: Custom
+ label:
+ zh_Hans: 模型上下文长度
+ en_US: Model Context Size
+ type: text-input
+ default: '4096'
+ placeholder:
+ zh_Hans: 输入您的模型上下文长度
+ en_US: Enter your Model Context Size
+ - variable: max_tokens
+ required: true
+ show_on:
+ - variable: __model_type
+ value: llm
+ - variable: base_model_name
+ value: Custom
+ label:
+ zh_Hans: 最大 token 上限
+ en_US: Upper Bound for Max Tokens
+ default: '4096'
+ type: text-input
+ placeholder:
+ zh_Hans: 输入您的模型最大 token 上限
+ en_US: Enter your model Upper Bound for Max Tokens
diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example
index 9cd04b4764..f29e5ef4d6 100644
--- a/api/tests/integration_tests/.env.example
+++ b/api/tests/integration_tests/.env.example
@@ -73,4 +73,10 @@ MOCK_SWITCH=false
# CODE EXECUTION CONFIGURATION
CODE_EXECUTION_ENDPOINT=
-CODE_EXECUTION_API_KEY=
\ No newline at end of file
+CODE_EXECUTION_API_KEY=
+
+# Volcengine MaaS Credentials
+VOLC_API_KEY=
+VOLC_SECRET_KEY=
+VOLC_MODEL_ENDPOINT_ID=
+VOLC_EMBEDDING_ENDPOINT_ID=
\ No newline at end of file
diff --git a/api/tests/integration_tests/model_runtime/volcengine_maas/__init__.py b/api/tests/integration_tests/model_runtime/volcengine_maas/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py b/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py
new file mode 100644
index 0000000000..61e9f704af
--- /dev/null
+++ b/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py
@@ -0,0 +1,81 @@
+import os
+
+import pytest
+
+from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.volcengine_maas.text_embedding.text_embedding import (
+ VolcengineMaaSTextEmbeddingModel,
+)
+
+
+def test_validate_credentials():
+ model = VolcengineMaaSTextEmbeddingModel()
+
+ with pytest.raises(CredentialsValidateFailedError):
+ model.validate_credentials(
+ model='NOT IMPORTANT',
+ credentials={
+ 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+ 'volc_region': 'cn-beijing',
+ 'volc_access_key_id': 'INVALID',
+ 'volc_secret_access_key': 'INVALID',
+ 'endpoint_id': 'INVALID',
+ }
+ )
+
+ model.validate_credentials(
+ model='NOT IMPORTANT',
+ credentials={
+ 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+ 'volc_region': 'cn-beijing',
+ 'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
+ 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
+ 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
+ },
+ )
+
+
+def test_invoke_model():
+ model = VolcengineMaaSTextEmbeddingModel()
+
+ result = model.invoke(
+ model='NOT IMPORTANT',
+ credentials={
+ 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+ 'volc_region': 'cn-beijing',
+ 'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
+ 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
+ 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
+ },
+ texts=[
+ "hello",
+ "world"
+ ],
+ user="abc-123"
+ )
+
+ assert isinstance(result, TextEmbeddingResult)
+ assert len(result.embeddings) == 2
+ assert result.usage.total_tokens > 0
+
+
+def test_get_num_tokens():
+ model = VolcengineMaaSTextEmbeddingModel()
+
+ num_tokens = model.get_num_tokens(
+ model='NOT IMPORTANT',
+ credentials={
+ 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+ 'volc_region': 'cn-beijing',
+ 'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
+ 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
+ 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
+ },
+ texts=[
+ "hello",
+ "world"
+ ]
+ )
+
+ assert num_tokens == 2
diff --git a/api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py b/api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py
new file mode 100644
index 0000000000..63835d0263
--- /dev/null
+++ b/api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py
@@ -0,0 +1,131 @@
+import os
+from collections.abc import Generator
+
+import pytest
+
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.volcengine_maas.llm.llm import VolcengineMaaSLargeLanguageModel
+
+
+def test_validate_credentials_for_chat_model():
+ model = VolcengineMaaSLargeLanguageModel()
+
+ with pytest.raises(CredentialsValidateFailedError):
+ model.validate_credentials(
+ model='NOT IMPORTANT',
+ credentials={
+ 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+ 'volc_region': 'cn-beijing',
+ 'volc_access_key_id': 'INVALID',
+ 'volc_secret_access_key': 'INVALID',
+ 'endpoint_id': 'INVALID',
+ }
+ )
+
+ model.validate_credentials(
+ model='NOT IMPORTANT',
+ credentials={
+ 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+ 'volc_region': 'cn-beijing',
+ 'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
+ 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
+ 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'),
+ }
+ )
+
+
+def test_invoke_model():
+ model = VolcengineMaaSLargeLanguageModel()
+
+ response = model.invoke(
+ model='NOT IMPORTANT',
+ credentials={
+ 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+ 'volc_region': 'cn-beijing',
+ 'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
+ 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
+ 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'),
+ 'base_model_name': 'Skylark2-pro-4k',
+ },
+ prompt_messages=[
+ UserPromptMessage(
+ content='Hello World!'
+ )
+ ],
+ model_parameters={
+ 'temperature': 0.7,
+ 'top_p': 1.0,
+ 'top_k': 1,
+ },
+ stop=['you'],
+ user="abc-123",
+ stream=False
+ )
+
+ assert isinstance(response, LLMResult)
+ assert len(response.message.content) > 0
+ assert response.usage.total_tokens > 0
+
+
+def test_invoke_stream_model():
+ model = VolcengineMaaSLargeLanguageModel()
+
+ response = model.invoke(
+ model='NOT IMPORTANT',
+ credentials={
+ 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+ 'volc_region': 'cn-beijing',
+ 'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
+ 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
+ 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'),
+ 'base_model_name': 'Skylark2-pro-4k',
+ },
+ prompt_messages=[
+ UserPromptMessage(
+ content='Hello World!'
+ )
+ ],
+ model_parameters={
+ 'temperature': 0.7,
+ 'top_p': 1.0,
+ 'top_k': 1,
+ },
+ stop=['you'],
+ stream=True,
+ user="abc-123"
+ )
+
+ assert isinstance(response, Generator)
+ for chunk in response:
+ assert isinstance(chunk, LLMResultChunk)
+ assert isinstance(chunk.delta, LLMResultChunkDelta)
+ assert isinstance(chunk.delta.message, AssistantPromptMessage)
+ assert len(
+ chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
+
+
+def test_get_num_tokens():
+ model = VolcengineMaaSLargeLanguageModel()
+
+ response = model.get_num_tokens(
+ model='NOT IMPORTANT',
+ credentials={
+ 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+ 'volc_region': 'cn-beijing',
+ 'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
+ 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
+ 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'),
+ 'base_model_name': 'Skylark2-pro-4k',
+ },
+ prompt_messages=[
+ UserPromptMessage(
+ content='Hello World!'
+ )
+ ],
+ tools=[]
+ )
+
+ assert isinstance(response, int)
+ assert response == 6