From fbfc811a447b46bea05c58177551903c737dc1a6 Mon Sep 17 00:00:00 2001 From: GeorgeCaoJ Date: Wed, 20 Nov 2024 11:15:19 +0800 Subject: [PATCH] feat: support function call for ollama block chat api (#10784) --- .../model_providers/ollama/llm/llm.py | 68 +++++++++++++++++-- .../model_providers/ollama/ollama.yaml | 19 ++++++ 2 files changed, 82 insertions(+), 5 deletions(-) diff --git a/api/core/model_runtime/model_providers/ollama/llm/llm.py b/api/core/model_runtime/model_providers/ollama/llm/llm.py index a7ea53e0e9..094a674645 100644 --- a/api/core/model_runtime/model_providers/ollama/llm/llm.py +++ b/api/core/model_runtime/model_providers/ollama/llm/llm.py @@ -22,6 +22,7 @@ from core.model_runtime.entities.message_entities import ( PromptMessageTool, SystemPromptMessage, TextPromptMessageContent, + ToolPromptMessage, UserPromptMessage, ) from core.model_runtime.entities.model_entities import ( @@ -86,6 +87,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters, + tools=tools, stop=stop, stream=stream, user=user, @@ -153,6 +155,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, @@ -196,6 +199,8 @@ class OllamaLargeLanguageModel(LargeLanguageModel): if completion_type is LLMMode.CHAT: endpoint_url = urljoin(endpoint_url, "api/chat") data["messages"] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] + if tools: + data["tools"] = [self._convert_prompt_message_tool_to_dict(tool) for tool in tools] else: endpoint_url = urljoin(endpoint_url, "api/generate") first_prompt_message = prompt_messages[0] @@ -232,7 +237,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): if stream: return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages) - return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages) + return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages, tools) def _handle_generate_response( self, @@ -241,6 +246,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): completion_type: LLMMode, response: requests.Response, prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]], ) -> LLMResult: """ Handle llm completion response @@ -253,14 +259,16 @@ class OllamaLargeLanguageModel(LargeLanguageModel): :return: llm result """ response_json = response.json() - + tool_calls = [] if completion_type is LLMMode.CHAT: message = response_json.get("message", {}) response_content = message.get("content", "") + response_tool_calls = message.get("tool_calls", []) + tool_calls = [self._extract_response_tool_call(tool_call) for tool_call in response_tool_calls] else: response_content = response_json["response"] - assistant_message = AssistantPromptMessage(content=response_content) + assistant_message = AssistantPromptMessage(content=response_content, tool_calls=tool_calls) if "prompt_eval_count" in response_json and "eval_count" in response_json: # transform usage @@ -405,9 +413,28 @@ class OllamaLargeLanguageModel(LargeLanguageModel): chunk_index += 1 + def _convert_prompt_message_tool_to_dict(self, tool: PromptMessageTool) -> dict: + """ + Convert PromptMessageTool to dict for Ollama API + + :param tool: tool + :return: tool dict + """ + return { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters, + }, + } + def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: """ Convert PromptMessage to dict for Ollama API + + :param message: prompt message + :return: message dict """ if isinstance(message, UserPromptMessage): message = cast(UserPromptMessage, message) @@ -432,6 +459,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel): elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) message_dict = {"role": "system", "content": message.content} + elif isinstance(message, ToolPromptMessage): + message = cast(ToolPromptMessage, message) + message_dict = {"role": "tool", "content": message.content} else: raise ValueError(f"Got unknown type {message}") @@ -452,6 +482,29 @@ class OllamaLargeLanguageModel(LargeLanguageModel): return num_tokens + def _extract_response_tool_call(self, response_tool_call: dict) -> AssistantPromptMessage.ToolCall: + """ + Extract response tool call + """ + tool_call = None + if response_tool_call and "function" in response_tool_call: + # Convert arguments to JSON string if it's a dict + arguments = response_tool_call.get("function").get("arguments") + if isinstance(arguments, dict): + arguments = json.dumps(arguments) + + function = AssistantPromptMessage.ToolCall.ToolCallFunction( + name=response_tool_call.get("function").get("name"), + arguments=arguments, + ) + tool_call = AssistantPromptMessage.ToolCall( + id=response_tool_call.get("function").get("name"), + type="function", + function=function, + ) + + return tool_call + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ Get customizable model schema. @@ -461,10 +514,15 @@ class OllamaLargeLanguageModel(LargeLanguageModel): :return: model schema """ - extras = {} + extras = { + "features": [], + } if "vision_support" in credentials and credentials["vision_support"] == "true": - extras["features"] = [ModelFeature.VISION] + extras["features"].append(ModelFeature.VISION) + if "function_call_support" in credentials and credentials["function_call_support"] == "true": + extras["features"].append(ModelFeature.TOOL_CALL) + extras["features"].append(ModelFeature.MULTI_TOOL_CALL) entity = AIModelEntity( model=model, diff --git a/api/core/model_runtime/model_providers/ollama/ollama.yaml b/api/core/model_runtime/model_providers/ollama/ollama.yaml index 33747753bd..6560fcd180 100644 --- a/api/core/model_runtime/model_providers/ollama/ollama.yaml +++ b/api/core/model_runtime/model_providers/ollama/ollama.yaml @@ -96,3 +96,22 @@ model_credential_schema: label: en_US: 'No' zh_Hans: 否 + - variable: function_call_support + label: + zh_Hans: 是否支持函数调用 + en_US: Function call support + show_on: + - variable: __model_type + value: llm + default: 'false' + type: radio + required: false + options: + - value: 'true' + label: + en_US: 'Yes' + zh_Hans: 是 + - value: 'false' + label: + en_US: 'No' + zh_Hans: 否