mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-11 22:38:59 +08:00
chore: refactor the beichuan model (#7953)
This commit is contained in:
parent
14af87527f
commit
0f72a8e89d
@ -27,11 +27,3 @@ provider_credential_schema:
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: secret_key
|
||||
label:
|
||||
en_US: Secret Key
|
||||
type: secret-input
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 Secret Key
|
||||
en_US: Enter your Secret Key
|
||||
|
@ -43,3 +43,4 @@ parameter_rules:
|
||||
zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
|
||||
en_US: Allow the model to perform external search to enhance the generation results.
|
||||
required: false
|
||||
deprecated: true
|
||||
|
@ -43,3 +43,4 @@ parameter_rules:
|
||||
zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
|
||||
en_US: Allow the model to perform external search to enhance the generation results.
|
||||
required: false
|
||||
deprecated: true
|
||||
|
@ -4,36 +4,32 @@ label:
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- multi-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.3
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 0.85
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
min: 0
|
||||
max: 20
|
||||
default: 5
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8000
|
||||
min: 1
|
||||
max: 192000
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
default: 1
|
||||
min: 1
|
||||
max: 2
|
||||
default: 2048
|
||||
- name: with_search_enhance
|
||||
label:
|
||||
zh_Hans: 搜索增强
|
||||
|
@ -4,36 +4,44 @@ label:
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- multi-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.3
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 0.85
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
min: 0
|
||||
max: 20
|
||||
default: 5
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8000
|
||||
min: 1
|
||||
max: 128000
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
default: 1
|
||||
min: 1
|
||||
max: 2
|
||||
default: 2048
|
||||
- name: res_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
- name: with_search_enhance
|
||||
label:
|
||||
zh_Hans: 搜索增强
|
||||
|
@ -4,36 +4,44 @@ label:
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- multi-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.3
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 0.85
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
min: 0
|
||||
max: 20
|
||||
default: 5
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8000
|
||||
min: 1
|
||||
max: 32000
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
default: 1
|
||||
min: 1
|
||||
max: 2
|
||||
default: 2048
|
||||
- name: res_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
- name: with_search_enhance
|
||||
label:
|
||||
zh_Hans: 搜索增强
|
||||
|
@ -4,36 +4,44 @@ label:
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- multi-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.3
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 0.85
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
min: 0
|
||||
max: 20
|
||||
default: 5
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8000
|
||||
min: 1
|
||||
max: 32000
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
default: 1
|
||||
min: 1
|
||||
max: 2
|
||||
default: 2048
|
||||
- name: res_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
- name: with_search_enhance
|
||||
label:
|
||||
zh_Hans: 搜索增强
|
||||
|
@ -1,11 +1,10 @@
|
||||
from collections.abc import Generator
|
||||
from enum import Enum
|
||||
from hashlib import md5
|
||||
from json import dumps, loads
|
||||
from typing import Any, Union
|
||||
import json
|
||||
from collections.abc import Iterator
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from requests import post
|
||||
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
|
||||
BadRequestError,
|
||||
InsufficientAccountBalance,
|
||||
@ -16,203 +15,133 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors impor
|
||||
)
|
||||
|
||||
|
||||
class BaichuanMessage:
|
||||
class Role(Enum):
|
||||
USER = 'user'
|
||||
ASSISTANT = 'assistant'
|
||||
# Baichuan does not have system message
|
||||
_SYSTEM = 'system'
|
||||
|
||||
role: str = Role.USER.value
|
||||
content: str
|
||||
usage: dict[str, int] = None
|
||||
stop_reason: str = ''
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
'role': self.role,
|
||||
'content': self.content,
|
||||
}
|
||||
|
||||
def __init__(self, content: str, role: str = 'user') -> None:
|
||||
self.content = content
|
||||
self.role = role
|
||||
|
||||
class BaichuanModel:
|
||||
api_key: str
|
||||
secret_key: str
|
||||
|
||||
def __init__(self, api_key: str, secret_key: str = '') -> None:
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
self.secret_key = secret_key
|
||||
|
||||
def _model_mapping(self, model: str) -> str:
|
||||
@property
|
||||
def _model_mapping(self) -> dict:
|
||||
return {
|
||||
'baichuan2-turbo': 'Baichuan2-Turbo',
|
||||
'baichuan2-turbo-192k': 'Baichuan2-Turbo-192k',
|
||||
'baichuan2-53b': 'Baichuan2-53B',
|
||||
'baichuan3-turbo': 'Baichuan3-Turbo',
|
||||
'baichuan3-turbo-128k': 'Baichuan3-Turbo-128k',
|
||||
'baichuan4': 'Baichuan4',
|
||||
}[model]
|
||||
"baichuan2-turbo": "Baichuan2-Turbo",
|
||||
"baichuan3-turbo": "Baichuan3-Turbo",
|
||||
"baichuan3-turbo-128k": "Baichuan3-Turbo-128k",
|
||||
"baichuan4": "Baichuan4",
|
||||
}
|
||||
|
||||
def _handle_chat_generate_response(self, response) -> BaichuanMessage:
|
||||
resp = response.json()
|
||||
choices = resp.get('choices', [])
|
||||
message = BaichuanMessage(content='', role='assistant')
|
||||
for choice in choices:
|
||||
message.content += choice['message']['content']
|
||||
message.role = choice['message']['role']
|
||||
if choice['finish_reason']:
|
||||
message.stop_reason = choice['finish_reason']
|
||||
@property
|
||||
def request_headers(self) -> dict[str, Any]:
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + self.api_key,
|
||||
}
|
||||
|
||||
if 'usage' in resp:
|
||||
message.usage = {
|
||||
'prompt_tokens': resp['usage']['prompt_tokens'],
|
||||
'completion_tokens': resp['usage']['completion_tokens'],
|
||||
'total_tokens': resp['usage']['total_tokens'],
|
||||
}
|
||||
def _build_parameters(
|
||||
self,
|
||||
model: str,
|
||||
stream: bool,
|
||||
messages: list[dict],
|
||||
parameters: dict[str, Any],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> dict[str, Any]:
|
||||
if model in self._model_mapping.keys():
|
||||
# the LargeLanguageModel._code_block_mode_wrapper() method will remove the response_format of parameters.
|
||||
# we need to rename it to res_format to get its value
|
||||
if parameters.get("res_format") == "json_object":
|
||||
parameters["response_format"] = {"type": "json_object"}
|
||||
|
||||
return message
|
||||
|
||||
def _handle_chat_stream_generate_response(self, response) -> Generator:
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
line = line.decode('utf-8')
|
||||
# remove the first `data: ` prefix
|
||||
if line.startswith('data:'):
|
||||
line = line[5:].strip()
|
||||
try:
|
||||
data = loads(line)
|
||||
except Exception as e:
|
||||
if line.strip() == '[DONE]':
|
||||
return
|
||||
choices = data.get('choices', [])
|
||||
# save stop reason temporarily
|
||||
stop_reason = ''
|
||||
for choice in choices:
|
||||
if choice.get('finish_reason'):
|
||||
stop_reason = choice['finish_reason']
|
||||
if tools or parameters.get("with_search_enhance") is True:
|
||||
parameters["tools"] = []
|
||||
|
||||
if len(choice['delta']['content']) == 0:
|
||||
continue
|
||||
yield BaichuanMessage(**choice['delta'])
|
||||
|
||||
# if there is usage, the response is the last one, yield it and return
|
||||
if 'usage' in data:
|
||||
message = BaichuanMessage(content='', role='assistant')
|
||||
message.usage = {
|
||||
'prompt_tokens': data['usage']['prompt_tokens'],
|
||||
'completion_tokens': data['usage']['completion_tokens'],
|
||||
'total_tokens': data['usage']['total_tokens'],
|
||||
}
|
||||
message.stop_reason = stop_reason
|
||||
yield message
|
||||
|
||||
def _build_parameters(self, model: str, stream: bool, messages: list[BaichuanMessage],
|
||||
parameters: dict[str, Any]) \
|
||||
-> dict[str, Any]:
|
||||
if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b'
|
||||
or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'):
|
||||
prompt_messages = []
|
||||
for message in messages:
|
||||
if message.role == BaichuanMessage.Role.USER.value or message.role == BaichuanMessage.Role._SYSTEM.value:
|
||||
# check if the latest message is a user message
|
||||
if len(prompt_messages) > 0 and prompt_messages[-1]['role'] == BaichuanMessage.Role.USER.value:
|
||||
prompt_messages[-1]['content'] += message.content
|
||||
else:
|
||||
prompt_messages.append({
|
||||
'content': message.content,
|
||||
'role': BaichuanMessage.Role.USER.value,
|
||||
})
|
||||
elif message.role == BaichuanMessage.Role.ASSISTANT.value:
|
||||
prompt_messages.append({
|
||||
'content': message.content,
|
||||
'role': message.role,
|
||||
})
|
||||
# [baichuan] frequency_penalty must be between 1 and 2
|
||||
if 'frequency_penalty' in parameters:
|
||||
if parameters['frequency_penalty'] < 1 or parameters['frequency_penalty'] > 2:
|
||||
parameters['frequency_penalty'] = 1
|
||||
# with_search_enhance is deprecated, use web_search instead
|
||||
if parameters.get("with_search_enhance") is True:
|
||||
parameters["tools"].append(
|
||||
{
|
||||
"type": "web_search",
|
||||
"web_search": {"enable": True},
|
||||
}
|
||||
)
|
||||
if tools:
|
||||
for tool in tools:
|
||||
parameters["tools"].append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# turbo api accepts flat parameters
|
||||
return {
|
||||
'model': self._model_mapping(model),
|
||||
'stream': stream,
|
||||
'messages': prompt_messages,
|
||||
"model": self._model_mapping.get(model),
|
||||
"stream": stream,
|
||||
"messages": messages,
|
||||
**parameters,
|
||||
}
|
||||
else:
|
||||
raise BadRequestError(f"Unknown model: {model}")
|
||||
|
||||
def _build_headers(self, model: str, data: dict[str, Any]) -> dict[str, Any]:
|
||||
if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b'
|
||||
or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'):
|
||||
# there is no secret key for turbo api
|
||||
return {
|
||||
'Content-Type': 'application/json',
|
||||
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) ',
|
||||
'Authorization': 'Bearer ' + self.api_key,
|
||||
}
|
||||
else:
|
||||
raise BadRequestError(f"Unknown model: {model}")
|
||||
|
||||
def _calculate_md5(self, input_string):
|
||||
return md5(input_string.encode('utf-8')).hexdigest()
|
||||
|
||||
def generate(self, model: str, stream: bool, messages: list[BaichuanMessage],
|
||||
parameters: dict[str, Any], timeout: int) \
|
||||
-> Union[Generator, BaichuanMessage]:
|
||||
|
||||
if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b'
|
||||
or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'):
|
||||
api_base = 'https://api.baichuan-ai.com/v1/chat/completions'
|
||||
def generate(
|
||||
self,
|
||||
model: str,
|
||||
stream: bool,
|
||||
messages: list[dict],
|
||||
parameters: dict[str, Any],
|
||||
timeout: int,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> Union[Iterator, dict]:
|
||||
|
||||
if model in self._model_mapping.keys():
|
||||
api_base = "https://api.baichuan-ai.com/v1/chat/completions"
|
||||
else:
|
||||
raise BadRequestError(f"Unknown model: {model}")
|
||||
|
||||
try:
|
||||
data = self._build_parameters(model, stream, messages, parameters)
|
||||
headers = self._build_headers(model, data)
|
||||
except KeyError:
|
||||
raise InternalServerError(f"Failed to build parameters for model: {model}")
|
||||
|
||||
data = self._build_parameters(model, stream, messages, parameters, tools)
|
||||
|
||||
try:
|
||||
response = post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=dumps(data),
|
||||
headers=self.request_headers,
|
||||
data=json.dumps(data),
|
||||
timeout=timeout,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
)
|
||||
except Exception as e:
|
||||
raise InternalServerError(f"Failed to invoke model: {e}")
|
||||
|
||||
|
||||
if response.status_code != 200:
|
||||
try:
|
||||
resp = response.json()
|
||||
# try to parse error message
|
||||
err = resp['error']['code']
|
||||
msg = resp['error']['message']
|
||||
err = resp["error"]["type"]
|
||||
msg = resp["error"]["message"]
|
||||
except Exception as e:
|
||||
raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}")
|
||||
raise InternalServerError(
|
||||
f"Failed to convert response to json: {e} with text: {response.text}"
|
||||
)
|
||||
|
||||
if err == 'invalid_api_key':
|
||||
if err == "invalid_api_key":
|
||||
raise InvalidAPIKeyError(msg)
|
||||
elif err == 'insufficient_quota':
|
||||
elif err == "insufficient_quota":
|
||||
raise InsufficientAccountBalance(msg)
|
||||
elif err == 'invalid_authentication':
|
||||
elif err == "invalid_authentication":
|
||||
raise InvalidAuthenticationError(msg)
|
||||
elif 'rate' in err:
|
||||
elif err == "invalid_request_error":
|
||||
raise BadRequestError(msg)
|
||||
elif "rate" in err:
|
||||
raise RateLimitReachedError(msg)
|
||||
elif 'internal' in err:
|
||||
elif "internal" in err:
|
||||
raise InternalServerError(msg)
|
||||
elif err == 'api_key_empty':
|
||||
elif err == "api_key_empty":
|
||||
raise InvalidAPIKeyError(msg)
|
||||
else:
|
||||
raise InternalServerError(f"Unknown error: {err} with message: {msg}")
|
||||
|
||||
|
||||
if stream:
|
||||
return self._handle_chat_stream_generate_response(response)
|
||||
return response.iter_lines()
|
||||
else:
|
||||
return self._handle_chat_generate_response(response)
|
||||
return response.json()
|
||||
|
@ -1,7 +1,12 @@
|
||||
from collections.abc import Generator
|
||||
import json
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import cast
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
@ -21,7 +26,7 @@ from core.model_runtime.errors.invoke import (
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanMessage, BaichuanModel
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanModel
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
|
||||
BadRequestError,
|
||||
InsufficientAccountBalance,
|
||||
@ -33,19 +38,40 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors impor
|
||||
|
||||
|
||||
class BaichuanLarguageModel(LargeLanguageModel):
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||
stream: bool = True, user: str | None = None) \
|
||||
-> LLMResult | Generator:
|
||||
return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None) -> int:
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
) -> LLMResult | Generator:
|
||||
return self._generate(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
) -> int:
|
||||
return self._num_tokens_from_messages(prompt_messages)
|
||||
|
||||
def _num_tokens_from_messages(self, messages: list[PromptMessage], ) -> int:
|
||||
def _num_tokens_from_messages(
|
||||
self,
|
||||
messages: list[PromptMessage],
|
||||
) -> int:
|
||||
"""Calculate num tokens for baichuan model"""
|
||||
|
||||
def tokens(text: str):
|
||||
@ -59,10 +85,10 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
if isinstance(value, list):
|
||||
text = ''
|
||||
text = ""
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item['type'] == 'text':
|
||||
text += item['text']
|
||||
if isinstance(item, dict) and item["type"] == "text":
|
||||
text += item["text"]
|
||||
|
||||
value = text
|
||||
|
||||
@ -84,19 +110,18 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls:
|
||||
message_dict["tool_calls"] = [tool_call.dict() for tool_call in
|
||||
message.tool_calls]
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
# copy from core/model_runtime/model_providers/anthropic/llm/llm.py
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": message.tool_call_id,
|
||||
"content": message.content
|
||||
}]
|
||||
"role": "tool",
|
||||
"content": message.content,
|
||||
"tool_call_id": message.tool_call_id
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown message type {type(message)}")
|
||||
@ -105,102 +130,159 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
# ping
|
||||
instance = BaichuanModel(
|
||||
api_key=credentials['api_key'],
|
||||
secret_key=credentials.get('secret_key', '')
|
||||
)
|
||||
instance = BaichuanModel(api_key=credentials["api_key"])
|
||||
|
||||
try:
|
||||
instance.generate(model=model, stream=False, messages=[
|
||||
BaichuanMessage(content='ping', role='user')
|
||||
], parameters={
|
||||
'max_tokens': 1,
|
||||
}, timeout=60)
|
||||
instance.generate(
|
||||
model=model,
|
||||
stream=False,
|
||||
messages=[{"content": "ping", "role": "user"}],
|
||||
parameters={
|
||||
"max_tokens": 1,
|
||||
},
|
||||
timeout=60,
|
||||
)
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f"Invalid API key: {e}")
|
||||
|
||||
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
||||
-> LLMResult | Generator:
|
||||
if tools is not None and len(tools) > 0:
|
||||
raise InvokeBadRequestError("Baichuan model doesn't support tools")
|
||||
def _generate(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stream: bool = True,
|
||||
) -> LLMResult | Generator:
|
||||
|
||||
instance = BaichuanModel(
|
||||
api_key=credentials['api_key'],
|
||||
secret_key=credentials.get('secret_key', '')
|
||||
)
|
||||
|
||||
# convert prompt messages to baichuan messages
|
||||
messages = [
|
||||
BaichuanMessage(
|
||||
content=message.content if isinstance(message.content, str) else ''.join([
|
||||
content.data for content in message.content
|
||||
]),
|
||||
role=message.role.value
|
||||
) for message in prompt_messages
|
||||
]
|
||||
instance = BaichuanModel(api_key=credentials["api_key"])
|
||||
messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
||||
|
||||
# invoke model
|
||||
response = instance.generate(model=model, stream=stream, messages=messages, parameters=model_parameters,
|
||||
timeout=60)
|
||||
response = instance.generate(
|
||||
model=model,
|
||||
stream=stream,
|
||||
messages=messages,
|
||||
parameters=model_parameters,
|
||||
timeout=60,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response)
|
||||
return self._handle_chat_generate_stream_response(
|
||||
model, prompt_messages, credentials, response
|
||||
)
|
||||
|
||||
return self._handle_chat_generate_response(model, prompt_messages, credentials, response)
|
||||
return self._handle_chat_generate_response(
|
||||
model, prompt_messages, credentials, response
|
||||
)
|
||||
|
||||
def _handle_chat_generate_response(
|
||||
self,
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
credentials: dict,
|
||||
response: dict,
|
||||
) -> LLMResult:
|
||||
choices = response.get("choices", [])
|
||||
assistant_message = AssistantPromptMessage(content='', tool_calls=[])
|
||||
if choices and choices[0]["finish_reason"] == "tool_calls":
|
||||
for choice in choices:
|
||||
for tool_call in choice["message"]["tool_calls"]:
|
||||
tool = AssistantPromptMessage.ToolCall(
|
||||
id=tool_call.get("id", ""),
|
||||
type=tool_call.get("type", ""),
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool_call.get("function", {}).get("name", ""),
|
||||
arguments=tool_call.get("function", {}).get("arguments", "")
|
||||
),
|
||||
)
|
||||
assistant_message.tool_calls.append(tool)
|
||||
else:
|
||||
for choice in choices:
|
||||
assistant_message.content += choice["message"]["content"]
|
||||
assistant_message.role = choice["message"]["role"]
|
||||
|
||||
usage = response.get("usage")
|
||||
if usage:
|
||||
# transform usage
|
||||
prompt_tokens = usage["prompt_tokens"]
|
||||
completion_tokens = usage["completion_tokens"]
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_messages(prompt_messages)
|
||||
completion_tokens = self._num_tokens_from_messages([assistant_message])
|
||||
|
||||
usage = self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
def _handle_chat_generate_response(self, model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
credentials: dict,
|
||||
response: BaichuanMessage) -> LLMResult:
|
||||
# convert baichuan message to llm result
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials,
|
||||
prompt_tokens=response.usage['prompt_tokens'],
|
||||
completion_tokens=response.usage['completion_tokens'])
|
||||
return LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=response.content,
|
||||
tool_calls=[]
|
||||
),
|
||||
message=assistant_message,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _handle_chat_generate_stream_response(self, model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
credentials: dict,
|
||||
response: Generator[BaichuanMessage, None, None]) -> Generator:
|
||||
for message in response:
|
||||
if message.usage:
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials,
|
||||
prompt_tokens=message.usage['prompt_tokens'],
|
||||
completion_tokens=message.usage['completion_tokens'])
|
||||
def _handle_chat_generate_stream_response(
|
||||
self,
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
credentials: dict,
|
||||
response: Iterator,
|
||||
) -> Generator:
|
||||
for line in response:
|
||||
if not line:
|
||||
continue
|
||||
line = line.decode("utf-8")
|
||||
# remove the first `data: ` prefix
|
||||
if line.startswith("data:"):
|
||||
line = line[5:].strip()
|
||||
try:
|
||||
data = json.loads(line)
|
||||
except Exception as e:
|
||||
if line.strip() == "[DONE]":
|
||||
return
|
||||
choices = data.get("choices", [])
|
||||
|
||||
stop_reason = ""
|
||||
for choice in choices:
|
||||
if choice.get("finish_reason"):
|
||||
stop_reason = choice["finish_reason"]
|
||||
|
||||
if len(choice["delta"]["content"]) == 0:
|
||||
continue
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=message.content,
|
||||
tool_calls=[]
|
||||
content=choice["delta"]["content"], tool_calls=[]
|
||||
),
|
||||
usage=usage,
|
||||
finish_reason=message.stop_reason if message.stop_reason else None,
|
||||
finish_reason=stop_reason,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
||||
# if there is usage, the response is the last one, yield it and return
|
||||
if "usage" in data:
|
||||
usage = self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_tokens=data["usage"]["prompt_tokens"],
|
||||
completion_tokens=data["usage"]["completion_tokens"],
|
||||
)
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=message.content,
|
||||
tool_calls=[]
|
||||
),
|
||||
finish_reason=message.stop_reason if message.stop_reason else None,
|
||||
message=AssistantPromptMessage(content="", tool_calls=[]),
|
||||
usage=usage,
|
||||
finish_reason=stop_reason,
|
||||
),
|
||||
)
|
||||
|
||||
@ -215,21 +297,13 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InternalServerError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
RateLimitReachedError
|
||||
],
|
||||
InvokeConnectionError: [],
|
||||
InvokeServerUnavailableError: [InternalServerError],
|
||||
InvokeRateLimitError: [RateLimitReachedError],
|
||||
InvokeAuthorizationError: [
|
||||
InvalidAuthenticationError,
|
||||
InsufficientAccountBalance,
|
||||
InvalidAPIKeyError,
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
BadRequestError,
|
||||
KeyError
|
||||
]
|
||||
InvokeBadRequestError: [BadRequestError, KeyError],
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user