mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-07-04 09:15:11 +08:00
Enhancement: add model provider - Amazon Sagemaker (#6255)
Co-authored-by: Yuanbo Li <ybalbert@amazon.com> Co-authored-by: crazywoola <427733928@qq.com>
This commit is contained in:
parent
dc847ba145
commit
4a026fa352
Binary file not shown.
After Width: | Height: | Size: 9.2 KiB |
Binary file not shown.
After Width: | Height: | Size: 9.5 KiB |
238
api/core/model_runtime/model_providers/sagemaker/llm/llm.py
Normal file
238
api/core/model_runtime/model_providers/sagemaker/llm/llm.py
Normal file
@ -0,0 +1,238 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
|
||||||
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
PromptMessage,
|
||||||
|
PromptMessageTool,
|
||||||
|
)
|
||||||
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType
|
||||||
|
from core.model_runtime.errors.invoke import (
|
||||||
|
InvokeAuthorizationError,
|
||||||
|
InvokeBadRequestError,
|
||||||
|
InvokeConnectionError,
|
||||||
|
InvokeError,
|
||||||
|
InvokeRateLimitError,
|
||||||
|
InvokeServerUnavailableError,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SageMakerLargeLanguageModel(LargeLanguageModel):
|
||||||
|
"""
|
||||||
|
Model class for Cohere large language model.
|
||||||
|
"""
|
||||||
|
sagemaker_client: Any = None
|
||||||
|
|
||||||
|
def _invoke(self, model: str, credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True, user: Optional[str] = None) \
|
||||||
|
-> Union[LLMResult, Generator]:
|
||||||
|
"""
|
||||||
|
Invoke large language model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param model_parameters: model parameters
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:param stop: stop words
|
||||||
|
:param stream: is stream response
|
||||||
|
:param user: unique user id
|
||||||
|
:return: full response or stream response chunk generator result
|
||||||
|
"""
|
||||||
|
# get model mode
|
||||||
|
model_mode = self.get_model_mode(model, credentials)
|
||||||
|
|
||||||
|
if not self.sagemaker_client:
|
||||||
|
access_key = credentials.get('access_key')
|
||||||
|
secret_key = credentials.get('secret_key')
|
||||||
|
aws_region = credentials.get('aws_region')
|
||||||
|
if aws_region:
|
||||||
|
if access_key and secret_key:
|
||||||
|
self.sagemaker_client = boto3.client("sagemaker-runtime",
|
||||||
|
aws_access_key_id=access_key,
|
||||||
|
aws_secret_access_key=secret_key,
|
||||||
|
region_name=aws_region)
|
||||||
|
else:
|
||||||
|
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||||
|
else:
|
||||||
|
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||||
|
|
||||||
|
|
||||||
|
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
|
||||||
|
response_model = self.sagemaker_client.invoke_endpoint(
|
||||||
|
EndpointName=sagemaker_endpoint,
|
||||||
|
Body=json.dumps(
|
||||||
|
{
|
||||||
|
"inputs": prompt_messages[0].content,
|
||||||
|
"parameters": { "stop" : stop},
|
||||||
|
"history" : []
|
||||||
|
}
|
||||||
|
),
|
||||||
|
ContentType="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
assistant_text = response_model['Body'].read().decode('utf8')
|
||||||
|
|
||||||
|
# transform assistant message to prompt message
|
||||||
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
|
content=assistant_text
|
||||||
|
)
|
||||||
|
|
||||||
|
usage = self._calc_response_usage(model, credentials, 0, 0)
|
||||||
|
|
||||||
|
response = LLMResult(
|
||||||
|
model=model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
message=assistant_prompt_message,
|
||||||
|
usage=usage
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||||
|
"""
|
||||||
|
Get number of tokens for given prompt messages
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# get model mode
|
||||||
|
model_mode = self.get_model_mode(model)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return 0
|
||||||
|
except Exception as e:
|
||||||
|
raise self._transform_invoke_error(e)
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# get model mode
|
||||||
|
model_mode = self.get_model_mode(model)
|
||||||
|
except Exception as ex:
|
||||||
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
|
"""
|
||||||
|
Map model invoke error to unified error
|
||||||
|
The key is the error type thrown to the caller
|
||||||
|
The value is the error type thrown by the model,
|
||||||
|
which needs to be converted into a unified error type for the caller.
|
||||||
|
|
||||||
|
:return: Invoke error mapping
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
InvokeConnectionError: [
|
||||||
|
InvokeConnectionError
|
||||||
|
],
|
||||||
|
InvokeServerUnavailableError: [
|
||||||
|
InvokeServerUnavailableError
|
||||||
|
],
|
||||||
|
InvokeRateLimitError: [
|
||||||
|
InvokeRateLimitError
|
||||||
|
],
|
||||||
|
InvokeAuthorizationError: [
|
||||||
|
InvokeAuthorizationError
|
||||||
|
],
|
||||||
|
InvokeBadRequestError: [
|
||||||
|
InvokeBadRequestError,
|
||||||
|
KeyError,
|
||||||
|
ValueError
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||||
|
"""
|
||||||
|
used to define customizable model schema
|
||||||
|
"""
|
||||||
|
rules = [
|
||||||
|
ParameterRule(
|
||||||
|
name='temperature',
|
||||||
|
type=ParameterType.FLOAT,
|
||||||
|
use_template='temperature',
|
||||||
|
label=I18nObject(
|
||||||
|
zh_Hans='温度',
|
||||||
|
en_US='Temperature'
|
||||||
|
),
|
||||||
|
),
|
||||||
|
ParameterRule(
|
||||||
|
name='top_p',
|
||||||
|
type=ParameterType.FLOAT,
|
||||||
|
use_template='top_p',
|
||||||
|
label=I18nObject(
|
||||||
|
zh_Hans='Top P',
|
||||||
|
en_US='Top P'
|
||||||
|
)
|
||||||
|
),
|
||||||
|
ParameterRule(
|
||||||
|
name='max_tokens',
|
||||||
|
type=ParameterType.INT,
|
||||||
|
use_template='max_tokens',
|
||||||
|
min=1,
|
||||||
|
max=credentials.get('context_length', 2048),
|
||||||
|
default=512,
|
||||||
|
label=I18nObject(
|
||||||
|
zh_Hans='最大生成长度',
|
||||||
|
en_US='Max Tokens'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
completion_type = LLMMode.value_of(credentials["mode"])
|
||||||
|
|
||||||
|
if completion_type == LLMMode.CHAT:
|
||||||
|
print(f"completion_type : {LLMMode.CHAT.value}")
|
||||||
|
|
||||||
|
if completion_type == LLMMode.COMPLETION:
|
||||||
|
print(f"completion_type : {LLMMode.COMPLETION.value}")
|
||||||
|
|
||||||
|
features = []
|
||||||
|
|
||||||
|
support_function_call = credentials.get('support_function_call', False)
|
||||||
|
if support_function_call:
|
||||||
|
features.append(ModelFeature.TOOL_CALL)
|
||||||
|
|
||||||
|
support_vision = credentials.get('support_vision', False)
|
||||||
|
if support_vision:
|
||||||
|
features.append(ModelFeature.VISION)
|
||||||
|
|
||||||
|
context_length = credentials.get('context_length', 2048)
|
||||||
|
|
||||||
|
entity = AIModelEntity(
|
||||||
|
model=model,
|
||||||
|
label=I18nObject(
|
||||||
|
en_US=model
|
||||||
|
),
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_type=ModelType.LLM,
|
||||||
|
features=features,
|
||||||
|
model_properties={
|
||||||
|
ModelPropertyKey.MODE: completion_type,
|
||||||
|
ModelPropertyKey.CONTEXT_SIZE: context_length
|
||||||
|
},
|
||||||
|
parameter_rules=rules
|
||||||
|
)
|
||||||
|
|
||||||
|
return entity
|
@ -0,0 +1,190 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||||
|
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||||
|
from core.model_runtime.errors.invoke import (
|
||||||
|
InvokeAuthorizationError,
|
||||||
|
InvokeBadRequestError,
|
||||||
|
InvokeConnectionError,
|
||||||
|
InvokeError,
|
||||||
|
InvokeRateLimitError,
|
||||||
|
InvokeServerUnavailableError,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class SageMakerRerankModel(RerankModel):
|
||||||
|
"""
|
||||||
|
Model class for Cohere rerank model.
|
||||||
|
"""
|
||||||
|
sagemaker_client: Any = None
|
||||||
|
|
||||||
|
def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint:str):
|
||||||
|
inputs = [query_input]*len(docs)
|
||||||
|
response_model = self.sagemaker_client.invoke_endpoint(
|
||||||
|
EndpointName=rerank_endpoint,
|
||||||
|
Body=json.dumps(
|
||||||
|
{
|
||||||
|
"inputs": inputs,
|
||||||
|
"docs": docs
|
||||||
|
}
|
||||||
|
),
|
||||||
|
ContentType="application/json",
|
||||||
|
)
|
||||||
|
json_str = response_model['Body'].read().decode('utf8')
|
||||||
|
json_obj = json.loads(json_str)
|
||||||
|
scores = json_obj['scores']
|
||||||
|
return scores if isinstance(scores, list) else [scores]
|
||||||
|
|
||||||
|
|
||||||
|
def _invoke(self, model: str, credentials: dict,
|
||||||
|
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
||||||
|
user: Optional[str] = None) \
|
||||||
|
-> RerankResult:
|
||||||
|
"""
|
||||||
|
Invoke rerank model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param query: search query
|
||||||
|
:param docs: docs for reranking
|
||||||
|
:param score_threshold: score threshold
|
||||||
|
:param top_n: top n
|
||||||
|
:param user: unique user id
|
||||||
|
:return: rerank result
|
||||||
|
"""
|
||||||
|
line = 0
|
||||||
|
try:
|
||||||
|
if len(docs) == 0:
|
||||||
|
return RerankResult(
|
||||||
|
model=model,
|
||||||
|
docs=docs
|
||||||
|
)
|
||||||
|
|
||||||
|
line = 1
|
||||||
|
if not self.sagemaker_client:
|
||||||
|
access_key = credentials.get('aws_access_key_id')
|
||||||
|
secret_key = credentials.get('aws_secret_access_key')
|
||||||
|
aws_region = credentials.get('aws_region')
|
||||||
|
if aws_region:
|
||||||
|
if access_key and secret_key:
|
||||||
|
self.sagemaker_client = boto3.client("sagemaker-runtime",
|
||||||
|
aws_access_key_id=access_key,
|
||||||
|
aws_secret_access_key=secret_key,
|
||||||
|
region_name=aws_region)
|
||||||
|
else:
|
||||||
|
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||||
|
else:
|
||||||
|
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||||
|
|
||||||
|
line = 2
|
||||||
|
|
||||||
|
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
|
||||||
|
candidate_docs = []
|
||||||
|
|
||||||
|
scores = self._sagemaker_rerank(query, docs, sagemaker_endpoint)
|
||||||
|
for idx in range(len(scores)):
|
||||||
|
candidate_docs.append({"content" : docs[idx], "score": scores[idx]})
|
||||||
|
|
||||||
|
sorted(candidate_docs, key=lambda x: x['score'], reverse=True)
|
||||||
|
|
||||||
|
line = 3
|
||||||
|
rerank_documents = []
|
||||||
|
for idx, result in enumerate(candidate_docs):
|
||||||
|
rerank_document = RerankDocument(
|
||||||
|
index=idx,
|
||||||
|
text=result.get('content'),
|
||||||
|
score=result.get('score', -100.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
if score_threshold is not None:
|
||||||
|
if rerank_document.score >= score_threshold:
|
||||||
|
rerank_documents.append(rerank_document)
|
||||||
|
else:
|
||||||
|
rerank_documents.append(rerank_document)
|
||||||
|
|
||||||
|
return RerankResult(
|
||||||
|
model=model,
|
||||||
|
docs=rerank_documents
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f'Exception {e}, line : {line}')
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self._invoke(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
query="What is the capital of the United States?",
|
||||||
|
docs=[
|
||||||
|
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||||
|
"Census, Carson City had a population of 55,274.",
|
||||||
|
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||||
|
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||||
|
],
|
||||||
|
score_threshold=0.8
|
||||||
|
)
|
||||||
|
except Exception as ex:
|
||||||
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
|
"""
|
||||||
|
Map model invoke error to unified error
|
||||||
|
The key is the error type thrown to the caller
|
||||||
|
The value is the error type thrown by the model,
|
||||||
|
which needs to be converted into a unified error type for the caller.
|
||||||
|
|
||||||
|
:return: Invoke error mapping
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
InvokeConnectionError: [
|
||||||
|
InvokeConnectionError
|
||||||
|
],
|
||||||
|
InvokeServerUnavailableError: [
|
||||||
|
InvokeServerUnavailableError
|
||||||
|
],
|
||||||
|
InvokeRateLimitError: [
|
||||||
|
InvokeRateLimitError
|
||||||
|
],
|
||||||
|
InvokeAuthorizationError: [
|
||||||
|
InvokeAuthorizationError
|
||||||
|
],
|
||||||
|
InvokeBadRequestError: [
|
||||||
|
InvokeBadRequestError,
|
||||||
|
KeyError,
|
||||||
|
ValueError
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||||
|
"""
|
||||||
|
used to define customizable model schema
|
||||||
|
"""
|
||||||
|
entity = AIModelEntity(
|
||||||
|
model=model,
|
||||||
|
label=I18nObject(
|
||||||
|
en_US=model
|
||||||
|
),
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_type=ModelType.RERANK,
|
||||||
|
model_properties={ },
|
||||||
|
parameter_rules=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
return entity
|
@ -0,0 +1,17 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SageMakerProvider(ModelProvider):
|
||||||
|
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate provider credentials
|
||||||
|
|
||||||
|
if validate failed, raise exception
|
||||||
|
|
||||||
|
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||||
|
"""
|
||||||
|
pass
|
125
api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml
Normal file
125
api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
provider: sagemaker
|
||||||
|
label:
|
||||||
|
zh_Hans: Sagemaker
|
||||||
|
en_US: Sagemaker
|
||||||
|
icon_small:
|
||||||
|
en_US: icon_s_en.png
|
||||||
|
icon_large:
|
||||||
|
en_US: icon_l_en.png
|
||||||
|
description:
|
||||||
|
en_US: Customized model on Sagemaker
|
||||||
|
zh_Hans: Sagemaker上的私有化部署的模型
|
||||||
|
background: "#ECE9E3"
|
||||||
|
help:
|
||||||
|
title:
|
||||||
|
en_US: How to deploy customized model on Sagemaker
|
||||||
|
zh_Hans: 如何在Sagemaker上的私有化部署的模型
|
||||||
|
url:
|
||||||
|
en_US: https://github.com/aws-samples/dify-aws-tool/blob/main/README.md#how-to-deploy-sagemaker-endpoint
|
||||||
|
zh_Hans: https://github.com/aws-samples/dify-aws-tool/blob/main/README_ZH.md#%E5%A6%82%E4%BD%95%E9%83%A8%E7%BD%B2sagemaker%E6%8E%A8%E7%90%86%E7%AB%AF%E7%82%B9
|
||||||
|
supported_model_types:
|
||||||
|
- llm
|
||||||
|
- text-embedding
|
||||||
|
- rerank
|
||||||
|
configurate_methods:
|
||||||
|
- customizable-model
|
||||||
|
model_credential_schema:
|
||||||
|
model:
|
||||||
|
label:
|
||||||
|
en_US: Model Name
|
||||||
|
zh_Hans: 模型名称
|
||||||
|
placeholder:
|
||||||
|
en_US: Enter your model name
|
||||||
|
zh_Hans: 输入模型名称
|
||||||
|
credential_form_schemas:
|
||||||
|
- variable: mode
|
||||||
|
show_on:
|
||||||
|
- variable: __model_type
|
||||||
|
value: llm
|
||||||
|
label:
|
||||||
|
en_US: Completion mode
|
||||||
|
type: select
|
||||||
|
required: false
|
||||||
|
default: chat
|
||||||
|
placeholder:
|
||||||
|
zh_Hans: 选择对话类型
|
||||||
|
en_US: Select completion mode
|
||||||
|
options:
|
||||||
|
- value: completion
|
||||||
|
label:
|
||||||
|
en_US: Completion
|
||||||
|
zh_Hans: 补全
|
||||||
|
- value: chat
|
||||||
|
label:
|
||||||
|
en_US: Chat
|
||||||
|
zh_Hans: 对话
|
||||||
|
- variable: sagemaker_endpoint
|
||||||
|
label:
|
||||||
|
en_US: sagemaker endpoint
|
||||||
|
type: text-input
|
||||||
|
required: true
|
||||||
|
placeholder:
|
||||||
|
zh_Hans: 请输出你的Sagemaker推理端点
|
||||||
|
en_US: Enter your Sagemaker Inference endpoint
|
||||||
|
- variable: aws_access_key_id
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: Access Key (If not provided, credentials are obtained from the running environment.)
|
||||||
|
zh_Hans: Access Key (如果未提供,凭证将从运行环境中获取。)
|
||||||
|
type: secret-input
|
||||||
|
placeholder:
|
||||||
|
en_US: Enter your Access Key
|
||||||
|
zh_Hans: 在此输入您的 Access Key
|
||||||
|
- variable: aws_secret_access_key
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: Secret Access Key
|
||||||
|
zh_Hans: Secret Access Key
|
||||||
|
type: secret-input
|
||||||
|
placeholder:
|
||||||
|
en_US: Enter your Secret Access Key
|
||||||
|
zh_Hans: 在此输入您的 Secret Access Key
|
||||||
|
- variable: aws_region
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: AWS Region
|
||||||
|
zh_Hans: AWS 地区
|
||||||
|
type: select
|
||||||
|
default: us-east-1
|
||||||
|
options:
|
||||||
|
- value: us-east-1
|
||||||
|
label:
|
||||||
|
en_US: US East (N. Virginia)
|
||||||
|
zh_Hans: 美国东部 (弗吉尼亚北部)
|
||||||
|
- value: us-west-2
|
||||||
|
label:
|
||||||
|
en_US: US West (Oregon)
|
||||||
|
zh_Hans: 美国西部 (俄勒冈州)
|
||||||
|
- value: ap-southeast-1
|
||||||
|
label:
|
||||||
|
en_US: Asia Pacific (Singapore)
|
||||||
|
zh_Hans: 亚太地区 (新加坡)
|
||||||
|
- value: ap-northeast-1
|
||||||
|
label:
|
||||||
|
en_US: Asia Pacific (Tokyo)
|
||||||
|
zh_Hans: 亚太地区 (东京)
|
||||||
|
- value: eu-central-1
|
||||||
|
label:
|
||||||
|
en_US: Europe (Frankfurt)
|
||||||
|
zh_Hans: 欧洲 (法兰克福)
|
||||||
|
- value: us-gov-west-1
|
||||||
|
label:
|
||||||
|
en_US: AWS GovCloud (US-West)
|
||||||
|
zh_Hans: AWS GovCloud (US-West)
|
||||||
|
- value: ap-southeast-2
|
||||||
|
label:
|
||||||
|
en_US: Asia Pacific (Sydney)
|
||||||
|
zh_Hans: 亚太地区 (悉尼)
|
||||||
|
- value: cn-north-1
|
||||||
|
label:
|
||||||
|
en_US: AWS Beijing (cn-north-1)
|
||||||
|
zh_Hans: 中国北京 (cn-north-1)
|
||||||
|
- value: cn-northwest-1
|
||||||
|
label:
|
||||||
|
en_US: AWS Ningxia (cn-northwest-1)
|
||||||
|
zh_Hans: 中国宁夏 (cn-northwest-1)
|
@ -0,0 +1,214 @@
|
|||||||
|
import itertools
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||||
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
|
from core.model_runtime.errors.invoke import (
|
||||||
|
InvokeAuthorizationError,
|
||||||
|
InvokeBadRequestError,
|
||||||
|
InvokeConnectionError,
|
||||||
|
InvokeError,
|
||||||
|
InvokeRateLimitError,
|
||||||
|
InvokeServerUnavailableError,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||||
|
|
||||||
|
BATCH_SIZE = 20
|
||||||
|
CONTEXT_SIZE=8192
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def batch_generator(generator, batch_size):
|
||||||
|
while True:
|
||||||
|
batch = list(itertools.islice(generator, batch_size))
|
||||||
|
if not batch:
|
||||||
|
break
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
class SageMakerEmbeddingModel(TextEmbeddingModel):
|
||||||
|
"""
|
||||||
|
Model class for Cohere text embedding model.
|
||||||
|
"""
|
||||||
|
sagemaker_client: Any = None
|
||||||
|
|
||||||
|
def _sagemaker_embedding(self, sm_client, endpoint_name, content_list:list[str]):
|
||||||
|
response_model = sm_client.invoke_endpoint(
|
||||||
|
EndpointName=endpoint_name,
|
||||||
|
Body=json.dumps(
|
||||||
|
{
|
||||||
|
"inputs": content_list,
|
||||||
|
"parameters": {},
|
||||||
|
"is_query" : False,
|
||||||
|
"instruction" : ''
|
||||||
|
}
|
||||||
|
),
|
||||||
|
ContentType="application/json",
|
||||||
|
)
|
||||||
|
json_str = response_model['Body'].read().decode('utf8')
|
||||||
|
json_obj = json.loads(json_str)
|
||||||
|
embeddings = json_obj['embeddings']
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def _invoke(self, model: str, credentials: dict,
|
||||||
|
texts: list[str], user: Optional[str] = None) \
|
||||||
|
-> TextEmbeddingResult:
|
||||||
|
"""
|
||||||
|
Invoke text embedding model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param texts: texts to embed
|
||||||
|
:param user: unique user id
|
||||||
|
:return: embeddings result
|
||||||
|
"""
|
||||||
|
# get model properties
|
||||||
|
try:
|
||||||
|
line = 1
|
||||||
|
if not self.sagemaker_client:
|
||||||
|
access_key = credentials.get('aws_access_key_id')
|
||||||
|
secret_key = credentials.get('aws_secret_access_key')
|
||||||
|
aws_region = credentials.get('aws_region')
|
||||||
|
if aws_region:
|
||||||
|
if access_key and secret_key:
|
||||||
|
self.sagemaker_client = boto3.client("sagemaker-runtime",
|
||||||
|
aws_access_key_id=access_key,
|
||||||
|
aws_secret_access_key=secret_key,
|
||||||
|
region_name=aws_region)
|
||||||
|
else:
|
||||||
|
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||||
|
else:
|
||||||
|
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||||
|
|
||||||
|
line = 2
|
||||||
|
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
|
||||||
|
|
||||||
|
line = 3
|
||||||
|
truncated_texts = [ item[:CONTEXT_SIZE] for item in texts ]
|
||||||
|
|
||||||
|
batches = batch_generator((text for text in truncated_texts), batch_size=BATCH_SIZE)
|
||||||
|
all_embeddings = []
|
||||||
|
|
||||||
|
line = 4
|
||||||
|
for batch in batches:
|
||||||
|
embeddings = self._sagemaker_embedding(self.sagemaker_client, sagemaker_endpoint, batch)
|
||||||
|
all_embeddings.extend(embeddings)
|
||||||
|
|
||||||
|
line = 5
|
||||||
|
# calc usage
|
||||||
|
usage = self._calc_response_usage(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
tokens=0 # It's not SAAS API, usage is meaningless
|
||||||
|
)
|
||||||
|
line = 6
|
||||||
|
|
||||||
|
return TextEmbeddingResult(
|
||||||
|
embeddings=all_embeddings,
|
||||||
|
usage=usage,
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f'Exception {e}, line : {line}')
|
||||||
|
|
||||||
|
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||||
|
"""
|
||||||
|
Get number of tokens for given prompt messages
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param texts: texts to embed
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
print("validate_credentials ok....")
|
||||||
|
except Exception as ex:
|
||||||
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
|
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||||
|
"""
|
||||||
|
Calculate response usage
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param tokens: input tokens
|
||||||
|
:return: usage
|
||||||
|
"""
|
||||||
|
# get input price info
|
||||||
|
input_price_info = self.get_price(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
price_type=PriceType.INPUT,
|
||||||
|
tokens=tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# transform usage
|
||||||
|
usage = EmbeddingUsage(
|
||||||
|
tokens=tokens,
|
||||||
|
total_tokens=tokens,
|
||||||
|
unit_price=input_price_info.unit_price,
|
||||||
|
price_unit=input_price_info.unit,
|
||||||
|
total_price=input_price_info.total_amount,
|
||||||
|
currency=input_price_info.currency,
|
||||||
|
latency=time.perf_counter() - self.started_at
|
||||||
|
)
|
||||||
|
|
||||||
|
return usage
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
|
return {
|
||||||
|
InvokeConnectionError: [
|
||||||
|
InvokeConnectionError
|
||||||
|
],
|
||||||
|
InvokeServerUnavailableError: [
|
||||||
|
InvokeServerUnavailableError
|
||||||
|
],
|
||||||
|
InvokeRateLimitError: [
|
||||||
|
InvokeRateLimitError
|
||||||
|
],
|
||||||
|
InvokeAuthorizationError: [
|
||||||
|
InvokeAuthorizationError
|
||||||
|
],
|
||||||
|
InvokeBadRequestError: [
|
||||||
|
KeyError
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||||
|
"""
|
||||||
|
used to define customizable model schema
|
||||||
|
"""
|
||||||
|
|
||||||
|
entity = AIModelEntity(
|
||||||
|
model=model,
|
||||||
|
label=I18nObject(
|
||||||
|
en_US=model
|
||||||
|
),
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
model_properties={
|
||||||
|
ModelPropertyKey.CONTEXT_SIZE: CONTEXT_SIZE,
|
||||||
|
ModelPropertyKey.MAX_CHUNKS: BATCH_SIZE,
|
||||||
|
},
|
||||||
|
parameter_rules=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
return entity
|
@ -0,0 +1,19 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.sagemaker.sagemaker import SageMakerProvider
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_provider_credentials():
|
||||||
|
provider = SageMakerProvider()
|
||||||
|
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
provider.validate_provider_credentials(
|
||||||
|
credentials={}
|
||||||
|
)
|
||||||
|
|
||||||
|
provider.validate_provider_credentials(
|
||||||
|
credentials={}
|
||||||
|
)
|
@ -0,0 +1,55 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.sagemaker.rerank.rerank import SageMakerRerankModel
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_credentials():
|
||||||
|
model = SageMakerRerankModel()
|
||||||
|
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
model.validate_credentials(
|
||||||
|
model='bge-m3-rerank-v2',
|
||||||
|
credentials={
|
||||||
|
"aws_region": os.getenv("AWS_REGION"),
|
||||||
|
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
|
||||||
|
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
|
||||||
|
},
|
||||||
|
query="What is the capital of the United States?",
|
||||||
|
docs=[
|
||||||
|
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||||
|
"Census, Carson City had a population of 55,274.",
|
||||||
|
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||||
|
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||||
|
],
|
||||||
|
score_threshold=0.8
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_model():
|
||||||
|
model = SageMakerRerankModel()
|
||||||
|
|
||||||
|
result = model.invoke(
|
||||||
|
model='bge-m3-rerank-v2',
|
||||||
|
credentials={
|
||||||
|
"aws_region": os.getenv("AWS_REGION"),
|
||||||
|
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
|
||||||
|
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
|
||||||
|
},
|
||||||
|
query="What is the capital of the United States?",
|
||||||
|
docs=[
|
||||||
|
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||||
|
"Census, Carson City had a population of 55,274.",
|
||||||
|
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||||
|
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||||
|
],
|
||||||
|
score_threshold=0.8
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, RerankResult)
|
||||||
|
assert len(result.docs) == 1
|
||||||
|
assert result.docs[0].index == 1
|
||||||
|
assert result.docs[0].score >= 0.8
|
@ -0,0 +1,55 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.sagemaker.text_embedding.text_embedding import SageMakerEmbeddingModel
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_credentials():
|
||||||
|
model = SageMakerEmbeddingModel()
|
||||||
|
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
model.validate_credentials(
|
||||||
|
model='bge-m3',
|
||||||
|
credentials={
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
model.validate_credentials(
|
||||||
|
model='bge-m3-embedding',
|
||||||
|
credentials={
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_model():
|
||||||
|
model = SageMakerEmbeddingModel()
|
||||||
|
|
||||||
|
result = model.invoke(
|
||||||
|
model='bge-m3-embedding',
|
||||||
|
credentials={
|
||||||
|
},
|
||||||
|
texts=[
|
||||||
|
"hello",
|
||||||
|
"world"
|
||||||
|
],
|
||||||
|
user="abc-123"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, TextEmbeddingResult)
|
||||||
|
assert len(result.embeddings) == 2
|
||||||
|
|
||||||
|
def test_get_num_tokens():
|
||||||
|
model = SageMakerEmbeddingModel()
|
||||||
|
|
||||||
|
num_tokens = model.get_num_tokens(
|
||||||
|
model='bge-m3-embedding',
|
||||||
|
credentials={
|
||||||
|
},
|
||||||
|
texts=[
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert num_tokens == 0
|
Loading…
x
Reference in New Issue
Block a user