mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 07:29:02 +08:00
chore: refactor the beichuan model (#7953)
This commit is contained in:
parent
14af87527f
commit
0f72a8e89d
@ -27,11 +27,3 @@ provider_credential_schema:
|
|||||||
placeholder:
|
placeholder:
|
||||||
zh_Hans: 在此输入您的 API Key
|
zh_Hans: 在此输入您的 API Key
|
||||||
en_US: Enter your API Key
|
en_US: Enter your API Key
|
||||||
- variable: secret_key
|
|
||||||
label:
|
|
||||||
en_US: Secret Key
|
|
||||||
type: secret-input
|
|
||||||
required: false
|
|
||||||
placeholder:
|
|
||||||
zh_Hans: 在此输入您的 Secret Key
|
|
||||||
en_US: Enter your Secret Key
|
|
||||||
|
@ -43,3 +43,4 @@ parameter_rules:
|
|||||||
zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
|
zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
|
||||||
en_US: Allow the model to perform external search to enhance the generation results.
|
en_US: Allow the model to perform external search to enhance the generation results.
|
||||||
required: false
|
required: false
|
||||||
|
deprecated: true
|
||||||
|
@ -43,3 +43,4 @@ parameter_rules:
|
|||||||
zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
|
zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
|
||||||
en_US: Allow the model to perform external search to enhance the generation results.
|
en_US: Allow the model to perform external search to enhance the generation results.
|
||||||
required: false
|
required: false
|
||||||
|
deprecated: true
|
||||||
|
@ -4,36 +4,32 @@ label:
|
|||||||
model_type: llm
|
model_type: llm
|
||||||
features:
|
features:
|
||||||
- agent-thought
|
- agent-thought
|
||||||
|
- multi-tool-call
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 32000
|
context_size: 32000
|
||||||
parameter_rules:
|
parameter_rules:
|
||||||
- name: temperature
|
- name: temperature
|
||||||
use_template: temperature
|
use_template: temperature
|
||||||
|
default: 0.3
|
||||||
- name: top_p
|
- name: top_p
|
||||||
use_template: top_p
|
use_template: top_p
|
||||||
|
default: 0.85
|
||||||
- name: top_k
|
- name: top_k
|
||||||
label:
|
label:
|
||||||
zh_Hans: 取样数量
|
zh_Hans: 取样数量
|
||||||
en_US: Top k
|
en_US: Top k
|
||||||
type: int
|
type: int
|
||||||
|
min: 0
|
||||||
|
max: 20
|
||||||
|
default: 5
|
||||||
help:
|
help:
|
||||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
en_US: Only sample from the top K options for each subsequent token.
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
required: false
|
required: false
|
||||||
- name: max_tokens
|
- name: max_tokens
|
||||||
use_template: max_tokens
|
use_template: max_tokens
|
||||||
required: true
|
default: 2048
|
||||||
default: 8000
|
|
||||||
min: 1
|
|
||||||
max: 192000
|
|
||||||
- name: presence_penalty
|
|
||||||
use_template: presence_penalty
|
|
||||||
- name: frequency_penalty
|
|
||||||
use_template: frequency_penalty
|
|
||||||
default: 1
|
|
||||||
min: 1
|
|
||||||
max: 2
|
|
||||||
- name: with_search_enhance
|
- name: with_search_enhance
|
||||||
label:
|
label:
|
||||||
zh_Hans: 搜索增强
|
zh_Hans: 搜索增强
|
||||||
|
@ -4,36 +4,44 @@ label:
|
|||||||
model_type: llm
|
model_type: llm
|
||||||
features:
|
features:
|
||||||
- agent-thought
|
- agent-thought
|
||||||
|
- multi-tool-call
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 128000
|
context_size: 128000
|
||||||
parameter_rules:
|
parameter_rules:
|
||||||
- name: temperature
|
- name: temperature
|
||||||
use_template: temperature
|
use_template: temperature
|
||||||
|
default: 0.3
|
||||||
- name: top_p
|
- name: top_p
|
||||||
use_template: top_p
|
use_template: top_p
|
||||||
|
default: 0.85
|
||||||
- name: top_k
|
- name: top_k
|
||||||
label:
|
label:
|
||||||
zh_Hans: 取样数量
|
zh_Hans: 取样数量
|
||||||
en_US: Top k
|
en_US: Top k
|
||||||
type: int
|
type: int
|
||||||
|
min: 0
|
||||||
|
max: 20
|
||||||
|
default: 5
|
||||||
help:
|
help:
|
||||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
en_US: Only sample from the top K options for each subsequent token.
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
required: false
|
required: false
|
||||||
- name: max_tokens
|
- name: max_tokens
|
||||||
use_template: max_tokens
|
use_template: max_tokens
|
||||||
required: true
|
default: 2048
|
||||||
default: 8000
|
- name: res_format
|
||||||
min: 1
|
label:
|
||||||
max: 128000
|
zh_Hans: 回复格式
|
||||||
- name: presence_penalty
|
en_US: response format
|
||||||
use_template: presence_penalty
|
type: string
|
||||||
- name: frequency_penalty
|
help:
|
||||||
use_template: frequency_penalty
|
zh_Hans: 指定模型必须输出的格式
|
||||||
default: 1
|
en_US: specifying the format that the model must output
|
||||||
min: 1
|
required: false
|
||||||
max: 2
|
options:
|
||||||
|
- text
|
||||||
|
- json_object
|
||||||
- name: with_search_enhance
|
- name: with_search_enhance
|
||||||
label:
|
label:
|
||||||
zh_Hans: 搜索增强
|
zh_Hans: 搜索增强
|
||||||
|
@ -4,36 +4,44 @@ label:
|
|||||||
model_type: llm
|
model_type: llm
|
||||||
features:
|
features:
|
||||||
- agent-thought
|
- agent-thought
|
||||||
|
- multi-tool-call
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 32000
|
context_size: 32000
|
||||||
parameter_rules:
|
parameter_rules:
|
||||||
- name: temperature
|
- name: temperature
|
||||||
use_template: temperature
|
use_template: temperature
|
||||||
|
default: 0.3
|
||||||
- name: top_p
|
- name: top_p
|
||||||
use_template: top_p
|
use_template: top_p
|
||||||
|
default: 0.85
|
||||||
- name: top_k
|
- name: top_k
|
||||||
label:
|
label:
|
||||||
zh_Hans: 取样数量
|
zh_Hans: 取样数量
|
||||||
en_US: Top k
|
en_US: Top k
|
||||||
type: int
|
type: int
|
||||||
|
min: 0
|
||||||
|
max: 20
|
||||||
|
default: 5
|
||||||
help:
|
help:
|
||||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
en_US: Only sample from the top K options for each subsequent token.
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
required: false
|
required: false
|
||||||
- name: max_tokens
|
- name: max_tokens
|
||||||
use_template: max_tokens
|
use_template: max_tokens
|
||||||
required: true
|
default: 2048
|
||||||
default: 8000
|
- name: res_format
|
||||||
min: 1
|
label:
|
||||||
max: 32000
|
zh_Hans: 回复格式
|
||||||
- name: presence_penalty
|
en_US: response format
|
||||||
use_template: presence_penalty
|
type: string
|
||||||
- name: frequency_penalty
|
help:
|
||||||
use_template: frequency_penalty
|
zh_Hans: 指定模型必须输出的格式
|
||||||
default: 1
|
en_US: specifying the format that the model must output
|
||||||
min: 1
|
required: false
|
||||||
max: 2
|
options:
|
||||||
|
- text
|
||||||
|
- json_object
|
||||||
- name: with_search_enhance
|
- name: with_search_enhance
|
||||||
label:
|
label:
|
||||||
zh_Hans: 搜索增强
|
zh_Hans: 搜索增强
|
||||||
|
@ -4,36 +4,44 @@ label:
|
|||||||
model_type: llm
|
model_type: llm
|
||||||
features:
|
features:
|
||||||
- agent-thought
|
- agent-thought
|
||||||
|
- multi-tool-call
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 32000
|
context_size: 32000
|
||||||
parameter_rules:
|
parameter_rules:
|
||||||
- name: temperature
|
- name: temperature
|
||||||
use_template: temperature
|
use_template: temperature
|
||||||
|
default: 0.3
|
||||||
- name: top_p
|
- name: top_p
|
||||||
use_template: top_p
|
use_template: top_p
|
||||||
|
default: 0.85
|
||||||
- name: top_k
|
- name: top_k
|
||||||
label:
|
label:
|
||||||
zh_Hans: 取样数量
|
zh_Hans: 取样数量
|
||||||
en_US: Top k
|
en_US: Top k
|
||||||
type: int
|
type: int
|
||||||
|
min: 0
|
||||||
|
max: 20
|
||||||
|
default: 5
|
||||||
help:
|
help:
|
||||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
en_US: Only sample from the top K options for each subsequent token.
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
required: false
|
required: false
|
||||||
- name: max_tokens
|
- name: max_tokens
|
||||||
use_template: max_tokens
|
use_template: max_tokens
|
||||||
required: true
|
default: 2048
|
||||||
default: 8000
|
- name: res_format
|
||||||
min: 1
|
label:
|
||||||
max: 32000
|
zh_Hans: 回复格式
|
||||||
- name: presence_penalty
|
en_US: response format
|
||||||
use_template: presence_penalty
|
type: string
|
||||||
- name: frequency_penalty
|
help:
|
||||||
use_template: frequency_penalty
|
zh_Hans: 指定模型必须输出的格式
|
||||||
default: 1
|
en_US: specifying the format that the model must output
|
||||||
min: 1
|
required: false
|
||||||
max: 2
|
options:
|
||||||
|
- text
|
||||||
|
- json_object
|
||||||
- name: with_search_enhance
|
- name: with_search_enhance
|
||||||
label:
|
label:
|
||||||
zh_Hans: 搜索增强
|
zh_Hans: 搜索增强
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
from collections.abc import Generator
|
import json
|
||||||
from enum import Enum
|
from collections.abc import Iterator
|
||||||
from hashlib import md5
|
from typing import Any, Optional, Union
|
||||||
from json import dumps, loads
|
|
||||||
from typing import Any, Union
|
|
||||||
|
|
||||||
from requests import post
|
from requests import post
|
||||||
|
|
||||||
|
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
|
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
|
||||||
BadRequestError,
|
BadRequestError,
|
||||||
InsufficientAccountBalance,
|
InsufficientAccountBalance,
|
||||||
@ -16,203 +15,133 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors impor
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaichuanMessage:
|
|
||||||
class Role(Enum):
|
|
||||||
USER = 'user'
|
|
||||||
ASSISTANT = 'assistant'
|
|
||||||
# Baichuan does not have system message
|
|
||||||
_SYSTEM = 'system'
|
|
||||||
|
|
||||||
role: str = Role.USER.value
|
|
||||||
content: str
|
|
||||||
usage: dict[str, int] = None
|
|
||||||
stop_reason: str = ''
|
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
'role': self.role,
|
|
||||||
'content': self.content,
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, content: str, role: str = 'user') -> None:
|
|
||||||
self.content = content
|
|
||||||
self.role = role
|
|
||||||
|
|
||||||
class BaichuanModel:
|
class BaichuanModel:
|
||||||
api_key: str
|
api_key: str
|
||||||
secret_key: str
|
|
||||||
|
|
||||||
def __init__(self, api_key: str, secret_key: str = '') -> None:
|
def __init__(self, api_key: str) -> None:
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.secret_key = secret_key
|
|
||||||
|
|
||||||
def _model_mapping(self, model: str) -> str:
|
@property
|
||||||
|
def _model_mapping(self) -> dict:
|
||||||
return {
|
return {
|
||||||
'baichuan2-turbo': 'Baichuan2-Turbo',
|
"baichuan2-turbo": "Baichuan2-Turbo",
|
||||||
'baichuan2-turbo-192k': 'Baichuan2-Turbo-192k',
|
"baichuan3-turbo": "Baichuan3-Turbo",
|
||||||
'baichuan2-53b': 'Baichuan2-53B',
|
"baichuan3-turbo-128k": "Baichuan3-Turbo-128k",
|
||||||
'baichuan3-turbo': 'Baichuan3-Turbo',
|
"baichuan4": "Baichuan4",
|
||||||
'baichuan3-turbo-128k': 'Baichuan3-Turbo-128k',
|
}
|
||||||
'baichuan4': 'Baichuan4',
|
|
||||||
}[model]
|
|
||||||
|
|
||||||
def _handle_chat_generate_response(self, response) -> BaichuanMessage:
|
@property
|
||||||
resp = response.json()
|
def request_headers(self) -> dict[str, Any]:
|
||||||
choices = resp.get('choices', [])
|
return {
|
||||||
message = BaichuanMessage(content='', role='assistant')
|
"Content-Type": "application/json",
|
||||||
for choice in choices:
|
"Authorization": "Bearer " + self.api_key,
|
||||||
message.content += choice['message']['content']
|
}
|
||||||
message.role = choice['message']['role']
|
|
||||||
if choice['finish_reason']:
|
|
||||||
message.stop_reason = choice['finish_reason']
|
|
||||||
|
|
||||||
if 'usage' in resp:
|
def _build_parameters(
|
||||||
message.usage = {
|
self,
|
||||||
'prompt_tokens': resp['usage']['prompt_tokens'],
|
model: str,
|
||||||
'completion_tokens': resp['usage']['completion_tokens'],
|
stream: bool,
|
||||||
'total_tokens': resp['usage']['total_tokens'],
|
messages: list[dict],
|
||||||
}
|
parameters: dict[str, Any],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
if model in self._model_mapping.keys():
|
||||||
|
# the LargeLanguageModel._code_block_mode_wrapper() method will remove the response_format of parameters.
|
||||||
|
# we need to rename it to res_format to get its value
|
||||||
|
if parameters.get("res_format") == "json_object":
|
||||||
|
parameters["response_format"] = {"type": "json_object"}
|
||||||
|
|
||||||
return message
|
if tools or parameters.get("with_search_enhance") is True:
|
||||||
|
parameters["tools"] = []
|
||||||
def _handle_chat_stream_generate_response(self, response) -> Generator:
|
|
||||||
for line in response.iter_lines():
|
|
||||||
if not line:
|
|
||||||
continue
|
|
||||||
line = line.decode('utf-8')
|
|
||||||
# remove the first `data: ` prefix
|
|
||||||
if line.startswith('data:'):
|
|
||||||
line = line[5:].strip()
|
|
||||||
try:
|
|
||||||
data = loads(line)
|
|
||||||
except Exception as e:
|
|
||||||
if line.strip() == '[DONE]':
|
|
||||||
return
|
|
||||||
choices = data.get('choices', [])
|
|
||||||
# save stop reason temporarily
|
|
||||||
stop_reason = ''
|
|
||||||
for choice in choices:
|
|
||||||
if choice.get('finish_reason'):
|
|
||||||
stop_reason = choice['finish_reason']
|
|
||||||
|
|
||||||
if len(choice['delta']['content']) == 0:
|
# with_search_enhance is deprecated, use web_search instead
|
||||||
continue
|
if parameters.get("with_search_enhance") is True:
|
||||||
yield BaichuanMessage(**choice['delta'])
|
parameters["tools"].append(
|
||||||
|
{
|
||||||
# if there is usage, the response is the last one, yield it and return
|
"type": "web_search",
|
||||||
if 'usage' in data:
|
"web_search": {"enable": True},
|
||||||
message = BaichuanMessage(content='', role='assistant')
|
}
|
||||||
message.usage = {
|
)
|
||||||
'prompt_tokens': data['usage']['prompt_tokens'],
|
if tools:
|
||||||
'completion_tokens': data['usage']['completion_tokens'],
|
for tool in tools:
|
||||||
'total_tokens': data['usage']['total_tokens'],
|
parameters["tools"].append(
|
||||||
}
|
{
|
||||||
message.stop_reason = stop_reason
|
"type": "function",
|
||||||
yield message
|
"function": {
|
||||||
|
"name": tool.name,
|
||||||
def _build_parameters(self, model: str, stream: bool, messages: list[BaichuanMessage],
|
"description": tool.description,
|
||||||
parameters: dict[str, Any]) \
|
"parameters": tool.parameters,
|
||||||
-> dict[str, Any]:
|
},
|
||||||
if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b'
|
}
|
||||||
or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'):
|
)
|
||||||
prompt_messages = []
|
|
||||||
for message in messages:
|
|
||||||
if message.role == BaichuanMessage.Role.USER.value or message.role == BaichuanMessage.Role._SYSTEM.value:
|
|
||||||
# check if the latest message is a user message
|
|
||||||
if len(prompt_messages) > 0 and prompt_messages[-1]['role'] == BaichuanMessage.Role.USER.value:
|
|
||||||
prompt_messages[-1]['content'] += message.content
|
|
||||||
else:
|
|
||||||
prompt_messages.append({
|
|
||||||
'content': message.content,
|
|
||||||
'role': BaichuanMessage.Role.USER.value,
|
|
||||||
})
|
|
||||||
elif message.role == BaichuanMessage.Role.ASSISTANT.value:
|
|
||||||
prompt_messages.append({
|
|
||||||
'content': message.content,
|
|
||||||
'role': message.role,
|
|
||||||
})
|
|
||||||
# [baichuan] frequency_penalty must be between 1 and 2
|
|
||||||
if 'frequency_penalty' in parameters:
|
|
||||||
if parameters['frequency_penalty'] < 1 or parameters['frequency_penalty'] > 2:
|
|
||||||
parameters['frequency_penalty'] = 1
|
|
||||||
|
|
||||||
# turbo api accepts flat parameters
|
# turbo api accepts flat parameters
|
||||||
return {
|
return {
|
||||||
'model': self._model_mapping(model),
|
"model": self._model_mapping.get(model),
|
||||||
'stream': stream,
|
"stream": stream,
|
||||||
'messages': prompt_messages,
|
"messages": messages,
|
||||||
**parameters,
|
**parameters,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise BadRequestError(f"Unknown model: {model}")
|
raise BadRequestError(f"Unknown model: {model}")
|
||||||
|
|
||||||
def _build_headers(self, model: str, data: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b'
|
|
||||||
or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'):
|
|
||||||
# there is no secret key for turbo api
|
|
||||||
return {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) ',
|
|
||||||
'Authorization': 'Bearer ' + self.api_key,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
raise BadRequestError(f"Unknown model: {model}")
|
|
||||||
|
|
||||||
def _calculate_md5(self, input_string):
|
|
||||||
return md5(input_string.encode('utf-8')).hexdigest()
|
|
||||||
|
|
||||||
def generate(self, model: str, stream: bool, messages: list[BaichuanMessage],
|
def generate(
|
||||||
parameters: dict[str, Any], timeout: int) \
|
self,
|
||||||
-> Union[Generator, BaichuanMessage]:
|
model: str,
|
||||||
|
stream: bool,
|
||||||
if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b'
|
messages: list[dict],
|
||||||
or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'):
|
parameters: dict[str, Any],
|
||||||
api_base = 'https://api.baichuan-ai.com/v1/chat/completions'
|
timeout: int,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
) -> Union[Iterator, dict]:
|
||||||
|
|
||||||
|
if model in self._model_mapping.keys():
|
||||||
|
api_base = "https://api.baichuan-ai.com/v1/chat/completions"
|
||||||
else:
|
else:
|
||||||
raise BadRequestError(f"Unknown model: {model}")
|
raise BadRequestError(f"Unknown model: {model}")
|
||||||
|
|
||||||
try:
|
data = self._build_parameters(model, stream, messages, parameters, tools)
|
||||||
data = self._build_parameters(model, stream, messages, parameters)
|
|
||||||
headers = self._build_headers(model, data)
|
|
||||||
except KeyError:
|
|
||||||
raise InternalServerError(f"Failed to build parameters for model: {model}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = post(
|
response = post(
|
||||||
url=api_base,
|
url=api_base,
|
||||||
headers=headers,
|
headers=self.request_headers,
|
||||||
data=dumps(data),
|
data=json.dumps(data),
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
stream=stream
|
stream=stream,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise InternalServerError(f"Failed to invoke model: {e}")
|
raise InternalServerError(f"Failed to invoke model: {e}")
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
try:
|
try:
|
||||||
resp = response.json()
|
resp = response.json()
|
||||||
# try to parse error message
|
# try to parse error message
|
||||||
err = resp['error']['code']
|
err = resp["error"]["type"]
|
||||||
msg = resp['error']['message']
|
msg = resp["error"]["message"]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}")
|
raise InternalServerError(
|
||||||
|
f"Failed to convert response to json: {e} with text: {response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
if err == 'invalid_api_key':
|
if err == "invalid_api_key":
|
||||||
raise InvalidAPIKeyError(msg)
|
raise InvalidAPIKeyError(msg)
|
||||||
elif err == 'insufficient_quota':
|
elif err == "insufficient_quota":
|
||||||
raise InsufficientAccountBalance(msg)
|
raise InsufficientAccountBalance(msg)
|
||||||
elif err == 'invalid_authentication':
|
elif err == "invalid_authentication":
|
||||||
raise InvalidAuthenticationError(msg)
|
raise InvalidAuthenticationError(msg)
|
||||||
elif 'rate' in err:
|
elif err == "invalid_request_error":
|
||||||
|
raise BadRequestError(msg)
|
||||||
|
elif "rate" in err:
|
||||||
raise RateLimitReachedError(msg)
|
raise RateLimitReachedError(msg)
|
||||||
elif 'internal' in err:
|
elif "internal" in err:
|
||||||
raise InternalServerError(msg)
|
raise InternalServerError(msg)
|
||||||
elif err == 'api_key_empty':
|
elif err == "api_key_empty":
|
||||||
raise InvalidAPIKeyError(msg)
|
raise InvalidAPIKeyError(msg)
|
||||||
else:
|
else:
|
||||||
raise InternalServerError(f"Unknown error: {err} with message: {msg}")
|
raise InternalServerError(f"Unknown error: {err} with message: {msg}")
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._handle_chat_stream_generate_response(response)
|
return response.iter_lines()
|
||||||
else:
|
else:
|
||||||
return self._handle_chat_generate_response(response)
|
return response.json()
|
||||||
|
@ -1,7 +1,12 @@
|
|||||||
from collections.abc import Generator
|
import json
|
||||||
|
from collections.abc import Generator, Iterator
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from core.model_runtime.entities.llm_entities import (
|
||||||
|
LLMResult,
|
||||||
|
LLMResultChunk,
|
||||||
|
LLMResultChunkDelta,
|
||||||
|
)
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
@ -21,7 +26,7 @@ from core.model_runtime.errors.invoke import (
|
|||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer
|
from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer
|
||||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanMessage, BaichuanModel
|
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanModel
|
||||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
|
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
|
||||||
BadRequestError,
|
BadRequestError,
|
||||||
InsufficientAccountBalance,
|
InsufficientAccountBalance,
|
||||||
@ -33,19 +38,40 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors impor
|
|||||||
|
|
||||||
|
|
||||||
class BaichuanLarguageModel(LargeLanguageModel):
|
class BaichuanLarguageModel(LargeLanguageModel):
|
||||||
def _invoke(self, model: str, credentials: dict,
|
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
|
||||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
|
||||||
stream: bool = True, user: str | None = None) \
|
|
||||||
-> LLMResult | Generator:
|
|
||||||
return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
|
||||||
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
|
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def _invoke(
|
||||||
tools: list[PromptMessageTool] | None = None) -> int:
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: list[PromptMessageTool] | None = None,
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> LLMResult | Generator:
|
||||||
|
return self._generate(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
model_parameters=model_parameters,
|
||||||
|
tools=tools,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_num_tokens(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
tools: list[PromptMessageTool] | None = None,
|
||||||
|
) -> int:
|
||||||
return self._num_tokens_from_messages(prompt_messages)
|
return self._num_tokens_from_messages(prompt_messages)
|
||||||
|
|
||||||
def _num_tokens_from_messages(self, messages: list[PromptMessage], ) -> int:
|
def _num_tokens_from_messages(
|
||||||
|
self,
|
||||||
|
messages: list[PromptMessage],
|
||||||
|
) -> int:
|
||||||
"""Calculate num tokens for baichuan model"""
|
"""Calculate num tokens for baichuan model"""
|
||||||
|
|
||||||
def tokens(text: str):
|
def tokens(text: str):
|
||||||
@ -59,10 +85,10 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
|||||||
num_tokens += tokens_per_message
|
num_tokens += tokens_per_message
|
||||||
for key, value in message.items():
|
for key, value in message.items():
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
text = ''
|
text = ""
|
||||||
for item in value:
|
for item in value:
|
||||||
if isinstance(item, dict) and item['type'] == 'text':
|
if isinstance(item, dict) and item["type"] == "text":
|
||||||
text += item['text']
|
text += item["text"]
|
||||||
|
|
||||||
value = text
|
value = text
|
||||||
|
|
||||||
@ -84,19 +110,18 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
|||||||
elif isinstance(message, AssistantPromptMessage):
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
message = cast(AssistantPromptMessage, message)
|
message = cast(AssistantPromptMessage, message)
|
||||||
message_dict = {"role": "assistant", "content": message.content}
|
message_dict = {"role": "assistant", "content": message.content}
|
||||||
|
if message.tool_calls:
|
||||||
|
message_dict["tool_calls"] = [tool_call.dict() for tool_call in
|
||||||
|
message.tool_calls]
|
||||||
elif isinstance(message, SystemPromptMessage):
|
elif isinstance(message, SystemPromptMessage):
|
||||||
message = cast(SystemPromptMessage, message)
|
message = cast(SystemPromptMessage, message)
|
||||||
message_dict = {"role": "user", "content": message.content}
|
message_dict = {"role": "system", "content": message.content}
|
||||||
elif isinstance(message, ToolPromptMessage):
|
elif isinstance(message, ToolPromptMessage):
|
||||||
# copy from core/model_runtime/model_providers/anthropic/llm/llm.py
|
|
||||||
message = cast(ToolPromptMessage, message)
|
message = cast(ToolPromptMessage, message)
|
||||||
message_dict = {
|
message_dict = {
|
||||||
"role": "user",
|
"role": "tool",
|
||||||
"content": [{
|
"content": message.content,
|
||||||
"type": "tool_result",
|
"tool_call_id": message.tool_call_id
|
||||||
"tool_use_id": message.tool_call_id,
|
|
||||||
"content": message.content
|
|
||||||
}]
|
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown message type {type(message)}")
|
raise ValueError(f"Unknown message type {type(message)}")
|
||||||
@ -105,102 +130,159 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
# ping
|
# ping
|
||||||
instance = BaichuanModel(
|
instance = BaichuanModel(api_key=credentials["api_key"])
|
||||||
api_key=credentials['api_key'],
|
|
||||||
secret_key=credentials.get('secret_key', '')
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
instance.generate(model=model, stream=False, messages=[
|
instance.generate(
|
||||||
BaichuanMessage(content='ping', role='user')
|
model=model,
|
||||||
], parameters={
|
stream=False,
|
||||||
'max_tokens': 1,
|
messages=[{"content": "ping", "role": "user"}],
|
||||||
}, timeout=60)
|
parameters={
|
||||||
|
"max_tokens": 1,
|
||||||
|
},
|
||||||
|
timeout=60,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise CredentialsValidateFailedError(f"Invalid API key: {e}")
|
raise CredentialsValidateFailedError(f"Invalid API key: {e}")
|
||||||
|
|
||||||
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def _generate(
|
||||||
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
self,
|
||||||
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
model: str,
|
||||||
-> LLMResult | Generator:
|
credentials: dict,
|
||||||
if tools is not None and len(tools) > 0:
|
prompt_messages: list[PromptMessage],
|
||||||
raise InvokeBadRequestError("Baichuan model doesn't support tools")
|
model_parameters: dict,
|
||||||
|
tools: list[PromptMessageTool] | None = None,
|
||||||
|
stream: bool = True,
|
||||||
|
) -> LLMResult | Generator:
|
||||||
|
|
||||||
instance = BaichuanModel(
|
instance = BaichuanModel(api_key=credentials["api_key"])
|
||||||
api_key=credentials['api_key'],
|
messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
||||||
secret_key=credentials.get('secret_key', '')
|
|
||||||
)
|
|
||||||
|
|
||||||
# convert prompt messages to baichuan messages
|
|
||||||
messages = [
|
|
||||||
BaichuanMessage(
|
|
||||||
content=message.content if isinstance(message.content, str) else ''.join([
|
|
||||||
content.data for content in message.content
|
|
||||||
]),
|
|
||||||
role=message.role.value
|
|
||||||
) for message in prompt_messages
|
|
||||||
]
|
|
||||||
|
|
||||||
# invoke model
|
# invoke model
|
||||||
response = instance.generate(model=model, stream=stream, messages=messages, parameters=model_parameters,
|
response = instance.generate(
|
||||||
timeout=60)
|
model=model,
|
||||||
|
stream=stream,
|
||||||
|
messages=messages,
|
||||||
|
parameters=model_parameters,
|
||||||
|
timeout=60,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response)
|
return self._handle_chat_generate_stream_response(
|
||||||
|
model, prompt_messages, credentials, response
|
||||||
|
)
|
||||||
|
|
||||||
return self._handle_chat_generate_response(model, prompt_messages, credentials, response)
|
return self._handle_chat_generate_response(
|
||||||
|
model, prompt_messages, credentials, response
|
||||||
|
)
|
||||||
|
|
||||||
|
def _handle_chat_generate_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
credentials: dict,
|
||||||
|
response: dict,
|
||||||
|
) -> LLMResult:
|
||||||
|
choices = response.get("choices", [])
|
||||||
|
assistant_message = AssistantPromptMessage(content='', tool_calls=[])
|
||||||
|
if choices and choices[0]["finish_reason"] == "tool_calls":
|
||||||
|
for choice in choices:
|
||||||
|
for tool_call in choice["message"]["tool_calls"]:
|
||||||
|
tool = AssistantPromptMessage.ToolCall(
|
||||||
|
id=tool_call.get("id", ""),
|
||||||
|
type=tool_call.get("type", ""),
|
||||||
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
|
name=tool_call.get("function", {}).get("name", ""),
|
||||||
|
arguments=tool_call.get("function", {}).get("arguments", "")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assistant_message.tool_calls.append(tool)
|
||||||
|
else:
|
||||||
|
for choice in choices:
|
||||||
|
assistant_message.content += choice["message"]["content"]
|
||||||
|
assistant_message.role = choice["message"]["role"]
|
||||||
|
|
||||||
|
usage = response.get("usage")
|
||||||
|
if usage:
|
||||||
|
# transform usage
|
||||||
|
prompt_tokens = usage["prompt_tokens"]
|
||||||
|
completion_tokens = usage["completion_tokens"]
|
||||||
|
else:
|
||||||
|
# calculate num tokens
|
||||||
|
prompt_tokens = self._num_tokens_from_messages(prompt_messages)
|
||||||
|
completion_tokens = self._num_tokens_from_messages([assistant_message])
|
||||||
|
|
||||||
|
usage = self._calc_response_usage(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
def _handle_chat_generate_response(self, model: str,
|
|
||||||
prompt_messages: list[PromptMessage],
|
|
||||||
credentials: dict,
|
|
||||||
response: BaichuanMessage) -> LLMResult:
|
|
||||||
# convert baichuan message to llm result
|
|
||||||
usage = self._calc_response_usage(model=model, credentials=credentials,
|
|
||||||
prompt_tokens=response.usage['prompt_tokens'],
|
|
||||||
completion_tokens=response.usage['completion_tokens'])
|
|
||||||
return LLMResult(
|
return LLMResult(
|
||||||
model=model,
|
model=model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
message=AssistantPromptMessage(
|
message=assistant_message,
|
||||||
content=response.content,
|
|
||||||
tool_calls=[]
|
|
||||||
),
|
|
||||||
usage=usage,
|
usage=usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_chat_generate_stream_response(self, model: str,
|
def _handle_chat_generate_stream_response(
|
||||||
prompt_messages: list[PromptMessage],
|
self,
|
||||||
credentials: dict,
|
model: str,
|
||||||
response: Generator[BaichuanMessage, None, None]) -> Generator:
|
prompt_messages: list[PromptMessage],
|
||||||
for message in response:
|
credentials: dict,
|
||||||
if message.usage:
|
response: Iterator,
|
||||||
usage = self._calc_response_usage(model=model, credentials=credentials,
|
) -> Generator:
|
||||||
prompt_tokens=message.usage['prompt_tokens'],
|
for line in response:
|
||||||
completion_tokens=message.usage['completion_tokens'])
|
if not line:
|
||||||
|
continue
|
||||||
|
line = line.decode("utf-8")
|
||||||
|
# remove the first `data: ` prefix
|
||||||
|
if line.startswith("data:"):
|
||||||
|
line = line[5:].strip()
|
||||||
|
try:
|
||||||
|
data = json.loads(line)
|
||||||
|
except Exception as e:
|
||||||
|
if line.strip() == "[DONE]":
|
||||||
|
return
|
||||||
|
choices = data.get("choices", [])
|
||||||
|
|
||||||
|
stop_reason = ""
|
||||||
|
for choice in choices:
|
||||||
|
if choice.get("finish_reason"):
|
||||||
|
stop_reason = choice["finish_reason"]
|
||||||
|
|
||||||
|
if len(choice["delta"]["content"]) == 0:
|
||||||
|
continue
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=model,
|
model=model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=0,
|
index=0,
|
||||||
message=AssistantPromptMessage(
|
message=AssistantPromptMessage(
|
||||||
content=message.content,
|
content=choice["delta"]["content"], tool_calls=[]
|
||||||
tool_calls=[]
|
|
||||||
),
|
),
|
||||||
usage=usage,
|
finish_reason=stop_reason,
|
||||||
finish_reason=message.stop_reason if message.stop_reason else None,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
|
# if there is usage, the response is the last one, yield it and return
|
||||||
|
if "usage" in data:
|
||||||
|
usage = self._calc_response_usage(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_tokens=data["usage"]["prompt_tokens"],
|
||||||
|
completion_tokens=data["usage"]["completion_tokens"],
|
||||||
|
)
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=model,
|
model=model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=0,
|
index=0,
|
||||||
message=AssistantPromptMessage(
|
message=AssistantPromptMessage(content="", tool_calls=[]),
|
||||||
content=message.content,
|
usage=usage,
|
||||||
tool_calls=[]
|
finish_reason=stop_reason,
|
||||||
),
|
|
||||||
finish_reason=message.stop_reason if message.stop_reason else None,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -215,21 +297,13 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
|||||||
:return: Invoke error mapping
|
:return: Invoke error mapping
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
InvokeConnectionError: [
|
InvokeConnectionError: [],
|
||||||
],
|
InvokeServerUnavailableError: [InternalServerError],
|
||||||
InvokeServerUnavailableError: [
|
InvokeRateLimitError: [RateLimitReachedError],
|
||||||
InternalServerError
|
|
||||||
],
|
|
||||||
InvokeRateLimitError: [
|
|
||||||
RateLimitReachedError
|
|
||||||
],
|
|
||||||
InvokeAuthorizationError: [
|
InvokeAuthorizationError: [
|
||||||
InvalidAuthenticationError,
|
InvalidAuthenticationError,
|
||||||
InsufficientAccountBalance,
|
InsufficientAccountBalance,
|
||||||
InvalidAPIKeyError,
|
InvalidAPIKeyError,
|
||||||
],
|
],
|
||||||
InvokeBadRequestError: [
|
InvokeBadRequestError: [BadRequestError, KeyError],
|
||||||
BadRequestError,
|
|
||||||
KeyError
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user