chore: refactor the beichuan model (#7953)

This commit is contained in:
非法操作 2024-09-04 16:22:31 +08:00 committed by GitHub
parent 14af87527f
commit 0f72a8e89d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 330 additions and 313 deletions

View File

@ -27,11 +27,3 @@ provider_credential_schema:
placeholder: placeholder:
zh_Hans: 在此输入您的 API Key zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key en_US: Enter your API Key
- variable: secret_key
label:
en_US: Secret Key
type: secret-input
required: false
placeholder:
zh_Hans: 在此输入您的 Secret Key
en_US: Enter your Secret Key

View File

@ -43,3 +43,4 @@ parameter_rules:
zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。 zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
en_US: Allow the model to perform external search to enhance the generation results. en_US: Allow the model to perform external search to enhance the generation results.
required: false required: false
deprecated: true

View File

@ -43,3 +43,4 @@ parameter_rules:
zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。 zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
en_US: Allow the model to perform external search to enhance the generation results. en_US: Allow the model to perform external search to enhance the generation results.
required: false required: false
deprecated: true

View File

@ -4,36 +4,32 @@ label:
model_type: llm model_type: llm
features: features:
- agent-thought - agent-thought
- multi-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 32000 context_size: 32000
parameter_rules: parameter_rules:
- name: temperature - name: temperature
use_template: temperature use_template: temperature
default: 0.3
- name: top_p - name: top_p
use_template: top_p use_template: top_p
default: 0.85
- name: top_k - name: top_k
label: label:
zh_Hans: 取样数量 zh_Hans: 取样数量
en_US: Top k en_US: Top k
type: int type: int
min: 0
max: 20
default: 5
help: help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token. en_US: Only sample from the top K options for each subsequent token.
required: false required: false
- name: max_tokens - name: max_tokens
use_template: max_tokens use_template: max_tokens
required: true default: 2048
default: 8000
min: 1
max: 192000
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
default: 1
min: 1
max: 2
- name: with_search_enhance - name: with_search_enhance
label: label:
zh_Hans: 搜索增强 zh_Hans: 搜索增强

View File

@ -4,36 +4,44 @@ label:
model_type: llm model_type: llm
features: features:
- agent-thought - agent-thought
- multi-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 128000 context_size: 128000
parameter_rules: parameter_rules:
- name: temperature - name: temperature
use_template: temperature use_template: temperature
default: 0.3
- name: top_p - name: top_p
use_template: top_p use_template: top_p
default: 0.85
- name: top_k - name: top_k
label: label:
zh_Hans: 取样数量 zh_Hans: 取样数量
en_US: Top k en_US: Top k
type: int type: int
min: 0
max: 20
default: 5
help: help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token. en_US: Only sample from the top K options for each subsequent token.
required: false required: false
- name: max_tokens - name: max_tokens
use_template: max_tokens use_template: max_tokens
required: true default: 2048
default: 8000 - name: res_format
min: 1 label:
max: 128000 zh_Hans: 回复格式
- name: presence_penalty en_US: response format
use_template: presence_penalty type: string
- name: frequency_penalty help:
use_template: frequency_penalty zh_Hans: 指定模型必须输出的格式
default: 1 en_US: specifying the format that the model must output
min: 1 required: false
max: 2 options:
- text
- json_object
- name: with_search_enhance - name: with_search_enhance
label: label:
zh_Hans: 搜索增强 zh_Hans: 搜索增强

View File

@ -4,36 +4,44 @@ label:
model_type: llm model_type: llm
features: features:
- agent-thought - agent-thought
- multi-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 32000 context_size: 32000
parameter_rules: parameter_rules:
- name: temperature - name: temperature
use_template: temperature use_template: temperature
default: 0.3
- name: top_p - name: top_p
use_template: top_p use_template: top_p
default: 0.85
- name: top_k - name: top_k
label: label:
zh_Hans: 取样数量 zh_Hans: 取样数量
en_US: Top k en_US: Top k
type: int type: int
min: 0
max: 20
default: 5
help: help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token. en_US: Only sample from the top K options for each subsequent token.
required: false required: false
- name: max_tokens - name: max_tokens
use_template: max_tokens use_template: max_tokens
required: true default: 2048
default: 8000 - name: res_format
min: 1 label:
max: 32000 zh_Hans: 回复格式
- name: presence_penalty en_US: response format
use_template: presence_penalty type: string
- name: frequency_penalty help:
use_template: frequency_penalty zh_Hans: 指定模型必须输出的格式
default: 1 en_US: specifying the format that the model must output
min: 1 required: false
max: 2 options:
- text
- json_object
- name: with_search_enhance - name: with_search_enhance
label: label:
zh_Hans: 搜索增强 zh_Hans: 搜索增强

View File

@ -4,36 +4,44 @@ label:
model_type: llm model_type: llm
features: features:
- agent-thought - agent-thought
- multi-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 32000 context_size: 32000
parameter_rules: parameter_rules:
- name: temperature - name: temperature
use_template: temperature use_template: temperature
default: 0.3
- name: top_p - name: top_p
use_template: top_p use_template: top_p
default: 0.85
- name: top_k - name: top_k
label: label:
zh_Hans: 取样数量 zh_Hans: 取样数量
en_US: Top k en_US: Top k
type: int type: int
min: 0
max: 20
default: 5
help: help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token. en_US: Only sample from the top K options for each subsequent token.
required: false required: false
- name: max_tokens - name: max_tokens
use_template: max_tokens use_template: max_tokens
required: true default: 2048
default: 8000 - name: res_format
min: 1 label:
max: 32000 zh_Hans: 回复格式
- name: presence_penalty en_US: response format
use_template: presence_penalty type: string
- name: frequency_penalty help:
use_template: frequency_penalty zh_Hans: 指定模型必须输出的格式
default: 1 en_US: specifying the format that the model must output
min: 1 required: false
max: 2 options:
- text
- json_object
- name: with_search_enhance - name: with_search_enhance
label: label:
zh_Hans: 搜索增强 zh_Hans: 搜索增强

View File

@ -1,11 +1,10 @@
from collections.abc import Generator import json
from enum import Enum from collections.abc import Iterator
from hashlib import md5 from typing import Any, Optional, Union
from json import dumps, loads
from typing import Any, Union
from requests import post from requests import post
from core.model_runtime.entities.message_entities import PromptMessageTool
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
BadRequestError, BadRequestError,
InsufficientAccountBalance, InsufficientAccountBalance,
@ -16,174 +15,100 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors impor
) )
class BaichuanMessage:
class Role(Enum):
USER = 'user'
ASSISTANT = 'assistant'
# Baichuan does not have system message
_SYSTEM = 'system'
role: str = Role.USER.value
content: str
usage: dict[str, int] = None
stop_reason: str = ''
def to_dict(self) -> dict[str, Any]:
return {
'role': self.role,
'content': self.content,
}
def __init__(self, content: str, role: str = 'user') -> None:
self.content = content
self.role = role
class BaichuanModel: class BaichuanModel:
api_key: str api_key: str
secret_key: str
def __init__(self, api_key: str, secret_key: str = '') -> None: def __init__(self, api_key: str) -> None:
self.api_key = api_key self.api_key = api_key
self.secret_key = secret_key
def _model_mapping(self, model: str) -> str: @property
def _model_mapping(self) -> dict:
return { return {
'baichuan2-turbo': 'Baichuan2-Turbo', "baichuan2-turbo": "Baichuan2-Turbo",
'baichuan2-turbo-192k': 'Baichuan2-Turbo-192k', "baichuan3-turbo": "Baichuan3-Turbo",
'baichuan2-53b': 'Baichuan2-53B', "baichuan3-turbo-128k": "Baichuan3-Turbo-128k",
'baichuan3-turbo': 'Baichuan3-Turbo', "baichuan4": "Baichuan4",
'baichuan3-turbo-128k': 'Baichuan3-Turbo-128k', }
'baichuan4': 'Baichuan4',
}[model]
def _handle_chat_generate_response(self, response) -> BaichuanMessage: @property
resp = response.json() def request_headers(self) -> dict[str, Any]:
choices = resp.get('choices', []) return {
message = BaichuanMessage(content='', role='assistant') "Content-Type": "application/json",
for choice in choices: "Authorization": "Bearer " + self.api_key,
message.content += choice['message']['content'] }
message.role = choice['message']['role']
if choice['finish_reason']:
message.stop_reason = choice['finish_reason']
if 'usage' in resp: def _build_parameters(
message.usage = { self,
'prompt_tokens': resp['usage']['prompt_tokens'], model: str,
'completion_tokens': resp['usage']['completion_tokens'], stream: bool,
'total_tokens': resp['usage']['total_tokens'], messages: list[dict],
} parameters: dict[str, Any],
tools: Optional[list[PromptMessageTool]] = None,
) -> dict[str, Any]:
if model in self._model_mapping.keys():
# the LargeLanguageModel._code_block_mode_wrapper() method will remove the response_format of parameters.
# we need to rename it to res_format to get its value
if parameters.get("res_format") == "json_object":
parameters["response_format"] = {"type": "json_object"}
return message if tools or parameters.get("with_search_enhance") is True:
parameters["tools"] = []
def _handle_chat_stream_generate_response(self, response) -> Generator: # with_search_enhance is deprecated, use web_search instead
for line in response.iter_lines(): if parameters.get("with_search_enhance") is True:
if not line: parameters["tools"].append(
continue {
line = line.decode('utf-8') "type": "web_search",
# remove the first `data: ` prefix "web_search": {"enable": True},
if line.startswith('data:'): }
line = line[5:].strip() )
try: if tools:
data = loads(line) for tool in tools:
except Exception as e: parameters["tools"].append(
if line.strip() == '[DONE]': {
return "type": "function",
choices = data.get('choices', []) "function": {
# save stop reason temporarily "name": tool.name,
stop_reason = '' "description": tool.description,
for choice in choices: "parameters": tool.parameters,
if choice.get('finish_reason'): },
stop_reason = choice['finish_reason'] }
)
if len(choice['delta']['content']) == 0:
continue
yield BaichuanMessage(**choice['delta'])
# if there is usage, the response is the last one, yield it and return
if 'usage' in data:
message = BaichuanMessage(content='', role='assistant')
message.usage = {
'prompt_tokens': data['usage']['prompt_tokens'],
'completion_tokens': data['usage']['completion_tokens'],
'total_tokens': data['usage']['total_tokens'],
}
message.stop_reason = stop_reason
yield message
def _build_parameters(self, model: str, stream: bool, messages: list[BaichuanMessage],
parameters: dict[str, Any]) \
-> dict[str, Any]:
if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b'
or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'):
prompt_messages = []
for message in messages:
if message.role == BaichuanMessage.Role.USER.value or message.role == BaichuanMessage.Role._SYSTEM.value:
# check if the latest message is a user message
if len(prompt_messages) > 0 and prompt_messages[-1]['role'] == BaichuanMessage.Role.USER.value:
prompt_messages[-1]['content'] += message.content
else:
prompt_messages.append({
'content': message.content,
'role': BaichuanMessage.Role.USER.value,
})
elif message.role == BaichuanMessage.Role.ASSISTANT.value:
prompt_messages.append({
'content': message.content,
'role': message.role,
})
# [baichuan] frequency_penalty must be between 1 and 2
if 'frequency_penalty' in parameters:
if parameters['frequency_penalty'] < 1 or parameters['frequency_penalty'] > 2:
parameters['frequency_penalty'] = 1
# turbo api accepts flat parameters # turbo api accepts flat parameters
return { return {
'model': self._model_mapping(model), "model": self._model_mapping.get(model),
'stream': stream, "stream": stream,
'messages': prompt_messages, "messages": messages,
**parameters, **parameters,
} }
else: else:
raise BadRequestError(f"Unknown model: {model}") raise BadRequestError(f"Unknown model: {model}")
def _build_headers(self, model: str, data: dict[str, Any]) -> dict[str, Any]: def generate(
if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b' self,
or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'): model: str,
# there is no secret key for turbo api stream: bool,
return { messages: list[dict],
'Content-Type': 'application/json', parameters: dict[str, Any],
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) ', timeout: int,
'Authorization': 'Bearer ' + self.api_key, tools: Optional[list[PromptMessageTool]] = None,
} ) -> Union[Iterator, dict]:
if model in self._model_mapping.keys():
api_base = "https://api.baichuan-ai.com/v1/chat/completions"
else: else:
raise BadRequestError(f"Unknown model: {model}") raise BadRequestError(f"Unknown model: {model}")
def _calculate_md5(self, input_string): data = self._build_parameters(model, stream, messages, parameters, tools)
return md5(input_string.encode('utf-8')).hexdigest()
def generate(self, model: str, stream: bool, messages: list[BaichuanMessage],
parameters: dict[str, Any], timeout: int) \
-> Union[Generator, BaichuanMessage]:
if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b'
or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'):
api_base = 'https://api.baichuan-ai.com/v1/chat/completions'
else:
raise BadRequestError(f"Unknown model: {model}")
try:
data = self._build_parameters(model, stream, messages, parameters)
headers = self._build_headers(model, data)
except KeyError:
raise InternalServerError(f"Failed to build parameters for model: {model}")
try: try:
response = post( response = post(
url=api_base, url=api_base,
headers=headers, headers=self.request_headers,
data=dumps(data), data=json.dumps(data),
timeout=timeout, timeout=timeout,
stream=stream stream=stream,
) )
except Exception as e: except Exception as e:
raise InternalServerError(f"Failed to invoke model: {e}") raise InternalServerError(f"Failed to invoke model: {e}")
@ -192,27 +117,31 @@ class BaichuanModel:
try: try:
resp = response.json() resp = response.json()
# try to parse error message # try to parse error message
err = resp['error']['code'] err = resp["error"]["type"]
msg = resp['error']['message'] msg = resp["error"]["message"]
except Exception as e: except Exception as e:
raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") raise InternalServerError(
f"Failed to convert response to json: {e} with text: {response.text}"
)
if err == 'invalid_api_key': if err == "invalid_api_key":
raise InvalidAPIKeyError(msg) raise InvalidAPIKeyError(msg)
elif err == 'insufficient_quota': elif err == "insufficient_quota":
raise InsufficientAccountBalance(msg) raise InsufficientAccountBalance(msg)
elif err == 'invalid_authentication': elif err == "invalid_authentication":
raise InvalidAuthenticationError(msg) raise InvalidAuthenticationError(msg)
elif 'rate' in err: elif err == "invalid_request_error":
raise BadRequestError(msg)
elif "rate" in err:
raise RateLimitReachedError(msg) raise RateLimitReachedError(msg)
elif 'internal' in err: elif "internal" in err:
raise InternalServerError(msg) raise InternalServerError(msg)
elif err == 'api_key_empty': elif err == "api_key_empty":
raise InvalidAPIKeyError(msg) raise InvalidAPIKeyError(msg)
else: else:
raise InternalServerError(f"Unknown error: {err} with message: {msg}") raise InternalServerError(f"Unknown error: {err} with message: {msg}")
if stream: if stream:
return self._handle_chat_stream_generate_response(response) return response.iter_lines()
else: else:
return self._handle_chat_generate_response(response) return response.json()

View File

@ -1,7 +1,12 @@
from collections.abc import Generator import json
from collections.abc import Generator, Iterator
from typing import cast from typing import cast
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
)
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
PromptMessage, PromptMessage,
@ -21,7 +26,7 @@ from core.model_runtime.errors.invoke import (
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanMessage, BaichuanModel from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanModel
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
BadRequestError, BadRequestError,
InsufficientAccountBalance, InsufficientAccountBalance,
@ -33,19 +38,40 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors impor
class BaichuanLarguageModel(LargeLanguageModel): class BaichuanLarguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None) \
-> LLMResult | Generator:
return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages,
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def _invoke(
tools: list[PromptMessageTool] | None = None) -> int: self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
return self._generate(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stream=stream,
)
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None,
) -> int:
return self._num_tokens_from_messages(prompt_messages) return self._num_tokens_from_messages(prompt_messages)
def _num_tokens_from_messages(self, messages: list[PromptMessage], ) -> int: def _num_tokens_from_messages(
self,
messages: list[PromptMessage],
) -> int:
"""Calculate num tokens for baichuan model""" """Calculate num tokens for baichuan model"""
def tokens(text: str): def tokens(text: str):
@ -59,10 +85,10 @@ class BaichuanLarguageModel(LargeLanguageModel):
num_tokens += tokens_per_message num_tokens += tokens_per_message
for key, value in message.items(): for key, value in message.items():
if isinstance(value, list): if isinstance(value, list):
text = '' text = ""
for item in value: for item in value:
if isinstance(item, dict) and item['type'] == 'text': if isinstance(item, dict) and item["type"] == "text":
text += item['text'] text += item["text"]
value = text value = text
@ -84,19 +110,18 @@ class BaichuanLarguageModel(LargeLanguageModel):
elif isinstance(message, AssistantPromptMessage): elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message) message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content} message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls:
message_dict["tool_calls"] = [tool_call.dict() for tool_call in
message.tool_calls]
elif isinstance(message, SystemPromptMessage): elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message) message = cast(SystemPromptMessage, message)
message_dict = {"role": "user", "content": message.content} message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage): elif isinstance(message, ToolPromptMessage):
# copy from core/model_runtime/model_providers/anthropic/llm/llm.py
message = cast(ToolPromptMessage, message) message = cast(ToolPromptMessage, message)
message_dict = { message_dict = {
"role": "user", "role": "tool",
"content": [{ "content": message.content,
"type": "tool_result", "tool_call_id": message.tool_call_id
"tool_use_id": message.tool_call_id,
"content": message.content
}]
} }
else: else:
raise ValueError(f"Unknown message type {type(message)}") raise ValueError(f"Unknown message type {type(message)}")
@ -105,102 +130,159 @@ class BaichuanLarguageModel(LargeLanguageModel):
def validate_credentials(self, model: str, credentials: dict) -> None: def validate_credentials(self, model: str, credentials: dict) -> None:
# ping # ping
instance = BaichuanModel( instance = BaichuanModel(api_key=credentials["api_key"])
api_key=credentials['api_key'],
secret_key=credentials.get('secret_key', '')
)
try: try:
instance.generate(model=model, stream=False, messages=[ instance.generate(
BaichuanMessage(content='ping', role='user') model=model,
], parameters={ stream=False,
'max_tokens': 1, messages=[{"content": "ping", "role": "user"}],
}, timeout=60) parameters={
"max_tokens": 1,
},
timeout=60,
)
except Exception as e: except Exception as e:
raise CredentialsValidateFailedError(f"Invalid API key: {e}") raise CredentialsValidateFailedError(f"Invalid API key: {e}")
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def _generate(
model_parameters: dict, tools: list[PromptMessageTool] | None = None, self,
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ model: str,
-> LLMResult | Generator: credentials: dict,
if tools is not None and len(tools) > 0: prompt_messages: list[PromptMessage],
raise InvokeBadRequestError("Baichuan model doesn't support tools") model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stream: bool = True,
) -> LLMResult | Generator:
instance = BaichuanModel( instance = BaichuanModel(api_key=credentials["api_key"])
api_key=credentials['api_key'], messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
secret_key=credentials.get('secret_key', '')
)
# convert prompt messages to baichuan messages
messages = [
BaichuanMessage(
content=message.content if isinstance(message.content, str) else ''.join([
content.data for content in message.content
]),
role=message.role.value
) for message in prompt_messages
]
# invoke model # invoke model
response = instance.generate(model=model, stream=stream, messages=messages, parameters=model_parameters, response = instance.generate(
timeout=60) model=model,
stream=stream,
messages=messages,
parameters=model_parameters,
timeout=60,
tools=tools,
)
if stream: if stream:
return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response) return self._handle_chat_generate_stream_response(
model, prompt_messages, credentials, response
)
return self._handle_chat_generate_response(model, prompt_messages, credentials, response) return self._handle_chat_generate_response(
model, prompt_messages, credentials, response
)
def _handle_chat_generate_response(
self,
model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: dict,
) -> LLMResult:
choices = response.get("choices", [])
assistant_message = AssistantPromptMessage(content='', tool_calls=[])
if choices and choices[0]["finish_reason"] == "tool_calls":
for choice in choices:
for tool_call in choice["message"]["tool_calls"]:
tool = AssistantPromptMessage.ToolCall(
id=tool_call.get("id", ""),
type=tool_call.get("type", ""),
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_call.get("function", {}).get("name", ""),
arguments=tool_call.get("function", {}).get("arguments", "")
),
)
assistant_message.tool_calls.append(tool)
else:
for choice in choices:
assistant_message.content += choice["message"]["content"]
assistant_message.role = choice["message"]["role"]
usage = response.get("usage")
if usage:
# transform usage
prompt_tokens = usage["prompt_tokens"]
completion_tokens = usage["completion_tokens"]
else:
# calculate num tokens
prompt_tokens = self._num_tokens_from_messages(prompt_messages)
completion_tokens = self._num_tokens_from_messages([assistant_message])
usage = self._calc_response_usage(
model=model,
credentials=credentials,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
def _handle_chat_generate_response(self, model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: BaichuanMessage) -> LLMResult:
# convert baichuan message to llm result
usage = self._calc_response_usage(model=model, credentials=credentials,
prompt_tokens=response.usage['prompt_tokens'],
completion_tokens=response.usage['completion_tokens'])
return LLMResult( return LLMResult(
model=model, model=model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
message=AssistantPromptMessage( message=assistant_message,
content=response.content,
tool_calls=[]
),
usage=usage, usage=usage,
) )
def _handle_chat_generate_stream_response(self, model: str, def _handle_chat_generate_stream_response(
prompt_messages: list[PromptMessage], self,
credentials: dict, model: str,
response: Generator[BaichuanMessage, None, None]) -> Generator: prompt_messages: list[PromptMessage],
for message in response: credentials: dict,
if message.usage: response: Iterator,
usage = self._calc_response_usage(model=model, credentials=credentials, ) -> Generator:
prompt_tokens=message.usage['prompt_tokens'], for line in response:
completion_tokens=message.usage['completion_tokens']) if not line:
continue
line = line.decode("utf-8")
# remove the first `data: ` prefix
if line.startswith("data:"):
line = line[5:].strip()
try:
data = json.loads(line)
except Exception as e:
if line.strip() == "[DONE]":
return
choices = data.get("choices", [])
stop_reason = ""
for choice in choices:
if choice.get("finish_reason"):
stop_reason = choice["finish_reason"]
if len(choice["delta"]["content"]) == 0:
continue
yield LLMResultChunk( yield LLMResultChunk(
model=model, model=model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
delta=LLMResultChunkDelta( delta=LLMResultChunkDelta(
index=0, index=0,
message=AssistantPromptMessage( message=AssistantPromptMessage(
content=message.content, content=choice["delta"]["content"], tool_calls=[]
tool_calls=[]
), ),
usage=usage, finish_reason=stop_reason,
finish_reason=message.stop_reason if message.stop_reason else None,
), ),
) )
else:
# if there is usage, the response is the last one, yield it and return
if "usage" in data:
usage = self._calc_response_usage(
model=model,
credentials=credentials,
prompt_tokens=data["usage"]["prompt_tokens"],
completion_tokens=data["usage"]["completion_tokens"],
)
yield LLMResultChunk( yield LLMResultChunk(
model=model, model=model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
delta=LLMResultChunkDelta( delta=LLMResultChunkDelta(
index=0, index=0,
message=AssistantPromptMessage( message=AssistantPromptMessage(content="", tool_calls=[]),
content=message.content, usage=usage,
tool_calls=[] finish_reason=stop_reason,
),
finish_reason=message.stop_reason if message.stop_reason else None,
), ),
) )
@ -215,21 +297,13 @@ class BaichuanLarguageModel(LargeLanguageModel):
:return: Invoke error mapping :return: Invoke error mapping
""" """
return { return {
InvokeConnectionError: [ InvokeConnectionError: [],
], InvokeServerUnavailableError: [InternalServerError],
InvokeServerUnavailableError: [ InvokeRateLimitError: [RateLimitReachedError],
InternalServerError
],
InvokeRateLimitError: [
RateLimitReachedError
],
InvokeAuthorizationError: [ InvokeAuthorizationError: [
InvalidAuthenticationError, InvalidAuthenticationError,
InsufficientAccountBalance, InsufficientAccountBalance,
InvalidAPIKeyError, InvalidAPIKeyError,
], ],
InvokeBadRequestError: [ InvokeBadRequestError: [BadRequestError, KeyError],
BadRequestError,
KeyError
]
} }