diff --git a/api/core/model_runtime/model_providers/_position.yaml b/api/core/model_runtime/model_providers/_position.yaml
index 1c27b2b4aa..a868cb8c78 100644
--- a/api/core/model_runtime/model_providers/_position.yaml
+++ b/api/core/model_runtime/model_providers/_position.yaml
@@ -2,6 +2,7 @@
- anthropic
- azure_openai
- google
+- vertex_ai
- nvidia
- cohere
- bedrock
diff --git a/api/core/model_runtime/model_providers/vertex_ai/__init__.py b/api/core/model_runtime/model_providers/vertex_ai/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/model_runtime/model_providers/vertex_ai/_assets/icon_l_en.png b/api/core/model_runtime/model_providers/vertex_ai/_assets/icon_l_en.png
new file mode 100644
index 0000000000..9f8f05231a
Binary files /dev/null and b/api/core/model_runtime/model_providers/vertex_ai/_assets/icon_l_en.png differ
diff --git a/api/core/model_runtime/model_providers/vertex_ai/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/vertex_ai/_assets/icon_s_en.svg
new file mode 100644
index 0000000000..efc3589c07
--- /dev/null
+++ b/api/core/model_runtime/model_providers/vertex_ai/_assets/icon_s_en.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/api/core/model_runtime/model_providers/vertex_ai/_common.py b/api/core/model_runtime/model_providers/vertex_ai/_common.py
new file mode 100644
index 0000000000..8f7c859e38
--- /dev/null
+++ b/api/core/model_runtime/model_providers/vertex_ai/_common.py
@@ -0,0 +1,15 @@
+from core.model_runtime.errors.invoke import InvokeError
+
+
+class _CommonVertexAi:
+ @property
+ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ """
+ Map model invoke error to unified error
+ The key is the error type thrown to the caller
+ The value is the error type thrown by the model,
+ which needs to be converted into a unified error type for the caller.
+
+ :return: Invoke error mapping
+ """
+ pass
diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/__init__.py b/api/core/model_runtime/model_providers/vertex_ai/llm/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro-vision.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro-vision.yaml
new file mode 100644
index 0000000000..da3bc8a64a
--- /dev/null
+++ b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro-vision.yaml
@@ -0,0 +1,38 @@
+model: gemini-1.0-pro-vision-001
+label:
+ en_US: Gemini 1.0 Pro Vision
+model_type: llm
+features:
+ - vision
+ - tool-call
+ - stream-tool-call
+model_properties:
+ mode: chat
+ context_size: 16384
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: top_k
+ label:
+ en_US: Top k
+ type: int
+ help:
+ en_US: Only sample from the top K options for each subsequent token.
+ required: false
+ - name: presence_penalty
+ use_template: presence_penalty
+ - name: frequency_penalty
+ use_template: frequency_penalty
+ - name: max_output_tokens
+ use_template: max_tokens
+ required: true
+ default: 2048
+ min: 1
+ max: 2048
+pricing:
+ input: '0.00'
+ output: '0.00'
+ unit: '0.000001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro.yaml
new file mode 100644
index 0000000000..029fab718c
--- /dev/null
+++ b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro.yaml
@@ -0,0 +1,38 @@
+model: gemini-1.0-pro-002
+label:
+ en_US: Gemini 1.0 Pro
+model_type: llm
+features:
+ - agent-thought
+ - tool-call
+ - stream-tool-call
+model_properties:
+ mode: chat
+ context_size: 32760
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: top_k
+ label:
+ en_US: Top k
+ type: int
+ help:
+ en_US: Only sample from the top K options for each subsequent token.
+ required: false
+ - name: presence_penalty
+ use_template: presence_penalty
+ - name: frequency_penalty
+ use_template: frequency_penalty
+ - name: max_output_tokens
+ use_template: max_tokens
+ required: true
+ default: 8192
+ min: 1
+ max: 8192
+pricing:
+ input: '0.00'
+ output: '0.00'
+ unit: '0.000001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-flash.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-flash.yaml
new file mode 100644
index 0000000000..72b8410aa1
--- /dev/null
+++ b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-flash.yaml
@@ -0,0 +1,38 @@
+model: gemini-1.5-flash-preview-0514
+label:
+ en_US: Gemini 1.5 Flash
+model_type: llm
+features:
+ - vision
+ - tool-call
+ - stream-tool-call
+model_properties:
+ mode: chat
+ context_size: 1048576
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: top_k
+ label:
+ en_US: Top k
+ type: int
+ help:
+ en_US: Only sample from the top K options for each subsequent token.
+ required: false
+ - name: presence_penalty
+ use_template: presence_penalty
+ - name: frequency_penalty
+ use_template: frequency_penalty
+ - name: max_output_tokens
+ use_template: max_tokens
+ required: true
+ default: 8192
+ min: 1
+ max: 8192
+pricing:
+ input: '0.00'
+ output: '0.00'
+ unit: '0.000001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-pro.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-pro.yaml
new file mode 100644
index 0000000000..141f61aad6
--- /dev/null
+++ b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-pro.yaml
@@ -0,0 +1,39 @@
+model: gemini-1.5-pro-preview-0514
+label:
+ en_US: Gemini 1.5 Pro
+model_type: llm
+features:
+ - agent-thought
+ - vision
+ - tool-call
+ - stream-tool-call
+model_properties:
+ mode: chat
+ context_size: 1048576
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: top_k
+ label:
+ en_US: Top k
+ type: int
+ help:
+ en_US: Only sample from the top K options for each subsequent token.
+ required: false
+ - name: presence_penalty
+ use_template: presence_penalty
+ - name: frequency_penalty
+ use_template: frequency_penalty
+ - name: max_output_tokens
+ use_template: max_tokens
+ required: true
+ default: 8192
+ min: 1
+ max: 8192
+pricing:
+ input: '0.00'
+ output: '0.00'
+ unit: '0.000001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py
new file mode 100644
index 0000000000..5e3905af98
--- /dev/null
+++ b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py
@@ -0,0 +1,438 @@
+import base64
+import json
+import logging
+from collections.abc import Generator
+from typing import Optional, Union
+
+import google.api_core.exceptions as exceptions
+import vertexai.generative_models as glm
+from google.cloud import aiplatform
+from google.oauth2 import service_account
+from vertexai.generative_models import HarmBlockThreshold, HarmCategory
+
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.message_entities import (
+ AssistantPromptMessage,
+ PromptMessage,
+ PromptMessageContentType,
+ PromptMessageTool,
+ SystemPromptMessage,
+ ToolPromptMessage,
+ UserPromptMessage,
+)
+from core.model_runtime.errors.invoke import (
+ InvokeAuthorizationError,
+ InvokeBadRequestError,
+ InvokeConnectionError,
+ InvokeError,
+ InvokeRateLimitError,
+ InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+
+logger = logging.getLogger(__name__)
+
+GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
+The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
+if you are not sure about the structure.
+
+
+{{instructions}}
+
+"""
+
+
+class VertexAiLargeLanguageModel(LargeLanguageModel):
+
+ def _invoke(self, model: str, credentials: dict,
+ prompt_messages: list[PromptMessage], model_parameters: dict,
+ tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
+ stream: bool = True, user: Optional[str] = None) \
+ -> Union[LLMResult, Generator]:
+ """
+ Invoke large language model
+
+ :param model: model name
+ :param credentials: model credentials
+ :param prompt_messages: prompt messages
+ :param model_parameters: model parameters
+ :param tools: tools for tool calling
+ :param stop: stop words
+ :param stream: is stream response
+ :param user: unique user id
+ :return: full response or stream response chunk generator result
+ """
+ # invoke model
+ return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
+
+ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+ tools: Optional[list[PromptMessageTool]] = None) -> int:
+ """
+ Get number of tokens for given prompt messages
+
+ :param model: model name
+ :param credentials: model credentials
+ :param prompt_messages: prompt messages
+ :param tools: tools for tool calling
+ :return:md = gml.GenerativeModel(model)
+ """
+ prompt = self._convert_messages_to_prompt(prompt_messages)
+
+ return self._get_num_tokens_by_gpt2(prompt)
+
+ def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
+ """
+ Format a list of messages into a full prompt for the Google model
+
+ :param messages: List of PromptMessage to combine.
+ :return: Combined string with necessary human_prompt and ai_prompt tags.
+ """
+ messages = messages.copy() # don't mutate the original list
+
+ text = "".join(
+ self._convert_one_message_to_text(message)
+ for message in messages
+ )
+
+ return text.rstrip()
+
+ def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool:
+ """
+ Convert tool messages to glm tools
+
+ :param tools: tool messages
+ :return: glm tools
+ """
+ return glm.Tool(
+ function_declarations=[
+ glm.FunctionDeclaration(
+ name=tool.name,
+ parameters=glm.Schema(
+ type=glm.Type.OBJECT,
+ properties={
+ key: {
+ 'type_': value.get('type', 'string').upper(),
+ 'description': value.get('description', ''),
+ 'enum': value.get('enum', [])
+ } for key, value in tool.parameters.get('properties', {}).items()
+ },
+ required=tool.parameters.get('required', [])
+ ),
+ ) for tool in tools
+ ]
+ )
+
+ def validate_credentials(self, model: str, credentials: dict) -> None:
+ """
+ Validate model credentials
+
+ :param model: model name
+ :param credentials: model credentials
+ :return:
+ """
+
+ try:
+ ping_message = SystemPromptMessage(content="ping")
+ self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5})
+
+ except Exception as ex:
+ raise CredentialsValidateFailedError(str(ex))
+
+
+ def _generate(self, model: str, credentials: dict,
+ prompt_messages: list[PromptMessage], model_parameters: dict,
+ tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
+ stream: bool = True, user: Optional[str] = None
+ ) -> Union[LLMResult, Generator]:
+ """
+ Invoke large language model
+
+ :param model: model name
+ :param credentials: credentials kwargs
+ :param prompt_messages: prompt messages
+ :param model_parameters: model parameters
+ :param stop: stop words
+ :param stream: is stream response
+ :param user: unique user id
+ :return: full response or stream response chunk generator result
+ """
+ config_kwargs = model_parameters.copy()
+ config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None)
+
+ if stop:
+ config_kwargs["stop_sequences"] = stop
+
+ service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
+ service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
+ project_id = credentials["vertex_project_id"]
+ location = credentials["vertex_location"]
+ aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
+
+ history = []
+ system_instruction = GEMINI_BLOCK_MODE_PROMPT
+ # hack for gemini-pro-vision, which currently does not support multi-turn chat
+ if model == "gemini-1.0-pro-vision-001":
+ last_msg = prompt_messages[-1]
+ content = self._format_message_to_glm_content(last_msg)
+ history.append(content)
+ else:
+ for msg in prompt_messages:
+ if isinstance(msg, SystemPromptMessage):
+ system_instruction = msg.content
+ else:
+ content = self._format_message_to_glm_content(msg)
+ if history and history[-1].role == content.role:
+ history[-1].parts.extend(content.parts)
+ else:
+ history.append(content)
+
+ safety_settings={
+ HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
+ }
+
+ google_model = glm.GenerativeModel(
+ model_name=model,
+ system_instruction=system_instruction
+ )
+
+ response = google_model.generate_content(
+ contents=history,
+ generation_config=glm.GenerationConfig(
+ **config_kwargs
+ ),
+ stream=stream,
+ safety_settings=safety_settings,
+ tools=self._convert_tools_to_glm_tool(tools) if tools else None
+ )
+
+ if stream:
+ return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
+
+ return self._handle_generate_response(model, credentials, response, prompt_messages)
+
+ def _handle_generate_response(self, model: str, credentials: dict, response: glm.GenerationResponse,
+ prompt_messages: list[PromptMessage]) -> LLMResult:
+ """
+ Handle llm response
+
+ :param model: model name
+ :param credentials: credentials
+ :param response: response
+ :param prompt_messages: prompt messages
+ :return: llm response
+ """
+ # transform assistant message to prompt message
+ assistant_prompt_message = AssistantPromptMessage(
+ content=response.candidates[0].content.parts[0].text
+ )
+
+ # calculate num tokens
+ prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
+ completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
+
+ # transform usage
+ usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+ # transform response
+ result = LLMResult(
+ model=model,
+ prompt_messages=prompt_messages,
+ message=assistant_prompt_message,
+ usage=usage,
+ )
+
+ return result
+
+ def _handle_generate_stream_response(self, model: str, credentials: dict, response: glm.GenerationResponse,
+ prompt_messages: list[PromptMessage]) -> Generator:
+ """
+ Handle llm stream response
+
+ :param model: model name
+ :param credentials: credentials
+ :param response: response
+ :param prompt_messages: prompt messages
+ :return: llm response chunk generator result
+ """
+ index = -1
+ for chunk in response:
+ for part in chunk.candidates[0].content.parts:
+ assistant_prompt_message = AssistantPromptMessage(
+ content=''
+ )
+
+ if part.text:
+ assistant_prompt_message.content += part.text
+
+ if part.function_call:
+ assistant_prompt_message.tool_calls = [
+ AssistantPromptMessage.ToolCall(
+ id=part.function_call.name,
+ type='function',
+ function=AssistantPromptMessage.ToolCall.ToolCallFunction(
+ name=part.function_call.name,
+ arguments=json.dumps({
+ key: value
+ for key, value in part.function_call.args.items()
+ })
+ )
+ )
+ ]
+
+ index += 1
+
+ if not hasattr(chunk, 'finish_reason') or not chunk.finish_reason:
+ # transform assistant message to prompt message
+ yield LLMResultChunk(
+ model=model,
+ prompt_messages=prompt_messages,
+ delta=LLMResultChunkDelta(
+ index=index,
+ message=assistant_prompt_message
+ )
+ )
+ else:
+
+ # calculate num tokens
+ prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
+ completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
+
+ # transform usage
+ usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+ yield LLMResultChunk(
+ model=model,
+ prompt_messages=prompt_messages,
+ delta=LLMResultChunkDelta(
+ index=index,
+ message=assistant_prompt_message,
+ finish_reason=chunk.candidates[0].finish_reason,
+ usage=usage
+ )
+ )
+
+ def _convert_one_message_to_text(self, message: PromptMessage) -> str:
+ """
+ Convert a single message to a string.
+
+ :param message: PromptMessage to convert.
+ :return: String representation of the message.
+ """
+ human_prompt = "\n\nuser:"
+ ai_prompt = "\n\nmodel:"
+
+ content = message.content
+ if isinstance(content, list):
+ content = "".join(
+ c.data for c in content if c.type != PromptMessageContentType.IMAGE
+ )
+
+ if isinstance(message, UserPromptMessage):
+ message_text = f"{human_prompt} {content}"
+ elif isinstance(message, AssistantPromptMessage):
+ message_text = f"{ai_prompt} {content}"
+ elif isinstance(message, SystemPromptMessage):
+ message_text = f"{human_prompt} {content}"
+ elif isinstance(message, ToolPromptMessage):
+ message_text = f"{human_prompt} {content}"
+ else:
+ raise ValueError(f"Got unknown type {message}")
+
+ return message_text
+
+ def _format_message_to_glm_content(self, message: PromptMessage) -> glm.Content:
+ """
+ Format a single message into glm.Content for Google API
+
+ :param message: one PromptMessage
+ :return: glm Content representation of message
+ """
+ if isinstance(message, UserPromptMessage):
+ glm_content = glm.Content(role="user", parts=[])
+
+ if (isinstance(message.content, str)):
+ glm_content = glm.Content(role="user", parts=[glm.Part.from_text(message.content)])
+ else:
+ parts = []
+ for c in message.content:
+ if c.type == PromptMessageContentType.TEXT:
+ parts.append(glm.Part.from_text(c.data))
+ else:
+ metadata, data = c.data.split(',', 1)
+ mime_type = metadata.split(';', 1)[0].split(':')[1]
+ blob = {"inline_data":{"mime_type":mime_type,"data":data}}
+ parts.append(blob)
+
+ glm_content = glm.Content(role="user", parts=[parts])
+ return glm_content
+ elif isinstance(message, AssistantPromptMessage):
+ if message.content:
+ glm_content = glm.Content(role="model", parts=[glm.Part.from_text(message.content)])
+ if message.tool_calls:
+ glm_content = glm.Content(role="model", parts=[glm.Part.from_function_response(glm.FunctionCall(
+ name=message.tool_calls[0].function.name,
+ args=json.loads(message.tool_calls[0].function.arguments),
+ ))])
+ return glm_content
+ elif isinstance(message, ToolPromptMessage):
+ glm_content = glm.Content(role="function", parts=[glm.Part(function_response=glm.FunctionResponse(
+ name=message.name,
+ response={
+ "response": message.content
+ }
+ ))])
+ return glm_content
+ else:
+ raise ValueError(f"Got unknown type {message}")
+
+ @property
+ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ """
+ Map model invoke error to unified error
+ The key is the ermd = gml.GenerativeModel(model)ror type thrown to the caller
+ The value is the md = gml.GenerativeModel(model)error type thrown by the model,
+ which needs to be converted into a unified error type for the caller.
+
+ :return: Invoke emd = gml.GenerativeModel(model)rror mapping
+ """
+ return {
+ InvokeConnectionError: [
+ exceptions.RetryError
+ ],
+ InvokeServerUnavailableError: [
+ exceptions.ServiceUnavailable,
+ exceptions.InternalServerError,
+ exceptions.BadGateway,
+ exceptions.GatewayTimeout,
+ exceptions.DeadlineExceeded
+ ],
+ InvokeRateLimitError: [
+ exceptions.ResourceExhausted,
+ exceptions.TooManyRequests
+ ],
+ InvokeAuthorizationError: [
+ exceptions.Unauthenticated,
+ exceptions.PermissionDenied,
+ exceptions.Unauthenticated,
+ exceptions.Forbidden
+ ],
+ InvokeBadRequestError: [
+ exceptions.BadRequest,
+ exceptions.InvalidArgument,
+ exceptions.FailedPrecondition,
+ exceptions.OutOfRange,
+ exceptions.NotFound,
+ exceptions.MethodNotAllowed,
+ exceptions.Conflict,
+ exceptions.AlreadyExists,
+ exceptions.Aborted,
+ exceptions.LengthRequired,
+ exceptions.PreconditionFailed,
+ exceptions.RequestRangeNotSatisfiable,
+ exceptions.Cancelled,
+ ]
+ }
\ No newline at end of file
diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/__init__.py b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-embedding-004.yaml b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-embedding-004.yaml
new file mode 100644
index 0000000000..32db6faf89
--- /dev/null
+++ b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-embedding-004.yaml
@@ -0,0 +1,8 @@
+model: text-embedding-004
+model_type: text-embedding
+model_properties:
+ context_size: 2048
+pricing:
+ input: '0.00013'
+ unit: '0.001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-multilingual-embedding-002.yaml b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-multilingual-embedding-002.yaml
new file mode 100644
index 0000000000..2ec0eea9f2
--- /dev/null
+++ b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-multilingual-embedding-002.yaml
@@ -0,0 +1,8 @@
+model: text-multilingual-embedding-002
+model_type: text-embedding
+model_properties:
+ context_size: 2048
+pricing:
+ input: '0.00013'
+ unit: '0.001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py
new file mode 100644
index 0000000000..ece63806c3
--- /dev/null
+++ b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py
@@ -0,0 +1,193 @@
+import base64
+import json
+import time
+from decimal import Decimal
+from typing import Optional
+
+import tiktoken
+from google.cloud import aiplatform
+from google.oauth2 import service_account
+from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import (
+ AIModelEntity,
+ FetchFrom,
+ ModelPropertyKey,
+ ModelType,
+ PriceConfig,
+ PriceType,
+)
+from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+from core.model_runtime.model_providers.vertex_ai._common import _CommonVertexAi
+
+
+class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
+ """
+ Model class for Vertex AI text embedding model.
+ """
+
+ def _invoke(self, model: str, credentials: dict,
+ texts: list[str], user: Optional[str] = None) \
+ -> TextEmbeddingResult:
+ """
+ Invoke text embedding model
+
+ :param model: model name
+ :param credentials: model credentials
+ :param texts: texts to embed
+ :return: embeddings result
+ """
+ service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
+ service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
+ project_id = credentials["vertex_project_id"]
+ location = credentials["vertex_location"]
+ aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
+
+ client = VertexTextEmbeddingModel.from_pretrained(model)
+
+
+
+ embeddings_batch, embedding_used_tokens = self._embedding_invoke(
+ client=client,
+ texts=texts
+ )
+
+ # calc usage
+ usage = self._calc_response_usage(
+ model=model,
+ credentials=credentials,
+ tokens=embedding_used_tokens
+ )
+
+ return TextEmbeddingResult(
+ embeddings=embeddings_batch,
+ usage=usage,
+ model=model
+ )
+
+ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+ """
+ Get number of tokens for given prompt messages
+
+ :param model: model name
+ :param credentials: model credentials
+ :param texts: texts to embed
+ :return:
+ """
+ if len(texts) == 0:
+ return 0
+
+ try:
+ enc = tiktoken.encoding_for_model(model)
+ except KeyError:
+ enc = tiktoken.get_encoding("cl100k_base")
+
+ total_num_tokens = 0
+ for text in texts:
+ # calculate the number of tokens in the encoded text
+ tokenized_text = enc.encode(text)
+ total_num_tokens += len(tokenized_text)
+
+ return total_num_tokens
+
+ def validate_credentials(self, model: str, credentials: dict) -> None:
+ """
+ Validate model credentials
+
+ :param model: model name
+ :param credentials: model credentials
+ :return:
+ """
+ try:
+ service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
+ service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
+ project_id = credentials["vertex_project_id"]
+ location = credentials["vertex_location"]
+ aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
+
+ client = VertexTextEmbeddingModel.from_pretrained(model)
+
+ # call embedding model
+ self._embedding_invoke(
+ model=model,
+ client=client,
+ texts=['ping']
+ )
+ except Exception as ex:
+ raise CredentialsValidateFailedError(str(ex))
+
+ def _embedding_invoke(self, client: VertexTextEmbeddingModel, texts: list[str]) -> [list[float], int]: # type: ignore
+ """
+ Invoke embedding model
+
+ :param model: model name
+ :param client: model client
+ :param texts: texts to embed
+ :return: embeddings and used tokens
+ """
+ response = client.get_embeddings(texts)
+
+ embeddings = []
+ token_usage = 0
+
+ for i in range(len(response)):
+ embeddings.append(response[i].values)
+ token_usage += int(response[i].statistics.token_count)
+
+ return embeddings, token_usage
+
+ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
+ """
+ Calculate response usage
+
+ :param model: model name
+ :param credentials: model credentials
+ :param tokens: input tokens
+ :return: usage
+ """
+ # get input price info
+ input_price_info = self.get_price(
+ model=model,
+ credentials=credentials,
+ price_type=PriceType.INPUT,
+ tokens=tokens
+ )
+
+ # transform usage
+ usage = EmbeddingUsage(
+ tokens=tokens,
+ total_tokens=tokens,
+ unit_price=input_price_info.unit_price,
+ price_unit=input_price_info.unit,
+ total_price=input_price_info.total_amount,
+ currency=input_price_info.currency,
+ latency=time.perf_counter() - self.started_at
+ )
+
+ return usage
+
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
+ """
+ generate custom model entities from credentials
+ """
+ entity = AIModelEntity(
+ model=model,
+ label=I18nObject(en_US=model),
+ model_type=ModelType.TEXT_EMBEDDING,
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+ model_properties={
+ ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
+ ModelPropertyKey.MAX_CHUNKS: 1,
+ },
+ parameter_rules=[],
+ pricing=PriceConfig(
+ input=Decimal(credentials.get('input_price', 0)),
+ unit=Decimal(credentials.get('unit', 0)),
+ currency=credentials.get('currency', "USD")
+ )
+ )
+
+ return entity
diff --git a/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py b/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py
new file mode 100644
index 0000000000..3cbfb088d1
--- /dev/null
+++ b/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py
@@ -0,0 +1,31 @@
+import logging
+
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+class VertexAiProvider(ModelProvider):
+ def validate_provider_credentials(self, credentials: dict) -> None:
+ """
+ Validate provider credentials
+
+ if validate failed, raise exception
+
+ :param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
+ """
+ try:
+ model_instance = self.get_model_instance(ModelType.LLM)
+
+ # Use `gemini-1.0-pro-002` model for validate,
+ model_instance.validate_credentials(
+ model='gemini-1.0-pro-002',
+ credentials=credentials
+ )
+ except CredentialsValidateFailedError as ex:
+ raise ex
+ except Exception as ex:
+ logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
+ raise ex
diff --git a/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.yaml b/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.yaml
new file mode 100644
index 0000000000..8b7f216b55
--- /dev/null
+++ b/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.yaml
@@ -0,0 +1,43 @@
+provider: vertex_ai
+label:
+ en_US: Vertex AI | Google Cloud Platform
+description:
+ en_US: Vertex AI in Google Cloud Platform.
+icon_small:
+ en_US: icon_s_en.svg
+icon_large:
+ en_US: icon_l_en.png
+background: "#FCFDFF"
+help:
+ title:
+ en_US: Get your Access Details from Google
+ url:
+ en_US: https://cloud.google.com/vertex-ai/
+supported_model_types:
+ - llm
+ - text-embedding
+configurate_methods:
+ - predefined-model
+provider_credential_schema:
+ credential_form_schemas:
+ - variable: vertex_project_id
+ label:
+ en_US: Project ID
+ type: text-input
+ required: true
+ placeholder:
+ en_US: Enter your Google Cloud Project ID
+ - variable: vertex_location
+ label:
+ en_US: Location
+ type: text-input
+ required: true
+ placeholder:
+ en_US: Enter your Google Cloud Location
+ - variable: vertex_service_account_key
+ label:
+ en_US: Service Account Key
+ type: secret-input
+ required: true
+ placeholder:
+ en_US: Enter your Google Cloud Service Account Key in base64 format
diff --git a/api/requirements.txt b/api/requirements.txt
index c9e9c2fa29..306f600afc 100644
--- a/api/requirements.txt
+++ b/api/requirements.txt
@@ -84,3 +84,4 @@ pgvecto-rs==0.1.4
firecrawl-py==0.0.5
oss2==2.18.5
pgvector==0.2.5
+google-cloud-aiplatform==1.49.0