From bfd905602fc6dbaa19e87168b1acdb6d38059e8b Mon Sep 17 00:00:00 2001 From: Chengyu Yan Date: Mon, 19 Aug 2024 09:15:19 +0800 Subject: [PATCH] feat(api): support wenxin text embedding (#7377) --- .../model_providers/wenxin/_common.py | 195 ++++++++++++++++++ .../model_providers/wenxin/llm/ernie_bot.py | 180 +--------------- .../wenxin/llm/ernie_bot_errors.py | 17 -- .../model_providers/wenxin/llm/llm.py | 39 +--- .../wenxin/text_embedding/__init__.py | 0 .../wenxin/text_embedding/embedding-v1.yaml | 9 + .../wenxin/text_embedding/text_embedding.py | 184 +++++++++++++++++ .../model_providers/wenxin/wenxin.yaml | 1 + .../model_providers/wenxin/wenxin_errors.py | 57 +++++ .../model_runtime/wenxin/test_embedding.py | 24 +++ .../unit_tests/core/model_runtime/__init__.py | 0 .../model_runtime/model_providers/__init__.py | 0 .../model_providers/wenxin/__init__.py | 0 .../wenxin/test_text_embedding.py | 75 +++++++ 14 files changed, 553 insertions(+), 228 deletions(-) create mode 100644 api/core/model_runtime/model_providers/wenxin/_common.py delete mode 100644 api/core/model_runtime/model_providers/wenxin/llm/ernie_bot_errors.py create mode 100644 api/core/model_runtime/model_providers/wenxin/text_embedding/__init__.py create mode 100644 api/core/model_runtime/model_providers/wenxin/text_embedding/embedding-v1.yaml create mode 100644 api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py create mode 100644 api/core/model_runtime/model_providers/wenxin/wenxin_errors.py create mode 100644 api/tests/integration_tests/model_runtime/wenxin/test_embedding.py create mode 100644 api/tests/unit_tests/core/model_runtime/__init__.py create mode 100644 api/tests/unit_tests/core/model_runtime/model_providers/__init__.py create mode 100644 api/tests/unit_tests/core/model_runtime/model_providers/wenxin/__init__.py create mode 100644 api/tests/unit_tests/core/model_runtime/model_providers/wenxin/test_text_embedding.py diff --git a/api/core/model_runtime/model_providers/wenxin/_common.py b/api/core/model_runtime/model_providers/wenxin/_common.py new file mode 100644 index 0000000000..ee9c34b6ac --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/_common.py @@ -0,0 +1,195 @@ +from datetime import datetime, timedelta +from threading import Lock + +from requests import post + +from core.model_runtime.model_providers.wenxin.wenxin_errors import ( + BadRequestError, + InternalServerError, + InvalidAPIKeyError, + InvalidAuthenticationError, + RateLimitReachedError, +) + +baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {} +baidu_access_tokens_lock = Lock() + + +class BaiduAccessToken: + api_key: str + access_token: str + expires: datetime + + def __init__(self, api_key: str) -> None: + self.api_key = api_key + self.access_token = '' + self.expires = datetime.now() + timedelta(days=3) + + @staticmethod + def _get_access_token(api_key: str, secret_key: str) -> str: + """ + request access token from Baidu + """ + try: + response = post( + url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}', + headers={ + 'Content-Type': 'application/json', + 'Accept': 'application/json' + }, + ) + except Exception as e: + raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}') + + resp = response.json() + if 'error' in resp: + if resp['error'] == 'invalid_client': + raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}') + elif resp['error'] == 'unknown_error': + raise InternalServerError(f'Internal server error: {resp["error_description"]}') + elif resp['error'] == 'invalid_request': + raise BadRequestError(f'Bad request: {resp["error_description"]}') + elif resp['error'] == 'rate_limit_exceeded': + raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}') + else: + raise Exception(f'Unknown error: {resp["error_description"]}') + + return resp['access_token'] + + @staticmethod + def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken': + """ + LLM from Baidu requires access token to invoke the API. + however, we have api_key and secret_key, and access token is valid for 30 days. + so we can cache the access token for 3 days. (avoid memory leak) + + it may be more efficient to use a ticker to refresh access token, but it will cause + more complexity, so we just refresh access tokens when get_access_token is called. + """ + + # loop up cache, remove expired access token + baidu_access_tokens_lock.acquire() + now = datetime.now() + for key in list(baidu_access_tokens.keys()): + token = baidu_access_tokens[key] + if token.expires < now: + baidu_access_tokens.pop(key) + + if api_key not in baidu_access_tokens: + # if access token not in cache, request it + token = BaiduAccessToken(api_key) + baidu_access_tokens[api_key] = token + # release it to enhance performance + # btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock + baidu_access_tokens_lock.release() + # try to get access token + token_str = BaiduAccessToken._get_access_token(api_key, secret_key) + token.access_token = token_str + token.expires = now + timedelta(days=3) + return token + else: + # if access token in cache, return it + token = baidu_access_tokens[api_key] + baidu_access_tokens_lock.release() + return token + + +class _CommonWenxin: + api_bases = { + 'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205', + 'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', + 'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', + 'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', + 'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', + 'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205', + 'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222', + 'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205', + 'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k', + 'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', + 'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', + 'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed', + 'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k', + 'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas', + 'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', + 'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k', + 'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', + 'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', + 'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k', + 'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview', + 'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat', + 'embedding-v1': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1', + } + + function_calling_supports = [ + 'ernie-bot', + 'ernie-bot-8k', + 'ernie-3.5-8k', + 'ernie-3.5-8k-0205', + 'ernie-3.5-8k-1222', + 'ernie-3.5-4k-0205', + 'ernie-3.5-128k', + 'ernie-4.0-8k', + 'ernie-4.0-turbo-8k', + 'ernie-4.0-turbo-8k-preview', + 'yi_34b_chat' + ] + + api_key: str = '' + secret_key: str = '' + + def __init__(self, api_key: str, secret_key: str): + self.api_key = api_key + self.secret_key = secret_key + + @staticmethod + def _to_credential_kwargs(credentials: dict) -> dict: + credentials_kwargs = { + "api_key": credentials['api_key'], + "secret_key": credentials['secret_key'] + } + return credentials_kwargs + + def _handle_error(self, code: int, msg: str): + error_map = { + 1: InternalServerError, + 2: InternalServerError, + 3: BadRequestError, + 4: RateLimitReachedError, + 6: InvalidAuthenticationError, + 13: InvalidAPIKeyError, + 14: InvalidAPIKeyError, + 15: InvalidAPIKeyError, + 17: RateLimitReachedError, + 18: RateLimitReachedError, + 19: RateLimitReachedError, + 100: InvalidAPIKeyError, + 111: InvalidAPIKeyError, + 200: InternalServerError, + 336000: InternalServerError, + 336001: BadRequestError, + 336002: BadRequestError, + 336003: BadRequestError, + 336004: InvalidAuthenticationError, + 336005: InvalidAPIKeyError, + 336006: BadRequestError, + 336007: BadRequestError, + 336008: BadRequestError, + 336100: InternalServerError, + 336101: BadRequestError, + 336102: BadRequestError, + 336103: BadRequestError, + 336104: BadRequestError, + 336105: BadRequestError, + 336200: InternalServerError, + 336303: BadRequestError, + 337006: BadRequestError + } + + if code in error_map: + raise error_map[code](msg) + else: + raise InternalServerError(f'Unknown error: {msg}') + + def _get_access_token(self) -> str: + token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key) + return token.access_token diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py index e345663d36..8109949b1d 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py @@ -1,102 +1,17 @@ from collections.abc import Generator -from datetime import datetime, timedelta from enum import Enum from json import dumps, loads -from threading import Lock from typing import Any, Union from requests import Response, post from core.model_runtime.entities.message_entities import PromptMessageTool -from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import ( +from core.model_runtime.model_providers.wenxin._common import _CommonWenxin +from core.model_runtime.model_providers.wenxin.wenxin_errors import ( BadRequestError, InternalServerError, - InvalidAPIKeyError, - InvalidAuthenticationError, - RateLimitReachedError, ) -# map api_key to access_token -baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {} -baidu_access_tokens_lock = Lock() - -class BaiduAccessToken: - api_key: str - access_token: str - expires: datetime - - def __init__(self, api_key: str) -> None: - self.api_key = api_key - self.access_token = '' - self.expires = datetime.now() + timedelta(days=3) - - def _get_access_token(api_key: str, secret_key: str) -> str: - """ - request access token from Baidu - """ - try: - response = post( - url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}', - headers={ - 'Content-Type': 'application/json', - 'Accept': 'application/json' - }, - ) - except Exception as e: - raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}') - - resp = response.json() - if 'error' in resp: - if resp['error'] == 'invalid_client': - raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}') - elif resp['error'] == 'unknown_error': - raise InternalServerError(f'Internal server error: {resp["error_description"]}') - elif resp['error'] == 'invalid_request': - raise BadRequestError(f'Bad request: {resp["error_description"]}') - elif resp['error'] == 'rate_limit_exceeded': - raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}') - else: - raise Exception(f'Unknown error: {resp["error_description"]}') - - return resp['access_token'] - - @staticmethod - def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken': - """ - LLM from Baidu requires access token to invoke the API. - however, we have api_key and secret_key, and access token is valid for 30 days. - so we can cache the access token for 3 days. (avoid memory leak) - - it may be more efficient to use a ticker to refresh access token, but it will cause - more complexity, so we just refresh access tokens when get_access_token is called. - """ - - # loop up cache, remove expired access token - baidu_access_tokens_lock.acquire() - now = datetime.now() - for key in list(baidu_access_tokens.keys()): - token = baidu_access_tokens[key] - if token.expires < now: - baidu_access_tokens.pop(key) - - if api_key not in baidu_access_tokens: - # if access token not in cache, request it - token = BaiduAccessToken(api_key) - baidu_access_tokens[api_key] = token - # release it to enhance performance - # btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock - baidu_access_tokens_lock.release() - # try to get access token - token_str = BaiduAccessToken._get_access_token(api_key, secret_key) - token.access_token = token_str - token.expires = now + timedelta(days=3) - return token - else: - # if access token in cache, return it - token = baidu_access_tokens[api_key] - baidu_access_tokens_lock.release() - return token - class ErnieMessage: class Role(Enum): @@ -120,51 +35,7 @@ class ErnieMessage: self.content = content self.role = role -class ErnieBotModel: - api_bases = { - 'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205', - 'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', - 'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', - 'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', - 'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', - 'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205', - 'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222', - 'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205', - 'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k', - 'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', - 'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', - 'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed', - 'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k', - 'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas', - 'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', - 'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k', - 'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', - 'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', - 'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k', - 'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview', - 'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat', - } - - function_calling_supports = [ - 'ernie-bot', - 'ernie-bot-8k', - 'ernie-3.5-8k', - 'ernie-3.5-8k-0205', - 'ernie-3.5-8k-1222', - 'ernie-3.5-4k-0205', - 'ernie-3.5-128k', - 'ernie-4.0-8k', - 'ernie-4.0-turbo-8k', - 'ernie-4.0-turbo-8k-preview', - 'yi_34b_chat' - ] - - api_key: str = '' - secret_key: str = '' - - def __init__(self, api_key: str, secret_key: str): - self.api_key = api_key - self.secret_key = secret_key +class ErnieBotModel(_CommonWenxin): def generate(self, model: str, stream: bool, messages: list[ErnieMessage], parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \ @@ -199,51 +70,6 @@ class ErnieBotModel: return self._handle_chat_stream_generate_response(resp) return self._handle_chat_generate_response(resp) - def _handle_error(self, code: int, msg: str): - error_map = { - 1: InternalServerError, - 2: InternalServerError, - 3: BadRequestError, - 4: RateLimitReachedError, - 6: InvalidAuthenticationError, - 13: InvalidAPIKeyError, - 14: InvalidAPIKeyError, - 15: InvalidAPIKeyError, - 17: RateLimitReachedError, - 18: RateLimitReachedError, - 19: RateLimitReachedError, - 100: InvalidAPIKeyError, - 111: InvalidAPIKeyError, - 200: InternalServerError, - 336000: InternalServerError, - 336001: BadRequestError, - 336002: BadRequestError, - 336003: BadRequestError, - 336004: InvalidAuthenticationError, - 336005: InvalidAPIKeyError, - 336006: BadRequestError, - 336007: BadRequestError, - 336008: BadRequestError, - 336100: InternalServerError, - 336101: BadRequestError, - 336102: BadRequestError, - 336103: BadRequestError, - 336104: BadRequestError, - 336105: BadRequestError, - 336200: InternalServerError, - 336303: BadRequestError, - 337006: BadRequestError - } - - if code in error_map: - raise error_map[code](msg) - else: - raise InternalServerError(f'Unknown error: {msg}') - - def _get_access_token(self) -> str: - token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key) - return token.access_token - def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]: return [ErnieMessage(message.content, message.role) for message in messages] diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot_errors.py b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot_errors.py deleted file mode 100644 index 67d76b4a29..0000000000 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot_errors.py +++ /dev/null @@ -1,17 +0,0 @@ -class InvalidAuthenticationError(Exception): - pass - -class InvalidAPIKeyError(Exception): - pass - -class RateLimitReachedError(Exception): - pass - -class InsufficientAccountBalance(Exception): - pass - -class InternalServerError(Exception): - pass - -class BadRequestError(Exception): - pass \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/wenxin/llm/llm.py b/api/core/model_runtime/model_providers/wenxin/llm/llm.py index d39d63deee..140606298c 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/llm.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/llm.py @@ -11,24 +11,13 @@ from core.model_runtime.entities.message_entities import ( UserPromptMessage, ) from core.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, InvokeError, - InvokeRateLimitError, - InvokeServerUnavailableError, ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.wenxin.llm.ernie_bot import BaiduAccessToken, ErnieBotModel, ErnieMessage -from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import ( - BadRequestError, - InsufficientAccountBalance, - InternalServerError, - InvalidAPIKeyError, - InvalidAuthenticationError, - RateLimitReachedError, -) +from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken +from core.model_runtime.model_providers.wenxin.llm.ernie_bot import ErnieBotModel, ErnieMessage +from core.model_runtime.model_providers.wenxin.wenxin_errors import invoke_error_mapping ERNIE_BOT_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure @@ -140,7 +129,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): api_key = credentials['api_key'] secret_key = credentials['secret_key'] try: - BaiduAccessToken._get_access_token(api_key, secret_key) + BaiduAccessToken.get_access_token(api_key, secret_key) except Exception as e: raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') @@ -254,22 +243,4 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): :return: Invoke error mapping """ - return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], - InvokeAuthorizationError: [ - InvalidAuthenticationError, - InsufficientAccountBalance, - InvalidAPIKeyError, - ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] - } + return invoke_error_mapping() diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/__init__.py b/api/core/model_runtime/model_providers/wenxin/text_embedding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/embedding-v1.yaml b/api/core/model_runtime/model_providers/wenxin/text_embedding/embedding-v1.yaml new file mode 100644 index 0000000000..eda48d9655 --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/embedding-v1.yaml @@ -0,0 +1,9 @@ +model: embedding-v1 +model_type: text-embedding +model_properties: + context_size: 384 + max_chunks: 16 +pricing: + input: '0.0005' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py new file mode 100644 index 0000000000..10ac1a1861 --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py @@ -0,0 +1,184 @@ +import time +from abc import abstractmethod +from collections.abc import Mapping +from json import dumps +from typing import Any, Optional + +import numpy as np +from requests import Response, post + +from core.model_runtime.entities.model_entities import PriceType +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.invoke import InvokeError +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken, _CommonWenxin +from core.model_runtime.model_providers.wenxin.wenxin_errors import ( + BadRequestError, + InternalServerError, + invoke_error_mapping, +) + + +class TextEmbedding: + @abstractmethod + def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): + raise NotImplementedError + + +class WenxinTextEmbedding(_CommonWenxin, TextEmbedding): + def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): + access_token = self._get_access_token() + url = f'{self.api_bases[model]}?access_token={access_token}' + body = self._build_embed_request_body(model, texts, user) + headers = { + 'Content-Type': 'application/json', + } + + resp = post(url, data=dumps(body), headers=headers) + if resp.status_code != 200: + raise InternalServerError(f'Failed to invoke ernie bot: {resp.text}') + return self._handle_embed_response(model, resp) + + def _build_embed_request_body(self, model: str, texts: list[str], user: str) -> dict[str, Any]: + if len(texts) == 0: + raise BadRequestError('The number of texts should not be zero.') + body = { + 'input': texts, + 'user_id': user, + } + return body + + def _handle_embed_response(self, model: str, response: Response) -> (list[list[float]], int, int): + data = response.json() + if 'error_code' in data: + code = data['error_code'] + msg = data['error_msg'] + # raise error + self._handle_error(code, msg) + + embeddings = [v['embedding'] for v in data['data']] + _usage = data['usage'] + tokens = _usage['prompt_tokens'] + total_tokens = _usage['total_tokens'] + + return embeddings, tokens, total_tokens + + +class WenxinTextEmbeddingModel(TextEmbeddingModel): + def _create_text_embedding(self, api_key: str, secret_key: str) -> TextEmbedding: + return WenxinTextEmbedding(api_key, secret_key) + + def _invoke(self, model: str, credentials: dict, texts: list[str], + user: Optional[str] = None) -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :return: embeddings result + """ + + api_key = credentials['api_key'] + secret_key = credentials['secret_key'] + embedding: TextEmbedding = self._create_text_embedding(api_key, secret_key) + user = user if user else 'ErnieBotDefault' + + context_size = self._get_context_size(model, credentials) + max_chunks = self._get_max_chunks(model, credentials) + inputs = [] + indices = [] + used_tokens = 0 + used_total_tokens = 0 + + for i, text in enumerate(texts): + + # Here token count is only an approximation based on the GPT2 tokenizer + num_tokens = self._get_num_tokens_by_gpt2(text) + + if num_tokens >= context_size: + cutoff = int(np.floor(len(text) * (context_size / num_tokens))) + # if num tokens is larger than context length, only use the start + inputs.append(text[0:cutoff]) + else: + inputs.append(text) + indices += [i] + + batched_embeddings = [] + _iter = range(0, len(inputs), max_chunks) + for i in _iter: + embeddings_batch, _used_tokens, _total_used_tokens = embedding.embed_documents( + model, + inputs[i: i + max_chunks], + user) + used_tokens += _used_tokens + used_total_tokens += _total_used_tokens + batched_embeddings += embeddings_batch + + usage = self._calc_response_usage(model, credentials, used_tokens, used_total_tokens) + return TextEmbeddingResult( + model=model, + embeddings=batched_embeddings, + usage=usage, + ) + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + if len(texts) == 0: + return 0 + total_num_tokens = 0 + for text in texts: + total_num_tokens += self._get_num_tokens_by_gpt2(text) + + return total_num_tokens + + def validate_credentials(self, model: str, credentials: Mapping) -> None: + api_key = credentials['api_key'] + secret_key = credentials['secret_key'] + try: + BaiduAccessToken.get_access_token(api_key, secret_key) + except Exception as e: + raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + return invoke_error_mapping() + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int, total_tokens: int) -> EmbeddingUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param tokens: input tokens + :return: usage + """ + # get input price info + input_price_info = self.get_price( + model=model, + credentials=credentials, + price_type=PriceType.INPUT, + tokens=tokens + ) + + # transform usage + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=total_tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at + ) + + return usage diff --git a/api/core/model_runtime/model_providers/wenxin/wenxin.yaml b/api/core/model_runtime/model_providers/wenxin/wenxin.yaml index b3a1f60824..6a6b38e6a1 100644 --- a/api/core/model_runtime/model_providers/wenxin/wenxin.yaml +++ b/api/core/model_runtime/model_providers/wenxin/wenxin.yaml @@ -17,6 +17,7 @@ help: en_US: https://cloud.baidu.com/wenxin.html supported_model_types: - llm + - text-embedding configurate_methods: - predefined-model provider_credential_schema: diff --git a/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py b/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py new file mode 100644 index 0000000000..0fbd0f55ec --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py @@ -0,0 +1,57 @@ +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) + + +def invoke_error_mapping() -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [ + ], + InvokeServerUnavailableError: [ + InternalServerError + ], + InvokeRateLimitError: [ + RateLimitReachedError + ], + InvokeAuthorizationError: [ + InvalidAuthenticationError, + InsufficientAccountBalance, + InvalidAPIKeyError, + ], + InvokeBadRequestError: [ + BadRequestError, + KeyError + ] + } + + +class InvalidAuthenticationError(Exception): + pass + +class InvalidAPIKeyError(Exception): + pass + +class RateLimitReachedError(Exception): + pass + +class InsufficientAccountBalance(Exception): + pass + +class InternalServerError(Exception): + pass + +class BadRequestError(Exception): + pass \ No newline at end of file diff --git a/api/tests/integration_tests/model_runtime/wenxin/test_embedding.py b/api/tests/integration_tests/model_runtime/wenxin/test_embedding.py new file mode 100644 index 0000000000..60e8036223 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/wenxin/test_embedding.py @@ -0,0 +1,24 @@ +import os +from time import sleep + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.model_providers.wenxin.text_embedding.text_embedding import WenxinTextEmbeddingModel + + +def test_invoke_embedding_model(): + sleep(3) + model = WenxinTextEmbeddingModel() + + response = model.invoke( + model='embedding-v1', + credentials={ + 'api_key': os.environ.get('WENXIN_API_KEY'), + 'secret_key': os.environ.get('WENXIN_SECRET_KEY') + }, + texts=['hello', '你好', 'xxxxx'], + user="abc-123" + ) + + assert isinstance(response, TextEmbeddingResult) + assert len(response.embeddings) == 3 + assert isinstance(response.embeddings[0], list) \ No newline at end of file diff --git a/api/tests/unit_tests/core/model_runtime/__init__.py b/api/tests/unit_tests/core/model_runtime/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/model_runtime/model_providers/__init__.py b/api/tests/unit_tests/core/model_runtime/model_providers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/model_runtime/model_providers/wenxin/__init__.py b/api/tests/unit_tests/core/model_runtime/model_providers/wenxin/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/model_runtime/model_providers/wenxin/test_text_embedding.py b/api/tests/unit_tests/core/model_runtime/model_providers/wenxin/test_text_embedding.py new file mode 100644 index 0000000000..68334fde82 --- /dev/null +++ b/api/tests/unit_tests/core/model_runtime/model_providers/wenxin/test_text_embedding.py @@ -0,0 +1,75 @@ +import numpy as np + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer +from core.model_runtime.model_providers.wenxin.text_embedding.text_embedding import ( + TextEmbedding, + WenxinTextEmbeddingModel, +) + + +def test_max_chunks(): + class _MockTextEmbedding(TextEmbedding): + def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): + embeddings = [[1.0, 2.0, 3.0] for i in range(len(texts))] + tokens = 0 + for text in texts: + tokens += len(text) + + return embeddings, tokens, tokens + + def _create_text_embedding(api_key: str, secret_key: str) -> TextEmbedding: + return _MockTextEmbedding() + + model = 'embedding-v1' + credentials = { + 'api_key': 'xxxx', + 'secret_key': 'yyyy', + } + embedding_model = WenxinTextEmbeddingModel() + context_size = embedding_model._get_context_size(model, credentials) + max_chunks = embedding_model._get_max_chunks(model, credentials) + embedding_model._create_text_embedding = _create_text_embedding + + texts = ['0123456789' for i in range(0, max_chunks * 2)] + result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, 'test') + assert len(result.embeddings) == max_chunks * 2 + + +def test_context_size(): + def get_num_tokens_by_gpt2(text: str) -> int: + return GPT2Tokenizer.get_num_tokens(text) + + def mock_text(token_size: int) -> str: + _text = "".join(['0' for i in range(token_size)]) + num_tokens = get_num_tokens_by_gpt2(_text) + ratio = int(np.floor(len(_text) / num_tokens)) + m_text = "".join([_text for i in range(ratio)]) + return m_text + + model = 'embedding-v1' + credentials = { + 'api_key': 'xxxx', + 'secret_key': 'yyyy', + } + embedding_model = WenxinTextEmbeddingModel() + context_size = embedding_model._get_context_size(model, credentials) + + class _MockTextEmbedding(TextEmbedding): + def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): + embeddings = [[1.0, 2.0, 3.0] for i in range(len(texts))] + tokens = 0 + for text in texts: + tokens += get_num_tokens_by_gpt2(text) + return embeddings, tokens, tokens + + def _create_text_embedding(api_key: str, secret_key: str) -> TextEmbedding: + return _MockTextEmbedding() + + embedding_model._create_text_embedding = _create_text_embedding + text = mock_text(context_size * 2) + assert get_num_tokens_by_gpt2(text) == context_size * 2 + + texts = [text] + result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, 'test') + assert result.usage.tokens == context_size