feat: support more model types and builtin tools on aws/sagemaker (#8061)

Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
This commit is contained in:
ybalbert001 2024-09-09 10:34:11 +08:00 committed by GitHub
parent ab7d79275e
commit 954580a4af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 1452 additions and 72 deletions

View File

@ -1,17 +1,36 @@
import json
import logging
from collections.abc import Generator
from typing import Any, Optional, Union
import re
from collections.abc import Generator, Iterator
from typing import Any, Optional, Union, cast
# from openai.types.chat import ChatCompletion, ChatCompletionChunk
import boto3
from sagemaker import Predictor, serializers
from sagemaker.session import Session
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
I18nObject,
ModelFeature,
ModelPropertyKey,
ModelType,
ParameterRule,
ParameterType,
)
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
@ -25,12 +44,140 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
logger = logging.getLogger(__name__)
def inference(predictor, messages:list[dict[str,Any]], params:dict[str,Any], stop:list, stream=False):
"""
params:
predictor : Sagemaker Predictor
messages (List[Dict[str,Any]]): message list
messages = [
{"role": "system", "content":"please answer in Chinese"},
{"role": "user", "content": "who are you? what are you doing?"},
]
params (Dict[str,Any]): model parameters for LLM
stream (bool): False by default
response:
result of inference if stream is False
Iterator of Chunks if stream is True
"""
payload = {
"model" : params.get('model_name'),
"stop" : stop,
"messages": messages,
"stream" : stream,
"max_tokens" : params.get('max_new_tokens', params.get('max_tokens', 2048)),
"temperature" : params.get('temperature', 0.1),
"top_p" : params.get('top_p', 0.9),
}
if not stream:
response = predictor.predict(payload)
return response
else:
response_stream = predictor.predict_stream(payload)
return response_stream
class SageMakerLargeLanguageModel(LargeLanguageModel):
"""
Model class for Cohere large language model.
"""
sagemaker_client: Any = None
sagemaker_sess : Any = None
predictor : Any = None
def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
resp: bytes) -> LLMResult:
"""
handle normal chat generate response
"""
resp_obj = json.loads(resp.decode('utf-8'))
resp_str = resp_obj.get('choices')[0].get('message').get('content')
if len(resp_str) == 0:
raise InvokeServerUnavailableError("Empty response")
assistant_prompt_message = AssistantPromptMessage(
content=resp_str,
tool_calls=[]
)
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools)
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens)
response = LLMResult(
model=model,
prompt_messages=prompt_messages,
system_fingerprint=None,
usage=usage,
message=assistant_prompt_message,
)
return response
def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
resp: Iterator[bytes]) -> Generator:
"""
handle stream chat generate response
"""
full_response = ''
buffer = ""
for chunk_bytes in resp:
buffer += chunk_bytes.decode('utf-8')
last_idx = 0
for match in re.finditer(r'^data:\s*(.+?)(\n\n)', buffer):
try:
data = json.loads(match.group(1).strip())
last_idx = match.span()[1]
if "content" in data["choices"][0]["delta"]:
chunk_content = data["choices"][0]["delta"]["content"]
assistant_prompt_message = AssistantPromptMessage(
content=chunk_content,
tool_calls=[]
)
if data["choices"][0]['finish_reason'] is not None:
temp_assistant_prompt_message = AssistantPromptMessage(
content=full_response,
tool_calls=[]
)
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[])
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
system_fingerprint=None,
delta=LLMResultChunkDelta(
index=0,
message=assistant_prompt_message,
finish_reason=data["choices"][0]['finish_reason'],
usage=usage
),
)
else:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
system_fingerprint=None,
delta=LLMResultChunkDelta(
index=0,
message=assistant_prompt_message
),
)
full_response += chunk_content
except (json.JSONDecodeError, KeyError, IndexError) as e:
logger.info("json parse exception, content: {}".format(match.group(1).strip()))
pass
buffer = buffer[last_idx:]
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
@ -50,9 +197,6 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
:param user: unique user id
:return: full response or stream response chunk generator result
"""
# get model mode
model_mode = self.get_model_mode(model, credentials)
if not self.sagemaker_client:
access_key = credentials.get('access_key')
secret_key = credentials.get('secret_key')
@ -68,37 +212,132 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
else:
self.sagemaker_client = boto3.client("sagemaker-runtime")
sagemaker_session = Session(sagemaker_runtime_client=self.sagemaker_client)
self.predictor = Predictor(
endpoint_name=credentials.get('sagemaker_endpoint'),
sagemaker_session=sagemaker_session,
serializer=serializers.JSONSerializer(),
)
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
response_model = self.sagemaker_client.invoke_endpoint(
EndpointName=sagemaker_endpoint,
Body=json.dumps(
{
"inputs": prompt_messages[0].content,
"parameters": { "stop" : stop},
"history" : []
}
),
ContentType="application/json",
)
assistant_text = response_model['Body'].read().decode('utf8')
messages:list[dict[str,Any]] = [ {"role": p.role.value, "content": p.content} for p in prompt_messages ]
response = inference(predictor=self.predictor, messages=messages, params=model_parameters, stop=stop, stream=stream)
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=assistant_text
)
if stream:
if tools and len(tools) > 0:
raise InvokeBadRequestError(f"{model}'s tool calls does not support stream mode")
usage = self._calc_response_usage(model, credentials, 0, 0)
return self._handle_chat_stream_response(model=model, credentials=credentials,
prompt_messages=prompt_messages,
tools=tools, resp=response)
return self._handle_chat_generate_response(model=model, credentials=credentials,
prompt_messages=prompt_messages,
tools=tools, resp=response)
response = LLMResult(
model=model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage
)
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
"""
Convert PromptMessage to dict for OpenAI Compatibility API
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
sub_messages = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(PromptMessageContent, message_content)
sub_message_dict = {
"type": "text",
"text": message_content.data
}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
sub_message_dict = {
"type": "image_url",
"image_url": {
"url": message_content.data,
"detail": message_content.detail.value
}
}
sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls and len(message.tool_calls) > 0:
message_dict["function_call"] = {
"name": message.tool_calls[0].function.name,
"arguments": message.tool_calls[0].function.arguments
}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content}
else:
raise ValueError(f"Unknown message type {type(message)}")
return response
return message_dict
def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool],
is_completion_model: bool = False) -> int:
def tokens(text: str):
return self._get_num_tokens_by_gpt2(text)
if is_completion_model:
return sum(tokens(str(message.content)) for message in messages)
tokens_per_message = 3
tokens_per_name = 1
num_tokens = 0
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():
if isinstance(value, list):
text = ''
for item in value:
if isinstance(item, dict) and item['type'] == 'text':
text += item['text']
value = text
if key == "tool_calls":
for tool_call in value:
for t_key, t_value in tool_call.items():
num_tokens += tokens(t_key)
if t_key == "function":
for f_key, f_value in t_value.items():
num_tokens += tokens(f_key)
num_tokens += tokens(f_value)
else:
num_tokens += tokens(t_key)
num_tokens += tokens(t_value)
if key == "function_call":
for t_key, t_value in value.items():
num_tokens += tokens(t_key)
if t_key == "function":
for f_key, f_value in t_value.items():
num_tokens += tokens(f_key)
num_tokens += tokens(f_value)
else:
num_tokens += tokens(t_key)
num_tokens += tokens(t_value)
else:
num_tokens += tokens(str(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3
if tools:
num_tokens += self._num_tokens_for_tools(tools)
return num_tokens
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
@ -112,10 +351,8 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
:return:
"""
# get model mode
model_mode = self.get_model_mode(model)
try:
return 0
return self._num_tokens_from_messages(prompt_messages, tools)
except Exception as e:
raise self._transform_invoke_error(e)
@ -129,7 +366,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
"""
try:
# get model mode
model_mode = self.get_model_mode(model)
pass
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@ -200,13 +437,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
)
]
completion_type = LLMMode.value_of(credentials["mode"])
if completion_type == LLMMode.CHAT:
print(f"completion_type : {LLMMode.CHAT.value}")
if completion_type == LLMMode.COMPLETION:
print(f"completion_type : {LLMMode.COMPLETION.value}")
completion_type = LLMMode.value_of(credentials["mode"]).value
features = []

View File

@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
class SageMakerRerankModel(RerankModel):
"""
Model class for Cohere rerank model.
Model class for SageMaker rerank model.
"""
sagemaker_client: Any = None

View File

@ -1,10 +1,11 @@
import logging
import uuid
from typing import IO, Any
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class SageMakerProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
@ -15,3 +16,28 @@ class SageMakerProvider(ModelProvider):
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
pass
def buffer_to_s3(s3_client:Any, file: IO[bytes], bucket:str, s3_prefix:str) -> str:
'''
return s3_uri of this file
'''
s3_key = f'{s3_prefix}{uuid.uuid4()}.mp3'
s3_client.put_object(
Body=file.read(),
Bucket=bucket,
Key=s3_key,
ContentType='audio/mp3'
)
return s3_key
def generate_presigned_url(s3_client:Any, file: IO[bytes], bucket_name:str, s3_prefix:str, expiration=600) -> str:
object_key = buffer_to_s3(s3_client, file, bucket_name, s3_prefix)
try:
response = s3_client.generate_presigned_url('get_object',
Params={'Bucket': bucket_name, 'Key': object_key},
ExpiresIn=expiration)
except Exception as e:
print(f"Error generating presigned URL: {e}")
return None
return response

View File

@ -21,6 +21,8 @@ supported_model_types:
- llm
- text-embedding
- rerank
- speech2text
- tts
configurate_methods:
- customizable-model
model_credential_schema:
@ -45,14 +47,10 @@ model_credential_schema:
zh_Hans: 选择对话类型
en_US: Select completion mode
options:
- value: completion
label:
en_US: Completion
zh_Hans: 补全
- value: chat
label:
en_US: Chat
zh_Hans: 对话
zh_Hans: Chat
- variable: sagemaker_endpoint
label:
en_US: sagemaker endpoint
@ -61,6 +59,76 @@ model_credential_schema:
placeholder:
zh_Hans: 请输出你的Sagemaker推理端点
en_US: Enter your Sagemaker Inference endpoint
- variable: audio_s3_cache_bucket
show_on:
- variable: __model_type
value: speech2text
label:
zh_Hans: 音频缓存桶(s3 bucket)
en_US: audio cache bucket(s3 bucket)
type: text-input
required: true
placeholder:
zh_Hans: sagemaker-us-east-1-******207838
en_US: sagemaker-us-east-1-*******7838
- variable: audio_model_type
show_on:
- variable: __model_type
value: tts
label:
en_US: Audio model type
type: select
required: true
placeholder:
zh_Hans: 语音模型类型
en_US: Audio model type
options:
- value: PresetVoice
label:
en_US: preset voice
zh_Hans: 内置音色
- value: CloneVoice
label:
en_US: clone voice
zh_Hans: 克隆音色
- value: CloneVoice_CrossLingual
label:
en_US: crosslingual clone voice
zh_Hans: 跨语种克隆音色
- value: InstructVoice
label:
en_US: Instruct voice
zh_Hans: 文字指令音色
- variable: prompt_audio
show_on:
- variable: __model_type
value: tts
label:
en_US: Mock Audio Source
type: text-input
required: false
placeholder:
zh_Hans: 被模仿的音色音频
en_US: source audio to be mocked
- variable: prompt_text
show_on:
- variable: __model_type
value: tts
label:
en_US: Prompt Audio Text
type: text-input
required: false
placeholder:
zh_Hans: 模仿音色的对应文本
en_US: text for the mocked source audio
- variable: instruct_text
show_on:
- variable: __model_type
value: tts
label:
en_US: instruct text for speaker
type: text-input
required: false
- variable: aws_access_key_id
required: false
label:

View File

@ -0,0 +1,142 @@
import json
import logging
from typing import IO, Any, Optional
import boto3
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from core.model_runtime.model_providers.sagemaker.sagemaker import generate_presigned_url
logger = logging.getLogger(__name__)
class SageMakerSpeech2TextModel(Speech2TextModel):
"""
Model class for Xinference speech to text model.
"""
sagemaker_client: Any = None
s3_client : Any = None
def _invoke(self, model: str, credentials: dict,
file: IO[bytes], user: Optional[str] = None) \
-> str:
"""
Invoke speech2text model
:param model: model name
:param credentials: model credentials
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
asr_text = None
try:
if not self.sagemaker_client:
access_key = credentials.get('aws_access_key_id')
secret_key = credentials.get('aws_secret_access_key')
aws_region = credentials.get('aws_region')
if aws_region:
if access_key and secret_key:
self.sagemaker_client = boto3.client("sagemaker-runtime",
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=aws_region)
self.s3_client = boto3.client("s3",
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
self.s3_client = boto3.client("s3", region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime")
self.s3_client = boto3.client("s3")
s3_prefix='dify/speech2text/'
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
bucket = credentials.get('audio_s3_cache_bucket')
s3_presign_url = generate_presigned_url(self.s3_client, file, bucket, s3_prefix)
payload = {
"audio_s3_presign_uri" : s3_presign_url
}
response_model = self.sagemaker_client.invoke_endpoint(
EndpointName=sagemaker_endpoint,
Body=json.dumps(payload),
ContentType="application/json"
)
json_str = response_model['Body'].read().decode('utf8')
json_obj = json.loads(json_str)
asr_text = json_obj['text']
except Exception as e:
logger.exception(f'Exception {e}, line : {line}')
return asr_text
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
pass
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
InvokeBadRequestError,
KeyError,
ValueError
]
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.SPEECH2TEXT,
model_properties={ },
parameter_rules=[]
)
return entity

View File

@ -0,0 +1,287 @@
import concurrent.futures
import copy
import json
import logging
from enum import Enum
from typing import Any, Optional
import boto3
import requests
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.model_providers.__base.tts_model import TTSModel
logger = logging.getLogger(__name__)
class TTSModelType(Enum):
PresetVoice = "PresetVoice"
CloneVoice = "CloneVoice"
CloneVoice_CrossLingual = "CloneVoice_CrossLingual"
InstructVoice = "InstructVoice"
class SageMakerText2SpeechModel(TTSModel):
sagemaker_client: Any = None
s3_client : Any = None
comprehend_client : Any = None
def __init__(self):
# preset voices, need support custom voice
self.model_voices = {
'__default': {
'all': [
{'name': 'Default', 'value': 'default'},
]
},
'CosyVoice': {
'zh-Hans': [
{'name': '中文男', 'value': '中文男'},
{'name': '中文女', 'value': '中文女'},
{'name': '粤语女', 'value': '粤语女'},
],
'zh-Hant': [
{'name': '中文男', 'value': '中文男'},
{'name': '中文女', 'value': '中文女'},
{'name': '粤语女', 'value': '粤语女'},
],
'en-US': [
{'name': '英文男', 'value': '英文男'},
{'name': '英文女', 'value': '英文女'},
],
'ja-JP': [
{'name': '日语男', 'value': '日语男'},
],
'ko-KR': [
{'name': '韩语女', 'value': '韩语女'},
]
}
}
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
pass
def _detect_lang_code(self, content:str, map_dict:dict=None):
map_dict = {
"zh" : "<|zh|>",
"en" : "<|en|>",
"ja" : "<|jp|>",
"zh-TW" : "<|yue|>",
"ko" : "<|ko|>"
}
response = self.comprehend_client.detect_dominant_language(Text=content)
language_code = response['Languages'][0]['LanguageCode']
return map_dict.get(language_code, '<|zh|>')
def _build_tts_payload(self, model_type:str, content_text:str, model_role:str, prompt_text:str, prompt_audio:str, instruct_text:str):
if model_type == TTSModelType.PresetVoice.value and model_role:
return { "tts_text" : content_text, "role" : model_role }
if model_type == TTSModelType.CloneVoice.value and prompt_text and prompt_audio:
return { "tts_text" : content_text, "prompt_text": prompt_text, "prompt_audio" : prompt_audio }
if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio:
lang_tag = self._detect_lang_code(content_text)
return { "tts_text" : f"{content_text}", "prompt_audio" : prompt_audio, "lang_tag" : lang_tag }
if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role:
return { "tts_text" : content_text, "role" : model_role, "instruct_text" : instruct_text }
raise RuntimeError(f"Invalid params for {model_type}")
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str,
user: Optional[str] = None):
"""
_invoke text2speech model
:param model: model name
:param tenant_id: user tenant id
:param credentials: model credentials
:param voice: model timbre
:param content_text: text content to be translated
:param user: unique user id
:return: text translated to audio file
"""
if not self.sagemaker_client:
access_key = credentials.get('aws_access_key_id')
secret_key = credentials.get('aws_secret_access_key')
aws_region = credentials.get('aws_region')
if aws_region:
if access_key and secret_key:
self.sagemaker_client = boto3.client("sagemaker-runtime",
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=aws_region)
self.s3_client = boto3.client("s3",
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=aws_region)
self.comprehend_client = boto3.client('comprehend',
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
self.s3_client = boto3.client("s3", region_name=aws_region)
self.comprehend_client = boto3.client('comprehend', region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime")
self.s3_client = boto3.client("s3")
self.comprehend_client = boto3.client('comprehend')
model_type = credentials.get('audio_model_type', 'PresetVoice')
prompt_text = credentials.get('prompt_text')
prompt_audio = credentials.get('prompt_audio')
instruct_text = credentials.get('instruct_text')
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
payload = self._build_tts_payload(
model_type,
content_text,
voice,
prompt_text,
prompt_audio,
instruct_text
)
return self._tts_invoke_streaming(model_type, payload, sagemaker_endpoint)
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TTS,
model_properties={},
parameter_rules=[]
)
return entity
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
InvokeBadRequestError,
KeyError,
ValueError
]
}
def _get_model_default_voice(self, model: str, credentials: dict) -> any:
return ""
def _get_model_word_limit(self, model: str, credentials: dict) -> int:
return 15
def _get_model_audio_type(self, model: str, credentials: dict) -> str:
return "mp3"
def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
return 5
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
audio_model_name = 'CosyVoice'
for key, voices in self.model_voices.items():
if key in audio_model_name:
if language and language in voices:
return voices[language]
elif 'all' in voices:
return voices['all']
return self.model_voices['__default']['all']
def _invoke_sagemaker(self, payload:dict, endpoint:str):
response_model = self.sagemaker_client.invoke_endpoint(
EndpointName=endpoint,
Body=json.dumps(payload),
ContentType="application/json",
)
json_str = response_model['Body'].read().decode('utf8')
json_obj = json.loads(json_str)
return json_obj
def _tts_invoke_streaming(self, model_type:str, payload:dict, sagemaker_endpoint:str) -> any:
"""
_tts_invoke_streaming text2speech model
:param model: model name
:param credentials: model credentials
:param content_text: text content to be translated
:param voice: model timbre
:return: text translated to audio file
"""
try:
lang_tag = ''
if model_type == TTSModelType.CloneVoice_CrossLingual.value:
lang_tag = payload.pop('lang_tag')
word_limit = self._get_model_word_limit(model='', credentials={})
content_text = payload.get("tts_text")
if len(content_text) > word_limit:
split_sentences = self._split_text_into_sentences(content_text, max_length=word_limit)
sentences = [ f"{lang_tag}{s}" for s in split_sentences if len(s) ]
len_sent = len(sentences)
executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(4, len_sent))
payloads = [ copy.deepcopy(payload) for i in range(len_sent) ]
for idx in range(len_sent):
payloads[idx]["tts_text"] = sentences[idx]
futures = [ executor.submit(
self._invoke_sagemaker,
payload=payload,
endpoint=sagemaker_endpoint,
)
for payload in payloads]
for index, future in enumerate(futures):
resp = future.result()
audio_bytes = requests.get(resp.get('s3_presign_url')).content
for i in range(0, len(audio_bytes), 1024):
yield audio_bytes[i:i + 1024]
else:
resp = self._invoke_sagemaker(payload, sagemaker_endpoint)
audio_bytes = requests.get(resp.get('s3_presign_url')).content
for i in range(0, len(audio_bytes), 1024):
yield audio_bytes[i:i + 1024]
except Exception as ex:
raise InvokeBadRequestError(str(ex))

View File

@ -3,6 +3,7 @@ import logging
from typing import Any, Union
import boto3
from botocore.exceptions import BotoCoreError
from pydantic import BaseModel, Field
from core.tools.entities.tool_entities import ToolInvokeMessage
@ -16,7 +17,7 @@ class GuardrailParameters(BaseModel):
guardrail_version: str = Field(..., description="The version of the guardrail")
source: str = Field(..., description="The source of the content")
text: str = Field(..., description="The text to apply the guardrail to")
aws_region: str = Field(default="us-east-1", description="AWS region for the Bedrock client")
aws_region: str = Field(..., description="AWS region for the Bedrock client")
class ApplyGuardrailTool(BuiltinTool):
def _invoke(self,
@ -40,6 +41,8 @@ class ApplyGuardrailTool(BuiltinTool):
source=params.source,
content=[{"text": {"text": params.text}}]
)
logger.info(f"Raw response from AWS: {json.dumps(response, indent=2)}")
# Check for empty response
if not response:
@ -69,7 +72,7 @@ class ApplyGuardrailTool(BuiltinTool):
return self.create_text_message(text=result)
except boto3.exceptions.BotoCoreError as e:
except BotoCoreError as e:
error_message = f'AWS service error: {str(e)}'
logger.error(error_message, exc_info=True)
return self.create_text_message(text=error_message)
@ -80,4 +83,4 @@ class ApplyGuardrailTool(BuiltinTool):
except Exception as e:
error_message = f'An unexpected error occurred: {str(e)}'
logger.error(error_message, exc_info=True)
return self.create_text_message(text=error_message)
return self.create_text_message(text=error_message)

View File

@ -54,3 +54,14 @@ parameters:
zh_Hans: 用于请求护栏审查的内容,可以是用户输入或 LLM 输出。
llm_description: The content used for requesting guardrail review, which can be either user input or LLM output.
form: llm
- name: aws_region
type: string
required: true
label:
en_US: AWS Region
zh_Hans: AWS 区域
human_description:
en_US: Please enter the AWS region for the Bedrock client, for example 'us-east-1'.
zh_Hans: 请输入 Bedrock 客户端的 AWS 区域,例如 'us-east-1'。
llm_description: Please enter the AWS region for the Bedrock client, for example 'us-east-1'.
form: form

View File

@ -0,0 +1,71 @@
import json
import logging
from typing import Any, Union
import boto3
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
console_handler = logging.StreamHandler()
logger.addHandler(console_handler)
class LambdaYamlToJsonTool(BuiltinTool):
lambda_client: Any = None
def _invoke_lambda(self, lambda_name: str, yaml_content: str) -> str:
msg = {
"body": yaml_content
}
logger.info(json.dumps(msg))
invoke_response = self.lambda_client.invoke(FunctionName=lambda_name,
InvocationType='RequestResponse',
Payload=json.dumps(msg))
response_body = invoke_response['Payload']
response_str = response_body.read().decode("utf-8")
resp_json = json.loads(response_str)
logger.info(resp_json)
if resp_json['statusCode'] != 200:
raise Exception(f"Invalid status code: {response_str}")
return resp_json['body']
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
try:
if not self.lambda_client:
aws_region = tool_parameters.get('aws_region') # todo: move aws_region out, and update client region
if aws_region:
self.lambda_client = boto3.client("lambda", region_name=aws_region)
else:
self.lambda_client = boto3.client("lambda")
yaml_content = tool_parameters.get('yaml_content', '')
if not yaml_content:
return self.create_text_message('Please input yaml_content')
lambda_name = tool_parameters.get('lambda_name', '')
if not lambda_name:
return self.create_text_message('Please input lambda_name')
logger.debug(f'{json.dumps(tool_parameters, indent=2, ensure_ascii=False)}')
result = self._invoke_lambda(lambda_name, yaml_content)
logger.debug(result)
return self.create_text_message(result)
except Exception as e:
return self.create_text_message(f'Exception: {str(e)}')
console_handler.flush()

View File

@ -0,0 +1,53 @@
identity:
name: lambda_yaml_to_json
author: AWS
label:
en_US: LambdaYamlToJson
zh_Hans: LambdaYamlToJson
pt_BR: LambdaYamlToJson
icon: icon.svg
description:
human:
en_US: A tool to convert yaml to json using AWS Lambda.
zh_Hans: 将 YAML 转为 JSON 的工具通过AWS Lambda
pt_BR: A tool to convert yaml to json using AWS Lambda.
llm: A tool to convert yaml to json.
parameters:
- name: yaml_content
type: string
required: true
label:
en_US: YAML content to convert for
zh_Hans: YAML 内容
pt_BR: YAML content to convert for
human_description:
en_US: YAML content to convert for
zh_Hans: YAML 内容
pt_BR: YAML content to convert for
llm_description: YAML content to convert for
form: llm
- name: aws_region
type: string
required: false
label:
en_US: region of lambda
zh_Hans: Lambda 所在的region
pt_BR: region of lambda
human_description:
en_US: region of lambda
zh_Hans: Lambda 所在的region
pt_BR: region of lambda
llm_description: region of lambda
form: form
- name: lambda_name
type: string
required: false
label:
en_US: name of lambda
zh_Hans: Lambda 名称
pt_BR: name of lambda
human_description:
en_US: name of lambda
zh_Hans: Lambda 名称
pt_BR: name of lambda
form: form

View File

@ -78,9 +78,7 @@ class SageMakerReRankTool(BuiltinTool):
sorted_candidate_docs = sorted(candidate_docs, key=lambda x: x['score'], reverse=True)
line = 9
results_str = json.dumps(sorted_candidate_docs[:self.topk], ensure_ascii=False)
return self.create_text_message(text=results_str)
return [ self.create_json_message(res) for res in sorted_candidate_docs[:self.topk] ]
except Exception as e:
return self.create_text_message(f'Exception {str(e)}, line : {line}')
return self.create_text_message(f'Exception {str(e)}, line : {line}')

View File

@ -0,0 +1,95 @@
import json
from enum import Enum
from typing import Any, Union
import boto3
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class TTSModelType(Enum):
PresetVoice = "PresetVoice"
CloneVoice = "CloneVoice"
CloneVoice_CrossLingual = "CloneVoice_CrossLingual"
InstructVoice = "InstructVoice"
class SageMakerTTSTool(BuiltinTool):
sagemaker_client: Any = None
sagemaker_endpoint:str = None
s3_client : Any = None
comprehend_client : Any = None
def _detect_lang_code(self, content:str, map_dict:dict=None):
map_dict = {
"zh" : "<|zh|>",
"en" : "<|en|>",
"ja" : "<|jp|>",
"zh-TW" : "<|yue|>",
"ko" : "<|ko|>"
}
response = self.comprehend_client.detect_dominant_language(Text=content)
language_code = response['Languages'][0]['LanguageCode']
return map_dict.get(language_code, '<|zh|>')
def _build_tts_payload(self, model_type:str, content_text:str, model_role:str, prompt_text:str, prompt_audio:str, instruct_text:str):
if model_type == TTSModelType.PresetVoice.value and model_role:
return { "tts_text" : content_text, "role" : model_role }
if model_type == TTSModelType.CloneVoice.value and prompt_text and prompt_audio:
return { "tts_text" : content_text, "prompt_text": prompt_text, "prompt_audio" : prompt_audio }
if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio:
lang_tag = self._detect_lang_code(content_text)
return { "tts_text" : f"{content_text}", "prompt_audio" : prompt_audio, "lang_tag" : lang_tag }
if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role:
return { "tts_text" : content_text, "role" : model_role, "instruct_text" : instruct_text }
raise RuntimeError(f"Invalid params for {model_type}")
def _invoke_sagemaker(self, payload:dict, endpoint:str):
response_model = self.sagemaker_client.invoke_endpoint(
EndpointName=endpoint,
Body=json.dumps(payload),
ContentType="application/json",
)
json_str = response_model['Body'].read().decode('utf8')
json_obj = json.loads(json_str)
return json_obj
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
try:
if not self.sagemaker_client:
aws_region = tool_parameters.get('aws_region')
if aws_region:
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
self.s3_client = boto3.client("s3", region_name=aws_region)
self.comprehend_client = boto3.client('comprehend', region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime")
self.s3_client = boto3.client("s3")
self.comprehend_client = boto3.client('comprehend')
if not self.sagemaker_endpoint:
self.sagemaker_endpoint = tool_parameters.get('sagemaker_endpoint')
tts_text = tool_parameters.get('tts_text')
tts_infer_type = tool_parameters.get('tts_infer_type')
voice = tool_parameters.get('voice')
mock_voice_audio = tool_parameters.get('mock_voice_audio')
mock_voice_text = tool_parameters.get('mock_voice_text')
voice_instruct_prompt = tool_parameters.get('voice_instruct_prompt')
payload = self._build_tts_payload(tts_infer_type, tts_text, voice, mock_voice_text, mock_voice_audio, voice_instruct_prompt)
result = self._invoke_sagemaker(payload, self.sagemaker_endpoint)
return self.create_text_message(text=result['s3_presign_url'])
except Exception as e:
return self.create_text_message(f'Exception {str(e)}')

View File

@ -0,0 +1,149 @@
identity:
name: sagemaker_tts
author: AWS
label:
en_US: SagemakerTTS
zh_Hans: Sagemaker语音合成
pt_BR: SagemakerTTS
icon: icon.svg
description:
human:
en_US: A tool for Speech synthesis - https://github.com/aws-samples/dify-aws-tool
zh_Hans: Sagemaker语音合成工具, 请参考 Github Repo - https://github.com/aws-samples/dify-aws-tool上的部署脚本
pt_BR: A tool for Speech synthesis.
llm: A tool for Speech synthesis. You can find deploy notebook on Github Repo - https://github.com/aws-samples/dify-aws-tool
parameters:
- name: sagemaker_endpoint
type: string
required: true
label:
en_US: sagemaker endpoint for tts
zh_Hans: 语音生成的SageMaker端点
pt_BR: sagemaker endpoint for tts
human_description:
en_US: sagemaker endpoint for tts
zh_Hans: 语音生成的SageMaker端点
pt_BR: sagemaker endpoint for tts
llm_description: sagemaker endpoint for tts
form: form
- name: tts_text
type: string
required: true
label:
en_US: tts text
zh_Hans: 语音合成原文
pt_BR: tts text
human_description:
en_US: tts text
zh_Hans: 语音合成原文
pt_BR: tts text
llm_description: tts text
form: llm
- name: tts_infer_type
type: select
required: false
label:
en_US: tts infer type
zh_Hans: 合成方式
pt_BR: tts infer type
human_description:
en_US: tts infer type
zh_Hans: 合成方式
pt_BR: tts infer type
llm_description: tts infer type
options:
- value: PresetVoice
label:
en_US: preset voice
zh_Hans: 预置音色
- value: CloneVoice
label:
en_US: clone voice
zh_Hans: 克隆音色
- value: CloneVoice_CrossLingual
label:
en_US: clone crossLingual voice
zh_Hans: 克隆音色(跨语言)
- value: InstructVoice
label:
en_US: instruct voice
zh_Hans: 指令音色
form: form
- name: voice
type: select
required: false
label:
en_US: preset voice
zh_Hans: 预置音色
pt_BR: preset voice
human_description:
en_US: preset voice
zh_Hans: 预置音色
pt_BR: preset voice
llm_description: preset voice
options:
- value: 中文男
label:
en_US: zh-cn male
zh_Hans: 中文男
- value: 中文女
label:
en_US: zh-cn female
zh_Hans: 中文女
- value: 粤语女
label:
en_US: zh-TW female
zh_Hans: 粤语女
form: form
- name: mock_voice_audio
type: string
required: false
label:
en_US: clone voice link
zh_Hans: 克隆音频链接
pt_BR: clone voice link
human_description:
en_US: clone voice link
zh_Hans: 克隆音频链接
pt_BR: clone voice link
llm_description: clone voice link
form: llm
- name: mock_voice_text
type: string
required: false
label:
en_US: text of clone voice
zh_Hans: 克隆音频对应文本
pt_BR: text of clone voice
human_description:
en_US: text of clone voice
zh_Hans: 克隆音频对应文本
pt_BR: text of clone voice
llm_description: text of clone voice
form: llm
- name: voice_instruct_prompt
type: string
required: false
label:
en_US: instruct prompt for voice
zh_Hans: 音色指令文本
pt_BR: instruct prompt for voice
human_description:
en_US: instruct prompt for voice
zh_Hans: 音色指令文本
pt_BR: instruct prompt for voice
llm_description: instruct prompt for voice
form: llm
- name: aws_region
type: string
required: false
label:
en_US: region of sagemaker endpoint
zh_Hans: SageMaker 端点所在的region
pt_BR: region of sagemaker endpoint
human_description:
en_US: region of sagemaker endpoint
zh_Hans: SageMaker 端点所在的region
pt_BR: region of sagemaker endpoint
llm_description: region of sagemaker endpoint
form: form

275
api/poetry.lock generated
View File

@ -520,22 +520,22 @@ files = [
[[package]]
name = "attrs"
version = "24.2.0"
version = "23.2.0"
description = "Classes Without Boilerplate"
optional = false
python-versions = ">=3.7"
files = [
{file = "attrs-24.2.0-py3-none-any.whl", hash = "sha256:81921eb96de3191c8258c199618104dd27ac608d9366f5e35d011eae1867ede2"},
{file = "attrs-24.2.0.tar.gz", hash = "sha256:5cfb1b9148b5b086569baec03f20d7b6bf3bcacc9a42bebf87ffaaca362f6346"},
{file = "attrs-23.2.0-py3-none-any.whl", hash = "sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1"},
{file = "attrs-23.2.0.tar.gz", hash = "sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30"},
]
[package.extras]
benchmark = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
cov = ["cloudpickle", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
dev = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier (<24.7)"]
tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"]
cov = ["attrs[tests]", "coverage[toml] (>=5.3)"]
dev = ["attrs[tests]", "pre-commit"]
docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope-interface"]
tests = ["attrs[tests-no-zope]", "zope-interface"]
tests-mypy = ["mypy (>=1.6)", "pytest-mypy-plugins"]
tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"]
[[package]]
name = "authlib"
@ -1719,6 +1719,17 @@ lz4 = ["clickhouse-cityhash (>=1.0.2.1)", "lz4", "lz4 (<=3.0.1)"]
numpy = ["numpy (>=1.12.0)", "pandas (>=0.24.0)"]
zstd = ["clickhouse-cityhash (>=1.0.2.1)", "zstd"]
[[package]]
name = "cloudpickle"
version = "2.2.1"
description = "Extended pickling support for Python objects"
optional = false
python-versions = ">=3.6"
files = [
{file = "cloudpickle-2.2.1-py3-none-any.whl", hash = "sha256:61f594d1f4c295fa5cd9014ceb3a1fc4a70b0de1164b94fbc2d854ccba056f9f"},
{file = "cloudpickle-2.2.1.tar.gz", hash = "sha256:d89684b8de9e34a2a43b3460fbca07d09d6e25ce858df4d5a44240403b6178f5"},
]
[[package]]
name = "cloudscraper"
version = "1.2.71"
@ -2151,6 +2162,21 @@ wrapt = ">=1.10,<2"
[package.extras]
dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"]
[[package]]
name = "dill"
version = "0.3.8"
description = "serialize all of Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7"},
{file = "dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca"},
]
[package.extras]
graph = ["objgraph (>=1.7.2)"]
profile = ["gprof2dot (>=2022.7.29)"]
[[package]]
name = "distro"
version = "1.9.0"
@ -2162,6 +2188,28 @@ files = [
{file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"},
]
[[package]]
name = "docker"
version = "7.1.0"
description = "A Python library for the Docker Engine API."
optional = false
python-versions = ">=3.8"
files = [
{file = "docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0"},
{file = "docker-7.1.0.tar.gz", hash = "sha256:ad8c70e6e3f8926cb8a92619b832b4ea5299e2831c14284663184e200546fa6c"},
]
[package.dependencies]
pywin32 = {version = ">=304", markers = "sys_platform == \"win32\""}
requests = ">=2.26.0"
urllib3 = ">=1.26.0"
[package.extras]
dev = ["coverage (==7.2.7)", "pytest (==7.4.2)", "pytest-cov (==4.1.0)", "pytest-timeout (==2.1.0)", "ruff (==0.1.8)"]
docs = ["myst-parser (==0.18.0)", "sphinx (==5.1.1)"]
ssh = ["paramiko (>=2.4.3)"]
websockets = ["websocket-client (>=1.3.0)"]
[[package]]
name = "docstring-parser"
version = "0.16"
@ -3309,6 +3357,21 @@ typing-extensions = "*"
[package.extras]
dev = ["Pillow", "absl-py", "black", "ipython", "nose2", "pandas", "pytype", "pyyaml"]
[[package]]
name = "google-pasta"
version = "0.2.0"
description = "pasta is an AST-based Python refactoring library"
optional = false
python-versions = "*"
files = [
{file = "google-pasta-0.2.0.tar.gz", hash = "sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e"},
{file = "google_pasta-0.2.0-py2-none-any.whl", hash = "sha256:4612951da876b1a10fe3960d7226f0c7682cf901e16ac06e473b267a5afa8954"},
{file = "google_pasta-0.2.0-py3-none-any.whl", hash = "sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed"},
]
[package.dependencies]
six = "*"
[[package]]
name = "google-resumable-media"
version = "2.7.2"
@ -3930,22 +3993,22 @@ files = [
[[package]]
name = "importlib-metadata"
version = "8.4.0"
version = "6.11.0"
description = "Read metadata from Python packages"
optional = false
python-versions = ">=3.8"
files = [
{file = "importlib_metadata-8.4.0-py3-none-any.whl", hash = "sha256:66f342cc6ac9818fc6ff340576acd24d65ba0b3efabb2b4ac08b598965a4a2f1"},
{file = "importlib_metadata-8.4.0.tar.gz", hash = "sha256:9a547d3bc3608b025f93d403fdd1aae741c24fbb8314df4b155675742ce303c5"},
{file = "importlib_metadata-6.11.0-py3-none-any.whl", hash = "sha256:f0afba6205ad8f8947c7d338b5342d5db2afbfd82f9cbef7879a9539cc12eb9b"},
{file = "importlib_metadata-6.11.0.tar.gz", hash = "sha256:1231cf92d825c9e03cfc4da076a16de6422c863558229ea0b22b675657463443"},
]
[package.dependencies]
zipp = ">=0.5"
[package.extras]
doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"]
perf = ["ipython"]
test = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"]
testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"]
[[package]]
name = "importlib-resources"
@ -4929,6 +4992,22 @@ files = [
[package.extras]
test = ["mypy (>=1.0)", "pytest (>=7.0.0)"]
[[package]]
name = "mock"
version = "4.0.3"
description = "Rolling backport of unittest.mock for all Pythons"
optional = false
python-versions = ">=3.6"
files = [
{file = "mock-4.0.3-py3-none-any.whl", hash = "sha256:122fcb64ee37cfad5b3f48d7a7d51875d7031aaf3d8be7c42e2bee25044eee62"},
{file = "mock-4.0.3.tar.gz", hash = "sha256:7d3fbbde18228f4ff2f1f119a45cdffa458b4c0dee32eb4d2bb2f82554bac7bc"},
]
[package.extras]
build = ["blurb", "twine", "wheel"]
docs = ["sphinx"]
test = ["pytest (<5.4)", "pytest-cov"]
[[package]]
name = "monotonic"
version = "1.6"
@ -5128,6 +5207,30 @@ files = [
{file = "multidict-6.0.5.tar.gz", hash = "sha256:f7e301075edaf50500f0b341543c41194d8df3ae5caf4702f2095f3ca73dd8da"},
]
[[package]]
name = "multiprocess"
version = "0.70.16"
description = "better multiprocessing and multithreading in Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "multiprocess-0.70.16-pp310-pypy310_pp73-macosx_10_13_x86_64.whl", hash = "sha256:476887be10e2f59ff183c006af746cb6f1fd0eadcfd4ef49e605cbe2659920ee"},
{file = "multiprocess-0.70.16-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d951bed82c8f73929ac82c61f01a7b5ce8f3e5ef40f5b52553b4f547ce2b08ec"},
{file = "multiprocess-0.70.16-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:37b55f71c07e2d741374998c043b9520b626a8dddc8b3129222ca4f1a06ef67a"},
{file = "multiprocess-0.70.16-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:ba8c31889abf4511c7308a8c52bb4a30b9d590e7f58523302ba00237702ca054"},
{file = "multiprocess-0.70.16-pp39-pypy39_pp73-macosx_10_13_x86_64.whl", hash = "sha256:0dfd078c306e08d46d7a8d06fb120313d87aa43af60d66da43ffff40b44d2f41"},
{file = "multiprocess-0.70.16-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e7b9d0f307cd9bd50851afaac0dba2cb6c44449efff697df7c7645f7d3f2be3a"},
{file = "multiprocess-0.70.16-py310-none-any.whl", hash = "sha256:c4a9944c67bd49f823687463660a2d6daae94c289adff97e0f9d696ba6371d02"},
{file = "multiprocess-0.70.16-py311-none-any.whl", hash = "sha256:af4cabb0dac72abfb1e794fa7855c325fd2b55a10a44628a3c1ad3311c04127a"},
{file = "multiprocess-0.70.16-py312-none-any.whl", hash = "sha256:fc0544c531920dde3b00c29863377f87e1632601092ea2daca74e4beb40faa2e"},
{file = "multiprocess-0.70.16-py38-none-any.whl", hash = "sha256:a71d82033454891091a226dfc319d0cfa8019a4e888ef9ca910372a446de4435"},
{file = "multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3"},
{file = "multiprocess-0.70.16.tar.gz", hash = "sha256:161af703d4652a0e1410be6abccecde4a7ddffd19341be0a7011b94aeb171ac1"},
]
[package.dependencies]
dill = ">=0.3.8"
[[package]]
name = "multitasking"
version = "0.0.11"
@ -5955,6 +6058,23 @@ sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-d
test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"]
xml = ["lxml (>=4.9.2)"]
[[package]]
name = "pathos"
version = "0.3.2"
description = "parallel graph management and execution in heterogeneous computing"
optional = false
python-versions = ">=3.8"
files = [
{file = "pathos-0.3.2-py3-none-any.whl", hash = "sha256:d669275e6eb4b3fbcd2846d7a6d1bba315fe23add0c614445ba1408d8b38bafe"},
{file = "pathos-0.3.2.tar.gz", hash = "sha256:4f2a42bc1e10ccf0fe71961e7145fc1437018b6b21bd93b2446abc3983e49a7a"},
]
[package.dependencies]
dill = ">=0.3.8"
multiprocess = ">=0.70.16"
pox = ">=0.3.4"
ppft = ">=1.7.6.8"
[[package]]
name = "peewee"
version = "3.17.6"
@ -6196,6 +6316,31 @@ dev = ["black", "flake8", "flake8-print", "isort", "pre-commit"]
sentry = ["django", "sentry-sdk"]
test = ["coverage", "django", "flake8", "freezegun (==0.3.15)", "mock (>=2.0.0)", "pylint", "pytest", "pytest-timeout"]
[[package]]
name = "pox"
version = "0.3.4"
description = "utilities for filesystem exploration and automated builds"
optional = false
python-versions = ">=3.8"
files = [
{file = "pox-0.3.4-py3-none-any.whl", hash = "sha256:651b8ae8a7b341b7bfd267f67f63106daeb9805f1ac11f323d5280d2da93fdb6"},
{file = "pox-0.3.4.tar.gz", hash = "sha256:16e6eca84f1bec3828210b06b052adf04cf2ab20c22fd6fbef5f78320c9a6fed"},
]
[[package]]
name = "ppft"
version = "1.7.6.8"
description = "distributed and parallel Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "ppft-1.7.6.8-py3-none-any.whl", hash = "sha256:de2dd4b1b080923dd9627fbdea52649fd741c752fce4f3cf37e26f785df23d9b"},
{file = "ppft-1.7.6.8.tar.gz", hash = "sha256:76a429a7d7b74c4d743f6dba8351e58d62b6432ed65df9fe204790160dab996d"},
]
[package.extras]
dill = ["dill (>=0.3.8)"]
[[package]]
name = "primp"
version = "0.6.1"
@ -8004,6 +8149,84 @@ tensorflow = ["safetensors[numpy]", "tensorflow (>=2.11.0)"]
testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "hypothesis (>=6.70.2)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "safetensors[numpy]", "setuptools-rust (>=1.5.2)"]
torch = ["safetensors[numpy]", "torch (>=1.10)"]
[[package]]
name = "sagemaker"
version = "2.231.0"
description = "Open source library for training and deploying models on Amazon SageMaker."
optional = false
python-versions = ">=3.8"
files = [
{file = "sagemaker-2.231.0-py3-none-any.whl", hash = "sha256:5b6d84484a58c6ac8b22af42c6c5e0ea3c5f42d719345fe6aafba42f93635000"},
{file = "sagemaker-2.231.0.tar.gz", hash = "sha256:d49ee9c35725832dd9810708938af723201b831e82924a3a6ac1c4260a3d8239"},
]
[package.dependencies]
attrs = ">=23.1.0,<24"
boto3 = ">=1.34.142,<2.0"
cloudpickle = "2.2.1"
docker = "*"
google-pasta = "*"
importlib-metadata = ">=1.4.0,<7.0"
jsonschema = "*"
numpy = ">=1.9.0,<2.0"
packaging = ">=20.0"
pandas = "*"
pathos = "*"
platformdirs = "*"
protobuf = ">=3.12,<5.0"
psutil = "*"
pyyaml = ">=6.0,<7.0"
requests = "*"
sagemaker-core = ">=1.0.0,<2.0.0"
schema = "*"
smdebug-rulesconfig = "1.0.1"
tblib = ">=1.7.0,<4"
tqdm = "*"
urllib3 = ">=1.26.8,<3.0.0"
[package.extras]
all = ["accelerate (>=0.24.1,<=0.27.0)", "docker (>=5.0.2,<8.0.0)", "fastapi (>=0.111.0)", "nest-asyncio", "pyspark (==3.3.1)", "pyyaml (>=5.4.1,<7)", "sagemaker-feature-store-pyspark-3-3", "sagemaker-schema-inference-artifacts (>=0.0.5)", "scipy (==1.10.1)", "urllib3 (>=1.26.8,<3.0.0)", "uvicorn (>=0.30.1)"]
feature-processor = ["pyspark (==3.3.1)", "sagemaker-feature-store-pyspark-3-3"]
huggingface = ["accelerate (>=0.24.1,<=0.27.0)", "fastapi (>=0.111.0)", "nest-asyncio", "sagemaker-schema-inference-artifacts (>=0.0.5)", "uvicorn (>=0.30.1)"]
local = ["docker (>=5.0.2,<8.0.0)", "pyyaml (>=5.4.1,<7)", "urllib3 (>=1.26.8,<3.0.0)"]
scipy = ["scipy (==1.10.1)"]
test = ["accelerate (>=0.24.1,<=0.27.0)", "apache-airflow (==2.9.3)", "apache-airflow-providers-amazon (==7.2.1)", "attrs (>=23.1.0,<24)", "awslogs (==0.14.0)", "black (==24.3.0)", "build[virtualenv] (==1.2.1)", "cloudpickle (==2.2.1)", "contextlib2 (==21.6.0)", "coverage (>=5.2,<6.2)", "docker (>=5.0.2,<8.0.0)", "fabric (==2.6.0)", "fastapi (>=0.111.0)", "flake8 (==4.0.1)", "huggingface-hub (>=0.23.4)", "jinja2 (==3.1.4)", "mlflow (>=2.12.2,<2.13)", "mock (==4.0.3)", "nbformat (>=5.9,<6)", "nest-asyncio", "numpy (>=1.24.0)", "onnx (>=1.15.0)", "pandas (>=1.3.5,<1.5)", "pillow (>=10.0.1,<=11)", "pyspark (==3.3.1)", "pytest (==6.2.5)", "pytest-cov (==3.0.0)", "pytest-rerunfailures (==10.2)", "pytest-timeout (==2.1.0)", "pytest-xdist (==2.4.0)", "pyvis (==0.2.1)", "pyyaml (==6.0)", "pyyaml (>=5.4.1,<7)", "requests (==2.32.2)", "sagemaker-experiments (==0.1.35)", "sagemaker-feature-store-pyspark-3-3", "sagemaker-schema-inference-artifacts (>=0.0.5)", "schema (==0.7.5)", "scikit-learn (==1.3.0)", "scipy (==1.10.1)", "stopit (==1.1.2)", "tensorflow (>=2.1,<=2.16)", "tox (==3.24.5)", "tritonclient[http] (<2.37.0)", "urllib3 (>=1.26.8,<3.0.0)", "uvicorn (>=0.30.1)", "xgboost (>=1.6.2,<=1.7.6)"]
[[package]]
name = "sagemaker-core"
version = "1.0.2"
description = "An python package for sagemaker core functionalities"
optional = false
python-versions = ">=3.8"
files = [
{file = "sagemaker_core-1.0.2-py3-none-any.whl", hash = "sha256:ce8d38a4a32efa83e4bc037a8befc7e29f87cd3eaf99acc4472b607f75a0f45a"},
{file = "sagemaker_core-1.0.2.tar.gz", hash = "sha256:8fb942aac5e7ed928dab512ffe6facf8c6bdd4595df63c59c0bd0795ea434f8d"},
]
[package.dependencies]
boto3 = ">=1.34.0,<2.0.0"
importlib-metadata = ">=1.4.0,<7.0"
jsonschema = "<5.0.0"
mock = ">4.0,<5.0"
platformdirs = ">=4.0.0,<5.0.0"
pydantic = ">=1.7.0,<3.0.0"
PyYAML = ">=6.0,<7.0"
rich = ">=13.0.0,<14.0.0"
[package.extras]
codegen = ["black (>=24.3.0,<25.0.0)", "pandas (>=2.0.0,<3.0.0)", "pylint (>=3.0.0,<4.0.0)", "pytest (>=8.0.0,<9.0.0)"]
[[package]]
name = "schema"
version = "0.7.7"
description = "Simple data validation library"
optional = false
python-versions = "*"
files = [
{file = "schema-0.7.7-py2.py3-none-any.whl", hash = "sha256:5d976a5b50f36e74e2157b47097b60002bd4d42e65425fcc9c9befadb4255dde"},
{file = "schema-0.7.7.tar.gz", hash = "sha256:7da553abd2958a19dc2547c388cde53398b39196175a9be59ea1caf5ab0a1807"},
]
[[package]]
name = "scikit-learn"
version = "1.5.1"
@ -8276,6 +8499,17 @@ files = [
{file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
]
[[package]]
name = "smdebug-rulesconfig"
version = "1.0.1"
description = "SMDebug RulesConfig"
optional = false
python-versions = ">=2.7"
files = [
{file = "smdebug_rulesconfig-1.0.1-py2.py3-none-any.whl", hash = "sha256:104da3e6931ecf879dfc687ca4bbb3bee5ea2bc27f4478e9dbb3ee3655f1ae61"},
{file = "smdebug_rulesconfig-1.0.1.tar.gz", hash = "sha256:7a19e6eb2e6bcfefbc07e4a86ef7a88f32495001a038bf28c7d8e77ab793fcd6"},
]
[[package]]
name = "sniffio"
version = "1.3.1"
@ -8473,6 +8707,17 @@ files = [
[package.extras]
widechars = ["wcwidth"]
[[package]]
name = "tblib"
version = "3.0.0"
description = "Traceback serialization library."
optional = false
python-versions = ">=3.8"
files = [
{file = "tblib-3.0.0-py3-none-any.whl", hash = "sha256:80a6c77e59b55e83911e1e607c649836a69c103963c5f28a46cbeef44acf8129"},
{file = "tblib-3.0.0.tar.gz", hash = "sha256:93622790a0a29e04f0346458face1e144dc4d32f493714c6c3dff82a4adb77e6"},
]
[[package]]
name = "tcvectordb"
version = "1.3.2"
@ -10126,4 +10371,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.13"
content-hash = "78c7db0bf525a72f4c8309e3363304d1a0a23cf0a6836bfb974a38a1fcde9158"
content-hash = "c3c637d643f4dcb3e35d0e7f2a3a4fbaf2a730512a4ca31adce5884c94f07f57"

View File

@ -113,6 +113,7 @@ azure-identity = "1.16.1"
azure-storage-blob = "12.13.0"
beautifulsoup4 = "4.12.2"
boto3 = "1.34.148"
sagemaker = "2.231.0"
bs4 = "~0.0.1"
cachetools = "~5.3.0"
celery = "~5.3.6"