mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 20:05:58 +08:00
Fix Issue: switch LLM of SageMaker endpoint doesn't take effect (#8737)
Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
This commit is contained in:
parent
91f70d0bd9
commit
68c7e68a8a
@ -84,8 +84,9 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
|
|||||||
Model class for Cohere large language model.
|
Model class for Cohere large language model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sagemaker_client: Any = None
|
sagemaker_session: Any = None
|
||||||
predictor: Any = None
|
predictor: Any = None
|
||||||
|
sagemaker_endpoint: str = None
|
||||||
|
|
||||||
def _handle_chat_generate_response(
|
def _handle_chat_generate_response(
|
||||||
self,
|
self,
|
||||||
@ -211,7 +212,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
|
|||||||
:param user: unique user id
|
:param user: unique user id
|
||||||
:return: full response or stream response chunk generator result
|
:return: full response or stream response chunk generator result
|
||||||
"""
|
"""
|
||||||
if not self.sagemaker_client:
|
if not self.sagemaker_session:
|
||||||
access_key = credentials.get("aws_access_key_id")
|
access_key = credentials.get("aws_access_key_id")
|
||||||
secret_key = credentials.get("aws_secret_access_key")
|
secret_key = credentials.get("aws_secret_access_key")
|
||||||
aws_region = credentials.get("aws_region")
|
aws_region = credentials.get("aws_region")
|
||||||
@ -226,11 +227,14 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
|
|||||||
else:
|
else:
|
||||||
boto_session = boto3.Session()
|
boto_session = boto3.Session()
|
||||||
|
|
||||||
self.sagemaker_client = boto_session.client("sagemaker")
|
sagemaker_client = boto_session.client("sagemaker")
|
||||||
sagemaker_session = Session(boto_session=boto_session, sagemaker_client=self.sagemaker_client)
|
self.sagemaker_session = Session(boto_session=boto_session, sagemaker_client=sagemaker_client)
|
||||||
|
|
||||||
|
if self.sagemaker_endpoint != credentials.get("sagemaker_endpoint"):
|
||||||
|
self.sagemaker_endpoint = credentials.get("sagemaker_endpoint")
|
||||||
self.predictor = Predictor(
|
self.predictor = Predictor(
|
||||||
endpoint_name=credentials.get("sagemaker_endpoint"),
|
endpoint_name=self.sagemaker_endpoint,
|
||||||
sagemaker_session=sagemaker_session,
|
sagemaker_session=self.sagemaker_session,
|
||||||
serializer=serializers.JSONSerializer(),
|
serializer=serializers.JSONSerializer(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user