mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 03:45:55 +08:00
feat: Add support for embed file with AWS Bedrock Titan Model (#3377)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
parent
6fa0e4072d
commit
f7a417fdb4
@ -15,6 +15,7 @@ help:
|
|||||||
en_US: https://console.aws.amazon.com/
|
en_US: https://console.aws.amazon.com/
|
||||||
supported_model_types:
|
supported_model_types:
|
||||||
- llm
|
- llm
|
||||||
|
- text-embedding
|
||||||
configurate_methods:
|
configurate_methods:
|
||||||
- predefined-model
|
- predefined-model
|
||||||
provider_credential_schema:
|
provider_credential_schema:
|
||||||
|
@ -0,0 +1 @@
|
|||||||
|
- amazon.titan-embed-text-v1
|
@ -0,0 +1,8 @@
|
|||||||
|
model: amazon.titan-embed-text-v1
|
||||||
|
model_type: text-embedding
|
||||||
|
model_properties:
|
||||||
|
context_size: 8192
|
||||||
|
pricing:
|
||||||
|
input: '0.0001'
|
||||||
|
unit: '0.001'
|
||||||
|
currency: USD
|
@ -0,0 +1,209 @@
|
|||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
from botocore.config import Config
|
||||||
|
from botocore.exceptions import (
|
||||||
|
ClientError,
|
||||||
|
EndpointConnectionError,
|
||||||
|
NoRegionError,
|
||||||
|
ServiceNotInRegionError,
|
||||||
|
UnknownServiceError,
|
||||||
|
)
|
||||||
|
|
||||||
|
from core.model_runtime.entities.model_entities import PriceType
|
||||||
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
|
from core.model_runtime.errors.invoke import (
|
||||||
|
InvokeAuthorizationError,
|
||||||
|
InvokeBadRequestError,
|
||||||
|
InvokeConnectionError,
|
||||||
|
InvokeError,
|
||||||
|
InvokeRateLimitError,
|
||||||
|
InvokeServerUnavailableError,
|
||||||
|
)
|
||||||
|
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
||||||
|
|
||||||
|
|
||||||
|
def _invoke(self, model: str, credentials: dict,
|
||||||
|
texts: list[str], user: Optional[str] = None) \
|
||||||
|
-> TextEmbeddingResult:
|
||||||
|
"""
|
||||||
|
Invoke text embedding model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param texts: texts to embed
|
||||||
|
:param user: unique user id
|
||||||
|
:return: embeddings result
|
||||||
|
"""
|
||||||
|
client_config = Config(
|
||||||
|
region_name=credentials["aws_region"]
|
||||||
|
)
|
||||||
|
|
||||||
|
bedrock_runtime = boto3.client(
|
||||||
|
service_name='bedrock-runtime',
|
||||||
|
config=client_config,
|
||||||
|
aws_access_key_id=credentials["aws_access_key_id"],
|
||||||
|
aws_secret_access_key=credentials["aws_secret_access_key"]
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = []
|
||||||
|
token_usage = 0
|
||||||
|
|
||||||
|
model_prefix = model.split('.')[0]
|
||||||
|
if model_prefix == "amazon":
|
||||||
|
for text in texts:
|
||||||
|
body = {
|
||||||
|
"inputText": text,
|
||||||
|
}
|
||||||
|
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
|
||||||
|
embeddings.extend([response_body.get('embedding')])
|
||||||
|
token_usage += response_body.get('inputTextTokenCount')
|
||||||
|
result = TextEmbeddingResult(
|
||||||
|
model=model,
|
||||||
|
embeddings=embeddings,
|
||||||
|
usage=self._calc_response_usage(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
tokens=token_usage
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||||
|
"""
|
||||||
|
Get number of tokens for given prompt messages
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param texts: texts to embed
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
num_tokens = 0
|
||||||
|
for text in texts:
|
||||||
|
num_tokens += self._get_num_tokens_by_gpt2(text)
|
||||||
|
return num_tokens
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
|
"""
|
||||||
|
Map model invoke error to unified error
|
||||||
|
The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller
|
||||||
|
The value is the md = genai.GenerativeModel(model)error type thrown by the model,
|
||||||
|
which needs to be converted into a unified error type for the caller.
|
||||||
|
|
||||||
|
:return: Invoke emd = genai.GenerativeModel(model)rror mapping
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
InvokeConnectionError: [],
|
||||||
|
InvokeServerUnavailableError: [],
|
||||||
|
InvokeRateLimitError: [],
|
||||||
|
InvokeAuthorizationError: [],
|
||||||
|
InvokeBadRequestError: []
|
||||||
|
}
|
||||||
|
|
||||||
|
def _create_payload(self, model_prefix: str, texts: list[str], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True):
|
||||||
|
"""
|
||||||
|
Create payload for bedrock api call depending on model provider
|
||||||
|
"""
|
||||||
|
payload = dict()
|
||||||
|
|
||||||
|
if model_prefix == "amazon":
|
||||||
|
payload['inputText'] = texts
|
||||||
|
|
||||||
|
|
||||||
|
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||||
|
"""
|
||||||
|
Calculate response usage
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param tokens: input tokens
|
||||||
|
:return: usage
|
||||||
|
"""
|
||||||
|
# get input price info
|
||||||
|
input_price_info = self.get_price(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
price_type=PriceType.INPUT,
|
||||||
|
tokens=tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# transform usage
|
||||||
|
usage = EmbeddingUsage(
|
||||||
|
tokens=tokens,
|
||||||
|
total_tokens=tokens,
|
||||||
|
unit_price=input_price_info.unit_price,
|
||||||
|
price_unit=input_price_info.unit,
|
||||||
|
total_price=input_price_info.total_amount,
|
||||||
|
currency=input_price_info.currency,
|
||||||
|
latency=time.perf_counter() - self.started_at
|
||||||
|
)
|
||||||
|
|
||||||
|
return usage
|
||||||
|
|
||||||
|
def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]:
|
||||||
|
"""
|
||||||
|
Map client error to invoke error
|
||||||
|
|
||||||
|
:param error_code: error code
|
||||||
|
:param error_msg: error message
|
||||||
|
:return: invoke error
|
||||||
|
"""
|
||||||
|
|
||||||
|
if error_code == "AccessDeniedException":
|
||||||
|
return InvokeAuthorizationError(error_msg)
|
||||||
|
elif error_code in ["ResourceNotFoundException", "ValidationException"]:
|
||||||
|
return InvokeBadRequestError(error_msg)
|
||||||
|
elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]:
|
||||||
|
return InvokeRateLimitError(error_msg)
|
||||||
|
elif error_code in ["ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException"]:
|
||||||
|
return InvokeServerUnavailableError(error_msg)
|
||||||
|
elif error_code == "ModelStreamErrorException":
|
||||||
|
return InvokeConnectionError(error_msg)
|
||||||
|
|
||||||
|
return InvokeError(error_msg)
|
||||||
|
|
||||||
|
|
||||||
|
def _invoke_bedrock_embedding(self, model: str, bedrock_runtime, body: dict, ):
|
||||||
|
accept = 'application/json'
|
||||||
|
content_type = 'application/json'
|
||||||
|
try:
|
||||||
|
response = bedrock_runtime.invoke_model(
|
||||||
|
body=json.dumps(body),
|
||||||
|
modelId=model,
|
||||||
|
accept=accept,
|
||||||
|
contentType=content_type
|
||||||
|
)
|
||||||
|
response_body = json.loads(response.get('body').read().decode('utf-8'))
|
||||||
|
return response_body
|
||||||
|
except ClientError as ex:
|
||||||
|
error_code = ex.response['Error']['Code']
|
||||||
|
full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
|
||||||
|
raise self._map_client_to_invoke_error(error_code, full_error_msg)
|
||||||
|
|
||||||
|
except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex:
|
||||||
|
raise InvokeConnectionError(str(ex))
|
||||||
|
|
||||||
|
except UnknownServiceError as ex:
|
||||||
|
raise InvokeServerUnavailableError(str(ex))
|
||||||
|
|
||||||
|
except Exception as ex:
|
||||||
|
raise InvokeError(str(ex))
|
Loading…
x
Reference in New Issue
Block a user