Fix model provider of vertex ai (#11437)

This commit is contained in:
Kazuki Takamatsu 2024-12-08 09:44:49 +09:00 committed by GitHub
parent 266d32bd77
commit 4d7cfd0de5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 8 deletions

View File

@ -104,13 +104,14 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
""" """
# use Anthropic official SDK references # use Anthropic official SDK references
# - https://github.com/anthropics/anthropic-sdk-python # - https://github.com/anthropics/anthropic-sdk-python
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"])) service_account_key = credentials.get("vertex_service_account_key", "")
project_id = credentials["vertex_project_id"] project_id = credentials["vertex_project_id"]
SCOPES = ["https://www.googleapis.com/auth/cloud-platform"] SCOPES = ["https://www.googleapis.com/auth/cloud-platform"]
token = "" token = ""
# get access token from service account credential # get access token from service account credential
if service_account_info: if service_account_key:
service_account_info = json.loads(base64.b64decode(service_account_key))
credentials = service_account.Credentials.from_service_account_info(service_account_info, scopes=SCOPES) credentials = service_account.Credentials.from_service_account_info(service_account_info, scopes=SCOPES)
request = google.auth.transport.requests.Request() request = google.auth.transport.requests.Request()
credentials.refresh(request) credentials.refresh(request)
@ -478,10 +479,11 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
if stop: if stop:
config_kwargs["stop_sequences"] = stop config_kwargs["stop_sequences"] = stop
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"])) service_account_key = credentials.get("vertex_service_account_key", "")
project_id = credentials["vertex_project_id"] project_id = credentials["vertex_project_id"]
location = credentials["vertex_location"] location = credentials["vertex_location"]
if service_account_info: if service_account_key:
service_account_info = json.loads(base64.b64decode(service_account_key))
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info) service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
aiplatform.init(credentials=service_accountSA, project=project_id, location=location) aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
else: else:

View File

@ -48,10 +48,11 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
:param input_type: input type :param input_type: input type
:return: embeddings result :return: embeddings result
""" """
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"])) service_account_key = credentials.get("vertex_service_account_key", "")
project_id = credentials["vertex_project_id"] project_id = credentials["vertex_project_id"]
location = credentials["vertex_location"] location = credentials["vertex_location"]
if service_account_info: if service_account_key:
service_account_info = json.loads(base64.b64decode(service_account_key))
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info) service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
aiplatform.init(credentials=service_accountSA, project=project_id, location=location) aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
else: else:
@ -100,10 +101,11 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
:return: :return:
""" """
try: try:
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"])) service_account_key = credentials.get("vertex_service_account_key", "")
project_id = credentials["vertex_project_id"] project_id = credentials["vertex_project_id"]
location = credentials["vertex_location"] location = credentials["vertex_location"]
if service_account_info: if service_account_key:
service_account_info = json.loads(base64.b64decode(service_account_key))
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info) service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
aiplatform.init(credentials=service_accountSA, project=project_id, location=location) aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
else: else: