From d7a6f25c636c29a29176ba84bfad62b1e3000886 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=80=90=E5=B0=8F=E5=BF=83?= Date: Fri, 12 Jul 2024 11:07:38 +0800 Subject: [PATCH] fix: differentiate prompts fields based on function_calling_type (#5880) --- .../openai_api_compatible/llm/llm.py | 42 ++++++++++--------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index b76f460737..e5cc884b6d 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -616,30 +616,34 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): message = cast(AssistantPromptMessage, message) message_dict = {"role": "assistant", "content": message.content} if message.tool_calls: - # message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call - # in - # message.tool_calls] - - function_call = message.tool_calls[0] - message_dict["function_call"] = { - "name": function_call.function.name, - "arguments": function_call.function.arguments, - } + function_calling_type = credentials.get('function_calling_type', 'no_call') + if function_calling_type == 'tool_call': + message_dict["tool_calls"] = [tool_call.dict() for tool_call in + message.tool_calls] + elif function_calling_type == 'function_call': + function_call = message.tool_calls[0] + message_dict["function_call"] = { + "name": function_call.function.name, + "arguments": function_call.function.arguments, + } elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) message_dict = {"role": "system", "content": message.content} elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) - # message_dict = { - # "role": "tool", - # "content": message.content, - # "tool_call_id": message.tool_call_id - # } - message_dict = { - "role": "tool" if credentials and credentials.get('function_calling_type', 'no_call') == 'tool_call' else "function", - "content": message.content, - "name": message.tool_call_id - } + function_calling_type = credentials.get('function_calling_type', 'no_call') + if function_calling_type == 'tool_call': + message_dict = { + "role": "tool", + "content": message.content, + "tool_call_id": message.tool_call_id + } + elif function_calling_type == 'function_call': + message_dict = { + "role": "function", + "content": message.content, + "name": message.tool_call_id + } else: raise ValueError(f"Got unknown type {message}")