fix llm integration problem: It doesn't work on docker env (#8701)

Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
This commit is contained in:
ybalbert001 2024-09-24 10:33:30 +08:00 committed by GitHub
parent 21e9608b23
commit 7c485f8bb8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -85,7 +85,6 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
""" """
sagemaker_client: Any = None sagemaker_client: Any = None
sagemaker_sess: Any = None
predictor: Any = None predictor: Any = None
def _handle_chat_generate_response( def _handle_chat_generate_response(
@ -213,23 +212,22 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
:return: full response or stream response chunk generator result :return: full response or stream response chunk generator result
""" """
if not self.sagemaker_client: if not self.sagemaker_client:
access_key = credentials.get("access_key") access_key = credentials.get("aws_access_key_id")
secret_key = credentials.get("secret_key") secret_key = credentials.get("aws_secret_access_key")
aws_region = credentials.get("aws_region") aws_region = credentials.get("aws_region")
boto_session = None
if aws_region: if aws_region:
if access_key and secret_key: if access_key and secret_key:
self.sagemaker_client = boto3.client( boto_session = boto3.Session(
"sagemaker-runtime", aws_access_key_id=access_key, aws_secret_access_key=secret_key, region_name=aws_region
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=aws_region,
) )
else: else:
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) boto_session = boto3.Session(region_name=aws_region)
else: else:
self.sagemaker_client = boto3.client("sagemaker-runtime") boto_session = boto3.Session()
sagemaker_session = Session(sagemaker_runtime_client=self.sagemaker_client) self.sagemaker_client = boto_session.client("sagemaker")
sagemaker_session = Session(boto_session=boto_session, sagemaker_client=self.sagemaker_client)
self.predictor = Predictor( self.predictor = Predictor(
endpoint_name=credentials.get("sagemaker_endpoint"), endpoint_name=credentials.get("sagemaker_endpoint"),
sagemaker_session=sagemaker_session, sagemaker_session=sagemaker_session,