mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-07-03 11:25:13 +08:00
Enhancement: add model provider - Amazon Sagemaker (#6255)
Co-authored-by: Yuanbo Li <ybalbert@amazon.com> Co-authored-by: crazywoola <427733928@qq.com>
This commit is contained in:
parent
dc847ba145
commit
4a026fa352
Binary file not shown.
After Width: | Height: | Size: 9.2 KiB |
Binary file not shown.
After Width: | Height: | Size: 9.5 KiB |
238
api/core/model_runtime/model_providers/sagemaker/llm/llm.py
Normal file
238
api/core/model_runtime/model_providers/sagemaker/llm/llm.py
Normal file
@ -0,0 +1,238 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import boto3
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SageMakerLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
Model class for Cohere large language model.
|
||||
"""
|
||||
sagemaker_client: Any = None
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# get model mode
|
||||
model_mode = self.get_model_mode(model, credentials)
|
||||
|
||||
if not self.sagemaker_client:
|
||||
access_key = credentials.get('access_key')
|
||||
secret_key = credentials.get('secret_key')
|
||||
aws_region = credentials.get('aws_region')
|
||||
if aws_region:
|
||||
if access_key and secret_key:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime",
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||
|
||||
|
||||
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
|
||||
response_model = self.sagemaker_client.invoke_endpoint(
|
||||
EndpointName=sagemaker_endpoint,
|
||||
Body=json.dumps(
|
||||
{
|
||||
"inputs": prompt_messages[0].content,
|
||||
"parameters": { "stop" : stop},
|
||||
"history" : []
|
||||
}
|
||||
),
|
||||
ContentType="application/json",
|
||||
)
|
||||
|
||||
assistant_text = response_model['Body'].read().decode('utf8')
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=assistant_text
|
||||
)
|
||||
|
||||
usage = self._calc_response_usage(model, credentials, 0, 0)
|
||||
|
||||
response = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
# get model mode
|
||||
model_mode = self.get_model_mode(model)
|
||||
|
||||
try:
|
||||
return 0
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# get model mode
|
||||
model_mode = self.get_model_mode(model)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
InvokeConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InvokeServerUnavailableError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
InvokeRateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvokeAuthorizationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
InvokeBadRequestError,
|
||||
KeyError,
|
||||
ValueError
|
||||
]
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
rules = [
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
type=ParameterType.FLOAT,
|
||||
use_template='temperature',
|
||||
label=I18nObject(
|
||||
zh_Hans='温度',
|
||||
en_US='Temperature'
|
||||
),
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p',
|
||||
type=ParameterType.FLOAT,
|
||||
use_template='top_p',
|
||||
label=I18nObject(
|
||||
zh_Hans='Top P',
|
||||
en_US='Top P'
|
||||
)
|
||||
),
|
||||
ParameterRule(
|
||||
name='max_tokens',
|
||||
type=ParameterType.INT,
|
||||
use_template='max_tokens',
|
||||
min=1,
|
||||
max=credentials.get('context_length', 2048),
|
||||
default=512,
|
||||
label=I18nObject(
|
||||
zh_Hans='最大生成长度',
|
||||
en_US='Max Tokens'
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
completion_type = LLMMode.value_of(credentials["mode"])
|
||||
|
||||
if completion_type == LLMMode.CHAT:
|
||||
print(f"completion_type : {LLMMode.CHAT.value}")
|
||||
|
||||
if completion_type == LLMMode.COMPLETION:
|
||||
print(f"completion_type : {LLMMode.COMPLETION.value}")
|
||||
|
||||
features = []
|
||||
|
||||
support_function_call = credentials.get('support_function_call', False)
|
||||
if support_function_call:
|
||||
features.append(ModelFeature.TOOL_CALL)
|
||||
|
||||
support_vision = credentials.get('support_vision', False)
|
||||
if support_vision:
|
||||
features.append(ModelFeature.VISION)
|
||||
|
||||
context_length = credentials.get('context_length', 2048)
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=model
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.LLM,
|
||||
features=features,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: completion_type,
|
||||
ModelPropertyKey.CONTEXT_SIZE: context_length
|
||||
},
|
||||
parameter_rules=rules
|
||||
)
|
||||
|
||||
return entity
|
@ -0,0 +1,190 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
import boto3
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SageMakerRerankModel(RerankModel):
|
||||
"""
|
||||
Model class for Cohere rerank model.
|
||||
"""
|
||||
sagemaker_client: Any = None
|
||||
|
||||
def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint:str):
|
||||
inputs = [query_input]*len(docs)
|
||||
response_model = self.sagemaker_client.invoke_endpoint(
|
||||
EndpointName=rerank_endpoint,
|
||||
Body=json.dumps(
|
||||
{
|
||||
"inputs": inputs,
|
||||
"docs": docs
|
||||
}
|
||||
),
|
||||
ContentType="application/json",
|
||||
)
|
||||
json_str = response_model['Body'].read().decode('utf8')
|
||||
json_obj = json.loads(json_str)
|
||||
scores = json_obj['scores']
|
||||
return scores if isinstance(scores, list) else [scores]
|
||||
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
||||
user: Optional[str] = None) \
|
||||
-> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
line = 0
|
||||
try:
|
||||
if len(docs) == 0:
|
||||
return RerankResult(
|
||||
model=model,
|
||||
docs=docs
|
||||
)
|
||||
|
||||
line = 1
|
||||
if not self.sagemaker_client:
|
||||
access_key = credentials.get('aws_access_key_id')
|
||||
secret_key = credentials.get('aws_secret_access_key')
|
||||
aws_region = credentials.get('aws_region')
|
||||
if aws_region:
|
||||
if access_key and secret_key:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime",
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||
|
||||
line = 2
|
||||
|
||||
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
|
||||
candidate_docs = []
|
||||
|
||||
scores = self._sagemaker_rerank(query, docs, sagemaker_endpoint)
|
||||
for idx in range(len(scores)):
|
||||
candidate_docs.append({"content" : docs[idx], "score": scores[idx]})
|
||||
|
||||
sorted(candidate_docs, key=lambda x: x['score'], reverse=True)
|
||||
|
||||
line = 3
|
||||
rerank_documents = []
|
||||
for idx, result in enumerate(candidate_docs):
|
||||
rerank_document = RerankDocument(
|
||||
index=idx,
|
||||
text=result.get('content'),
|
||||
score=result.get('score', -100.0)
|
||||
)
|
||||
|
||||
if score_threshold is not None:
|
||||
if rerank_document.score >= score_threshold:
|
||||
rerank_documents.append(rerank_document)
|
||||
else:
|
||||
rerank_documents.append(rerank_document)
|
||||
|
||||
return RerankResult(
|
||||
model=model,
|
||||
docs=rerank_documents
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f'Exception {e}, line : {line}')
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query="What is the capital of the United States?",
|
||||
docs=[
|
||||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||
"Census, Carson City had a population of 55,274.",
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||
],
|
||||
score_threshold=0.8
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
InvokeConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InvokeServerUnavailableError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
InvokeRateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvokeAuthorizationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
InvokeBadRequestError,
|
||||
KeyError,
|
||||
ValueError
|
||||
]
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=model
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.RERANK,
|
||||
model_properties={ },
|
||||
parameter_rules=[]
|
||||
)
|
||||
|
||||
return entity
|
@ -0,0 +1,17 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SageMakerProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
pass
|
125
api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml
Normal file
125
api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml
Normal file
@ -0,0 +1,125 @@
|
||||
provider: sagemaker
|
||||
label:
|
||||
zh_Hans: Sagemaker
|
||||
en_US: Sagemaker
|
||||
icon_small:
|
||||
en_US: icon_s_en.png
|
||||
icon_large:
|
||||
en_US: icon_l_en.png
|
||||
description:
|
||||
en_US: Customized model on Sagemaker
|
||||
zh_Hans: Sagemaker上的私有化部署的模型
|
||||
background: "#ECE9E3"
|
||||
help:
|
||||
title:
|
||||
en_US: How to deploy customized model on Sagemaker
|
||||
zh_Hans: 如何在Sagemaker上的私有化部署的模型
|
||||
url:
|
||||
en_US: https://github.com/aws-samples/dify-aws-tool/blob/main/README.md#how-to-deploy-sagemaker-endpoint
|
||||
zh_Hans: https://github.com/aws-samples/dify-aws-tool/blob/main/README_ZH.md#%E5%A6%82%E4%BD%95%E9%83%A8%E7%BD%B2sagemaker%E6%8E%A8%E7%90%86%E7%AB%AF%E7%82%B9
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
- rerank
|
||||
configurate_methods:
|
||||
- customizable-model
|
||||
model_credential_schema:
|
||||
model:
|
||||
label:
|
||||
en_US: Model Name
|
||||
zh_Hans: 模型名称
|
||||
placeholder:
|
||||
en_US: Enter your model name
|
||||
zh_Hans: 输入模型名称
|
||||
credential_form_schemas:
|
||||
- variable: mode
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
label:
|
||||
en_US: Completion mode
|
||||
type: select
|
||||
required: false
|
||||
default: chat
|
||||
placeholder:
|
||||
zh_Hans: 选择对话类型
|
||||
en_US: Select completion mode
|
||||
options:
|
||||
- value: completion
|
||||
label:
|
||||
en_US: Completion
|
||||
zh_Hans: 补全
|
||||
- value: chat
|
||||
label:
|
||||
en_US: Chat
|
||||
zh_Hans: 对话
|
||||
- variable: sagemaker_endpoint
|
||||
label:
|
||||
en_US: sagemaker endpoint
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 请输出你的Sagemaker推理端点
|
||||
en_US: Enter your Sagemaker Inference endpoint
|
||||
- variable: aws_access_key_id
|
||||
required: false
|
||||
label:
|
||||
en_US: Access Key (If not provided, credentials are obtained from the running environment.)
|
||||
zh_Hans: Access Key (如果未提供,凭证将从运行环境中获取。)
|
||||
type: secret-input
|
||||
placeholder:
|
||||
en_US: Enter your Access Key
|
||||
zh_Hans: 在此输入您的 Access Key
|
||||
- variable: aws_secret_access_key
|
||||
required: false
|
||||
label:
|
||||
en_US: Secret Access Key
|
||||
zh_Hans: Secret Access Key
|
||||
type: secret-input
|
||||
placeholder:
|
||||
en_US: Enter your Secret Access Key
|
||||
zh_Hans: 在此输入您的 Secret Access Key
|
||||
- variable: aws_region
|
||||
required: false
|
||||
label:
|
||||
en_US: AWS Region
|
||||
zh_Hans: AWS 地区
|
||||
type: select
|
||||
default: us-east-1
|
||||
options:
|
||||
- value: us-east-1
|
||||
label:
|
||||
en_US: US East (N. Virginia)
|
||||
zh_Hans: 美国东部 (弗吉尼亚北部)
|
||||
- value: us-west-2
|
||||
label:
|
||||
en_US: US West (Oregon)
|
||||
zh_Hans: 美国西部 (俄勒冈州)
|
||||
- value: ap-southeast-1
|
||||
label:
|
||||
en_US: Asia Pacific (Singapore)
|
||||
zh_Hans: 亚太地区 (新加坡)
|
||||
- value: ap-northeast-1
|
||||
label:
|
||||
en_US: Asia Pacific (Tokyo)
|
||||
zh_Hans: 亚太地区 (东京)
|
||||
- value: eu-central-1
|
||||
label:
|
||||
en_US: Europe (Frankfurt)
|
||||
zh_Hans: 欧洲 (法兰克福)
|
||||
- value: us-gov-west-1
|
||||
label:
|
||||
en_US: AWS GovCloud (US-West)
|
||||
zh_Hans: AWS GovCloud (US-West)
|
||||
- value: ap-southeast-2
|
||||
label:
|
||||
en_US: Asia Pacific (Sydney)
|
||||
zh_Hans: 亚太地区 (悉尼)
|
||||
- value: cn-north-1
|
||||
label:
|
||||
en_US: AWS Beijing (cn-north-1)
|
||||
zh_Hans: 中国北京 (cn-north-1)
|
||||
- value: cn-northwest-1
|
||||
label:
|
||||
en_US: AWS Ningxia (cn-northwest-1)
|
||||
zh_Hans: 中国宁夏 (cn-northwest-1)
|
@ -0,0 +1,214 @@
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
import boto3
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
|
||||
BATCH_SIZE = 20
|
||||
CONTEXT_SIZE=8192
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def batch_generator(generator, batch_size):
|
||||
while True:
|
||||
batch = list(itertools.islice(generator, batch_size))
|
||||
if not batch:
|
||||
break
|
||||
yield batch
|
||||
|
||||
class SageMakerEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Cohere text embedding model.
|
||||
"""
|
||||
sagemaker_client: Any = None
|
||||
|
||||
def _sagemaker_embedding(self, sm_client, endpoint_name, content_list:list[str]):
|
||||
response_model = sm_client.invoke_endpoint(
|
||||
EndpointName=endpoint_name,
|
||||
Body=json.dumps(
|
||||
{
|
||||
"inputs": content_list,
|
||||
"parameters": {},
|
||||
"is_query" : False,
|
||||
"instruction" : ''
|
||||
}
|
||||
),
|
||||
ContentType="application/json",
|
||||
)
|
||||
json_str = response_model['Body'].read().decode('utf8')
|
||||
json_obj = json.loads(json_str)
|
||||
embeddings = json_obj['embeddings']
|
||||
return embeddings
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
# get model properties
|
||||
try:
|
||||
line = 1
|
||||
if not self.sagemaker_client:
|
||||
access_key = credentials.get('aws_access_key_id')
|
||||
secret_key = credentials.get('aws_secret_access_key')
|
||||
aws_region = credentials.get('aws_region')
|
||||
if aws_region:
|
||||
if access_key and secret_key:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime",
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||
|
||||
line = 2
|
||||
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
|
||||
|
||||
line = 3
|
||||
truncated_texts = [ item[:CONTEXT_SIZE] for item in texts ]
|
||||
|
||||
batches = batch_generator((text for text in truncated_texts), batch_size=BATCH_SIZE)
|
||||
all_embeddings = []
|
||||
|
||||
line = 4
|
||||
for batch in batches:
|
||||
embeddings = self._sagemaker_embedding(self.sagemaker_client, sagemaker_endpoint, batch)
|
||||
all_embeddings.extend(embeddings)
|
||||
|
||||
line = 5
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
tokens=0 # It's not SAAS API, usage is meaningless
|
||||
)
|
||||
line = 6
|
||||
|
||||
return TextEmbeddingResult(
|
||||
embeddings=all_embeddings,
|
||||
usage=usage,
|
||||
model=model
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f'Exception {e}, line : {line}')
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
return 0
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
print("validate_credentials ok....")
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
InvokeConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InvokeServerUnavailableError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
InvokeRateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvokeAuthorizationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
KeyError
|
||||
]
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=model
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: CONTEXT_SIZE,
|
||||
ModelPropertyKey.MAX_CHUNKS: BATCH_SIZE,
|
||||
},
|
||||
parameter_rules=[]
|
||||
)
|
||||
|
||||
return entity
|
@ -0,0 +1,19 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.sagemaker.sagemaker import SageMakerProvider
|
||||
|
||||
|
||||
def test_validate_provider_credentials():
|
||||
provider = SageMakerProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={}
|
||||
)
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={}
|
||||
)
|
@ -0,0 +1,55 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.sagemaker.rerank.rerank import SageMakerRerankModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = SageMakerRerankModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='bge-m3-rerank-v2',
|
||||
credentials={
|
||||
"aws_region": os.getenv("AWS_REGION"),
|
||||
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
|
||||
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
|
||||
},
|
||||
query="What is the capital of the United States?",
|
||||
docs=[
|
||||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||
"Census, Carson City had a population of 55,274.",
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||
],
|
||||
score_threshold=0.8
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = SageMakerRerankModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='bge-m3-rerank-v2',
|
||||
credentials={
|
||||
"aws_region": os.getenv("AWS_REGION"),
|
||||
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
|
||||
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
|
||||
},
|
||||
query="What is the capital of the United States?",
|
||||
docs=[
|
||||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||
"Census, Carson City had a population of 55,274.",
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||
],
|
||||
score_threshold=0.8
|
||||
)
|
||||
|
||||
assert isinstance(result, RerankResult)
|
||||
assert len(result.docs) == 1
|
||||
assert result.docs[0].index == 1
|
||||
assert result.docs[0].score >= 0.8
|
@ -0,0 +1,55 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.sagemaker.text_embedding.text_embedding import SageMakerEmbeddingModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = SageMakerEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='bge-m3',
|
||||
credentials={
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='bge-m3-embedding',
|
||||
credentials={
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = SageMakerEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='bge-m3-embedding',
|
||||
credentials={
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = SageMakerEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='bge-m3-embedding',
|
||||
credentials={
|
||||
},
|
||||
texts=[
|
||||
]
|
||||
)
|
||||
|
||||
assert num_tokens == 0
|
Loading…
x
Reference in New Issue
Block a user