diff --git a/api/core/model_runtime/model_providers/_position.yaml b/api/core/model_runtime/model_providers/_position.yaml index 1c27b2b4aa..a868cb8c78 100644 --- a/api/core/model_runtime/model_providers/_position.yaml +++ b/api/core/model_runtime/model_providers/_position.yaml @@ -2,6 +2,7 @@ - anthropic - azure_openai - google +- vertex_ai - nvidia - cohere - bedrock diff --git a/api/core/model_runtime/model_providers/vertex_ai/__init__.py b/api/core/model_runtime/model_providers/vertex_ai/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/vertex_ai/_assets/icon_l_en.png b/api/core/model_runtime/model_providers/vertex_ai/_assets/icon_l_en.png new file mode 100644 index 0000000000..9f8f05231a Binary files /dev/null and b/api/core/model_runtime/model_providers/vertex_ai/_assets/icon_l_en.png differ diff --git a/api/core/model_runtime/model_providers/vertex_ai/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/vertex_ai/_assets/icon_s_en.svg new file mode 100644 index 0000000000..efc3589c07 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/_assets/icon_s_en.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/vertex_ai/_common.py b/api/core/model_runtime/model_providers/vertex_ai/_common.py new file mode 100644 index 0000000000..8f7c859e38 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/_common.py @@ -0,0 +1,15 @@ +from core.model_runtime.errors.invoke import InvokeError + + +class _CommonVertexAi: + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + pass diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/__init__.py b/api/core/model_runtime/model_providers/vertex_ai/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro-vision.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro-vision.yaml new file mode 100644 index 0000000000..da3bc8a64a --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro-vision.yaml @@ -0,0 +1,38 @@ +model: gemini-1.0-pro-vision-001 +label: + en_US: Gemini 1.0 Pro Vision +model_type: llm +features: + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 16384 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + en_US: Top k + type: int + help: + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_output_tokens + use_template: max_tokens + required: true + default: 2048 + min: 1 + max: 2048 +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro.yaml new file mode 100644 index 0000000000..029fab718c --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro.yaml @@ -0,0 +1,38 @@ +model: gemini-1.0-pro-002 +label: + en_US: Gemini 1.0 Pro +model_type: llm +features: + - agent-thought + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 32760 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + en_US: Top k + type: int + help: + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_output_tokens + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-flash.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-flash.yaml new file mode 100644 index 0000000000..72b8410aa1 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-flash.yaml @@ -0,0 +1,38 @@ +model: gemini-1.5-flash-preview-0514 +label: + en_US: Gemini 1.5 Flash +model_type: llm +features: + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 1048576 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + en_US: Top k + type: int + help: + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_output_tokens + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-pro.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-pro.yaml new file mode 100644 index 0000000000..141f61aad6 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-pro.yaml @@ -0,0 +1,39 @@ +model: gemini-1.5-pro-preview-0514 +label: + en_US: Gemini 1.5 Pro +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 1048576 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + en_US: Top k + type: int + help: + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_output_tokens + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py new file mode 100644 index 0000000000..5e3905af98 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py @@ -0,0 +1,438 @@ +import base64 +import json +import logging +from collections.abc import Generator +from typing import Optional, Union + +import google.api_core.exceptions as exceptions +import vertexai.generative_models as glm +from google.cloud import aiplatform +from google.oauth2 import service_account +from vertexai.generative_models import HarmBlockThreshold, HarmCategory + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageContentType, + PromptMessageTool, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + +logger = logging.getLogger(__name__) + +GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. +The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure +if you are not sure about the structure. + + +{{instructions}} + +""" + + +class VertexAiLargeLanguageModel(LargeLanguageModel): + + def _invoke(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, + stream: bool = True, user: Optional[str] = None) \ + -> Union[LLMResult, Generator]: + """ + Invoke large language model + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param tools: tools for tool calling + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ + # invoke model + return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :return:md = gml.GenerativeModel(model) + """ + prompt = self._convert_messages_to_prompt(prompt_messages) + + return self._get_num_tokens_by_gpt2(prompt) + + def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: + """ + Format a list of messages into a full prompt for the Google model + + :param messages: List of PromptMessage to combine. + :return: Combined string with necessary human_prompt and ai_prompt tags. + """ + messages = messages.copy() # don't mutate the original list + + text = "".join( + self._convert_one_message_to_text(message) + for message in messages + ) + + return text.rstrip() + + def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool: + """ + Convert tool messages to glm tools + + :param tools: tool messages + :return: glm tools + """ + return glm.Tool( + function_declarations=[ + glm.FunctionDeclaration( + name=tool.name, + parameters=glm.Schema( + type=glm.Type.OBJECT, + properties={ + key: { + 'type_': value.get('type', 'string').upper(), + 'description': value.get('description', ''), + 'enum': value.get('enum', []) + } for key, value in tool.parameters.get('properties', {}).items() + }, + required=tool.parameters.get('required', []) + ), + ) for tool in tools + ] + ) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + + try: + ping_message = SystemPromptMessage(content="ping") + self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5}) + + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + + def _generate(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, + stream: bool = True, user: Optional[str] = None + ) -> Union[LLMResult, Generator]: + """ + Invoke large language model + + :param model: model name + :param credentials: credentials kwargs + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ + config_kwargs = model_parameters.copy() + config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None) + + if stop: + config_kwargs["stop_sequences"] = stop + + service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"])) + service_accountSA = service_account.Credentials.from_service_account_info(service_account_info) + project_id = credentials["vertex_project_id"] + location = credentials["vertex_location"] + aiplatform.init(credentials=service_accountSA, project=project_id, location=location) + + history = [] + system_instruction = GEMINI_BLOCK_MODE_PROMPT + # hack for gemini-pro-vision, which currently does not support multi-turn chat + if model == "gemini-1.0-pro-vision-001": + last_msg = prompt_messages[-1] + content = self._format_message_to_glm_content(last_msg) + history.append(content) + else: + for msg in prompt_messages: + if isinstance(msg, SystemPromptMessage): + system_instruction = msg.content + else: + content = self._format_message_to_glm_content(msg) + if history and history[-1].role == content.role: + history[-1].parts.extend(content.parts) + else: + history.append(content) + + safety_settings={ + HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, + } + + google_model = glm.GenerativeModel( + model_name=model, + system_instruction=system_instruction + ) + + response = google_model.generate_content( + contents=history, + generation_config=glm.GenerationConfig( + **config_kwargs + ), + stream=stream, + safety_settings=safety_settings, + tools=self._convert_tools_to_glm_tool(tools) if tools else None + ) + + if stream: + return self._handle_generate_stream_response(model, credentials, response, prompt_messages) + + return self._handle_generate_response(model, credentials, response, prompt_messages) + + def _handle_generate_response(self, model: str, credentials: dict, response: glm.GenerationResponse, + prompt_messages: list[PromptMessage]) -> LLMResult: + """ + Handle llm response + + :param model: model name + :param credentials: credentials + :param response: response + :param prompt_messages: prompt messages + :return: llm response + """ + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage( + content=response.candidates[0].content.parts[0].text + ) + + # calculate num tokens + prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) + completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + # transform response + result = LLMResult( + model=model, + prompt_messages=prompt_messages, + message=assistant_prompt_message, + usage=usage, + ) + + return result + + def _handle_generate_stream_response(self, model: str, credentials: dict, response: glm.GenerationResponse, + prompt_messages: list[PromptMessage]) -> Generator: + """ + Handle llm stream response + + :param model: model name + :param credentials: credentials + :param response: response + :param prompt_messages: prompt messages + :return: llm response chunk generator result + """ + index = -1 + for chunk in response: + for part in chunk.candidates[0].content.parts: + assistant_prompt_message = AssistantPromptMessage( + content='' + ) + + if part.text: + assistant_prompt_message.content += part.text + + if part.function_call: + assistant_prompt_message.tool_calls = [ + AssistantPromptMessage.ToolCall( + id=part.function_call.name, + type='function', + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=part.function_call.name, + arguments=json.dumps({ + key: value + for key, value in part.function_call.args.items() + }) + ) + ) + ] + + index += 1 + + if not hasattr(chunk, 'finish_reason') or not chunk.finish_reason: + # transform assistant message to prompt message + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=assistant_prompt_message + ) + ) + else: + + # calculate num tokens + prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) + completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=assistant_prompt_message, + finish_reason=chunk.candidates[0].finish_reason, + usage=usage + ) + ) + + def _convert_one_message_to_text(self, message: PromptMessage) -> str: + """ + Convert a single message to a string. + + :param message: PromptMessage to convert. + :return: String representation of the message. + """ + human_prompt = "\n\nuser:" + ai_prompt = "\n\nmodel:" + + content = message.content + if isinstance(content, list): + content = "".join( + c.data for c in content if c.type != PromptMessageContentType.IMAGE + ) + + if isinstance(message, UserPromptMessage): + message_text = f"{human_prompt} {content}" + elif isinstance(message, AssistantPromptMessage): + message_text = f"{ai_prompt} {content}" + elif isinstance(message, SystemPromptMessage): + message_text = f"{human_prompt} {content}" + elif isinstance(message, ToolPromptMessage): + message_text = f"{human_prompt} {content}" + else: + raise ValueError(f"Got unknown type {message}") + + return message_text + + def _format_message_to_glm_content(self, message: PromptMessage) -> glm.Content: + """ + Format a single message into glm.Content for Google API + + :param message: one PromptMessage + :return: glm Content representation of message + """ + if isinstance(message, UserPromptMessage): + glm_content = glm.Content(role="user", parts=[]) + + if (isinstance(message.content, str)): + glm_content = glm.Content(role="user", parts=[glm.Part.from_text(message.content)]) + else: + parts = [] + for c in message.content: + if c.type == PromptMessageContentType.TEXT: + parts.append(glm.Part.from_text(c.data)) + else: + metadata, data = c.data.split(',', 1) + mime_type = metadata.split(';', 1)[0].split(':')[1] + blob = {"inline_data":{"mime_type":mime_type,"data":data}} + parts.append(blob) + + glm_content = glm.Content(role="user", parts=[parts]) + return glm_content + elif isinstance(message, AssistantPromptMessage): + if message.content: + glm_content = glm.Content(role="model", parts=[glm.Part.from_text(message.content)]) + if message.tool_calls: + glm_content = glm.Content(role="model", parts=[glm.Part.from_function_response(glm.FunctionCall( + name=message.tool_calls[0].function.name, + args=json.loads(message.tool_calls[0].function.arguments), + ))]) + return glm_content + elif isinstance(message, ToolPromptMessage): + glm_content = glm.Content(role="function", parts=[glm.Part(function_response=glm.FunctionResponse( + name=message.name, + response={ + "response": message.content + } + ))]) + return glm_content + else: + raise ValueError(f"Got unknown type {message}") + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the ermd = gml.GenerativeModel(model)ror type thrown to the caller + The value is the md = gml.GenerativeModel(model)error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke emd = gml.GenerativeModel(model)rror mapping + """ + return { + InvokeConnectionError: [ + exceptions.RetryError + ], + InvokeServerUnavailableError: [ + exceptions.ServiceUnavailable, + exceptions.InternalServerError, + exceptions.BadGateway, + exceptions.GatewayTimeout, + exceptions.DeadlineExceeded + ], + InvokeRateLimitError: [ + exceptions.ResourceExhausted, + exceptions.TooManyRequests + ], + InvokeAuthorizationError: [ + exceptions.Unauthenticated, + exceptions.PermissionDenied, + exceptions.Unauthenticated, + exceptions.Forbidden + ], + InvokeBadRequestError: [ + exceptions.BadRequest, + exceptions.InvalidArgument, + exceptions.FailedPrecondition, + exceptions.OutOfRange, + exceptions.NotFound, + exceptions.MethodNotAllowed, + exceptions.Conflict, + exceptions.AlreadyExists, + exceptions.Aborted, + exceptions.LengthRequired, + exceptions.PreconditionFailed, + exceptions.RequestRangeNotSatisfiable, + exceptions.Cancelled, + ] + } \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/__init__.py b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-embedding-004.yaml b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-embedding-004.yaml new file mode 100644 index 0000000000..32db6faf89 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-embedding-004.yaml @@ -0,0 +1,8 @@ +model: text-embedding-004 +model_type: text-embedding +model_properties: + context_size: 2048 +pricing: + input: '0.00013' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-multilingual-embedding-002.yaml b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-multilingual-embedding-002.yaml new file mode 100644 index 0000000000..2ec0eea9f2 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-multilingual-embedding-002.yaml @@ -0,0 +1,8 @@ +model: text-multilingual-embedding-002 +model_type: text-embedding +model_properties: + context_size: 2048 +pricing: + input: '0.00013' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py new file mode 100644 index 0000000000..ece63806c3 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py @@ -0,0 +1,193 @@ +import base64 +import json +import time +from decimal import Decimal +from typing import Optional + +import tiktoken +from google.cloud import aiplatform +from google.oauth2 import service_account +from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelPropertyKey, + ModelType, + PriceConfig, + PriceType, +) +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from core.model_runtime.model_providers.vertex_ai._common import _CommonVertexAi + + +class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel): + """ + Model class for Vertex AI text embedding model. + """ + + def _invoke(self, model: str, credentials: dict, + texts: list[str], user: Optional[str] = None) \ + -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: embeddings result + """ + service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"])) + service_accountSA = service_account.Credentials.from_service_account_info(service_account_info) + project_id = credentials["vertex_project_id"] + location = credentials["vertex_location"] + aiplatform.init(credentials=service_accountSA, project=project_id, location=location) + + client = VertexTextEmbeddingModel.from_pretrained(model) + + + + embeddings_batch, embedding_used_tokens = self._embedding_invoke( + client=client, + texts=texts + ) + + # calc usage + usage = self._calc_response_usage( + model=model, + credentials=credentials, + tokens=embedding_used_tokens + ) + + return TextEmbeddingResult( + embeddings=embeddings_batch, + usage=usage, + model=model + ) + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + if len(texts) == 0: + return 0 + + try: + enc = tiktoken.encoding_for_model(model) + except KeyError: + enc = tiktoken.get_encoding("cl100k_base") + + total_num_tokens = 0 + for text in texts: + # calculate the number of tokens in the encoded text + tokenized_text = enc.encode(text) + total_num_tokens += len(tokenized_text) + + return total_num_tokens + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"])) + service_accountSA = service_account.Credentials.from_service_account_info(service_account_info) + project_id = credentials["vertex_project_id"] + location = credentials["vertex_location"] + aiplatform.init(credentials=service_accountSA, project=project_id, location=location) + + client = VertexTextEmbeddingModel.from_pretrained(model) + + # call embedding model + self._embedding_invoke( + model=model, + client=client, + texts=['ping'] + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def _embedding_invoke(self, client: VertexTextEmbeddingModel, texts: list[str]) -> [list[float], int]: # type: ignore + """ + Invoke embedding model + + :param model: model name + :param client: model client + :param texts: texts to embed + :return: embeddings and used tokens + """ + response = client.get_embeddings(texts) + + embeddings = [] + token_usage = 0 + + for i in range(len(response)): + embeddings.append(response[i].values) + token_usage += int(response[i].statistics.token_count) + + return embeddings, token_usage + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param tokens: input tokens + :return: usage + """ + # get input price info + input_price_info = self.get_price( + model=model, + credentials=credentials, + price_type=PriceType.INPUT, + tokens=tokens + ) + + # transform usage + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at + ) + + return usage + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + generate custom model entities from credentials + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.TEXT_EMBEDDING, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')), + ModelPropertyKey.MAX_CHUNKS: 1, + }, + parameter_rules=[], + pricing=PriceConfig( + input=Decimal(credentials.get('input_price', 0)), + unit=Decimal(credentials.get('unit', 0)), + currency=credentials.get('currency', "USD") + ) + ) + + return entity diff --git a/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py b/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py new file mode 100644 index 0000000000..3cbfb088d1 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py @@ -0,0 +1,31 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class VertexAiProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + model_instance = self.get_model_instance(ModelType.LLM) + + # Use `gemini-1.0-pro-002` model for validate, + model_instance.validate_credentials( + model='gemini-1.0-pro-002', + credentials=credentials + ) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + raise ex diff --git a/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.yaml b/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.yaml new file mode 100644 index 0000000000..8b7f216b55 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.yaml @@ -0,0 +1,43 @@ +provider: vertex_ai +label: + en_US: Vertex AI | Google Cloud Platform +description: + en_US: Vertex AI in Google Cloud Platform. +icon_small: + en_US: icon_s_en.svg +icon_large: + en_US: icon_l_en.png +background: "#FCFDFF" +help: + title: + en_US: Get your Access Details from Google + url: + en_US: https://cloud.google.com/vertex-ai/ +supported_model_types: + - llm + - text-embedding +configurate_methods: + - predefined-model +provider_credential_schema: + credential_form_schemas: + - variable: vertex_project_id + label: + en_US: Project ID + type: text-input + required: true + placeholder: + en_US: Enter your Google Cloud Project ID + - variable: vertex_location + label: + en_US: Location + type: text-input + required: true + placeholder: + en_US: Enter your Google Cloud Location + - variable: vertex_service_account_key + label: + en_US: Service Account Key + type: secret-input + required: true + placeholder: + en_US: Enter your Google Cloud Service Account Key in base64 format diff --git a/api/requirements.txt b/api/requirements.txt index c9e9c2fa29..306f600afc 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -84,3 +84,4 @@ pgvecto-rs==0.1.4 firecrawl-py==0.0.5 oss2==2.18.5 pgvector==0.2.5 +google-cloud-aiplatform==1.49.0