mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-19 20:59:09 +08:00
Support for Vertex AI (#4586)
This commit is contained in:
parent
9ae72cdcf4
commit
296887754f
@ -2,6 +2,7 @@
|
|||||||
- anthropic
|
- anthropic
|
||||||
- azure_openai
|
- azure_openai
|
||||||
- google
|
- google
|
||||||
|
- vertex_ai
|
||||||
- nvidia
|
- nvidia
|
||||||
- cohere
|
- cohere
|
||||||
- bedrock
|
- bedrock
|
||||||
|
Binary file not shown.
After Width: | Height: | Size: 18 KiB |
@ -0,0 +1 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" width="24px" height="24px"><path d="M20,13.89A.77.77,0,0,0,19,13.73l-7,5.14v.22a.72.72,0,1,1,0,1.43v0a.74.74,0,0,0,.45-.15l7.41-5.47A.76.76,0,0,0,20,13.89Z" style="fill:#669df6"/><path d="M12,20.52a.72.72,0,0,1,0-1.43h0v-.22L5,13.73a.76.76,0,0,0-1,.16.74.74,0,0,0,.16,1l7.41,5.47a.73.73,0,0,0,.44.15v0Z" style="fill:#aecbfa"/><path d="M12,18.34a1.47,1.47,0,1,0,1.47,1.47A1.47,1.47,0,0,0,12,18.34Zm0,2.18a.72.72,0,1,1,.72-.71A.71.71,0,0,1,12,20.52Z" style="fill:#4285f4"/><path d="M6,6.11a.76.76,0,0,1-.75-.75V3.48a.76.76,0,1,1,1.51,0V5.36A.76.76,0,0,1,6,6.11Z" style="fill:#aecbfa"/><circle cx="5.98" cy="12" r="0.76" style="fill:#aecbfa"/><circle cx="5.98" cy="9.79" r="0.76" style="fill:#aecbfa"/><circle cx="5.98" cy="7.57" r="0.76" style="fill:#aecbfa"/><path d="M18,8.31a.76.76,0,0,1-.75-.76V5.67a.75.75,0,1,1,1.5,0V7.55A.75.75,0,0,1,18,8.31Z" style="fill:#4285f4"/><circle cx="18.02" cy="12.01" r="0.76" style="fill:#4285f4"/><circle cx="18.02" cy="9.76" r="0.76" style="fill:#4285f4"/><circle cx="18.02" cy="3.48" r="0.76" style="fill:#4285f4"/><path d="M12,15a.76.76,0,0,1-.75-.75V12.34a.76.76,0,0,1,1.51,0v1.89A.76.76,0,0,1,12,15Z" style="fill:#669df6"/><circle cx="12" cy="16.45" r="0.76" style="fill:#669df6"/><circle cx="12" cy="10.14" r="0.76" style="fill:#669df6"/><circle cx="12" cy="7.92" r="0.76" style="fill:#669df6"/><path d="M15,10.54a.76.76,0,0,1-.75-.75V7.91a.76.76,0,1,1,1.51,0V9.79A.76.76,0,0,1,15,10.54Z" style="fill:#4285f4"/><circle cx="15.01" cy="5.69" r="0.76" style="fill:#4285f4"/><circle cx="15.01" cy="14.19" r="0.76" style="fill:#4285f4"/><circle cx="15.01" cy="11.97" r="0.76" style="fill:#4285f4"/><circle cx="8.99" cy="14.19" r="0.76" style="fill:#aecbfa"/><circle cx="8.99" cy="7.92" r="0.76" style="fill:#aecbfa"/><circle cx="8.99" cy="5.69" r="0.76" style="fill:#aecbfa"/><path d="M9,12.73A.76.76,0,0,1,8.24,12V10.1a.75.75,0,1,1,1.5,0V12A.75.75,0,0,1,9,12.73Z" style="fill:#aecbfa"/></svg>
|
After Width: | Height: | Size: 1.9 KiB |
15
api/core/model_runtime/model_providers/vertex_ai/_common.py
Normal file
15
api/core/model_runtime/model_providers/vertex_ai/_common.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
|
|
||||||
|
|
||||||
|
class _CommonVertexAi:
|
||||||
|
@property
|
||||||
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
|
"""
|
||||||
|
Map model invoke error to unified error
|
||||||
|
The key is the error type thrown to the caller
|
||||||
|
The value is the error type thrown by the model,
|
||||||
|
which needs to be converted into a unified error type for the caller.
|
||||||
|
|
||||||
|
:return: Invoke error mapping
|
||||||
|
"""
|
||||||
|
pass
|
@ -0,0 +1,38 @@
|
|||||||
|
model: gemini-1.0-pro-vision-001
|
||||||
|
label:
|
||||||
|
en_US: Gemini 1.0 Pro Vision
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- vision
|
||||||
|
- tool-call
|
||||||
|
- stream-tool-call
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 16384
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
|
- name: presence_penalty
|
||||||
|
use_template: presence_penalty
|
||||||
|
- name: frequency_penalty
|
||||||
|
use_template: frequency_penalty
|
||||||
|
- name: max_output_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
required: true
|
||||||
|
default: 2048
|
||||||
|
min: 1
|
||||||
|
max: 2048
|
||||||
|
pricing:
|
||||||
|
input: '0.00'
|
||||||
|
output: '0.00'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
@ -0,0 +1,38 @@
|
|||||||
|
model: gemini-1.0-pro-002
|
||||||
|
label:
|
||||||
|
en_US: Gemini 1.0 Pro
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
- stream-tool-call
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 32760
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
|
- name: presence_penalty
|
||||||
|
use_template: presence_penalty
|
||||||
|
- name: frequency_penalty
|
||||||
|
use_template: frequency_penalty
|
||||||
|
- name: max_output_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
required: true
|
||||||
|
default: 8192
|
||||||
|
min: 1
|
||||||
|
max: 8192
|
||||||
|
pricing:
|
||||||
|
input: '0.00'
|
||||||
|
output: '0.00'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
@ -0,0 +1,38 @@
|
|||||||
|
model: gemini-1.5-flash-preview-0514
|
||||||
|
label:
|
||||||
|
en_US: Gemini 1.5 Flash
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- vision
|
||||||
|
- tool-call
|
||||||
|
- stream-tool-call
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 1048576
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
|
- name: presence_penalty
|
||||||
|
use_template: presence_penalty
|
||||||
|
- name: frequency_penalty
|
||||||
|
use_template: frequency_penalty
|
||||||
|
- name: max_output_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
required: true
|
||||||
|
default: 8192
|
||||||
|
min: 1
|
||||||
|
max: 8192
|
||||||
|
pricing:
|
||||||
|
input: '0.00'
|
||||||
|
output: '0.00'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
@ -0,0 +1,39 @@
|
|||||||
|
model: gemini-1.5-pro-preview-0514
|
||||||
|
label:
|
||||||
|
en_US: Gemini 1.5 Pro
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
- vision
|
||||||
|
- tool-call
|
||||||
|
- stream-tool-call
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 1048576
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
|
- name: presence_penalty
|
||||||
|
use_template: presence_penalty
|
||||||
|
- name: frequency_penalty
|
||||||
|
use_template: frequency_penalty
|
||||||
|
- name: max_output_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
required: true
|
||||||
|
default: 8192
|
||||||
|
min: 1
|
||||||
|
max: 8192
|
||||||
|
pricing:
|
||||||
|
input: '0.00'
|
||||||
|
output: '0.00'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
438
api/core/model_runtime/model_providers/vertex_ai/llm/llm.py
Normal file
438
api/core/model_runtime/model_providers/vertex_ai/llm/llm.py
Normal file
@ -0,0 +1,438 @@
|
|||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import google.api_core.exceptions as exceptions
|
||||||
|
import vertexai.generative_models as glm
|
||||||
|
from google.cloud import aiplatform
|
||||||
|
from google.oauth2 import service_account
|
||||||
|
from vertexai.generative_models import HarmBlockThreshold, HarmCategory
|
||||||
|
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||||
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
PromptMessage,
|
||||||
|
PromptMessageContentType,
|
||||||
|
PromptMessageTool,
|
||||||
|
SystemPromptMessage,
|
||||||
|
ToolPromptMessage,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.invoke import (
|
||||||
|
InvokeAuthorizationError,
|
||||||
|
InvokeBadRequestError,
|
||||||
|
InvokeConnectionError,
|
||||||
|
InvokeError,
|
||||||
|
InvokeRateLimitError,
|
||||||
|
InvokeServerUnavailableError,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||||
|
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||||
|
if you are not sure about the structure.
|
||||||
|
|
||||||
|
<instructions>
|
||||||
|
{{instructions}}
|
||||||
|
</instructions>
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class VertexAiLargeLanguageModel(LargeLanguageModel):
|
||||||
|
|
||||||
|
def _invoke(self, model: str, credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True, user: Optional[str] = None) \
|
||||||
|
-> Union[LLMResult, Generator]:
|
||||||
|
"""
|
||||||
|
Invoke large language model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param model_parameters: model parameters
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:param stop: stop words
|
||||||
|
:param stream: is stream response
|
||||||
|
:param user: unique user id
|
||||||
|
:return: full response or stream response chunk generator result
|
||||||
|
"""
|
||||||
|
# invoke model
|
||||||
|
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||||
|
|
||||||
|
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||||
|
"""
|
||||||
|
Get number of tokens for given prompt messages
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:return:md = gml.GenerativeModel(model)
|
||||||
|
"""
|
||||||
|
prompt = self._convert_messages_to_prompt(prompt_messages)
|
||||||
|
|
||||||
|
return self._get_num_tokens_by_gpt2(prompt)
|
||||||
|
|
||||||
|
def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
|
||||||
|
"""
|
||||||
|
Format a list of messages into a full prompt for the Google model
|
||||||
|
|
||||||
|
:param messages: List of PromptMessage to combine.
|
||||||
|
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
||||||
|
"""
|
||||||
|
messages = messages.copy() # don't mutate the original list
|
||||||
|
|
||||||
|
text = "".join(
|
||||||
|
self._convert_one_message_to_text(message)
|
||||||
|
for message in messages
|
||||||
|
)
|
||||||
|
|
||||||
|
return text.rstrip()
|
||||||
|
|
||||||
|
def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool:
|
||||||
|
"""
|
||||||
|
Convert tool messages to glm tools
|
||||||
|
|
||||||
|
:param tools: tool messages
|
||||||
|
:return: glm tools
|
||||||
|
"""
|
||||||
|
return glm.Tool(
|
||||||
|
function_declarations=[
|
||||||
|
glm.FunctionDeclaration(
|
||||||
|
name=tool.name,
|
||||||
|
parameters=glm.Schema(
|
||||||
|
type=glm.Type.OBJECT,
|
||||||
|
properties={
|
||||||
|
key: {
|
||||||
|
'type_': value.get('type', 'string').upper(),
|
||||||
|
'description': value.get('description', ''),
|
||||||
|
'enum': value.get('enum', [])
|
||||||
|
} for key, value in tool.parameters.get('properties', {}).items()
|
||||||
|
},
|
||||||
|
required=tool.parameters.get('required', [])
|
||||||
|
),
|
||||||
|
) for tool in tools
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
ping_message = SystemPromptMessage(content="ping")
|
||||||
|
self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5})
|
||||||
|
|
||||||
|
except Exception as ex:
|
||||||
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
|
|
||||||
|
def _generate(self, model: str, credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True, user: Optional[str] = None
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
|
"""
|
||||||
|
Invoke large language model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: credentials kwargs
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param model_parameters: model parameters
|
||||||
|
:param stop: stop words
|
||||||
|
:param stream: is stream response
|
||||||
|
:param user: unique user id
|
||||||
|
:return: full response or stream response chunk generator result
|
||||||
|
"""
|
||||||
|
config_kwargs = model_parameters.copy()
|
||||||
|
config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None)
|
||||||
|
|
||||||
|
if stop:
|
||||||
|
config_kwargs["stop_sequences"] = stop
|
||||||
|
|
||||||
|
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
|
||||||
|
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
|
||||||
|
project_id = credentials["vertex_project_id"]
|
||||||
|
location = credentials["vertex_location"]
|
||||||
|
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
|
||||||
|
|
||||||
|
history = []
|
||||||
|
system_instruction = GEMINI_BLOCK_MODE_PROMPT
|
||||||
|
# hack for gemini-pro-vision, which currently does not support multi-turn chat
|
||||||
|
if model == "gemini-1.0-pro-vision-001":
|
||||||
|
last_msg = prompt_messages[-1]
|
||||||
|
content = self._format_message_to_glm_content(last_msg)
|
||||||
|
history.append(content)
|
||||||
|
else:
|
||||||
|
for msg in prompt_messages:
|
||||||
|
if isinstance(msg, SystemPromptMessage):
|
||||||
|
system_instruction = msg.content
|
||||||
|
else:
|
||||||
|
content = self._format_message_to_glm_content(msg)
|
||||||
|
if history and history[-1].role == content.role:
|
||||||
|
history[-1].parts.extend(content.parts)
|
||||||
|
else:
|
||||||
|
history.append(content)
|
||||||
|
|
||||||
|
safety_settings={
|
||||||
|
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
||||||
|
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
||||||
|
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
||||||
|
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
||||||
|
}
|
||||||
|
|
||||||
|
google_model = glm.GenerativeModel(
|
||||||
|
model_name=model,
|
||||||
|
system_instruction=system_instruction
|
||||||
|
)
|
||||||
|
|
||||||
|
response = google_model.generate_content(
|
||||||
|
contents=history,
|
||||||
|
generation_config=glm.GenerationConfig(
|
||||||
|
**config_kwargs
|
||||||
|
),
|
||||||
|
stream=stream,
|
||||||
|
safety_settings=safety_settings,
|
||||||
|
tools=self._convert_tools_to_glm_tool(tools) if tools else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
||||||
|
|
||||||
|
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||||
|
|
||||||
|
def _handle_generate_response(self, model: str, credentials: dict, response: glm.GenerationResponse,
|
||||||
|
prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||||
|
"""
|
||||||
|
Handle llm response
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: credentials
|
||||||
|
:param response: response
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:return: llm response
|
||||||
|
"""
|
||||||
|
# transform assistant message to prompt message
|
||||||
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
|
content=response.candidates[0].content.parts[0].text
|
||||||
|
)
|
||||||
|
|
||||||
|
# calculate num tokens
|
||||||
|
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||||
|
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||||
|
|
||||||
|
# transform usage
|
||||||
|
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||||
|
|
||||||
|
# transform response
|
||||||
|
result = LLMResult(
|
||||||
|
model=model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
message=assistant_prompt_message,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _handle_generate_stream_response(self, model: str, credentials: dict, response: glm.GenerationResponse,
|
||||||
|
prompt_messages: list[PromptMessage]) -> Generator:
|
||||||
|
"""
|
||||||
|
Handle llm stream response
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: credentials
|
||||||
|
:param response: response
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:return: llm response chunk generator result
|
||||||
|
"""
|
||||||
|
index = -1
|
||||||
|
for chunk in response:
|
||||||
|
for part in chunk.candidates[0].content.parts:
|
||||||
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
|
content=''
|
||||||
|
)
|
||||||
|
|
||||||
|
if part.text:
|
||||||
|
assistant_prompt_message.content += part.text
|
||||||
|
|
||||||
|
if part.function_call:
|
||||||
|
assistant_prompt_message.tool_calls = [
|
||||||
|
AssistantPromptMessage.ToolCall(
|
||||||
|
id=part.function_call.name,
|
||||||
|
type='function',
|
||||||
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
|
name=part.function_call.name,
|
||||||
|
arguments=json.dumps({
|
||||||
|
key: value
|
||||||
|
for key, value in part.function_call.args.items()
|
||||||
|
})
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
index += 1
|
||||||
|
|
||||||
|
if not hasattr(chunk, 'finish_reason') or not chunk.finish_reason:
|
||||||
|
# transform assistant message to prompt message
|
||||||
|
yield LLMResultChunk(
|
||||||
|
model=model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=index,
|
||||||
|
message=assistant_prompt_message
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
|
||||||
|
# calculate num tokens
|
||||||
|
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||||
|
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||||
|
|
||||||
|
# transform usage
|
||||||
|
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||||
|
|
||||||
|
yield LLMResultChunk(
|
||||||
|
model=model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=index,
|
||||||
|
message=assistant_prompt_message,
|
||||||
|
finish_reason=chunk.candidates[0].finish_reason,
|
||||||
|
usage=usage
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
|
||||||
|
"""
|
||||||
|
Convert a single message to a string.
|
||||||
|
|
||||||
|
:param message: PromptMessage to convert.
|
||||||
|
:return: String representation of the message.
|
||||||
|
"""
|
||||||
|
human_prompt = "\n\nuser:"
|
||||||
|
ai_prompt = "\n\nmodel:"
|
||||||
|
|
||||||
|
content = message.content
|
||||||
|
if isinstance(content, list):
|
||||||
|
content = "".join(
|
||||||
|
c.data for c in content if c.type != PromptMessageContentType.IMAGE
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(message, UserPromptMessage):
|
||||||
|
message_text = f"{human_prompt} {content}"
|
||||||
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
|
message_text = f"{ai_prompt} {content}"
|
||||||
|
elif isinstance(message, SystemPromptMessage):
|
||||||
|
message_text = f"{human_prompt} {content}"
|
||||||
|
elif isinstance(message, ToolPromptMessage):
|
||||||
|
message_text = f"{human_prompt} {content}"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
|
||||||
|
return message_text
|
||||||
|
|
||||||
|
def _format_message_to_glm_content(self, message: PromptMessage) -> glm.Content:
|
||||||
|
"""
|
||||||
|
Format a single message into glm.Content for Google API
|
||||||
|
|
||||||
|
:param message: one PromptMessage
|
||||||
|
:return: glm Content representation of message
|
||||||
|
"""
|
||||||
|
if isinstance(message, UserPromptMessage):
|
||||||
|
glm_content = glm.Content(role="user", parts=[])
|
||||||
|
|
||||||
|
if (isinstance(message.content, str)):
|
||||||
|
glm_content = glm.Content(role="user", parts=[glm.Part.from_text(message.content)])
|
||||||
|
else:
|
||||||
|
parts = []
|
||||||
|
for c in message.content:
|
||||||
|
if c.type == PromptMessageContentType.TEXT:
|
||||||
|
parts.append(glm.Part.from_text(c.data))
|
||||||
|
else:
|
||||||
|
metadata, data = c.data.split(',', 1)
|
||||||
|
mime_type = metadata.split(';', 1)[0].split(':')[1]
|
||||||
|
blob = {"inline_data":{"mime_type":mime_type,"data":data}}
|
||||||
|
parts.append(blob)
|
||||||
|
|
||||||
|
glm_content = glm.Content(role="user", parts=[parts])
|
||||||
|
return glm_content
|
||||||
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
|
if message.content:
|
||||||
|
glm_content = glm.Content(role="model", parts=[glm.Part.from_text(message.content)])
|
||||||
|
if message.tool_calls:
|
||||||
|
glm_content = glm.Content(role="model", parts=[glm.Part.from_function_response(glm.FunctionCall(
|
||||||
|
name=message.tool_calls[0].function.name,
|
||||||
|
args=json.loads(message.tool_calls[0].function.arguments),
|
||||||
|
))])
|
||||||
|
return glm_content
|
||||||
|
elif isinstance(message, ToolPromptMessage):
|
||||||
|
glm_content = glm.Content(role="function", parts=[glm.Part(function_response=glm.FunctionResponse(
|
||||||
|
name=message.name,
|
||||||
|
response={
|
||||||
|
"response": message.content
|
||||||
|
}
|
||||||
|
))])
|
||||||
|
return glm_content
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
|
"""
|
||||||
|
Map model invoke error to unified error
|
||||||
|
The key is the ermd = gml.GenerativeModel(model)ror type thrown to the caller
|
||||||
|
The value is the md = gml.GenerativeModel(model)error type thrown by the model,
|
||||||
|
which needs to be converted into a unified error type for the caller.
|
||||||
|
|
||||||
|
:return: Invoke emd = gml.GenerativeModel(model)rror mapping
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
InvokeConnectionError: [
|
||||||
|
exceptions.RetryError
|
||||||
|
],
|
||||||
|
InvokeServerUnavailableError: [
|
||||||
|
exceptions.ServiceUnavailable,
|
||||||
|
exceptions.InternalServerError,
|
||||||
|
exceptions.BadGateway,
|
||||||
|
exceptions.GatewayTimeout,
|
||||||
|
exceptions.DeadlineExceeded
|
||||||
|
],
|
||||||
|
InvokeRateLimitError: [
|
||||||
|
exceptions.ResourceExhausted,
|
||||||
|
exceptions.TooManyRequests
|
||||||
|
],
|
||||||
|
InvokeAuthorizationError: [
|
||||||
|
exceptions.Unauthenticated,
|
||||||
|
exceptions.PermissionDenied,
|
||||||
|
exceptions.Unauthenticated,
|
||||||
|
exceptions.Forbidden
|
||||||
|
],
|
||||||
|
InvokeBadRequestError: [
|
||||||
|
exceptions.BadRequest,
|
||||||
|
exceptions.InvalidArgument,
|
||||||
|
exceptions.FailedPrecondition,
|
||||||
|
exceptions.OutOfRange,
|
||||||
|
exceptions.NotFound,
|
||||||
|
exceptions.MethodNotAllowed,
|
||||||
|
exceptions.Conflict,
|
||||||
|
exceptions.AlreadyExists,
|
||||||
|
exceptions.Aborted,
|
||||||
|
exceptions.LengthRequired,
|
||||||
|
exceptions.PreconditionFailed,
|
||||||
|
exceptions.RequestRangeNotSatisfiable,
|
||||||
|
exceptions.Cancelled,
|
||||||
|
]
|
||||||
|
}
|
@ -0,0 +1,8 @@
|
|||||||
|
model: text-embedding-004
|
||||||
|
model_type: text-embedding
|
||||||
|
model_properties:
|
||||||
|
context_size: 2048
|
||||||
|
pricing:
|
||||||
|
input: '0.00013'
|
||||||
|
unit: '0.001'
|
||||||
|
currency: USD
|
@ -0,0 +1,8 @@
|
|||||||
|
model: text-multilingual-embedding-002
|
||||||
|
model_type: text-embedding
|
||||||
|
model_properties:
|
||||||
|
context_size: 2048
|
||||||
|
pricing:
|
||||||
|
input: '0.00013'
|
||||||
|
unit: '0.001'
|
||||||
|
currency: USD
|
@ -0,0 +1,193 @@
|
|||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from decimal import Decimal
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
|
from google.cloud import aiplatform
|
||||||
|
from google.oauth2 import service_account
|
||||||
|
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
|
||||||
|
|
||||||
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
from core.model_runtime.entities.model_entities import (
|
||||||
|
AIModelEntity,
|
||||||
|
FetchFrom,
|
||||||
|
ModelPropertyKey,
|
||||||
|
ModelType,
|
||||||
|
PriceConfig,
|
||||||
|
PriceType,
|
||||||
|
)
|
||||||
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||||
|
from core.model_runtime.model_providers.vertex_ai._common import _CommonVertexAi
|
||||||
|
|
||||||
|
|
||||||
|
class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
|
||||||
|
"""
|
||||||
|
Model class for Vertex AI text embedding model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _invoke(self, model: str, credentials: dict,
|
||||||
|
texts: list[str], user: Optional[str] = None) \
|
||||||
|
-> TextEmbeddingResult:
|
||||||
|
"""
|
||||||
|
Invoke text embedding model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param texts: texts to embed
|
||||||
|
:return: embeddings result
|
||||||
|
"""
|
||||||
|
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
|
||||||
|
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
|
||||||
|
project_id = credentials["vertex_project_id"]
|
||||||
|
location = credentials["vertex_location"]
|
||||||
|
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
|
||||||
|
|
||||||
|
client = VertexTextEmbeddingModel.from_pretrained(model)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||||
|
client=client,
|
||||||
|
texts=texts
|
||||||
|
)
|
||||||
|
|
||||||
|
# calc usage
|
||||||
|
usage = self._calc_response_usage(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
tokens=embedding_used_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
return TextEmbeddingResult(
|
||||||
|
embeddings=embeddings_batch,
|
||||||
|
usage=usage,
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||||
|
"""
|
||||||
|
Get number of tokens for given prompt messages
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param texts: texts to embed
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if len(texts) == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
enc = tiktoken.encoding_for_model(model)
|
||||||
|
except KeyError:
|
||||||
|
enc = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
|
total_num_tokens = 0
|
||||||
|
for text in texts:
|
||||||
|
# calculate the number of tokens in the encoded text
|
||||||
|
tokenized_text = enc.encode(text)
|
||||||
|
total_num_tokens += len(tokenized_text)
|
||||||
|
|
||||||
|
return total_num_tokens
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
|
||||||
|
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
|
||||||
|
project_id = credentials["vertex_project_id"]
|
||||||
|
location = credentials["vertex_location"]
|
||||||
|
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
|
||||||
|
|
||||||
|
client = VertexTextEmbeddingModel.from_pretrained(model)
|
||||||
|
|
||||||
|
# call embedding model
|
||||||
|
self._embedding_invoke(
|
||||||
|
model=model,
|
||||||
|
client=client,
|
||||||
|
texts=['ping']
|
||||||
|
)
|
||||||
|
except Exception as ex:
|
||||||
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
|
def _embedding_invoke(self, client: VertexTextEmbeddingModel, texts: list[str]) -> [list[float], int]: # type: ignore
|
||||||
|
"""
|
||||||
|
Invoke embedding model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param client: model client
|
||||||
|
:param texts: texts to embed
|
||||||
|
:return: embeddings and used tokens
|
||||||
|
"""
|
||||||
|
response = client.get_embeddings(texts)
|
||||||
|
|
||||||
|
embeddings = []
|
||||||
|
token_usage = 0
|
||||||
|
|
||||||
|
for i in range(len(response)):
|
||||||
|
embeddings.append(response[i].values)
|
||||||
|
token_usage += int(response[i].statistics.token_count)
|
||||||
|
|
||||||
|
return embeddings, token_usage
|
||||||
|
|
||||||
|
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||||
|
"""
|
||||||
|
Calculate response usage
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param tokens: input tokens
|
||||||
|
:return: usage
|
||||||
|
"""
|
||||||
|
# get input price info
|
||||||
|
input_price_info = self.get_price(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
price_type=PriceType.INPUT,
|
||||||
|
tokens=tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# transform usage
|
||||||
|
usage = EmbeddingUsage(
|
||||||
|
tokens=tokens,
|
||||||
|
total_tokens=tokens,
|
||||||
|
unit_price=input_price_info.unit_price,
|
||||||
|
price_unit=input_price_info.unit,
|
||||||
|
total_price=input_price_info.total_amount,
|
||||||
|
currency=input_price_info.currency,
|
||||||
|
latency=time.perf_counter() - self.started_at
|
||||||
|
)
|
||||||
|
|
||||||
|
return usage
|
||||||
|
|
||||||
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||||
|
"""
|
||||||
|
generate custom model entities from credentials
|
||||||
|
"""
|
||||||
|
entity = AIModelEntity(
|
||||||
|
model=model,
|
||||||
|
label=I18nObject(en_US=model),
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_properties={
|
||||||
|
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
|
||||||
|
ModelPropertyKey.MAX_CHUNKS: 1,
|
||||||
|
},
|
||||||
|
parameter_rules=[],
|
||||||
|
pricing=PriceConfig(
|
||||||
|
input=Decimal(credentials.get('input_price', 0)),
|
||||||
|
unit=Decimal(credentials.get('unit', 0)),
|
||||||
|
currency=credentials.get('currency', "USD")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return entity
|
@ -0,0 +1,31 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class VertexAiProvider(ModelProvider):
|
||||||
|
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate provider credentials
|
||||||
|
|
||||||
|
if validate failed, raise exception
|
||||||
|
|
||||||
|
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model_instance = self.get_model_instance(ModelType.LLM)
|
||||||
|
|
||||||
|
# Use `gemini-1.0-pro-002` model for validate,
|
||||||
|
model_instance.validate_credentials(
|
||||||
|
model='gemini-1.0-pro-002',
|
||||||
|
credentials=credentials
|
||||||
|
)
|
||||||
|
except CredentialsValidateFailedError as ex:
|
||||||
|
raise ex
|
||||||
|
except Exception as ex:
|
||||||
|
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||||
|
raise ex
|
@ -0,0 +1,43 @@
|
|||||||
|
provider: vertex_ai
|
||||||
|
label:
|
||||||
|
en_US: Vertex AI | Google Cloud Platform
|
||||||
|
description:
|
||||||
|
en_US: Vertex AI in Google Cloud Platform.
|
||||||
|
icon_small:
|
||||||
|
en_US: icon_s_en.svg
|
||||||
|
icon_large:
|
||||||
|
en_US: icon_l_en.png
|
||||||
|
background: "#FCFDFF"
|
||||||
|
help:
|
||||||
|
title:
|
||||||
|
en_US: Get your Access Details from Google
|
||||||
|
url:
|
||||||
|
en_US: https://cloud.google.com/vertex-ai/
|
||||||
|
supported_model_types:
|
||||||
|
- llm
|
||||||
|
- text-embedding
|
||||||
|
configurate_methods:
|
||||||
|
- predefined-model
|
||||||
|
provider_credential_schema:
|
||||||
|
credential_form_schemas:
|
||||||
|
- variable: vertex_project_id
|
||||||
|
label:
|
||||||
|
en_US: Project ID
|
||||||
|
type: text-input
|
||||||
|
required: true
|
||||||
|
placeholder:
|
||||||
|
en_US: Enter your Google Cloud Project ID
|
||||||
|
- variable: vertex_location
|
||||||
|
label:
|
||||||
|
en_US: Location
|
||||||
|
type: text-input
|
||||||
|
required: true
|
||||||
|
placeholder:
|
||||||
|
en_US: Enter your Google Cloud Location
|
||||||
|
- variable: vertex_service_account_key
|
||||||
|
label:
|
||||||
|
en_US: Service Account Key
|
||||||
|
type: secret-input
|
||||||
|
required: true
|
||||||
|
placeholder:
|
||||||
|
en_US: Enter your Google Cloud Service Account Key in base64 format
|
@ -84,3 +84,4 @@ pgvecto-rs==0.1.4
|
|||||||
firecrawl-py==0.0.5
|
firecrawl-py==0.0.5
|
||||||
oss2==2.18.5
|
oss2==2.18.5
|
||||||
pgvector==0.2.5
|
pgvector==0.2.5
|
||||||
|
google-cloud-aiplatform==1.49.0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user