mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 04:29:07 +08:00
feat(api): support wenxin text embedding (#7377)
This commit is contained in:
parent
a0a67873aa
commit
bfd905602f
195
api/core/model_runtime/model_providers/wenxin/_common.py
Normal file
195
api/core/model_runtime/model_providers/wenxin/_common.py
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
from datetime import datetime, timedelta
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
|
from requests import post
|
||||||
|
|
||||||
|
from core.model_runtime.model_providers.wenxin.wenxin_errors import (
|
||||||
|
BadRequestError,
|
||||||
|
InternalServerError,
|
||||||
|
InvalidAPIKeyError,
|
||||||
|
InvalidAuthenticationError,
|
||||||
|
RateLimitReachedError,
|
||||||
|
)
|
||||||
|
|
||||||
|
baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {}
|
||||||
|
baidu_access_tokens_lock = Lock()
|
||||||
|
|
||||||
|
|
||||||
|
class BaiduAccessToken:
|
||||||
|
api_key: str
|
||||||
|
access_token: str
|
||||||
|
expires: datetime
|
||||||
|
|
||||||
|
def __init__(self, api_key: str) -> None:
|
||||||
|
self.api_key = api_key
|
||||||
|
self.access_token = ''
|
||||||
|
self.expires = datetime.now() + timedelta(days=3)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_access_token(api_key: str, secret_key: str) -> str:
|
||||||
|
"""
|
||||||
|
request access token from Baidu
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = post(
|
||||||
|
url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}',
|
||||||
|
headers={
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'Accept': 'application/json'
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}')
|
||||||
|
|
||||||
|
resp = response.json()
|
||||||
|
if 'error' in resp:
|
||||||
|
if resp['error'] == 'invalid_client':
|
||||||
|
raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}')
|
||||||
|
elif resp['error'] == 'unknown_error':
|
||||||
|
raise InternalServerError(f'Internal server error: {resp["error_description"]}')
|
||||||
|
elif resp['error'] == 'invalid_request':
|
||||||
|
raise BadRequestError(f'Bad request: {resp["error_description"]}')
|
||||||
|
elif resp['error'] == 'rate_limit_exceeded':
|
||||||
|
raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}')
|
||||||
|
else:
|
||||||
|
raise Exception(f'Unknown error: {resp["error_description"]}')
|
||||||
|
|
||||||
|
return resp['access_token']
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken':
|
||||||
|
"""
|
||||||
|
LLM from Baidu requires access token to invoke the API.
|
||||||
|
however, we have api_key and secret_key, and access token is valid for 30 days.
|
||||||
|
so we can cache the access token for 3 days. (avoid memory leak)
|
||||||
|
|
||||||
|
it may be more efficient to use a ticker to refresh access token, but it will cause
|
||||||
|
more complexity, so we just refresh access tokens when get_access_token is called.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# loop up cache, remove expired access token
|
||||||
|
baidu_access_tokens_lock.acquire()
|
||||||
|
now = datetime.now()
|
||||||
|
for key in list(baidu_access_tokens.keys()):
|
||||||
|
token = baidu_access_tokens[key]
|
||||||
|
if token.expires < now:
|
||||||
|
baidu_access_tokens.pop(key)
|
||||||
|
|
||||||
|
if api_key not in baidu_access_tokens:
|
||||||
|
# if access token not in cache, request it
|
||||||
|
token = BaiduAccessToken(api_key)
|
||||||
|
baidu_access_tokens[api_key] = token
|
||||||
|
# release it to enhance performance
|
||||||
|
# btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock
|
||||||
|
baidu_access_tokens_lock.release()
|
||||||
|
# try to get access token
|
||||||
|
token_str = BaiduAccessToken._get_access_token(api_key, secret_key)
|
||||||
|
token.access_token = token_str
|
||||||
|
token.expires = now + timedelta(days=3)
|
||||||
|
return token
|
||||||
|
else:
|
||||||
|
# if access token in cache, return it
|
||||||
|
token = baidu_access_tokens[api_key]
|
||||||
|
baidu_access_tokens_lock.release()
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
class _CommonWenxin:
|
||||||
|
api_bases = {
|
||||||
|
'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
|
||||||
|
'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
|
||||||
|
'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
|
||||||
|
'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
|
||||||
|
'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
|
||||||
|
'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205',
|
||||||
|
'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222',
|
||||||
|
'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
|
||||||
|
'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k',
|
||||||
|
'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
|
||||||
|
'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
|
||||||
|
'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed',
|
||||||
|
'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k',
|
||||||
|
'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
|
||||||
|
'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
|
||||||
|
'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
|
||||||
|
'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
|
||||||
|
'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
|
||||||
|
'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k',
|
||||||
|
'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
|
||||||
|
'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat',
|
||||||
|
'embedding-v1': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1',
|
||||||
|
}
|
||||||
|
|
||||||
|
function_calling_supports = [
|
||||||
|
'ernie-bot',
|
||||||
|
'ernie-bot-8k',
|
||||||
|
'ernie-3.5-8k',
|
||||||
|
'ernie-3.5-8k-0205',
|
||||||
|
'ernie-3.5-8k-1222',
|
||||||
|
'ernie-3.5-4k-0205',
|
||||||
|
'ernie-3.5-128k',
|
||||||
|
'ernie-4.0-8k',
|
||||||
|
'ernie-4.0-turbo-8k',
|
||||||
|
'ernie-4.0-turbo-8k-preview',
|
||||||
|
'yi_34b_chat'
|
||||||
|
]
|
||||||
|
|
||||||
|
api_key: str = ''
|
||||||
|
secret_key: str = ''
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, secret_key: str):
|
||||||
|
self.api_key = api_key
|
||||||
|
self.secret_key = secret_key
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _to_credential_kwargs(credentials: dict) -> dict:
|
||||||
|
credentials_kwargs = {
|
||||||
|
"api_key": credentials['api_key'],
|
||||||
|
"secret_key": credentials['secret_key']
|
||||||
|
}
|
||||||
|
return credentials_kwargs
|
||||||
|
|
||||||
|
def _handle_error(self, code: int, msg: str):
|
||||||
|
error_map = {
|
||||||
|
1: InternalServerError,
|
||||||
|
2: InternalServerError,
|
||||||
|
3: BadRequestError,
|
||||||
|
4: RateLimitReachedError,
|
||||||
|
6: InvalidAuthenticationError,
|
||||||
|
13: InvalidAPIKeyError,
|
||||||
|
14: InvalidAPIKeyError,
|
||||||
|
15: InvalidAPIKeyError,
|
||||||
|
17: RateLimitReachedError,
|
||||||
|
18: RateLimitReachedError,
|
||||||
|
19: RateLimitReachedError,
|
||||||
|
100: InvalidAPIKeyError,
|
||||||
|
111: InvalidAPIKeyError,
|
||||||
|
200: InternalServerError,
|
||||||
|
336000: InternalServerError,
|
||||||
|
336001: BadRequestError,
|
||||||
|
336002: BadRequestError,
|
||||||
|
336003: BadRequestError,
|
||||||
|
336004: InvalidAuthenticationError,
|
||||||
|
336005: InvalidAPIKeyError,
|
||||||
|
336006: BadRequestError,
|
||||||
|
336007: BadRequestError,
|
||||||
|
336008: BadRequestError,
|
||||||
|
336100: InternalServerError,
|
||||||
|
336101: BadRequestError,
|
||||||
|
336102: BadRequestError,
|
||||||
|
336103: BadRequestError,
|
||||||
|
336104: BadRequestError,
|
||||||
|
336105: BadRequestError,
|
||||||
|
336200: InternalServerError,
|
||||||
|
336303: BadRequestError,
|
||||||
|
337006: BadRequestError
|
||||||
|
}
|
||||||
|
|
||||||
|
if code in error_map:
|
||||||
|
raise error_map[code](msg)
|
||||||
|
else:
|
||||||
|
raise InternalServerError(f'Unknown error: {msg}')
|
||||||
|
|
||||||
|
def _get_access_token(self) -> str:
|
||||||
|
token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key)
|
||||||
|
return token.access_token
|
@ -1,102 +1,17 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from json import dumps, loads
|
from json import dumps, loads
|
||||||
from threading import Lock
|
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
from requests import Response, post
|
from requests import Response, post
|
||||||
|
|
||||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||||
from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import (
|
from core.model_runtime.model_providers.wenxin._common import _CommonWenxin
|
||||||
|
from core.model_runtime.model_providers.wenxin.wenxin_errors import (
|
||||||
BadRequestError,
|
BadRequestError,
|
||||||
InternalServerError,
|
InternalServerError,
|
||||||
InvalidAPIKeyError,
|
|
||||||
InvalidAuthenticationError,
|
|
||||||
RateLimitReachedError,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# map api_key to access_token
|
|
||||||
baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {}
|
|
||||||
baidu_access_tokens_lock = Lock()
|
|
||||||
|
|
||||||
class BaiduAccessToken:
|
|
||||||
api_key: str
|
|
||||||
access_token: str
|
|
||||||
expires: datetime
|
|
||||||
|
|
||||||
def __init__(self, api_key: str) -> None:
|
|
||||||
self.api_key = api_key
|
|
||||||
self.access_token = ''
|
|
||||||
self.expires = datetime.now() + timedelta(days=3)
|
|
||||||
|
|
||||||
def _get_access_token(api_key: str, secret_key: str) -> str:
|
|
||||||
"""
|
|
||||||
request access token from Baidu
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
response = post(
|
|
||||||
url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}',
|
|
||||||
headers={
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'Accept': 'application/json'
|
|
||||||
},
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}')
|
|
||||||
|
|
||||||
resp = response.json()
|
|
||||||
if 'error' in resp:
|
|
||||||
if resp['error'] == 'invalid_client':
|
|
||||||
raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}')
|
|
||||||
elif resp['error'] == 'unknown_error':
|
|
||||||
raise InternalServerError(f'Internal server error: {resp["error_description"]}')
|
|
||||||
elif resp['error'] == 'invalid_request':
|
|
||||||
raise BadRequestError(f'Bad request: {resp["error_description"]}')
|
|
||||||
elif resp['error'] == 'rate_limit_exceeded':
|
|
||||||
raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}')
|
|
||||||
else:
|
|
||||||
raise Exception(f'Unknown error: {resp["error_description"]}')
|
|
||||||
|
|
||||||
return resp['access_token']
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken':
|
|
||||||
"""
|
|
||||||
LLM from Baidu requires access token to invoke the API.
|
|
||||||
however, we have api_key and secret_key, and access token is valid for 30 days.
|
|
||||||
so we can cache the access token for 3 days. (avoid memory leak)
|
|
||||||
|
|
||||||
it may be more efficient to use a ticker to refresh access token, but it will cause
|
|
||||||
more complexity, so we just refresh access tokens when get_access_token is called.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# loop up cache, remove expired access token
|
|
||||||
baidu_access_tokens_lock.acquire()
|
|
||||||
now = datetime.now()
|
|
||||||
for key in list(baidu_access_tokens.keys()):
|
|
||||||
token = baidu_access_tokens[key]
|
|
||||||
if token.expires < now:
|
|
||||||
baidu_access_tokens.pop(key)
|
|
||||||
|
|
||||||
if api_key not in baidu_access_tokens:
|
|
||||||
# if access token not in cache, request it
|
|
||||||
token = BaiduAccessToken(api_key)
|
|
||||||
baidu_access_tokens[api_key] = token
|
|
||||||
# release it to enhance performance
|
|
||||||
# btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock
|
|
||||||
baidu_access_tokens_lock.release()
|
|
||||||
# try to get access token
|
|
||||||
token_str = BaiduAccessToken._get_access_token(api_key, secret_key)
|
|
||||||
token.access_token = token_str
|
|
||||||
token.expires = now + timedelta(days=3)
|
|
||||||
return token
|
|
||||||
else:
|
|
||||||
# if access token in cache, return it
|
|
||||||
token = baidu_access_tokens[api_key]
|
|
||||||
baidu_access_tokens_lock.release()
|
|
||||||
return token
|
|
||||||
|
|
||||||
|
|
||||||
class ErnieMessage:
|
class ErnieMessage:
|
||||||
class Role(Enum):
|
class Role(Enum):
|
||||||
@ -120,51 +35,7 @@ class ErnieMessage:
|
|||||||
self.content = content
|
self.content = content
|
||||||
self.role = role
|
self.role = role
|
||||||
|
|
||||||
class ErnieBotModel:
|
class ErnieBotModel(_CommonWenxin):
|
||||||
api_bases = {
|
|
||||||
'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
|
|
||||||
'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
|
|
||||||
'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
|
|
||||||
'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
|
|
||||||
'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
|
|
||||||
'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205',
|
|
||||||
'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222',
|
|
||||||
'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
|
|
||||||
'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k',
|
|
||||||
'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
|
|
||||||
'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
|
|
||||||
'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed',
|
|
||||||
'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k',
|
|
||||||
'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
|
|
||||||
'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
|
|
||||||
'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
|
|
||||||
'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
|
|
||||||
'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
|
|
||||||
'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k',
|
|
||||||
'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
|
|
||||||
'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat',
|
|
||||||
}
|
|
||||||
|
|
||||||
function_calling_supports = [
|
|
||||||
'ernie-bot',
|
|
||||||
'ernie-bot-8k',
|
|
||||||
'ernie-3.5-8k',
|
|
||||||
'ernie-3.5-8k-0205',
|
|
||||||
'ernie-3.5-8k-1222',
|
|
||||||
'ernie-3.5-4k-0205',
|
|
||||||
'ernie-3.5-128k',
|
|
||||||
'ernie-4.0-8k',
|
|
||||||
'ernie-4.0-turbo-8k',
|
|
||||||
'ernie-4.0-turbo-8k-preview',
|
|
||||||
'yi_34b_chat'
|
|
||||||
]
|
|
||||||
|
|
||||||
api_key: str = ''
|
|
||||||
secret_key: str = ''
|
|
||||||
|
|
||||||
def __init__(self, api_key: str, secret_key: str):
|
|
||||||
self.api_key = api_key
|
|
||||||
self.secret_key = secret_key
|
|
||||||
|
|
||||||
def generate(self, model: str, stream: bool, messages: list[ErnieMessage],
|
def generate(self, model: str, stream: bool, messages: list[ErnieMessage],
|
||||||
parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \
|
parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \
|
||||||
@ -199,51 +70,6 @@ class ErnieBotModel:
|
|||||||
return self._handle_chat_stream_generate_response(resp)
|
return self._handle_chat_stream_generate_response(resp)
|
||||||
return self._handle_chat_generate_response(resp)
|
return self._handle_chat_generate_response(resp)
|
||||||
|
|
||||||
def _handle_error(self, code: int, msg: str):
|
|
||||||
error_map = {
|
|
||||||
1: InternalServerError,
|
|
||||||
2: InternalServerError,
|
|
||||||
3: BadRequestError,
|
|
||||||
4: RateLimitReachedError,
|
|
||||||
6: InvalidAuthenticationError,
|
|
||||||
13: InvalidAPIKeyError,
|
|
||||||
14: InvalidAPIKeyError,
|
|
||||||
15: InvalidAPIKeyError,
|
|
||||||
17: RateLimitReachedError,
|
|
||||||
18: RateLimitReachedError,
|
|
||||||
19: RateLimitReachedError,
|
|
||||||
100: InvalidAPIKeyError,
|
|
||||||
111: InvalidAPIKeyError,
|
|
||||||
200: InternalServerError,
|
|
||||||
336000: InternalServerError,
|
|
||||||
336001: BadRequestError,
|
|
||||||
336002: BadRequestError,
|
|
||||||
336003: BadRequestError,
|
|
||||||
336004: InvalidAuthenticationError,
|
|
||||||
336005: InvalidAPIKeyError,
|
|
||||||
336006: BadRequestError,
|
|
||||||
336007: BadRequestError,
|
|
||||||
336008: BadRequestError,
|
|
||||||
336100: InternalServerError,
|
|
||||||
336101: BadRequestError,
|
|
||||||
336102: BadRequestError,
|
|
||||||
336103: BadRequestError,
|
|
||||||
336104: BadRequestError,
|
|
||||||
336105: BadRequestError,
|
|
||||||
336200: InternalServerError,
|
|
||||||
336303: BadRequestError,
|
|
||||||
337006: BadRequestError
|
|
||||||
}
|
|
||||||
|
|
||||||
if code in error_map:
|
|
||||||
raise error_map[code](msg)
|
|
||||||
else:
|
|
||||||
raise InternalServerError(f'Unknown error: {msg}')
|
|
||||||
|
|
||||||
def _get_access_token(self) -> str:
|
|
||||||
token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key)
|
|
||||||
return token.access_token
|
|
||||||
|
|
||||||
def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]:
|
def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]:
|
||||||
return [ErnieMessage(message.content, message.role) for message in messages]
|
return [ErnieMessage(message.content, message.role) for message in messages]
|
||||||
|
|
||||||
|
@ -1,17 +0,0 @@
|
|||||||
class InvalidAuthenticationError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class InvalidAPIKeyError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class RateLimitReachedError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class InsufficientAccountBalance(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class InternalServerError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class BadRequestError(Exception):
|
|
||||||
pass
|
|
@ -11,24 +11,13 @@ from core.model_runtime.entities.message_entities import (
|
|||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
from core.model_runtime.errors.invoke import (
|
from core.model_runtime.errors.invoke import (
|
||||||
InvokeAuthorizationError,
|
|
||||||
InvokeBadRequestError,
|
|
||||||
InvokeConnectionError,
|
|
||||||
InvokeError,
|
InvokeError,
|
||||||
InvokeRateLimitError,
|
|
||||||
InvokeServerUnavailableError,
|
|
||||||
)
|
)
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.model_runtime.model_providers.wenxin.llm.ernie_bot import BaiduAccessToken, ErnieBotModel, ErnieMessage
|
from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken
|
||||||
from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import (
|
from core.model_runtime.model_providers.wenxin.llm.ernie_bot import ErnieBotModel, ErnieMessage
|
||||||
BadRequestError,
|
from core.model_runtime.model_providers.wenxin.wenxin_errors import invoke_error_mapping
|
||||||
InsufficientAccountBalance,
|
|
||||||
InternalServerError,
|
|
||||||
InvalidAPIKeyError,
|
|
||||||
InvalidAuthenticationError,
|
|
||||||
RateLimitReachedError,
|
|
||||||
)
|
|
||||||
|
|
||||||
ERNIE_BOT_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
ERNIE_BOT_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||||
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||||
@ -140,7 +129,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
|
|||||||
api_key = credentials['api_key']
|
api_key = credentials['api_key']
|
||||||
secret_key = credentials['secret_key']
|
secret_key = credentials['secret_key']
|
||||||
try:
|
try:
|
||||||
BaiduAccessToken._get_access_token(api_key, secret_key)
|
BaiduAccessToken.get_access_token(api_key, secret_key)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
|
raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
|
||||||
|
|
||||||
@ -254,22 +243,4 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
:return: Invoke error mapping
|
:return: Invoke error mapping
|
||||||
"""
|
"""
|
||||||
return {
|
return invoke_error_mapping()
|
||||||
InvokeConnectionError: [
|
|
||||||
],
|
|
||||||
InvokeServerUnavailableError: [
|
|
||||||
InternalServerError
|
|
||||||
],
|
|
||||||
InvokeRateLimitError: [
|
|
||||||
RateLimitReachedError
|
|
||||||
],
|
|
||||||
InvokeAuthorizationError: [
|
|
||||||
InvalidAuthenticationError,
|
|
||||||
InsufficientAccountBalance,
|
|
||||||
InvalidAPIKeyError,
|
|
||||||
],
|
|
||||||
InvokeBadRequestError: [
|
|
||||||
BadRequestError,
|
|
||||||
KeyError
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
@ -0,0 +1,9 @@
|
|||||||
|
model: embedding-v1
|
||||||
|
model_type: text-embedding
|
||||||
|
model_properties:
|
||||||
|
context_size: 384
|
||||||
|
max_chunks: 16
|
||||||
|
pricing:
|
||||||
|
input: '0.0005'
|
||||||
|
unit: '0.001'
|
||||||
|
currency: RMB
|
@ -0,0 +1,184 @@
|
|||||||
|
import time
|
||||||
|
from abc import abstractmethod
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from json import dumps
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from requests import Response, post
|
||||||
|
|
||||||
|
from core.model_runtime.entities.model_entities import PriceType
|
||||||
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||||
|
from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken, _CommonWenxin
|
||||||
|
from core.model_runtime.model_providers.wenxin.wenxin_errors import (
|
||||||
|
BadRequestError,
|
||||||
|
InternalServerError,
|
||||||
|
invoke_error_mapping,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TextEmbedding:
|
||||||
|
@abstractmethod
|
||||||
|
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class WenxinTextEmbedding(_CommonWenxin, TextEmbedding):
|
||||||
|
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
|
||||||
|
access_token = self._get_access_token()
|
||||||
|
url = f'{self.api_bases[model]}?access_token={access_token}'
|
||||||
|
body = self._build_embed_request_body(model, texts, user)
|
||||||
|
headers = {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = post(url, data=dumps(body), headers=headers)
|
||||||
|
if resp.status_code != 200:
|
||||||
|
raise InternalServerError(f'Failed to invoke ernie bot: {resp.text}')
|
||||||
|
return self._handle_embed_response(model, resp)
|
||||||
|
|
||||||
|
def _build_embed_request_body(self, model: str, texts: list[str], user: str) -> dict[str, Any]:
|
||||||
|
if len(texts) == 0:
|
||||||
|
raise BadRequestError('The number of texts should not be zero.')
|
||||||
|
body = {
|
||||||
|
'input': texts,
|
||||||
|
'user_id': user,
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
|
||||||
|
def _handle_embed_response(self, model: str, response: Response) -> (list[list[float]], int, int):
|
||||||
|
data = response.json()
|
||||||
|
if 'error_code' in data:
|
||||||
|
code = data['error_code']
|
||||||
|
msg = data['error_msg']
|
||||||
|
# raise error
|
||||||
|
self._handle_error(code, msg)
|
||||||
|
|
||||||
|
embeddings = [v['embedding'] for v in data['data']]
|
||||||
|
_usage = data['usage']
|
||||||
|
tokens = _usage['prompt_tokens']
|
||||||
|
total_tokens = _usage['total_tokens']
|
||||||
|
|
||||||
|
return embeddings, tokens, total_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class WenxinTextEmbeddingModel(TextEmbeddingModel):
|
||||||
|
def _create_text_embedding(self, api_key: str, secret_key: str) -> TextEmbedding:
|
||||||
|
return WenxinTextEmbedding(api_key, secret_key)
|
||||||
|
|
||||||
|
def _invoke(self, model: str, credentials: dict, texts: list[str],
|
||||||
|
user: Optional[str] = None) -> TextEmbeddingResult:
|
||||||
|
"""
|
||||||
|
Invoke text embedding model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param texts: texts to embed
|
||||||
|
:param user: unique user id
|
||||||
|
:return: embeddings result
|
||||||
|
"""
|
||||||
|
|
||||||
|
api_key = credentials['api_key']
|
||||||
|
secret_key = credentials['secret_key']
|
||||||
|
embedding: TextEmbedding = self._create_text_embedding(api_key, secret_key)
|
||||||
|
user = user if user else 'ErnieBotDefault'
|
||||||
|
|
||||||
|
context_size = self._get_context_size(model, credentials)
|
||||||
|
max_chunks = self._get_max_chunks(model, credentials)
|
||||||
|
inputs = []
|
||||||
|
indices = []
|
||||||
|
used_tokens = 0
|
||||||
|
used_total_tokens = 0
|
||||||
|
|
||||||
|
for i, text in enumerate(texts):
|
||||||
|
|
||||||
|
# Here token count is only an approximation based on the GPT2 tokenizer
|
||||||
|
num_tokens = self._get_num_tokens_by_gpt2(text)
|
||||||
|
|
||||||
|
if num_tokens >= context_size:
|
||||||
|
cutoff = int(np.floor(len(text) * (context_size / num_tokens)))
|
||||||
|
# if num tokens is larger than context length, only use the start
|
||||||
|
inputs.append(text[0:cutoff])
|
||||||
|
else:
|
||||||
|
inputs.append(text)
|
||||||
|
indices += [i]
|
||||||
|
|
||||||
|
batched_embeddings = []
|
||||||
|
_iter = range(0, len(inputs), max_chunks)
|
||||||
|
for i in _iter:
|
||||||
|
embeddings_batch, _used_tokens, _total_used_tokens = embedding.embed_documents(
|
||||||
|
model,
|
||||||
|
inputs[i: i + max_chunks],
|
||||||
|
user)
|
||||||
|
used_tokens += _used_tokens
|
||||||
|
used_total_tokens += _total_used_tokens
|
||||||
|
batched_embeddings += embeddings_batch
|
||||||
|
|
||||||
|
usage = self._calc_response_usage(model, credentials, used_tokens, used_total_tokens)
|
||||||
|
return TextEmbeddingResult(
|
||||||
|
model=model,
|
||||||
|
embeddings=batched_embeddings,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||||
|
"""
|
||||||
|
Get number of tokens for given prompt messages
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param texts: texts to embed
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if len(texts) == 0:
|
||||||
|
return 0
|
||||||
|
total_num_tokens = 0
|
||||||
|
for text in texts:
|
||||||
|
total_num_tokens += self._get_num_tokens_by_gpt2(text)
|
||||||
|
|
||||||
|
return total_num_tokens
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: Mapping) -> None:
|
||||||
|
api_key = credentials['api_key']
|
||||||
|
secret_key = credentials['secret_key']
|
||||||
|
try:
|
||||||
|
BaiduAccessToken.get_access_token(api_key, secret_key)
|
||||||
|
except Exception as e:
|
||||||
|
raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
|
return invoke_error_mapping()
|
||||||
|
|
||||||
|
def _calc_response_usage(self, model: str, credentials: dict, tokens: int, total_tokens: int) -> EmbeddingUsage:
|
||||||
|
"""
|
||||||
|
Calculate response usage
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param tokens: input tokens
|
||||||
|
:return: usage
|
||||||
|
"""
|
||||||
|
# get input price info
|
||||||
|
input_price_info = self.get_price(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
price_type=PriceType.INPUT,
|
||||||
|
tokens=tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# transform usage
|
||||||
|
usage = EmbeddingUsage(
|
||||||
|
tokens=tokens,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
unit_price=input_price_info.unit_price,
|
||||||
|
price_unit=input_price_info.unit,
|
||||||
|
total_price=input_price_info.total_amount,
|
||||||
|
currency=input_price_info.currency,
|
||||||
|
latency=time.perf_counter() - self.started_at
|
||||||
|
)
|
||||||
|
|
||||||
|
return usage
|
@ -17,6 +17,7 @@ help:
|
|||||||
en_US: https://cloud.baidu.com/wenxin.html
|
en_US: https://cloud.baidu.com/wenxin.html
|
||||||
supported_model_types:
|
supported_model_types:
|
||||||
- llm
|
- llm
|
||||||
|
- text-embedding
|
||||||
configurate_methods:
|
configurate_methods:
|
||||||
- predefined-model
|
- predefined-model
|
||||||
provider_credential_schema:
|
provider_credential_schema:
|
||||||
|
@ -0,0 +1,57 @@
|
|||||||
|
from core.model_runtime.errors.invoke import (
|
||||||
|
InvokeAuthorizationError,
|
||||||
|
InvokeBadRequestError,
|
||||||
|
InvokeConnectionError,
|
||||||
|
InvokeError,
|
||||||
|
InvokeRateLimitError,
|
||||||
|
InvokeServerUnavailableError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def invoke_error_mapping() -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
|
"""
|
||||||
|
Map model invoke error to unified error
|
||||||
|
The key is the error type thrown to the caller
|
||||||
|
The value is the error type thrown by the model,
|
||||||
|
which needs to be converted into a unified error type for the caller.
|
||||||
|
|
||||||
|
:return: Invoke error mapping
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
InvokeConnectionError: [
|
||||||
|
],
|
||||||
|
InvokeServerUnavailableError: [
|
||||||
|
InternalServerError
|
||||||
|
],
|
||||||
|
InvokeRateLimitError: [
|
||||||
|
RateLimitReachedError
|
||||||
|
],
|
||||||
|
InvokeAuthorizationError: [
|
||||||
|
InvalidAuthenticationError,
|
||||||
|
InsufficientAccountBalance,
|
||||||
|
InvalidAPIKeyError,
|
||||||
|
],
|
||||||
|
InvokeBadRequestError: [
|
||||||
|
BadRequestError,
|
||||||
|
KeyError
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidAuthenticationError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class InvalidAPIKeyError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class RateLimitReachedError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class InsufficientAccountBalance(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class InternalServerError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class BadRequestError(Exception):
|
||||||
|
pass
|
@ -0,0 +1,24 @@
|
|||||||
|
import os
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
|
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||||
|
from core.model_runtime.model_providers.wenxin.text_embedding.text_embedding import WenxinTextEmbeddingModel
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_embedding_model():
|
||||||
|
sleep(3)
|
||||||
|
model = WenxinTextEmbeddingModel()
|
||||||
|
|
||||||
|
response = model.invoke(
|
||||||
|
model='embedding-v1',
|
||||||
|
credentials={
|
||||||
|
'api_key': os.environ.get('WENXIN_API_KEY'),
|
||||||
|
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
|
||||||
|
},
|
||||||
|
texts=['hello', '你好', 'xxxxx'],
|
||||||
|
user="abc-123"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, TextEmbeddingResult)
|
||||||
|
assert len(response.embeddings) == 3
|
||||||
|
assert isinstance(response.embeddings[0], list)
|
0
api/tests/unit_tests/core/model_runtime/__init__.py
Normal file
0
api/tests/unit_tests/core/model_runtime/__init__.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||||
|
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
||||||
|
from core.model_runtime.model_providers.wenxin.text_embedding.text_embedding import (
|
||||||
|
TextEmbedding,
|
||||||
|
WenxinTextEmbeddingModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_max_chunks():
|
||||||
|
class _MockTextEmbedding(TextEmbedding):
|
||||||
|
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
|
||||||
|
embeddings = [[1.0, 2.0, 3.0] for i in range(len(texts))]
|
||||||
|
tokens = 0
|
||||||
|
for text in texts:
|
||||||
|
tokens += len(text)
|
||||||
|
|
||||||
|
return embeddings, tokens, tokens
|
||||||
|
|
||||||
|
def _create_text_embedding(api_key: str, secret_key: str) -> TextEmbedding:
|
||||||
|
return _MockTextEmbedding()
|
||||||
|
|
||||||
|
model = 'embedding-v1'
|
||||||
|
credentials = {
|
||||||
|
'api_key': 'xxxx',
|
||||||
|
'secret_key': 'yyyy',
|
||||||
|
}
|
||||||
|
embedding_model = WenxinTextEmbeddingModel()
|
||||||
|
context_size = embedding_model._get_context_size(model, credentials)
|
||||||
|
max_chunks = embedding_model._get_max_chunks(model, credentials)
|
||||||
|
embedding_model._create_text_embedding = _create_text_embedding
|
||||||
|
|
||||||
|
texts = ['0123456789' for i in range(0, max_chunks * 2)]
|
||||||
|
result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, 'test')
|
||||||
|
assert len(result.embeddings) == max_chunks * 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_context_size():
|
||||||
|
def get_num_tokens_by_gpt2(text: str) -> int:
|
||||||
|
return GPT2Tokenizer.get_num_tokens(text)
|
||||||
|
|
||||||
|
def mock_text(token_size: int) -> str:
|
||||||
|
_text = "".join(['0' for i in range(token_size)])
|
||||||
|
num_tokens = get_num_tokens_by_gpt2(_text)
|
||||||
|
ratio = int(np.floor(len(_text) / num_tokens))
|
||||||
|
m_text = "".join([_text for i in range(ratio)])
|
||||||
|
return m_text
|
||||||
|
|
||||||
|
model = 'embedding-v1'
|
||||||
|
credentials = {
|
||||||
|
'api_key': 'xxxx',
|
||||||
|
'secret_key': 'yyyy',
|
||||||
|
}
|
||||||
|
embedding_model = WenxinTextEmbeddingModel()
|
||||||
|
context_size = embedding_model._get_context_size(model, credentials)
|
||||||
|
|
||||||
|
class _MockTextEmbedding(TextEmbedding):
|
||||||
|
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
|
||||||
|
embeddings = [[1.0, 2.0, 3.0] for i in range(len(texts))]
|
||||||
|
tokens = 0
|
||||||
|
for text in texts:
|
||||||
|
tokens += get_num_tokens_by_gpt2(text)
|
||||||
|
return embeddings, tokens, tokens
|
||||||
|
|
||||||
|
def _create_text_embedding(api_key: str, secret_key: str) -> TextEmbedding:
|
||||||
|
return _MockTextEmbedding()
|
||||||
|
|
||||||
|
embedding_model._create_text_embedding = _create_text_embedding
|
||||||
|
text = mock_text(context_size * 2)
|
||||||
|
assert get_num_tokens_by_gpt2(text) == context_size * 2
|
||||||
|
|
||||||
|
texts = [text]
|
||||||
|
result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, 'test')
|
||||||
|
assert result.usage.tokens == context_size
|
Loading…
x
Reference in New Issue
Block a user