diff --git a/api/core/model_runtime/model_providers/_position.yaml b/api/core/model_runtime/model_providers/_position.yaml index bda53e4394..1c27b2b4aa 100644 --- a/api/core/model_runtime/model_providers/_position.yaml +++ b/api/core/model_runtime/model_providers/_position.yaml @@ -26,5 +26,6 @@ - yi - openllm - localai +- volcengine_maas - openai_api_compatible - deepseek diff --git a/api/core/model_runtime/model_providers/volcengine_maas/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_l_en.svg b/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_l_en.svg new file mode 100644 index 0000000000..616e90916b --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_l_en.svg @@ -0,0 +1,23 @@ + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_l_zh.svg b/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_l_zh.svg new file mode 100644 index 0000000000..24b92195bd --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_l_zh.svg @@ -0,0 +1,39 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_s_en.svg new file mode 100644 index 0000000000..e6454a89b7 --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_s_en.svg @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/volcengine_maas/client.py b/api/core/model_runtime/model_providers/volcengine_maas/client.py new file mode 100644 index 0000000000..c7bf4fde8c --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/client.py @@ -0,0 +1,108 @@ +import re +from collections.abc import Callable, Generator +from typing import cast + +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentType, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.model_providers.volcengine_maas.errors import wrap_error +from core.model_runtime.model_providers.volcengine_maas.volc_sdk import ChatRole, MaasException, MaasService + + +class MaaSClient(MaasService): + def __init__(self, host: str, region: str): + self.endpoint_id = None + super().__init__(host, region) + + def set_endpoint_id(self, endpoint_id: str): + self.endpoint_id = endpoint_id + + @classmethod + def from_credential(cls, credentials: dict) -> 'MaaSClient': + host = credentials['api_endpoint_host'] + region = credentials['volc_region'] + ak = credentials['volc_access_key_id'] + sk = credentials['volc_secret_access_key'] + endpoint_id = credentials['endpoint_id'] + + client = cls(host, region) + client.set_endpoint_id(endpoint_id) + client.set_ak(ak) + client.set_sk(sk) + return client + + def chat(self, params: dict, messages: list[PromptMessage], stream=False) -> Generator | dict: + req = { + 'parameters': params, + 'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages] + } + if not stream: + return super().chat( + self.endpoint_id, + req, + ) + return super().stream_chat( + self.endpoint_id, + req, + ) + + def embeddings(self, texts: list[str]) -> dict: + req = { + 'input': texts + } + return super().embeddings(self.endpoint_id, req) + + @staticmethod + def convert_prompt_message_to_maas_message(message: PromptMessage) -> dict: + if isinstance(message, UserPromptMessage): + message = cast(UserPromptMessage, message) + if isinstance(message.content, str): + message_dict = {"role": ChatRole.USER, + "content": message.content} + else: + content = [] + for message_content in message.content: + if message_content.type == PromptMessageContentType.TEXT: + raise ValueError( + 'Content object type only support image_url') + elif message_content.type == PromptMessageContentType.IMAGE: + message_content = cast( + ImagePromptMessageContent, message_content) + image_data = re.sub( + r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data) + content.append({ + 'type': 'image_url', + 'image_url': { + 'url': '', + 'image_bytes': image_data, + 'detail': message_content.detail, + } + }) + + message_dict = {'role': ChatRole.USER, 'content': content} + elif isinstance(message, AssistantPromptMessage): + message = cast(AssistantPromptMessage, message) + message_dict = {'role': ChatRole.ASSISTANT, + 'content': message.content} + elif isinstance(message, SystemPromptMessage): + message = cast(SystemPromptMessage, message) + message_dict = {'role': ChatRole.SYSTEM, + 'content': message.content} + else: + raise ValueError(f"Got unknown PromptMessage type {message}") + + return message_dict + + @staticmethod + def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator: + try: + resp = fn() + except MaasException as e: + raise wrap_error(e) + + return resp diff --git a/api/core/model_runtime/model_providers/volcengine_maas/errors.py b/api/core/model_runtime/model_providers/volcengine_maas/errors.py new file mode 100644 index 0000000000..63397a456e --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/errors.py @@ -0,0 +1,156 @@ +from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException + + +class ClientSDKRequestError(MaasException): + pass + + +class SignatureDoesNotMatch(MaasException): + pass + + +class RequestTimeout(MaasException): + pass + + +class ServiceConnectionTimeout(MaasException): + pass + + +class MissingAuthenticationHeader(MaasException): + pass + + +class AuthenticationHeaderIsInvalid(MaasException): + pass + + +class InternalServiceError(MaasException): + pass + + +class MissingParameter(MaasException): + pass + + +class InvalidParameter(MaasException): + pass + + +class AuthenticationExpire(MaasException): + pass + + +class EndpointIsInvalid(MaasException): + pass + + +class EndpointIsNotEnable(MaasException): + pass + + +class ModelNotSupportStreamMode(MaasException): + pass + + +class ReqTextExistRisk(MaasException): + pass + + +class RespTextExistRisk(MaasException): + pass + + +class EndpointRateLimitExceeded(MaasException): + pass + + +class ServiceConnectionRefused(MaasException): + pass + + +class ServiceConnectionClosed(MaasException): + pass + + +class UnauthorizedUserForEndpoint(MaasException): + pass + + +class InvalidEndpointWithNoURL(MaasException): + pass + + +class EndpointAccountRpmRateLimitExceeded(MaasException): + pass + + +class EndpointAccountTpmRateLimitExceeded(MaasException): + pass + + +class ServiceResourceWaitQueueFull(MaasException): + pass + + +class EndpointIsPending(MaasException): + pass + + +class ServiceNotOpen(MaasException): + pass + + +AuthErrors = { + 'SignatureDoesNotMatch': SignatureDoesNotMatch, + 'MissingAuthenticationHeader': MissingAuthenticationHeader, + 'AuthenticationHeaderIsInvalid': AuthenticationHeaderIsInvalid, + 'AuthenticationExpire': AuthenticationExpire, + 'UnauthorizedUserForEndpoint': UnauthorizedUserForEndpoint, +} + +BadRequestErrors = { + 'MissingParameter': MissingParameter, + 'InvalidParameter': InvalidParameter, + 'EndpointIsInvalid': EndpointIsInvalid, + 'EndpointIsNotEnable': EndpointIsNotEnable, + 'ModelNotSupportStreamMode': ModelNotSupportStreamMode, + 'ReqTextExistRisk': ReqTextExistRisk, + 'RespTextExistRisk': RespTextExistRisk, + 'InvalidEndpointWithNoURL': InvalidEndpointWithNoURL, + 'ServiceNotOpen': ServiceNotOpen, +} + +RateLimitErrors = { + 'EndpointRateLimitExceeded': EndpointRateLimitExceeded, + 'EndpointAccountRpmRateLimitExceeded': EndpointAccountRpmRateLimitExceeded, + 'EndpointAccountTpmRateLimitExceeded': EndpointAccountTpmRateLimitExceeded, +} + +ServerUnavailableErrors = { + 'InternalServiceError': InternalServiceError, + 'EndpointIsPending': EndpointIsPending, + 'ServiceResourceWaitQueueFull': ServiceResourceWaitQueueFull, +} + +ConnectionErrors = { + 'ClientSDKRequestError': ClientSDKRequestError, + 'RequestTimeout': RequestTimeout, + 'ServiceConnectionTimeout': ServiceConnectionTimeout, + 'ServiceConnectionRefused': ServiceConnectionRefused, + 'ServiceConnectionClosed': ServiceConnectionClosed, +} + +ErrorCodeMap = { + **AuthErrors, + **BadRequestErrors, + **RateLimitErrors, + **ServerUnavailableErrors, + **ConnectionErrors, +} + + +def wrap_error(e: MaasException) -> Exception: + if ErrorCodeMap.get(e.code): + return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id) + return e diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py new file mode 100644 index 0000000000..7a36d019e2 --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py @@ -0,0 +1,284 @@ +import logging +from collections.abc import Generator + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient +from core.model_runtime.model_providers.volcengine_maas.errors import ( + AuthErrors, + BadRequestErrors, + ConnectionErrors, + RateLimitErrors, + ServerUnavailableErrors, +) +from core.model_runtime.model_providers.volcengine_maas.llm.models import ModelConfigs +from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException + +logger = logging.getLogger(__name__) + + +class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): + def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ + -> LLMResult | Generator: + return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate credentials + """ + # ping + client = MaaSClient.from_credential(credentials) + try: + client.chat( + { + 'max_new_tokens': 16, + 'temperature': 0.7, + 'top_p': 0.9, + 'top_k': 15, + }, + [UserPromptMessage(content='ping\nAnswer: ')], + ) + except MaasException as e: + raise CredentialsValidateFailedError(e.message) + + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None) -> int: + if len(prompt_messages) == 0: + return 0 + return self._num_tokens_from_messages(prompt_messages) + + def _num_tokens_from_messages(self, messages: list[PromptMessage]) -> int: + """ + Calculate num tokens. + + :param messages: messages + """ + num_tokens = 0 + messages_dict = [ + MaaSClient.convert_prompt_message_to_maas_message(m) for m in messages] + for message in messages_dict: + for key, value in message.items(): + num_tokens += self._get_num_tokens_by_gpt2(str(key)) + num_tokens += self._get_num_tokens_by_gpt2(str(value)) + + return num_tokens + + def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ + -> LLMResult | Generator: + + client = MaaSClient.from_credential(credentials) + + req_params = ModelConfigs.get( + credentials['base_model_name'], {}).get('req_params', {}).copy() + if credentials.get('context_size'): + req_params['max_prompt_tokens'] = credentials.get('context_size') + if credentials.get('max_tokens'): + req_params['max_new_tokens'] = credentials.get('max_tokens') + if model_parameters.get('max_tokens'): + req_params['max_new_tokens'] = model_parameters.get('max_tokens') + if model_parameters.get('temperature'): + req_params['temperature'] = model_parameters.get('temperature') + if model_parameters.get('top_p'): + req_params['top_p'] = model_parameters.get('top_p') + if model_parameters.get('top_k'): + req_params['top_k'] = model_parameters.get('top_k') + if model_parameters.get('presence_penalty'): + req_params['presence_penalty'] = model_parameters.get( + 'presence_penalty') + if model_parameters.get('frequency_penalty'): + req_params['frequency_penalty'] = model_parameters.get( + 'frequency_penalty') + if stop: + req_params['stop'] = stop + + resp = MaaSClient.wrap_exception( + lambda: client.chat(req_params, prompt_messages, stream)) + if not stream: + return self._handle_chat_response(model, credentials, prompt_messages, resp) + return self._handle_stream_chat_response(model, credentials, prompt_messages, resp) + + def _handle_stream_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: Generator) -> Generator: + for index, r in enumerate(resp): + choices = r['choices'] + if not choices: + continue + choice = choices[0] + message = choice['message'] + usage = None + if r.get('usage'): + usage = self._calc_usage(model, credentials, r['usage']) + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=AssistantPromptMessage( + content=message['content'] if message['content'] else '', + tool_calls=[] + ), + usage=usage, + finish_reason=choice.get('finish_reason'), + ), + ) + + def _handle_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: dict) -> LLMResult: + choices = resp['choices'] + if not choices: + return + choice = choices[0] + message = choice['message'] + + return LLMResult( + model=model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage( + content=message['content'] if message['content'] else '', + tool_calls=[], + ), + usage=self._calc_usage(model, credentials, resp['usage']), + ) + + def _calc_usage(self, model: str, credentials: dict, usage: dict) -> LLMUsage: + return self._calc_response_usage(model=model, credentials=credentials, + prompt_tokens=usage['prompt_tokens'], + completion_tokens=usage['completion_tokens'] + ) + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + max_tokens = ModelConfigs.get( + credentials['base_model_name'], {}).get('req_params', {}).get('max_new_tokens') + if credentials.get('max_tokens'): + max_tokens = int(credentials.get('max_tokens')) + rules = [ + ParameterRule( + name='temperature', + type=ParameterType.FLOAT, + use_template='temperature', + label=I18nObject( + zh_Hans='温度', + en_US='Temperature' + ) + ), + ParameterRule( + name='top_p', + type=ParameterType.FLOAT, + use_template='top_p', + label=I18nObject( + zh_Hans='Top P', + en_US='Top P' + ) + ), + ParameterRule( + name='top_k', + type=ParameterType.INT, + min=1, + default=1, + label=I18nObject( + zh_Hans='Top K', + en_US='Top K' + ) + ), + ParameterRule( + name='presence_penalty', + type=ParameterType.FLOAT, + use_template='presence_penalty', + label={ + 'en_US': 'Presence Penalty', + 'zh_Hans': '存在惩罚', + }, + min=-2.0, + max=2.0, + ), + ParameterRule( + name='frequency_penalty', + type=ParameterType.FLOAT, + use_template='frequency_penalty', + label={ + 'en_US': 'Frequency Penalty', + 'zh_Hans': '频率惩罚', + }, + min=-2.0, + max=2.0, + ), + ParameterRule( + name='max_tokens', + type=ParameterType.INT, + use_template='max_tokens', + min=1, + max=max_tokens, + default=512, + label=I18nObject( + zh_Hans='最大生成长度', + en_US='Max Tokens' + ) + ), + ] + + model_properties = ModelConfigs.get( + credentials['base_model_name'], {}).get('model_properties', {}).copy() + if credentials.get('mode'): + model_properties[ModelPropertyKey.MODE] = credentials.get('mode') + if credentials.get('context_size'): + model_properties[ModelPropertyKey.CONTEXT_SIZE] = int( + credentials.get('context_size', 4096)) + entity = AIModelEntity( + model=model, + label=I18nObject( + en_US=model + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.LLM, + model_properties=model_properties, + parameter_rules=rules + ) + + return entity + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: ConnectionErrors.values(), + InvokeServerUnavailableError: ServerUnavailableErrors.values(), + InvokeRateLimitError: RateLimitErrors.values(), + InvokeAuthorizationError: AuthErrors.values(), + InvokeBadRequestError: BadRequestErrors.values(), + } diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py new file mode 100644 index 0000000000..d022f0069b --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py @@ -0,0 +1,12 @@ +ModelConfigs = { + 'Skylark2-pro-4k': { + 'req_params': { + 'max_prompt_tokens': 4096, + 'max_new_tokens': 4000, + }, + 'model_properties': { + 'context_size': 4096, + 'mode': 'chat', + } + } +} diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py new file mode 100644 index 0000000000..d63399aec2 --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py @@ -0,0 +1,132 @@ +import time +from typing import Optional + +from core.model_runtime.entities.model_entities import PriceType +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient +from core.model_runtime.model_providers.volcengine_maas.errors import ( + AuthErrors, + BadRequestErrors, + ConnectionErrors, + RateLimitErrors, + ServerUnavailableErrors, +) +from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException + + +class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): + """ + Model class for VolcengineMaaS text embedding model. + """ + + def _invoke(self, model: str, credentials: dict, + texts: list[str], user: Optional[str] = None) \ + -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :return: embeddings result + """ + client = MaaSClient.from_credential(credentials) + resp = MaaSClient.wrap_exception(lambda: client.embeddings(texts)) + + usage = self._calc_response_usage( + model=model, credentials=credentials, tokens=resp['total_tokens']) + + result = TextEmbeddingResult( + model=model, + embeddings=[v['embedding'] for v in resp['data']], + usage=usage + ) + + return result + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + num_tokens = 0 + for text in texts: + # use GPT2Tokenizer to get num tokens + num_tokens += self._get_num_tokens_by_gpt2(text) + return num_tokens + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self._invoke(model=model, credentials=credentials, texts=['ping']) + except MaasException as e: + raise CredentialsValidateFailedError(e.message) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: ConnectionErrors.values(), + InvokeServerUnavailableError: ServerUnavailableErrors.values(), + InvokeRateLimitError: RateLimitErrors.values(), + InvokeAuthorizationError: AuthErrors.values(), + InvokeBadRequestError: BadRequestErrors.values(), + } + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param tokens: input tokens + :return: usage + """ + # get input price info + input_price_info = self.get_price( + model=model, + credentials=credentials, + price_type=PriceType.INPUT, + tokens=tokens + ) + + # transform usage + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at + ) + + return usage diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/__init__.py new file mode 100644 index 0000000000..64f342f16e --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/__init__.py @@ -0,0 +1,4 @@ +from .common import ChatRole +from .maas import MaasException, MaasService + +__all__ = ['MaasService', 'ChatRole', 'MaasException'] diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/__init__.py @@ -0,0 +1 @@ + diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py new file mode 100644 index 0000000000..48110f16d7 --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py @@ -0,0 +1,144 @@ +# coding : utf-8 +import datetime + +import pytz + +from .util import Util + + +class MetaData: + def __init__(self): + self.algorithm = '' + self.credential_scope = '' + self.signed_headers = '' + self.date = '' + self.region = '' + self.service = '' + + def set_date(self, date): + self.date = date + + def set_service(self, service): + self.service = service + + def set_region(self, region): + self.region = region + + def set_algorithm(self, algorithm): + self.algorithm = algorithm + + def set_credential_scope(self, credential_scope): + self.credential_scope = credential_scope + + def set_signed_headers(self, signed_headers): + self.signed_headers = signed_headers + + +class SignResult: + def __init__(self): + self.xdate = '' + self.xCredential = '' + self.xAlgorithm = '' + self.xSignedHeaders = '' + self.xSignedQueries = '' + self.xSignature = '' + self.xContextSha256 = '' + self.xSecurityToken = '' + + self.authorization = '' + + def __str__(self): + return '\n'.join(['{}:{}'.format(*item) for item in self.__dict__.items()]) + + +class Credentials: + def __init__(self, ak, sk, service, region, session_token=''): + self.ak = ak + self.sk = sk + self.service = service + self.region = region + self.session_token = session_token + + def set_ak(self, ak): + self.ak = ak + + def set_sk(self, sk): + self.sk = sk + + def set_session_token(self, session_token): + self.session_token = session_token + + +class Signer: + @staticmethod + def sign(request, credentials): + if request.path == '': + request.path = '/' + if request.method != 'GET' and not ('Content-Type' in request.headers): + request.headers['Content-Type'] = 'application/x-www-form-urlencoded; charset=utf-8' + + format_date = Signer.get_current_format_date() + request.headers['X-Date'] = format_date + if credentials.session_token != '': + request.headers['X-Security-Token'] = credentials.session_token + + md = MetaData() + md.set_algorithm('HMAC-SHA256') + md.set_service(credentials.service) + md.set_region(credentials.region) + md.set_date(format_date[:8]) + + hashed_canon_req = Signer.hashed_canonical_request_v4(request, md) + md.set_credential_scope('/'.join([md.date, md.region, md.service, 'request'])) + + signing_str = '\n'.join([md.algorithm, format_date, md.credential_scope, hashed_canon_req]) + signing_key = Signer.get_signing_secret_key_v4(credentials.sk, md.date, md.region, md.service) + sign = Util.to_hex(Util.hmac_sha256(signing_key, signing_str)) + request.headers['Authorization'] = Signer.build_auth_header_v4(sign, md, credentials) + return + + @staticmethod + def hashed_canonical_request_v4(request, meta): + body_hash = Util.sha256(request.body) + request.headers['X-Content-Sha256'] = body_hash + + signed_headers = dict() + for key in request.headers: + if key in ['Content-Type', 'Content-Md5', 'Host'] or key.startswith('X-'): + signed_headers[key.lower()] = request.headers[key] + + if 'host' in signed_headers: + v = signed_headers['host'] + if v.find(':') != -1: + split = v.split(':') + port = split[1] + if str(port) == '80' or str(port) == '443': + signed_headers['host'] = split[0] + + signed_str = '' + for key in sorted(signed_headers.keys()): + signed_str += key + ':' + signed_headers[key] + '\n' + + meta.set_signed_headers(';'.join(sorted(signed_headers.keys()))) + + canonical_request = '\n'.join( + [request.method, Util.norm_uri(request.path), Util.norm_query(request.query), signed_str, + meta.signed_headers, body_hash]) + + return Util.sha256(canonical_request) + + @staticmethod + def get_signing_secret_key_v4(sk, date, region, service): + date = Util.hmac_sha256(bytes(sk, encoding='utf-8'), date) + region = Util.hmac_sha256(date, region) + service = Util.hmac_sha256(region, service) + return Util.hmac_sha256(service, 'request') + + @staticmethod + def build_auth_header_v4(signature, meta, credentials): + credential = credentials.ak + '/' + meta.credential_scope + return meta.algorithm + ' Credential=' + credential + ', SignedHeaders=' + meta.signed_headers + ', Signature=' + signature + + @staticmethod + def get_current_format_date(): + return datetime.datetime.now(tz=pytz.timezone('UTC')).strftime("%Y%m%dT%H%M%SZ") diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py new file mode 100644 index 0000000000..03734ec54f --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py @@ -0,0 +1,207 @@ +import json +from collections import OrderedDict +from urllib.parse import urlencode + +import requests + +from .auth import Signer + +VERSION = 'v1.0.137' + + +class Service: + def __init__(self, service_info, api_info): + self.service_info = service_info + self.api_info = api_info + self.session = requests.session() + + def set_ak(self, ak): + self.service_info.credentials.set_ak(ak) + + def set_sk(self, sk): + self.service_info.credentials.set_sk(sk) + + def set_session_token(self, session_token): + self.service_info.credentials.set_session_token(session_token) + + def set_host(self, host): + self.service_info.host = host + + def set_scheme(self, scheme): + self.service_info.scheme = scheme + + def get(self, api, params, doseq=0): + if not (api in self.api_info): + raise Exception("no such api") + api_info = self.api_info[api] + + r = self.prepare_request(api_info, params, doseq) + + Signer.sign(r, self.service_info.credentials) + + url = r.build(doseq) + resp = self.session.get(url, headers=r.headers, + timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout)) + if resp.status_code == 200: + return resp.text + else: + raise Exception(resp.text) + + def post(self, api, params, form): + if not (api in self.api_info): + raise Exception("no such api") + api_info = self.api_info[api] + r = self.prepare_request(api_info, params) + r.headers['Content-Type'] = 'application/x-www-form-urlencoded' + r.form = self.merge(api_info.form, form) + r.body = urlencode(r.form, True) + Signer.sign(r, self.service_info.credentials) + + url = r.build() + + resp = self.session.post(url, headers=r.headers, data=r.form, + timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout)) + if resp.status_code == 200: + return resp.text + else: + raise Exception(resp.text) + + def json(self, api, params, body): + if not (api in self.api_info): + raise Exception("no such api") + api_info = self.api_info[api] + r = self.prepare_request(api_info, params) + r.headers['Content-Type'] = 'application/json' + r.body = body + + Signer.sign(r, self.service_info.credentials) + + url = r.build() + resp = self.session.post(url, headers=r.headers, data=r.body, + timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout)) + if resp.status_code == 200: + return json.dumps(resp.json()) + else: + raise Exception(resp.text.encode("utf-8")) + + def put(self, url, file_path, headers): + with open(file_path, 'rb') as f: + resp = self.session.put(url, headers=headers, data=f) + if resp.status_code == 200: + return True, resp.text.encode("utf-8") + else: + return False, resp.text.encode("utf-8") + + def put_data(self, url, data, headers): + resp = self.session.put(url, headers=headers, data=data) + if resp.status_code == 200: + return True, resp.text.encode("utf-8") + else: + return False, resp.text.encode("utf-8") + + def prepare_request(self, api_info, params, doseq=0): + for key in params: + if type(params[key]) == int or type(params[key]) == float or type(params[key]) == bool: + params[key] = str(params[key]) + elif type(params[key]) == list: + if not doseq: + params[key] = ','.join(params[key]) + + connection_timeout = self.service_info.connection_timeout + socket_timeout = self.service_info.socket_timeout + + r = Request() + r.set_schema(self.service_info.scheme) + r.set_method(api_info.method) + r.set_connection_timeout(connection_timeout) + r.set_socket_timeout(socket_timeout) + + headers = self.merge(api_info.header, self.service_info.header) + headers['Host'] = self.service_info.host + headers['User-Agent'] = 'volc-sdk-python/' + VERSION + r.set_headers(headers) + + query = self.merge(api_info.query, params) + r.set_query(query) + + r.set_host(self.service_info.host) + r.set_path(api_info.path) + + return r + + @staticmethod + def merge(param1, param2): + od = OrderedDict() + for key in param1: + od[key] = param1[key] + + for key in param2: + od[key] = param2[key] + + return od + + +class Request: + def __init__(self): + self.schema = '' + self.method = '' + self.host = '' + self.path = '' + self.headers = OrderedDict() + self.query = OrderedDict() + self.body = '' + self.form = dict() + self.connection_timeout = 0 + self.socket_timeout = 0 + + def set_schema(self, schema): + self.schema = schema + + def set_method(self, method): + self.method = method + + def set_host(self, host): + self.host = host + + def set_path(self, path): + self.path = path + + def set_headers(self, headers): + self.headers = headers + + def set_query(self, query): + self.query = query + + def set_body(self, body): + self.body = body + + def set_connection_timeout(self, connection_timeout): + self.connection_timeout = connection_timeout + + def set_socket_timeout(self, socket_timeout): + self.socket_timeout = socket_timeout + + def build(self, doseq=0): + return self.schema + '://' + self.host + self.path + '?' + urlencode(self.query, doseq) + + +class ServiceInfo: + def __init__(self, host, header, credentials, connection_timeout, socket_timeout, scheme='http'): + self.host = host + self.header = header + self.credentials = credentials + self.connection_timeout = connection_timeout + self.socket_timeout = socket_timeout + self.scheme = scheme + + +class ApiInfo: + def __init__(self, method, path, query, form, header): + self.method = method + self.path = path + self.query = query + self.form = form + self.header = header + + def __str__(self): + return 'method: ' + self.method + ', path: ' + self.path diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/util.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/util.py new file mode 100644 index 0000000000..7eb5fdfa91 --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/util.py @@ -0,0 +1,43 @@ +import hashlib +import hmac +from functools import reduce +from urllib.parse import quote + + +class Util: + @staticmethod + def norm_uri(path): + return quote(path).replace('%2F', '/').replace('+', '%20') + + @staticmethod + def norm_query(params): + query = '' + for key in sorted(params.keys()): + if type(params[key]) == list: + for k in params[key]: + query = query + quote(key, safe='-_.~') + '=' + quote(k, safe='-_.~') + '&' + else: + query = query + quote(key, safe='-_.~') + '=' + quote(params[key], safe='-_.~') + '&' + query = query[:-1] + return query.replace('+', '%20') + + @staticmethod + def hmac_sha256(key, content): + return hmac.new(key, bytes(content, encoding='utf-8'), hashlib.sha256).digest() + + @staticmethod + def sha256(content): + if isinstance(content, str) is True: + return hashlib.sha256(content.encode('utf-8')).hexdigest() + else: + return hashlib.sha256(content).hexdigest() + + @staticmethod + def to_hex(content): + lst = [] + for ch in content: + hv = hex(ch).replace('0x', '') + if len(hv) == 1: + hv = '0' + hv + lst.append(hv) + return reduce(lambda x, y: x + y, lst) diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/common.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/common.py new file mode 100644 index 0000000000..8b14d026d9 --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/common.py @@ -0,0 +1,79 @@ +import json +import random +from datetime import datetime + + +class ChatRole: + USER = "user" + ASSISTANT = "assistant" + SYSTEM = "system" + FUNCTION = "function" + + +class _Dict(dict): + __setattr__ = dict.__setitem__ + __getattr__ = dict.__getitem__ + + def __missing__(self, key): + return None + + +def dict_to_object(dict_obj): + # 支持嵌套类型 + if isinstance(dict_obj, list): + insts = [] + for i in dict_obj: + insts.append(dict_to_object(i)) + return insts + + if isinstance(dict_obj, dict): + inst = _Dict() + for k, v in dict_obj.items(): + inst[k] = dict_to_object(v) + return inst + + return dict_obj + + +def json_to_object(json_str, req_id=None): + obj = dict_to_object(json.loads(json_str)) + if obj and isinstance(obj, dict) and req_id: + obj["req_id"] = req_id + return obj + + +def gen_req_id(): + return datetime.now().strftime("%Y%m%d%H%M%S") + format( + random.randint(0, 2 ** 64 - 1), "020X" + ) + + +class SSEDecoder: + def __init__(self, source): + self.source = source + + def _read(self): + data = b'' + for chunk in self.source: + for line in chunk.splitlines(True): + data += line + if data.endswith((b'\r\r', b'\n\n', b'\r\n\r\n')): + yield data + data = b'' + if data: + yield data + + def next(self): + for chunk in self._read(): + for line in chunk.splitlines(): + # skip comment + if line.startswith(b':'): + continue + + if b':' in line: + field, value = line.split(b':', 1) + else: + field, value = line, b'' + + if field == b'data' and len(value) > 0: + yield value diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/maas.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/maas.py new file mode 100644 index 0000000000..3cbe9d9f09 --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/maas.py @@ -0,0 +1,213 @@ +import copy +import json +from collections.abc import Iterator + +from .base.auth import Credentials, Signer +from .base.service import ApiInfo, Service, ServiceInfo +from .common import SSEDecoder, dict_to_object, gen_req_id, json_to_object + + +class MaasService(Service): + def __init__(self, host, region, connection_timeout=60, socket_timeout=60): + service_info = self.get_service_info( + host, region, connection_timeout, socket_timeout + ) + self._apikey = None + api_info = self.get_api_info() + super().__init__(service_info, api_info) + + def set_apikey(self, apikey): + self._apikey = apikey + + @staticmethod + def get_service_info(host, region, connection_timeout, socket_timeout): + service_info = ServiceInfo( + host, + {"Accept": "application/json"}, + Credentials("", "", "ml_maas", region), + connection_timeout, + socket_timeout, + "https", + ) + return service_info + + @staticmethod + def get_api_info(): + api_info = { + "chat": ApiInfo("POST", "/api/v2/endpoint/{endpoint_id}/chat", {}, {}, {}), + "embeddings": ApiInfo( + "POST", "/api/v2/endpoint/{endpoint_id}/embeddings", {}, {}, {} + ), + } + return api_info + + def chat(self, endpoint_id, req): + req["stream"] = False + return self._request(endpoint_id, "chat", req) + + def stream_chat(self, endpoint_id, req): + req_id = gen_req_id() + self._validate("chat", req_id) + apikey = self._apikey + + try: + req["stream"] = True + res = self._call( + endpoint_id, "chat", req_id, {}, json.dumps(req).encode("utf-8"), apikey, stream=True + ) + + decoder = SSEDecoder(res) + + def iter_fn(): + for data in decoder.next(): + if data == b"[DONE]": + return + + try: + res = json_to_object( + str(data, encoding="utf-8"), req_id=req_id) + except Exception: + raise + + if res.error is not None and res.error.code_n != 0: + raise MaasException( + res.error.code_n, + res.error.code, + res.error.message, + req_id, + ) + yield res + + return iter_fn() + except MaasException: + raise + except Exception as e: + raise new_client_sdk_request_error(str(e)) + + def embeddings(self, endpoint_id, req): + return self._request(endpoint_id, "embeddings", req) + + def _request(self, endpoint_id, api, req, params={}): + req_id = gen_req_id() + + self._validate(api, req_id) + + apikey = self._apikey + + try: + res = self._call(endpoint_id, api, req_id, params, + json.dumps(req).encode("utf-8"), apikey) + resp = dict_to_object(res.json()) + if resp and isinstance(resp, dict): + resp["req_id"] = req_id + return resp + + except MaasException as e: + raise e + except Exception as e: + raise new_client_sdk_request_error(str(e), req_id) + + def _validate(self, api, req_id): + credentials_exist = ( + self.service_info.credentials is not None and + self.service_info.credentials.sk is not None and + self.service_info.credentials.ak is not None + ) + + if not self._apikey and not credentials_exist: + raise new_client_sdk_request_error("no valid credential", req_id) + + if not (api in self.api_info): + raise new_client_sdk_request_error("no such api", req_id) + + def _call(self, endpoint_id, api, req_id, params, body, apikey=None, stream=False): + api_info = copy.deepcopy(self.api_info[api]) + api_info.path = api_info.path.format(endpoint_id=endpoint_id) + + r = self.prepare_request(api_info, params) + r.headers["x-tt-logid"] = req_id + r.headers["Content-Type"] = "application/json" + r.body = body + + if apikey is None: + Signer.sign(r, self.service_info.credentials) + elif apikey is not None: + r.headers["Authorization"] = "Bearer " + apikey + + url = r.build() + res = self.session.post( + url, + headers=r.headers, + data=r.body, + timeout=( + self.service_info.connection_timeout, + self.service_info.socket_timeout, + ), + stream=stream, + ) + + if res.status_code != 200: + raw = res.text.encode() + res.close() + try: + resp = json_to_object( + str(raw, encoding="utf-8"), req_id=req_id) + except Exception: + raise new_client_sdk_request_error(raw, req_id) + + if resp.error: + raise MaasException( + resp.error.code_n, resp.error.code, resp.error.message, req_id + ) + else: + raise new_client_sdk_request_error(resp, req_id) + + return res + + +class MaasException(Exception): + def __init__(self, code_n, code, message, req_id): + self.code_n = code_n + self.code = code + self.message = message + self.req_id = req_id + + def __str__(self): + return ("Detailed exception information is listed below.\n" + + "req_id: {}\n" + + "code_n: {}\n" + + "code: {}\n" + + "message: {}").format(self.req_id, self.code_n, self.code, self.message) + + +def new_client_sdk_request_error(raw, req_id=""): + return MaasException(1709701, "ClientSDKRequestError", "MaaS SDK request error: {}".format(raw), req_id) + + +class BinaryResponseContent: + def __init__(self, response, request_id) -> None: + self.response = response + self.request_id = request_id + + def stream_to_file( + self, + file: str + ) -> None: + is_first = True + error_bytes = b'' + with open(file, mode="wb") as f: + for data in self.response: + if len(error_bytes) > 0 or (is_first and "\"error\":" in str(data)): + error_bytes += data + else: + f.write(data) + + if len(error_bytes) > 0: + resp = json_to_object( + str(error_bytes, encoding="utf-8"), req_id=self.request_id) + raise MaasException( + resp.error.code_n, resp.error.code, resp.error.message, self.request_id + ) + + def iter_bytes(self) -> Iterator[bytes]: + yield from self.response diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.py b/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.py new file mode 100644 index 0000000000..10f9be2d08 --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.py @@ -0,0 +1,10 @@ +import logging + +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class VolcengineMaaSProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + pass diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml b/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml new file mode 100644 index 0000000000..4f299ecae0 --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml @@ -0,0 +1,151 @@ +provider: volcengine_maas +label: + en_US: Volcengine +description: + en_US: Volcengine MaaS models. +icon_small: + en_US: icon_s_en.svg +icon_large: + en_US: icon_l_en.svg + zh_Hans: icon_l_zh.svg +background: "#F9FAFB" +help: + title: + en_US: Get your Access Key and Secret Access Key from Volcengine Console + url: + en_US: https://console.volcengine.com/iam/keymanage/ +supported_model_types: + - llm + - text-embedding +configurate_methods: + - customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your Model Name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: volc_access_key_id + required: true + label: + en_US: Access Key + zh_Hans: Access Key + type: secret-input + placeholder: + en_US: Enter your Access Key + zh_Hans: 输入您的 Access Key + - variable: volc_secret_access_key + required: true + label: + en_US: Secret Access Key + zh_Hans: Secret Access Key + type: secret-input + placeholder: + en_US: Enter your Secret Access Key + zh_Hans: 输入您的 Secret Access Key + - variable: volc_region + required: true + label: + en_US: Volcengine Region + zh_Hans: 火山引擎地区 + type: text-input + default: cn-beijing + placeholder: + en_US: Enter Volcengine Region + zh_Hans: 输入火山引擎地域 + - variable: api_endpoint_host + required: true + label: + en_US: API Endpoint Host + zh_Hans: API Endpoint Host + type: text-input + default: maas-api.ml-platform-cn-beijing.volces.com + placeholder: + en_US: Enter your API Endpoint Host + zh_Hans: 输入 API Endpoint Host + - variable: endpoint_id + required: true + label: + en_US: Endpoint ID + zh_Hans: Endpoint ID + type: text-input + placeholder: + en_US: Enter your Endpoint ID + zh_Hans: 输入您的 Endpoint ID + - variable: base_model_name + show_on: + - variable: __model_type + value: llm + label: + en_US: Base Model + zh_Hans: 基础模型 + type: select + required: true + options: + - label: + en_US: Skylark2-pro-4k + value: Skylark2-pro-4k + show_on: + - variable: __model_type + value: llm + - label: + en_US: Custom + zh_Hans: 自定义 + value: Custom + - variable: mode + required: true + show_on: + - variable: __model_type + value: llm + - variable: base_model_name + value: Custom + label: + zh_Hans: 模型类型 + en_US: Completion Mode + type: select + default: chat + placeholder: + zh_Hans: 选择对话类型 + en_US: Select Completion Mode + options: + - value: completion + label: + en_US: Completion + zh_Hans: 补全 + - value: chat + label: + en_US: Chat + zh_Hans: 对话 + - variable: context_size + required: true + show_on: + - variable: __model_type + value: llm + - variable: base_model_name + value: Custom + label: + zh_Hans: 模型上下文长度 + en_US: Model Context Size + type: text-input + default: '4096' + placeholder: + zh_Hans: 输入您的模型上下文长度 + en_US: Enter your Model Context Size + - variable: max_tokens + required: true + show_on: + - variable: __model_type + value: llm + - variable: base_model_name + value: Custom + label: + zh_Hans: 最大 token 上限 + en_US: Upper Bound for Max Tokens + default: '4096' + type: text-input + placeholder: + zh_Hans: 输入您的模型最大 token 上限 + en_US: Enter your model Upper Bound for Max Tokens diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index 9cd04b4764..f29e5ef4d6 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -73,4 +73,10 @@ MOCK_SWITCH=false # CODE EXECUTION CONFIGURATION CODE_EXECUTION_ENDPOINT= -CODE_EXECUTION_API_KEY= \ No newline at end of file +CODE_EXECUTION_API_KEY= + +# Volcengine MaaS Credentials +VOLC_API_KEY= +VOLC_SECRET_KEY= +VOLC_MODEL_ENDPOINT_ID= +VOLC_EMBEDDING_ENDPOINT_ID= \ No newline at end of file diff --git a/api/tests/integration_tests/model_runtime/volcengine_maas/__init__.py b/api/tests/integration_tests/model_runtime/volcengine_maas/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py b/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py new file mode 100644 index 0000000000..61e9f704af --- /dev/null +++ b/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py @@ -0,0 +1,81 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.volcengine_maas.text_embedding.text_embedding import ( + VolcengineMaaSTextEmbeddingModel, +) + + +def test_validate_credentials(): + model = VolcengineMaaSTextEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='NOT IMPORTANT', + credentials={ + 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', + 'volc_region': 'cn-beijing', + 'volc_access_key_id': 'INVALID', + 'volc_secret_access_key': 'INVALID', + 'endpoint_id': 'INVALID', + } + ) + + model.validate_credentials( + model='NOT IMPORTANT', + credentials={ + 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', + 'volc_region': 'cn-beijing', + 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), + 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), + 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'), + }, + ) + + +def test_invoke_model(): + model = VolcengineMaaSTextEmbeddingModel() + + result = model.invoke( + model='NOT IMPORTANT', + credentials={ + 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', + 'volc_region': 'cn-beijing', + 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), + 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), + 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'), + }, + texts=[ + "hello", + "world" + ], + user="abc-123" + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens > 0 + + +def test_get_num_tokens(): + model = VolcengineMaaSTextEmbeddingModel() + + num_tokens = model.get_num_tokens( + model='NOT IMPORTANT', + credentials={ + 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', + 'volc_region': 'cn-beijing', + 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), + 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), + 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'), + }, + texts=[ + "hello", + "world" + ] + ) + + assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py b/api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py new file mode 100644 index 0000000000..63835d0263 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py @@ -0,0 +1,131 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.volcengine_maas.llm.llm import VolcengineMaaSLargeLanguageModel + + +def test_validate_credentials_for_chat_model(): + model = VolcengineMaaSLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='NOT IMPORTANT', + credentials={ + 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', + 'volc_region': 'cn-beijing', + 'volc_access_key_id': 'INVALID', + 'volc_secret_access_key': 'INVALID', + 'endpoint_id': 'INVALID', + } + ) + + model.validate_credentials( + model='NOT IMPORTANT', + credentials={ + 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', + 'volc_region': 'cn-beijing', + 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), + 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), + 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'), + } + ) + + +def test_invoke_model(): + model = VolcengineMaaSLargeLanguageModel() + + response = model.invoke( + model='NOT IMPORTANT', + credentials={ + 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', + 'volc_region': 'cn-beijing', + 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), + 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), + 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'), + 'base_model_name': 'Skylark2-pro-4k', + }, + prompt_messages=[ + UserPromptMessage( + content='Hello World!' + ) + ], + model_parameters={ + 'temperature': 0.7, + 'top_p': 1.0, + 'top_k': 1, + }, + stop=['you'], + user="abc-123", + stream=False + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + assert response.usage.total_tokens > 0 + + +def test_invoke_stream_model(): + model = VolcengineMaaSLargeLanguageModel() + + response = model.invoke( + model='NOT IMPORTANT', + credentials={ + 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', + 'volc_region': 'cn-beijing', + 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), + 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), + 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'), + 'base_model_name': 'Skylark2-pro-4k', + }, + prompt_messages=[ + UserPromptMessage( + content='Hello World!' + ) + ], + model_parameters={ + 'temperature': 0.7, + 'top_p': 1.0, + 'top_k': 1, + }, + stop=['you'], + stream=True, + user="abc-123" + ) + + assert isinstance(response, Generator) + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len( + chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_get_num_tokens(): + model = VolcengineMaaSLargeLanguageModel() + + response = model.get_num_tokens( + model='NOT IMPORTANT', + credentials={ + 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', + 'volc_region': 'cn-beijing', + 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), + 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), + 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'), + 'base_model_name': 'Skylark2-pro-4k', + }, + prompt_messages=[ + UserPromptMessage( + content='Hello World!' + ) + ], + tools=[] + ) + + assert isinstance(response, int) + assert response == 6