From 0ce97e6315836a95cc07ec6d9fe099e8537311c5 Mon Sep 17 00:00:00 2001 From: sino Date: Wed, 12 Jun 2024 15:43:50 +0800 Subject: [PATCH] feat: support doubao llm function calling (#5100) --- .../model_providers/volcengine_maas/client.py | 30 ++++++++++++- .../volcengine_maas/llm/llm.py | 32 +++++++++++-- .../volcengine_maas/llm/models.py | 45 ++++++++++++++----- 3 files changed, 91 insertions(+), 16 deletions(-) diff --git a/api/core/model_runtime/model_providers/volcengine_maas/client.py b/api/core/model_runtime/model_providers/volcengine_maas/client.py index c7bf4fde8c..471cb3c94e 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/client.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/client.py @@ -7,7 +7,9 @@ from core.model_runtime.entities.message_entities import ( ImagePromptMessageContent, PromptMessage, PromptMessageContentType, + PromptMessageTool, SystemPromptMessage, + ToolPromptMessage, UserPromptMessage, ) from core.model_runtime.model_providers.volcengine_maas.errors import wrap_error @@ -36,10 +38,11 @@ class MaaSClient(MaasService): client.set_sk(sk) return client - def chat(self, params: dict, messages: list[PromptMessage], stream=False) -> Generator | dict: + def chat(self, params: dict, messages: list[PromptMessage], stream=False, **extra_model_kwargs) -> Generator | dict: req = { 'parameters': params, - 'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages] + 'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages], + **extra_model_kwargs, } if not stream: return super().chat( @@ -89,10 +92,22 @@ class MaaSClient(MaasService): message = cast(AssistantPromptMessage, message) message_dict = {'role': ChatRole.ASSISTANT, 'content': message.content} + if message.tool_calls: + message_dict['tool_calls'] = [ + { + 'name': call.function.name, + 'arguments': call.function.arguments + } for call in message.tool_calls + ] elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) message_dict = {'role': ChatRole.SYSTEM, 'content': message.content} + elif isinstance(message, ToolPromptMessage): + message = cast(ToolPromptMessage, message) + message_dict = {'role': ChatRole.FUNCTION, + 'content': message.content, + 'name': message.tool_call_id} else: raise ValueError(f"Got unknown PromptMessage type {message}") @@ -106,3 +121,14 @@ class MaaSClient(MaasService): raise wrap_error(e) return resp + + @staticmethod + def transform_tool_prompt_to_maas_config(tool: PromptMessageTool): + return { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters, + } + } diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py index 7a36d019e2..8bea30324b 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py @@ -119,8 +119,15 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): if stop: req_params['stop'] = stop + extra_model_kwargs = {} + + if tools: + extra_model_kwargs['tools'] = [ + MaaSClient.transform_tool_prompt_to_maas_config(tool) for tool in tools + ] + resp = MaaSClient.wrap_exception( - lambda: client.chat(req_params, prompt_messages, stream)) + lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs)) if not stream: return self._handle_chat_response(model, credentials, prompt_messages, resp) return self._handle_stream_chat_response(model, credentials, prompt_messages, resp) @@ -156,12 +163,26 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): choice = choices[0] message = choice['message'] + # parse tool calls + tool_calls = [] + if message['tool_calls']: + for call in message['tool_calls']: + tool_call = AssistantPromptMessage.ToolCall( + id=call['function']['name'], + type=call['type'], + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=call['function']['name'], + arguments=call['function']['arguments'] + ) + ) + tool_calls.append(tool_call) + return LLMResult( model=model, prompt_messages=prompt_messages, message=AssistantPromptMessage( content=message['content'] if message['content'] else '', - tool_calls=[], + tool_calls=tool_calls, ), usage=self._calc_usage(model, credentials, resp['usage']), ) @@ -252,6 +273,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): if credentials.get('context_size'): model_properties[ModelPropertyKey.CONTEXT_SIZE] = int( credentials.get('context_size', 4096)) + + model_features = ModelConfigs.get( + credentials['base_model_name'], {}).get('features', []) + entity = AIModelEntity( model=model, label=I18nObject( @@ -260,7 +285,8 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, model_properties=model_properties, - parameter_rules=rules + parameter_rules=rules, + features=model_features, ) return entity diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py index a2bd81b945..3a793cd6a8 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py @@ -1,3 +1,5 @@ +from core.model_runtime.entities.model_entities import ModelFeature + ModelConfigs = { 'Doubao-pro-4k': { 'req_params': { @@ -7,7 +9,10 @@ ModelConfigs = { 'model_properties': { 'context_size': 4096, 'mode': 'chat', - } + }, + 'features': [ + ModelFeature.TOOL_CALL + ], }, 'Doubao-lite-4k': { 'req_params': { @@ -17,7 +22,10 @@ ModelConfigs = { 'model_properties': { 'context_size': 4096, 'mode': 'chat', - } + }, + 'features': [ + ModelFeature.TOOL_CALL + ], }, 'Doubao-pro-32k': { 'req_params': { @@ -27,7 +35,10 @@ ModelConfigs = { 'model_properties': { 'context_size': 32768, 'mode': 'chat', - } + }, + 'features': [ + ModelFeature.TOOL_CALL + ], }, 'Doubao-lite-32k': { 'req_params': { @@ -37,7 +48,10 @@ ModelConfigs = { 'model_properties': { 'context_size': 32768, 'mode': 'chat', - } + }, + 'features': [ + ModelFeature.TOOL_CALL + ], }, 'Doubao-pro-128k': { 'req_params': { @@ -47,7 +61,10 @@ ModelConfigs = { 'model_properties': { 'context_size': 131072, 'mode': 'chat', - } + }, + 'features': [ + ModelFeature.TOOL_CALL + ], }, 'Doubao-lite-128k': { 'req_params': { @@ -57,7 +74,10 @@ ModelConfigs = { 'model_properties': { 'context_size': 131072, 'mode': 'chat', - } + }, + 'features': [ + ModelFeature.TOOL_CALL + ], }, 'Skylark2-pro-4k': { 'req_params': { @@ -67,26 +87,29 @@ ModelConfigs = { 'model_properties': { 'context_size': 4096, 'mode': 'chat', - } + }, + 'features': [], }, 'Llama3-8B': { - 'req_params': { + 'req_params': { 'max_prompt_tokens': 8192, 'max_new_tokens': 8192, }, 'model_properties': { 'context_size': 8192, 'mode': 'chat', - } + }, + 'features': [], }, 'Llama3-70B': { - 'req_params': { + 'req_params': { 'max_prompt_tokens': 8192, 'max_new_tokens': 8192, }, 'model_properties': { 'context_size': 8192, 'mode': 'chat', - } + }, + 'features': [], } }