feat: support doubao llm function calling (#5100)

This commit is contained in:
sino 2024-06-12 15:43:50 +08:00 committed by GitHub
parent 25b0a97851
commit 0ce97e6315
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 91 additions and 16 deletions

View File

@ -7,7 +7,9 @@ from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent, ImagePromptMessageContent,
PromptMessage, PromptMessage,
PromptMessageContentType, PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage, SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.model_providers.volcengine_maas.errors import wrap_error from core.model_runtime.model_providers.volcengine_maas.errors import wrap_error
@ -36,10 +38,11 @@ class MaaSClient(MaasService):
client.set_sk(sk) client.set_sk(sk)
return client return client
def chat(self, params: dict, messages: list[PromptMessage], stream=False) -> Generator | dict: def chat(self, params: dict, messages: list[PromptMessage], stream=False, **extra_model_kwargs) -> Generator | dict:
req = { req = {
'parameters': params, 'parameters': params,
'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages] 'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages],
**extra_model_kwargs,
} }
if not stream: if not stream:
return super().chat( return super().chat(
@ -89,10 +92,22 @@ class MaaSClient(MaasService):
message = cast(AssistantPromptMessage, message) message = cast(AssistantPromptMessage, message)
message_dict = {'role': ChatRole.ASSISTANT, message_dict = {'role': ChatRole.ASSISTANT,
'content': message.content} 'content': message.content}
if message.tool_calls:
message_dict['tool_calls'] = [
{
'name': call.function.name,
'arguments': call.function.arguments
} for call in message.tool_calls
]
elif isinstance(message, SystemPromptMessage): elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message) message = cast(SystemPromptMessage, message)
message_dict = {'role': ChatRole.SYSTEM, message_dict = {'role': ChatRole.SYSTEM,
'content': message.content} 'content': message.content}
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {'role': ChatRole.FUNCTION,
'content': message.content,
'name': message.tool_call_id}
else: else:
raise ValueError(f"Got unknown PromptMessage type {message}") raise ValueError(f"Got unknown PromptMessage type {message}")
@ -106,3 +121,14 @@ class MaaSClient(MaasService):
raise wrap_error(e) raise wrap_error(e)
return resp return resp
@staticmethod
def transform_tool_prompt_to_maas_config(tool: PromptMessageTool):
return {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
}
}

View File

@ -119,8 +119,15 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
if stop: if stop:
req_params['stop'] = stop req_params['stop'] = stop
extra_model_kwargs = {}
if tools:
extra_model_kwargs['tools'] = [
MaaSClient.transform_tool_prompt_to_maas_config(tool) for tool in tools
]
resp = MaaSClient.wrap_exception( resp = MaaSClient.wrap_exception(
lambda: client.chat(req_params, prompt_messages, stream)) lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs))
if not stream: if not stream:
return self._handle_chat_response(model, credentials, prompt_messages, resp) return self._handle_chat_response(model, credentials, prompt_messages, resp)
return self._handle_stream_chat_response(model, credentials, prompt_messages, resp) return self._handle_stream_chat_response(model, credentials, prompt_messages, resp)
@ -156,12 +163,26 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
choice = choices[0] choice = choices[0]
message = choice['message'] message = choice['message']
# parse tool calls
tool_calls = []
if message['tool_calls']:
for call in message['tool_calls']:
tool_call = AssistantPromptMessage.ToolCall(
id=call['function']['name'],
type=call['type'],
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=call['function']['name'],
arguments=call['function']['arguments']
)
)
tool_calls.append(tool_call)
return LLMResult( return LLMResult(
model=model, model=model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
message=AssistantPromptMessage( message=AssistantPromptMessage(
content=message['content'] if message['content'] else '', content=message['content'] if message['content'] else '',
tool_calls=[], tool_calls=tool_calls,
), ),
usage=self._calc_usage(model, credentials, resp['usage']), usage=self._calc_usage(model, credentials, resp['usage']),
) )
@ -252,6 +273,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
if credentials.get('context_size'): if credentials.get('context_size'):
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int( model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
credentials.get('context_size', 4096)) credentials.get('context_size', 4096))
model_features = ModelConfigs.get(
credentials['base_model_name'], {}).get('features', [])
entity = AIModelEntity( entity = AIModelEntity(
model=model, model=model,
label=I18nObject( label=I18nObject(
@ -260,7 +285,8 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM, model_type=ModelType.LLM,
model_properties=model_properties, model_properties=model_properties,
parameter_rules=rules parameter_rules=rules,
features=model_features,
) )
return entity return entity

View File

@ -1,3 +1,5 @@
from core.model_runtime.entities.model_entities import ModelFeature
ModelConfigs = { ModelConfigs = {
'Doubao-pro-4k': { 'Doubao-pro-4k': {
'req_params': { 'req_params': {
@ -7,7 +9,10 @@ ModelConfigs = {
'model_properties': { 'model_properties': {
'context_size': 4096, 'context_size': 4096,
'mode': 'chat', 'mode': 'chat',
} },
'features': [
ModelFeature.TOOL_CALL
],
}, },
'Doubao-lite-4k': { 'Doubao-lite-4k': {
'req_params': { 'req_params': {
@ -17,7 +22,10 @@ ModelConfigs = {
'model_properties': { 'model_properties': {
'context_size': 4096, 'context_size': 4096,
'mode': 'chat', 'mode': 'chat',
} },
'features': [
ModelFeature.TOOL_CALL
],
}, },
'Doubao-pro-32k': { 'Doubao-pro-32k': {
'req_params': { 'req_params': {
@ -27,7 +35,10 @@ ModelConfigs = {
'model_properties': { 'model_properties': {
'context_size': 32768, 'context_size': 32768,
'mode': 'chat', 'mode': 'chat',
} },
'features': [
ModelFeature.TOOL_CALL
],
}, },
'Doubao-lite-32k': { 'Doubao-lite-32k': {
'req_params': { 'req_params': {
@ -37,7 +48,10 @@ ModelConfigs = {
'model_properties': { 'model_properties': {
'context_size': 32768, 'context_size': 32768,
'mode': 'chat', 'mode': 'chat',
} },
'features': [
ModelFeature.TOOL_CALL
],
}, },
'Doubao-pro-128k': { 'Doubao-pro-128k': {
'req_params': { 'req_params': {
@ -47,7 +61,10 @@ ModelConfigs = {
'model_properties': { 'model_properties': {
'context_size': 131072, 'context_size': 131072,
'mode': 'chat', 'mode': 'chat',
} },
'features': [
ModelFeature.TOOL_CALL
],
}, },
'Doubao-lite-128k': { 'Doubao-lite-128k': {
'req_params': { 'req_params': {
@ -57,7 +74,10 @@ ModelConfigs = {
'model_properties': { 'model_properties': {
'context_size': 131072, 'context_size': 131072,
'mode': 'chat', 'mode': 'chat',
} },
'features': [
ModelFeature.TOOL_CALL
],
}, },
'Skylark2-pro-4k': { 'Skylark2-pro-4k': {
'req_params': { 'req_params': {
@ -67,26 +87,29 @@ ModelConfigs = {
'model_properties': { 'model_properties': {
'context_size': 4096, 'context_size': 4096,
'mode': 'chat', 'mode': 'chat',
} },
'features': [],
}, },
'Llama3-8B': { 'Llama3-8B': {
'req_params': { 'req_params': {
'max_prompt_tokens': 8192, 'max_prompt_tokens': 8192,
'max_new_tokens': 8192, 'max_new_tokens': 8192,
}, },
'model_properties': { 'model_properties': {
'context_size': 8192, 'context_size': 8192,
'mode': 'chat', 'mode': 'chat',
} },
'features': [],
}, },
'Llama3-70B': { 'Llama3-70B': {
'req_params': { 'req_params': {
'max_prompt_tokens': 8192, 'max_prompt_tokens': 8192,
'max_new_tokens': 8192, 'max_new_tokens': 8192,
}, },
'model_properties': { 'model_properties': {
'context_size': 8192, 'context_size': 8192,
'mode': 'chat', 'mode': 'chat',
} },
'features': [],
} }
} }