diff --git a/api/core/model_runtime/model_providers/sagemaker/__init__.py b/api/core/model_runtime/model_providers/sagemaker/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/sagemaker/_assets/icon_l_en.png b/api/core/model_runtime/model_providers/sagemaker/_assets/icon_l_en.png new file mode 100644 index 0000000000..0abe07a78f Binary files /dev/null and b/api/core/model_runtime/model_providers/sagemaker/_assets/icon_l_en.png differ diff --git a/api/core/model_runtime/model_providers/sagemaker/_assets/icon_s_en.png b/api/core/model_runtime/model_providers/sagemaker/_assets/icon_s_en.png new file mode 100644 index 0000000000..6b88942a5c Binary files /dev/null and b/api/core/model_runtime/model_providers/sagemaker/_assets/icon_s_en.png differ diff --git a/api/core/model_runtime/model_providers/sagemaker/llm/__init__.py b/api/core/model_runtime/model_providers/sagemaker/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py new file mode 100644 index 0000000000..f8e7757a96 --- /dev/null +++ b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py @@ -0,0 +1,238 @@ +import json +import logging +from collections.abc import Generator +from typing import Any, Optional, Union + +import boto3 + +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, +) +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + +logger = logging.getLogger(__name__) + + +class SageMakerLargeLanguageModel(LargeLanguageModel): + """ + Model class for Cohere large language model. + """ + sagemaker_client: Any = None + + def _invoke(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, + stream: bool = True, user: Optional[str] = None) \ + -> Union[LLMResult, Generator]: + """ + Invoke large language model + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param tools: tools for tool calling + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ + # get model mode + model_mode = self.get_model_mode(model, credentials) + + if not self.sagemaker_client: + access_key = credentials.get('access_key') + secret_key = credentials.get('secret_key') + aws_region = credentials.get('aws_region') + if aws_region: + if access_key and secret_key: + self.sagemaker_client = boto3.client("sagemaker-runtime", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime") + + + sagemaker_endpoint = credentials.get('sagemaker_endpoint') + response_model = self.sagemaker_client.invoke_endpoint( + EndpointName=sagemaker_endpoint, + Body=json.dumps( + { + "inputs": prompt_messages[0].content, + "parameters": { "stop" : stop}, + "history" : [] + } + ), + ContentType="application/json", + ) + + assistant_text = response_model['Body'].read().decode('utf8') + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage( + content=assistant_text + ) + + usage = self._calc_response_usage(model, credentials, 0, 0) + + response = LLMResult( + model=model, + prompt_messages=prompt_messages, + message=assistant_prompt_message, + usage=usage + ) + + return response + + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :return: + """ + # get model mode + model_mode = self.get_model_mode(model) + + try: + return 0 + except Exception as e: + raise self._transform_invoke_error(e) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + # get model mode + model_mode = self.get_model_mode(model) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [ + InvokeConnectionError + ], + InvokeServerUnavailableError: [ + InvokeServerUnavailableError + ], + InvokeRateLimitError: [ + InvokeRateLimitError + ], + InvokeAuthorizationError: [ + InvokeAuthorizationError + ], + InvokeBadRequestError: [ + InvokeBadRequestError, + KeyError, + ValueError + ] + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + rules = [ + ParameterRule( + name='temperature', + type=ParameterType.FLOAT, + use_template='temperature', + label=I18nObject( + zh_Hans='温度', + en_US='Temperature' + ), + ), + ParameterRule( + name='top_p', + type=ParameterType.FLOAT, + use_template='top_p', + label=I18nObject( + zh_Hans='Top P', + en_US='Top P' + ) + ), + ParameterRule( + name='max_tokens', + type=ParameterType.INT, + use_template='max_tokens', + min=1, + max=credentials.get('context_length', 2048), + default=512, + label=I18nObject( + zh_Hans='最大生成长度', + en_US='Max Tokens' + ) + ) + ] + + completion_type = LLMMode.value_of(credentials["mode"]) + + if completion_type == LLMMode.CHAT: + print(f"completion_type : {LLMMode.CHAT.value}") + + if completion_type == LLMMode.COMPLETION: + print(f"completion_type : {LLMMode.COMPLETION.value}") + + features = [] + + support_function_call = credentials.get('support_function_call', False) + if support_function_call: + features.append(ModelFeature.TOOL_CALL) + + support_vision = credentials.get('support_vision', False) + if support_vision: + features.append(ModelFeature.VISION) + + context_length = credentials.get('context_length', 2048) + + entity = AIModelEntity( + model=model, + label=I18nObject( + en_US=model + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.LLM, + features=features, + model_properties={ + ModelPropertyKey.MODE: completion_type, + ModelPropertyKey.CONTEXT_SIZE: context_length + }, + parameter_rules=rules + ) + + return entity diff --git a/api/core/model_runtime/model_providers/sagemaker/rerank/__init__.py b/api/core/model_runtime/model_providers/sagemaker/rerank/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py new file mode 100644 index 0000000000..0b06f54ef1 --- /dev/null +++ b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py @@ -0,0 +1,190 @@ +import json +import logging +from typing import Any, Optional + +import boto3 + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + +logger = logging.getLogger(__name__) + +class SageMakerRerankModel(RerankModel): + """ + Model class for Cohere rerank model. + """ + sagemaker_client: Any = None + + def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint:str): + inputs = [query_input]*len(docs) + response_model = self.sagemaker_client.invoke_endpoint( + EndpointName=rerank_endpoint, + Body=json.dumps( + { + "inputs": inputs, + "docs": docs + } + ), + ContentType="application/json", + ) + json_str = response_model['Body'].read().decode('utf8') + json_obj = json.loads(json_str) + scores = json_obj['scores'] + return scores if isinstance(scores, list) else [scores] + + + def _invoke(self, model: str, credentials: dict, + query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, + user: Optional[str] = None) \ + -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id + :return: rerank result + """ + line = 0 + try: + if len(docs) == 0: + return RerankResult( + model=model, + docs=docs + ) + + line = 1 + if not self.sagemaker_client: + access_key = credentials.get('aws_access_key_id') + secret_key = credentials.get('aws_secret_access_key') + aws_region = credentials.get('aws_region') + if aws_region: + if access_key and secret_key: + self.sagemaker_client = boto3.client("sagemaker-runtime", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime") + + line = 2 + + sagemaker_endpoint = credentials.get('sagemaker_endpoint') + candidate_docs = [] + + scores = self._sagemaker_rerank(query, docs, sagemaker_endpoint) + for idx in range(len(scores)): + candidate_docs.append({"content" : docs[idx], "score": scores[idx]}) + + sorted(candidate_docs, key=lambda x: x['score'], reverse=True) + + line = 3 + rerank_documents = [] + for idx, result in enumerate(candidate_docs): + rerank_document = RerankDocument( + index=idx, + text=result.get('content'), + score=result.get('score', -100.0) + ) + + if score_threshold is not None: + if rerank_document.score >= score_threshold: + rerank_documents.append(rerank_document) + else: + rerank_documents.append(rerank_document) + + return RerankResult( + model=model, + docs=rerank_documents + ) + + except Exception as e: + logger.exception(f'Exception {e}, line : {line}') + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self._invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8 + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [ + InvokeConnectionError + ], + InvokeServerUnavailableError: [ + InvokeServerUnavailableError + ], + InvokeRateLimitError: [ + InvokeRateLimitError + ], + InvokeAuthorizationError: [ + InvokeAuthorizationError + ], + InvokeBadRequestError: [ + InvokeBadRequestError, + KeyError, + ValueError + ] + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + entity = AIModelEntity( + model=model, + label=I18nObject( + en_US=model + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.RERANK, + model_properties={ }, + parameter_rules=[] + ) + + return entity \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/sagemaker/sagemaker.py b/api/core/model_runtime/model_providers/sagemaker/sagemaker.py new file mode 100644 index 0000000000..02d05f406c --- /dev/null +++ b/api/core/model_runtime/model_providers/sagemaker/sagemaker.py @@ -0,0 +1,17 @@ +import logging + +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class SageMakerProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + pass diff --git a/api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml b/api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml new file mode 100644 index 0000000000..290cb0edab --- /dev/null +++ b/api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml @@ -0,0 +1,125 @@ +provider: sagemaker +label: + zh_Hans: Sagemaker + en_US: Sagemaker +icon_small: + en_US: icon_s_en.png +icon_large: + en_US: icon_l_en.png +description: + en_US: Customized model on Sagemaker + zh_Hans: Sagemaker上的私有化部署的模型 +background: "#ECE9E3" +help: + title: + en_US: How to deploy customized model on Sagemaker + zh_Hans: 如何在Sagemaker上的私有化部署的模型 + url: + en_US: https://github.com/aws-samples/dify-aws-tool/blob/main/README.md#how-to-deploy-sagemaker-endpoint + zh_Hans: https://github.com/aws-samples/dify-aws-tool/blob/main/README_ZH.md#%E5%A6%82%E4%BD%95%E9%83%A8%E7%BD%B2sagemaker%E6%8E%A8%E7%90%86%E7%AB%AF%E7%82%B9 +supported_model_types: + - llm + - text-embedding + - rerank +configurate_methods: + - customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your model name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: mode + show_on: + - variable: __model_type + value: llm + label: + en_US: Completion mode + type: select + required: false + default: chat + placeholder: + zh_Hans: 选择对话类型 + en_US: Select completion mode + options: + - value: completion + label: + en_US: Completion + zh_Hans: 补全 + - value: chat + label: + en_US: Chat + zh_Hans: 对话 + - variable: sagemaker_endpoint + label: + en_US: sagemaker endpoint + type: text-input + required: true + placeholder: + zh_Hans: 请输出你的Sagemaker推理端点 + en_US: Enter your Sagemaker Inference endpoint + - variable: aws_access_key_id + required: false + label: + en_US: Access Key (If not provided, credentials are obtained from the running environment.) + zh_Hans: Access Key (如果未提供,凭证将从运行环境中获取。) + type: secret-input + placeholder: + en_US: Enter your Access Key + zh_Hans: 在此输入您的 Access Key + - variable: aws_secret_access_key + required: false + label: + en_US: Secret Access Key + zh_Hans: Secret Access Key + type: secret-input + placeholder: + en_US: Enter your Secret Access Key + zh_Hans: 在此输入您的 Secret Access Key + - variable: aws_region + required: false + label: + en_US: AWS Region + zh_Hans: AWS 地区 + type: select + default: us-east-1 + options: + - value: us-east-1 + label: + en_US: US East (N. Virginia) + zh_Hans: 美国东部 (弗吉尼亚北部) + - value: us-west-2 + label: + en_US: US West (Oregon) + zh_Hans: 美国西部 (俄勒冈州) + - value: ap-southeast-1 + label: + en_US: Asia Pacific (Singapore) + zh_Hans: 亚太地区 (新加坡) + - value: ap-northeast-1 + label: + en_US: Asia Pacific (Tokyo) + zh_Hans: 亚太地区 (东京) + - value: eu-central-1 + label: + en_US: Europe (Frankfurt) + zh_Hans: 欧洲 (法兰克福) + - value: us-gov-west-1 + label: + en_US: AWS GovCloud (US-West) + zh_Hans: AWS GovCloud (US-West) + - value: ap-southeast-2 + label: + en_US: Asia Pacific (Sydney) + zh_Hans: 亚太地区 (悉尼) + - value: cn-north-1 + label: + en_US: AWS Beijing (cn-north-1) + zh_Hans: 中国北京 (cn-north-1) + - value: cn-northwest-1 + label: + en_US: AWS Ningxia (cn-northwest-1) + zh_Hans: 中国宁夏 (cn-northwest-1) diff --git a/api/core/model_runtime/model_providers/sagemaker/text_embedding/__init__.py b/api/core/model_runtime/model_providers/sagemaker/text_embedding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py new file mode 100644 index 0000000000..4b2858b1a2 --- /dev/null +++ b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py @@ -0,0 +1,214 @@ +import itertools +import json +import logging +import time +from typing import Any, Optional + +import boto3 + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel + +BATCH_SIZE = 20 +CONTEXT_SIZE=8192 + +logger = logging.getLogger(__name__) + +def batch_generator(generator, batch_size): + while True: + batch = list(itertools.islice(generator, batch_size)) + if not batch: + break + yield batch + +class SageMakerEmbeddingModel(TextEmbeddingModel): + """ + Model class for Cohere text embedding model. + """ + sagemaker_client: Any = None + + def _sagemaker_embedding(self, sm_client, endpoint_name, content_list:list[str]): + response_model = sm_client.invoke_endpoint( + EndpointName=endpoint_name, + Body=json.dumps( + { + "inputs": content_list, + "parameters": {}, + "is_query" : False, + "instruction" : '' + } + ), + ContentType="application/json", + ) + json_str = response_model['Body'].read().decode('utf8') + json_obj = json.loads(json_str) + embeddings = json_obj['embeddings'] + return embeddings + + def _invoke(self, model: str, credentials: dict, + texts: list[str], user: Optional[str] = None) \ + -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :return: embeddings result + """ + # get model properties + try: + line = 1 + if not self.sagemaker_client: + access_key = credentials.get('aws_access_key_id') + secret_key = credentials.get('aws_secret_access_key') + aws_region = credentials.get('aws_region') + if aws_region: + if access_key and secret_key: + self.sagemaker_client = boto3.client("sagemaker-runtime", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime") + + line = 2 + sagemaker_endpoint = credentials.get('sagemaker_endpoint') + + line = 3 + truncated_texts = [ item[:CONTEXT_SIZE] for item in texts ] + + batches = batch_generator((text for text in truncated_texts), batch_size=BATCH_SIZE) + all_embeddings = [] + + line = 4 + for batch in batches: + embeddings = self._sagemaker_embedding(self.sagemaker_client, sagemaker_endpoint, batch) + all_embeddings.extend(embeddings) + + line = 5 + # calc usage + usage = self._calc_response_usage( + model=model, + credentials=credentials, + tokens=0 # It's not SAAS API, usage is meaningless + ) + line = 6 + + return TextEmbeddingResult( + embeddings=all_embeddings, + usage=usage, + model=model + ) + + except Exception as e: + logger.exception(f'Exception {e}, line : {line}') + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + return 0 + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + print("validate_credentials ok....") + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param tokens: input tokens + :return: usage + """ + # get input price info + input_price_info = self.get_price( + model=model, + credentials=credentials, + price_type=PriceType.INPUT, + tokens=tokens + ) + + # transform usage + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at + ) + + return usage + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + return { + InvokeConnectionError: [ + InvokeConnectionError + ], + InvokeServerUnavailableError: [ + InvokeServerUnavailableError + ], + InvokeRateLimitError: [ + InvokeRateLimitError + ], + InvokeAuthorizationError: [ + InvokeAuthorizationError + ], + InvokeBadRequestError: [ + KeyError + ] + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + + entity = AIModelEntity( + model=model, + label=I18nObject( + en_US=model + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.TEXT_EMBEDDING, + model_properties={ + ModelPropertyKey.CONTEXT_SIZE: CONTEXT_SIZE, + ModelPropertyKey.MAX_CHUNKS: BATCH_SIZE, + }, + parameter_rules=[] + ) + + return entity diff --git a/api/tests/integration_tests/model_runtime/sagemaker/__init__.py b/api/tests/integration_tests/model_runtime/sagemaker/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/model_runtime/sagemaker/test_provider.py b/api/tests/integration_tests/model_runtime/sagemaker/test_provider.py new file mode 100644 index 0000000000..639227e745 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/sagemaker/test_provider.py @@ -0,0 +1,19 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.sagemaker.sagemaker import SageMakerProvider + + +def test_validate_provider_credentials(): + provider = SageMakerProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials( + credentials={} + ) + + provider.validate_provider_credentials( + credentials={} + ) diff --git a/api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py b/api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py new file mode 100644 index 0000000000..c67849dd79 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py @@ -0,0 +1,55 @@ +import os + +import pytest + +from core.model_runtime.entities.rerank_entities import RerankResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.sagemaker.rerank.rerank import SageMakerRerankModel + + +def test_validate_credentials(): + model = SageMakerRerankModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='bge-m3-rerank-v2', + credentials={ + "aws_region": os.getenv("AWS_REGION"), + "aws_access_key": os.getenv("AWS_ACCESS_KEY"), + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + }, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8 + ) + + +def test_invoke_model(): + model = SageMakerRerankModel() + + result = model.invoke( + model='bge-m3-rerank-v2', + credentials={ + "aws_region": os.getenv("AWS_REGION"), + "aws_access_key": os.getenv("AWS_ACCESS_KEY"), + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + }, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8 + ) + + assert isinstance(result, RerankResult) + assert len(result.docs) == 1 + assert result.docs[0].index == 1 + assert result.docs[0].score >= 0.8 diff --git a/api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py b/api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py new file mode 100644 index 0000000000..e817e8f04a --- /dev/null +++ b/api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py @@ -0,0 +1,55 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.sagemaker.text_embedding.text_embedding import SageMakerEmbeddingModel + + +def test_validate_credentials(): + model = SageMakerEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='bge-m3', + credentials={ + } + ) + + model.validate_credentials( + model='bge-m3-embedding', + credentials={ + } + ) + + +def test_invoke_model(): + model = SageMakerEmbeddingModel() + + result = model.invoke( + model='bge-m3-embedding', + credentials={ + }, + texts=[ + "hello", + "world" + ], + user="abc-123" + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + +def test_get_num_tokens(): + model = SageMakerEmbeddingModel() + + num_tokens = model.get_num_tokens( + model='bge-m3-embedding', + credentials={ + }, + texts=[ + ] + ) + + assert num_tokens == 0