Add support to boto3 default connection (#5246)

### What problem does this PR solve?
 
This pull request includes changes to the initialization logic of the
`ChatModel` and `EmbeddingModel` classes to enhance the handling of AWS
credentials.

Use cases:
- Use env variables for credentials instead of managing them on the DB 
- Easy connection when deploying on an AWS machine

### Type of change

- [X] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Omar Leonardo Sanchez Granados 2025-02-23 22:01:14 -05:00 committed by GitHub
parent a0b461a18e
commit 4f2816c01c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 3 deletions

View File

@ -703,7 +703,12 @@ class BedrockChat(Base):
self.bedrock_sk = json.loads(key).get('bedrock_sk', '')
self.bedrock_region = json.loads(key).get('bedrock_region', '')
self.model_name = model_name
self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
if self.bedrock_ak == '' or self.bedrock_sk == '' or self.bedrock_region == '':
# Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.)
self.client = boto3.client('bedrock-runtime')
else:
self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
def chat(self, system, history, gen_conf):

View File

@ -476,8 +476,13 @@ class BedrockEmbed(Base):
self.bedrock_sk = json.loads(key).get('bedrock_sk', '')
self.bedrock_region = json.loads(key).get('bedrock_region', '')
self.model_name = model_name
self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
if self.bedrock_ak == '' or self.bedrock_sk == '' or self.bedrock_region == '':
# Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.)
self.client = boto3.client('bedrock-runtime')
else:
self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
def encode(self, texts: list):
texts = [truncate(t, 8196) for t in texts]