From a258a9029171a1b3ad403f88cf8482f49c18a7d4 Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Fri, 12 Apr 2024 16:38:02 +0800 Subject: [PATCH] feat: gemini pro function call (#3406) --- .../google/llm/gemini-1.5-pro-latest.yaml | 2 + .../google/llm/gemini-pro.yaml | 2 + .../model_providers/google/llm/llm.py | 198 ++++++++++++------ .../model_runtime/__mock/google.py | 11 +- 4 files changed, 151 insertions(+), 62 deletions(-) diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-latest.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-latest.yaml index 892284ae0b..d65dc02674 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-latest.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-latest.yaml @@ -5,6 +5,8 @@ model_type: llm features: - agent-thought - vision + - tool-call + - stream-tool-call model_properties: mode: chat context_size: 1048576 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-pro.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-pro.yaml index ffdc9c3659..4e9f59e7da 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-pro.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-pro.yaml @@ -4,6 +4,8 @@ label: model_type: llm features: - agent-thought + - tool-call + - stream-tool-call model_properties: mode: chat context_size: 30720 diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index 2feff8ebe9..27912b13cc 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -1,7 +1,9 @@ +import json import logging from collections.abc import Generator from typing import Optional, Union +import google.ai.generativelanguage as glm import google.api_core.exceptions as exceptions import google.generativeai as genai import google.generativeai.client as client @@ -13,9 +15,9 @@ from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, PromptMessageContentType, - PromptMessageRole, PromptMessageTool, SystemPromptMessage, + ToolPromptMessage, UserPromptMessage, ) from core.model_runtime.errors.invoke import ( @@ -62,7 +64,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ # invoke model - return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) + return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: @@ -94,6 +96,32 @@ class GoogleLargeLanguageModel(LargeLanguageModel): ) return text.rstrip() + + def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool: + """ + Convert tool messages to glm tools + + :param tools: tool messages + :return: glm tools + """ + return glm.Tool( + function_declarations=[ + glm.FunctionDeclaration( + name=tool.name, + parameters=glm.Schema( + type=glm.Type.OBJECT, + properties={ + key: { + 'type_': value.get('type', 'string').upper(), + 'description': value.get('description', ''), + 'enum': value.get('enum', []) + } for key, value in tool.parameters.get('properties', {}).items() + }, + required=tool.parameters.get('required', []) + ), + ) for tool in tools + ] + ) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -105,7 +133,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): """ try: - ping_message = PromptMessage(content="ping", role="system") + ping_message = SystemPromptMessage(content="ping") self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5}) except Exception as ex: @@ -114,8 +142,9 @@ class GoogleLargeLanguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, + stream: bool = True, user: Optional[str] = None + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -153,7 +182,6 @@ class GoogleLargeLanguageModel(LargeLanguageModel): else: history.append(content) - # Create a new ClientManager with tenant's API key new_client_manager = client._ClientManager() new_client_manager.configure(api_key=credentials["google_api_key"]) @@ -167,14 +195,15 @@ class GoogleLargeLanguageModel(LargeLanguageModel): HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, } - + response = google_model.generate_content( contents=history, generation_config=genai.types.GenerationConfig( **config_kwargs ), stream=stream, - safety_settings=safety_settings + safety_settings=safety_settings, + tools=self._convert_tools_to_glm_tool(tools) if tools else None, ) if stream: @@ -228,43 +257,61 @@ class GoogleLargeLanguageModel(LargeLanguageModel): """ index = -1 for chunk in response: - content = chunk.text - index += 1 - - assistant_prompt_message = AssistantPromptMessage( - content=content if content else '', - ) - - if not response._done: - - # transform assistant message to prompt message - yield LLMResultChunk( - model=model, - prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + for part in chunk.parts: + assistant_prompt_message = AssistantPromptMessage( + content='' ) - else: - - # calculate num tokens - prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) - completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) - # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - - yield LLMResultChunk( - model=model, - prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message, - finish_reason=chunk.candidates[0].finish_reason, - usage=usage + if part.text: + assistant_prompt_message.content += part.text + + if part.function_call: + assistant_prompt_message.tool_calls = [ + AssistantPromptMessage.ToolCall( + id=part.function_call.name, + type='function', + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=part.function_call.name, + arguments=json.dumps({ + key: value + for key, value in part.function_call.args.items() + }) + ) + ) + ] + + index += 1 + + if not response._done: + + # transform assistant message to prompt message + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=assistant_prompt_message + ) + ) + else: + + # calculate num tokens + prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) + completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=assistant_prompt_message, + finish_reason=chunk.candidates[0].finish_reason, + usage=usage + ) ) - ) def _convert_one_message_to_text(self, message: PromptMessage) -> str: """ @@ -288,6 +335,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel): message_text = f"{ai_prompt} {content}" elif isinstance(message, SystemPromptMessage): message_text = f"{human_prompt} {content}" + elif isinstance(message, ToolPromptMessage): + message_text = f"{human_prompt} {content}" else: raise ValueError(f"Got unknown type {message}") @@ -300,26 +349,53 @@ class GoogleLargeLanguageModel(LargeLanguageModel): :param message: one PromptMessage :return: glm Content representation of message """ - - parts = [] - if (isinstance(message.content, str)): - parts.append(to_part(message.content)) + if isinstance(message, UserPromptMessage): + glm_content = { + "role": "user", + "parts": [] + } + if (isinstance(message.content, str)): + glm_content['parts'].append(to_part(message.content)) + else: + for c in message.content: + if c.type == PromptMessageContentType.TEXT: + glm_content['parts'].append(to_part(c.data)) + else: + metadata, data = c.data.split(',', 1) + mime_type = metadata.split(';', 1)[0].split(':')[1] + blob = {"inline_data":{"mime_type":mime_type,"data":data}} + glm_content['parts'].append(blob) + return glm_content + elif isinstance(message, AssistantPromptMessage): + glm_content = { + "role": "model", + "parts": [] + } + if message.content: + glm_content['parts'].append(to_part(message.content)) + if message.tool_calls: + glm_content["parts"].append(to_part(glm.FunctionCall( + name=message.tool_calls[0].function.name, + args=json.loads(message.tool_calls[0].function.arguments), + ))) + return glm_content + elif isinstance(message, SystemPromptMessage): + return { + "role": "user", + "parts": [to_part(message.content)] + } + elif isinstance(message, ToolPromptMessage): + return { + "role": "function", + "parts": [glm.Part(function_response=glm.FunctionResponse( + name=message.name, + response={ + "response": message.content + } + ))] + } else: - for c in message.content: - if c.type == PromptMessageContentType.TEXT: - parts.append(to_part(c.data)) - else: - metadata, data = c.data.split(',', 1) - mime_type = metadata.split(';', 1)[0].split(':')[1] - blob = {"inline_data":{"mime_type":mime_type,"data":data}} - parts.append(blob) - - glm_content = { - "role": "user" if message.role in (PromptMessageRole.USER, PromptMessageRole.SYSTEM) else "model", - "parts": parts - } - - return glm_content + raise ValueError(f"Got unknown type {message}") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: diff --git a/api/tests/integration_tests/model_runtime/__mock/google.py b/api/tests/integration_tests/model_runtime/__mock/google.py index 4ac4dfe1f0..cc4d8c6fbd 100644 --- a/api/tests/integration_tests/model_runtime/__mock/google.py +++ b/api/tests/integration_tests/model_runtime/__mock/google.py @@ -10,6 +10,7 @@ from google.generativeai import GenerativeModel from google.generativeai.client import _ClientManager, configure from google.generativeai.types import GenerateContentResponse from google.generativeai.types.generation_types import BaseGenerateContentResponse +from google.ai.generativelanguage_v1beta.types import content as gag_content current_api_key = '' @@ -29,7 +30,7 @@ class MockGoogleResponseClass(object): }), chunks=[] - ) + ) else: yield GenerateContentResponse( done=False, @@ -43,6 +44,14 @@ class MockGoogleResponseClass(object): class MockGoogleResponseCandidateClass(object): finish_reason = 'stop' + @property + def content(self) -> gag_content.Content: + return gag_content.Content( + parts=[ + gag_content.Part(text='it\'s google!') + ] + ) + class MockGoogleClass(object): @staticmethod def generate_content_sync() -> GenerateContentResponse: