feat: support doubao llm function calling (#5100)

This commit is contained in:
sino 2024-06-12 15:43:50 +08:00 committed by GitHub
parent 25b0a97851
commit 0ce97e6315
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 91 additions and 16 deletions

View File

@ -7,7 +7,9 @@ from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.model_providers.volcengine_maas.errors import wrap_error
@ -36,10 +38,11 @@ class MaaSClient(MaasService):
client.set_sk(sk)
return client
def chat(self, params: dict, messages: list[PromptMessage], stream=False) -> Generator | dict:
def chat(self, params: dict, messages: list[PromptMessage], stream=False, **extra_model_kwargs) -> Generator | dict:
req = {
'parameters': params,
'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages]
'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages],
**extra_model_kwargs,
}
if not stream:
return super().chat(
@ -89,10 +92,22 @@ class MaaSClient(MaasService):
message = cast(AssistantPromptMessage, message)
message_dict = {'role': ChatRole.ASSISTANT,
'content': message.content}
if message.tool_calls:
message_dict['tool_calls'] = [
{
'name': call.function.name,
'arguments': call.function.arguments
} for call in message.tool_calls
]
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {'role': ChatRole.SYSTEM,
'content': message.content}
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {'role': ChatRole.FUNCTION,
'content': message.content,
'name': message.tool_call_id}
else:
raise ValueError(f"Got unknown PromptMessage type {message}")
@ -106,3 +121,14 @@ class MaaSClient(MaasService):
raise wrap_error(e)
return resp
@staticmethod
def transform_tool_prompt_to_maas_config(tool: PromptMessageTool):
return {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
}
}

View File

@ -119,8 +119,15 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
if stop:
req_params['stop'] = stop
extra_model_kwargs = {}
if tools:
extra_model_kwargs['tools'] = [
MaaSClient.transform_tool_prompt_to_maas_config(tool) for tool in tools
]
resp = MaaSClient.wrap_exception(
lambda: client.chat(req_params, prompt_messages, stream))
lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs))
if not stream:
return self._handle_chat_response(model, credentials, prompt_messages, resp)
return self._handle_stream_chat_response(model, credentials, prompt_messages, resp)
@ -156,12 +163,26 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
choice = choices[0]
message = choice['message']
# parse tool calls
tool_calls = []
if message['tool_calls']:
for call in message['tool_calls']:
tool_call = AssistantPromptMessage.ToolCall(
id=call['function']['name'],
type=call['type'],
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=call['function']['name'],
arguments=call['function']['arguments']
)
)
tool_calls.append(tool_call)
return LLMResult(
model=model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(
content=message['content'] if message['content'] else '',
tool_calls=[],
tool_calls=tool_calls,
),
usage=self._calc_usage(model, credentials, resp['usage']),
)
@ -252,6 +273,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
if credentials.get('context_size'):
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
credentials.get('context_size', 4096))
model_features = ModelConfigs.get(
credentials['base_model_name'], {}).get('features', [])
entity = AIModelEntity(
model=model,
label=I18nObject(
@ -260,7 +285,8 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM,
model_properties=model_properties,
parameter_rules=rules
parameter_rules=rules,
features=model_features,
)
return entity

View File

@ -1,3 +1,5 @@
from core.model_runtime.entities.model_entities import ModelFeature
ModelConfigs = {
'Doubao-pro-4k': {
'req_params': {
@ -7,7 +9,10 @@ ModelConfigs = {
'model_properties': {
'context_size': 4096,
'mode': 'chat',
}
},
'features': [
ModelFeature.TOOL_CALL
],
},
'Doubao-lite-4k': {
'req_params': {
@ -17,7 +22,10 @@ ModelConfigs = {
'model_properties': {
'context_size': 4096,
'mode': 'chat',
}
},
'features': [
ModelFeature.TOOL_CALL
],
},
'Doubao-pro-32k': {
'req_params': {
@ -27,7 +35,10 @@ ModelConfigs = {
'model_properties': {
'context_size': 32768,
'mode': 'chat',
}
},
'features': [
ModelFeature.TOOL_CALL
],
},
'Doubao-lite-32k': {
'req_params': {
@ -37,7 +48,10 @@ ModelConfigs = {
'model_properties': {
'context_size': 32768,
'mode': 'chat',
}
},
'features': [
ModelFeature.TOOL_CALL
],
},
'Doubao-pro-128k': {
'req_params': {
@ -47,7 +61,10 @@ ModelConfigs = {
'model_properties': {
'context_size': 131072,
'mode': 'chat',
}
},
'features': [
ModelFeature.TOOL_CALL
],
},
'Doubao-lite-128k': {
'req_params': {
@ -57,7 +74,10 @@ ModelConfigs = {
'model_properties': {
'context_size': 131072,
'mode': 'chat',
}
},
'features': [
ModelFeature.TOOL_CALL
],
},
'Skylark2-pro-4k': {
'req_params': {
@ -67,26 +87,29 @@ ModelConfigs = {
'model_properties': {
'context_size': 4096,
'mode': 'chat',
}
},
'features': [],
},
'Llama3-8B': {
'req_params': {
'req_params': {
'max_prompt_tokens': 8192,
'max_new_tokens': 8192,
},
'model_properties': {
'context_size': 8192,
'mode': 'chat',
}
},
'features': [],
},
'Llama3-70B': {
'req_params': {
'req_params': {
'max_prompt_tokens': 8192,
'max_new_tokens': 8192,
},
'model_properties': {
'context_size': 8192,
'mode': 'chat',
}
},
'features': [],
}
}