[ref] use one method to get boto client for aws bedrock (#11506)

This commit is contained in:
Warren Chen 2024-12-12 13:56:52 +08:00 committed by GitHub
parent a360af8687
commit 7b5839335a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 27 additions and 27 deletions

View File

@ -0,0 +1,21 @@
import boto3
from botocore.config import Config
def get_bedrock_client(service_name, credentials=None):
client_config = Config(region_name=credentials["aws_region"])
aws_access_key_id = credentials["aws_access_key_id"]
aws_secret_access_key = credentials["aws_secret_access_key"]
if aws_access_key_id and aws_secret_access_key:
# use aksk to call bedrock
client = boto3.client(
service_name=service_name,
config=client_config,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
)
else:
# use iam without aksk to call
client = boto3.client(service_name=service_name, config=client_config)
return client

View File

@ -40,6 +40,7 @@ from core.model_runtime.errors.invoke import (
) )
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
@ -173,13 +174,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:param stream: is stream response :param stream: is stream response
:return: full response or stream response chunk generator result :return: full response or stream response chunk generator result
""" """
bedrock_client = boto3.client( bedrock_client = get_bedrock_client("bedrock-runtime", credentials)
service_name="bedrock-runtime",
aws_access_key_id=credentials.get("aws_access_key_id"),
aws_secret_access_key=credentials.get("aws_secret_access_key"),
region_name=credentials["aws_region"],
)
system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages) system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages)
inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop) inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop)

View File

@ -1,8 +1,5 @@
from typing import Optional from typing import Optional
import boto3
from botocore.config import Config
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import ( from core.model_runtime.errors.invoke import (
InvokeAuthorizationError, InvokeAuthorizationError,
@ -14,6 +11,7 @@ from core.model_runtime.errors.invoke import (
) )
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel from core.model_runtime.model_providers.__base.rerank_model import RerankModel
from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client
class BedrockRerankModel(RerankModel): class BedrockRerankModel(RerankModel):
@ -48,13 +46,7 @@ class BedrockRerankModel(RerankModel):
return RerankResult(model=model, docs=docs) return RerankResult(model=model, docs=docs)
# initialize client # initialize client
client_config = Config(region_name=credentials["aws_region"]) bedrock_runtime = get_bedrock_client("bedrock-agent-runtime", credentials)
bedrock_runtime = boto3.client(
service_name="bedrock-agent-runtime",
config=client_config,
aws_access_key_id=credentials.get("aws_access_key_id", ""),
aws_secret_access_key=credentials.get("aws_secret_access_key"),
)
queries = [{"type": "TEXT", "textQuery": {"text": query}}] queries = [{"type": "TEXT", "textQuery": {"text": query}}]
text_sources = [] text_sources = []
for text in docs: for text in docs:

View File

@ -3,8 +3,6 @@ import logging
import time import time
from typing import Optional from typing import Optional
import boto3
from botocore.config import Config
from botocore.exceptions import ( from botocore.exceptions import (
ClientError, ClientError,
EndpointConnectionError, EndpointConnectionError,
@ -25,6 +23,7 @@ from core.model_runtime.errors.invoke import (
InvokeServerUnavailableError, InvokeServerUnavailableError,
) )
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -48,14 +47,7 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
:param input_type: input type :param input_type: input type
:return: embeddings result :return: embeddings result
""" """
client_config = Config(region_name=credentials["aws_region"]) bedrock_runtime = get_bedrock_client("bedrock-runtime", credentials)
bedrock_runtime = boto3.client(
service_name="bedrock-runtime",
config=client_config,
aws_access_key_id=credentials.get("aws_access_key_id"),
aws_secret_access_key=credentials.get("aws_secret_access_key"),
)
embeddings = [] embeddings = []
token_usage = 0 token_usage = 0