diff --git a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py index 04789197ee..97b7692044 100644 --- a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py +++ b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py @@ -84,8 +84,9 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): Model class for Cohere large language model. """ - sagemaker_client: Any = None + sagemaker_session: Any = None predictor: Any = None + sagemaker_endpoint: str = None def _handle_chat_generate_response( self, @@ -211,7 +212,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): :param user: unique user id :return: full response or stream response chunk generator result """ - if not self.sagemaker_client: + if not self.sagemaker_session: access_key = credentials.get("aws_access_key_id") secret_key = credentials.get("aws_secret_access_key") aws_region = credentials.get("aws_region") @@ -226,11 +227,14 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): else: boto_session = boto3.Session() - self.sagemaker_client = boto_session.client("sagemaker") - sagemaker_session = Session(boto_session=boto_session, sagemaker_client=self.sagemaker_client) + sagemaker_client = boto_session.client("sagemaker") + self.sagemaker_session = Session(boto_session=boto_session, sagemaker_client=sagemaker_client) + + if self.sagemaker_endpoint != credentials.get("sagemaker_endpoint"): + self.sagemaker_endpoint = credentials.get("sagemaker_endpoint") self.predictor = Predictor( - endpoint_name=credentials.get("sagemaker_endpoint"), - sagemaker_session=sagemaker_session, + endpoint_name=self.sagemaker_endpoint, + sagemaker_session=self.sagemaker_session, serializer=serializers.JSONSerializer(), )