From 954580a4af67edd31e6832ed0e5aed4487127ba1 Mon Sep 17 00:00:00 2001 From: ybalbert001 <120714773+ybalbert001@users.noreply.github.com> Date: Mon, 9 Sep 2024 10:34:11 +0800 Subject: [PATCH] feat: support more model types and builtin tools on aws/sagemaker (#8061) Co-authored-by: Yuanbo Li --- .../model_providers/sagemaker/llm/llm.py | 317 +++++++++++++++--- .../sagemaker/rerank/rerank.py | 2 +- .../model_providers/sagemaker/sagemaker.py | 28 +- .../model_providers/sagemaker/sagemaker.yaml | 78 ++++- .../sagemaker/speech2text/__init__.py | 0 .../sagemaker/speech2text/speech2text.py | 142 ++++++++ .../model_providers/sagemaker/tts/__init__.py | 0 .../model_providers/sagemaker/tts/tts.py | 287 ++++++++++++++++ .../builtin/aws/tools/apply_guardrail.py | 9 +- .../builtin/aws/tools/apply_guardrail.yaml | 11 + .../builtin/aws/tools/lambda_yaml_to_json.py | 71 ++++ .../aws/tools/lambda_yaml_to_json.yaml | 53 +++ .../aws/tools/sagemaker_text_rerank.py | 6 +- .../builtin/aws/tools/sagemaker_tts.py | 95 ++++++ .../builtin/aws/tools/sagemaker_tts.yaml | 149 ++++++++ api/poetry.lock | 275 ++++++++++++++- api/pyproject.toml | 1 + 17 files changed, 1452 insertions(+), 72 deletions(-) create mode 100644 api/core/model_runtime/model_providers/sagemaker/speech2text/__init__.py create mode 100644 api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py create mode 100644 api/core/model_runtime/model_providers/sagemaker/tts/__init__.py create mode 100644 api/core/model_runtime/model_providers/sagemaker/tts/tts.py create mode 100644 api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py create mode 100644 api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.yaml create mode 100644 api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py create mode 100644 api/core/tools/provider/builtin/aws/tools/sagemaker_tts.yaml diff --git a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py index f8e7757a96..3d4c5825af 100644 --- a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py +++ b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py @@ -1,17 +1,36 @@ import json import logging -from collections.abc import Generator -from typing import Any, Optional, Union +import re +from collections.abc import Generator, Iterator +from typing import Any, Optional, Union, cast +# from openai.types.chat import ChatCompletion, ChatCompletionChunk import boto3 +from sagemaker import Predictor, serializers +from sagemaker.session import Session -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, + ImagePromptMessageContent, PromptMessage, + PromptMessageContent, + PromptMessageContentType, PromptMessageTool, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + I18nObject, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, ) -from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType from core.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, @@ -25,12 +44,140 @@ from core.model_runtime.model_providers.__base.large_language_model import Large logger = logging.getLogger(__name__) +def inference(predictor, messages:list[dict[str,Any]], params:dict[str,Any], stop:list, stream=False): + """ + params: + predictor : Sagemaker Predictor + messages (List[Dict[str,Any]]): message list。 + messages = [ + {"role": "system", "content":"please answer in Chinese"}, + {"role": "user", "content": "who are you? what are you doing?"}, + ] + params (Dict[str,Any]): model parameters for LLM。 + stream (bool): False by default。 + + response: + result of inference if stream is False + Iterator of Chunks if stream is True + """ + payload = { + "model" : params.get('model_name'), + "stop" : stop, + "messages": messages, + "stream" : stream, + "max_tokens" : params.get('max_new_tokens', params.get('max_tokens', 2048)), + "temperature" : params.get('temperature', 0.1), + "top_p" : params.get('top_p', 0.9), + } + + if not stream: + response = predictor.predict(payload) + return response + else: + response_stream = predictor.predict_stream(payload) + return response_stream class SageMakerLargeLanguageModel(LargeLanguageModel): """ Model class for Cohere large language model. """ sagemaker_client: Any = None + sagemaker_sess : Any = None + predictor : Any = None + + def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: bytes) -> LLMResult: + """ + handle normal chat generate response + """ + resp_obj = json.loads(resp.decode('utf-8')) + resp_str = resp_obj.get('choices')[0].get('message').get('content') + + if len(resp_str) == 0: + raise InvokeServerUnavailableError("Empty response") + + assistant_prompt_message = AssistantPromptMessage( + content=resp_str, + tool_calls=[] + ) + + prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) + completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools) + + usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens) + + response = LLMResult( + model=model, + prompt_messages=prompt_messages, + system_fingerprint=None, + usage=usage, + message=assistant_prompt_message, + ) + + return response + + def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Iterator[bytes]) -> Generator: + """ + handle stream chat generate response + """ + full_response = '' + buffer = "" + for chunk_bytes in resp: + buffer += chunk_bytes.decode('utf-8') + last_idx = 0 + for match in re.finditer(r'^data:\s*(.+?)(\n\n)', buffer): + try: + data = json.loads(match.group(1).strip()) + last_idx = match.span()[1] + + if "content" in data["choices"][0]["delta"]: + chunk_content = data["choices"][0]["delta"]["content"] + assistant_prompt_message = AssistantPromptMessage( + content=chunk_content, + tool_calls=[] + ) + + if data["choices"][0]['finish_reason'] is not None: + temp_assistant_prompt_message = AssistantPromptMessage( + content=full_response, + tool_calls=[] + ) + prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) + completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) + usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + system_fingerprint=None, + delta=LLMResultChunkDelta( + index=0, + message=assistant_prompt_message, + finish_reason=data["choices"][0]['finish_reason'], + usage=usage + ), + ) + else: + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + system_fingerprint=None, + delta=LLMResultChunkDelta( + index=0, + message=assistant_prompt_message + ), + ) + + full_response += chunk_content + except (json.JSONDecodeError, KeyError, IndexError) as e: + logger.info("json parse exception, content: {}".format(match.group(1).strip())) + pass + + buffer = buffer[last_idx:] def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, @@ -50,9 +197,6 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): :param user: unique user id :return: full response or stream response chunk generator result """ - # get model mode - model_mode = self.get_model_mode(model, credentials) - if not self.sagemaker_client: access_key = credentials.get('access_key') secret_key = credentials.get('secret_key') @@ -68,37 +212,132 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): else: self.sagemaker_client = boto3.client("sagemaker-runtime") + sagemaker_session = Session(sagemaker_runtime_client=self.sagemaker_client) + self.predictor = Predictor( + endpoint_name=credentials.get('sagemaker_endpoint'), + sagemaker_session=sagemaker_session, + serializer=serializers.JSONSerializer(), + ) - sagemaker_endpoint = credentials.get('sagemaker_endpoint') - response_model = self.sagemaker_client.invoke_endpoint( - EndpointName=sagemaker_endpoint, - Body=json.dumps( - { - "inputs": prompt_messages[0].content, - "parameters": { "stop" : stop}, - "history" : [] - } - ), - ContentType="application/json", - ) - assistant_text = response_model['Body'].read().decode('utf8') + messages:list[dict[str,Any]] = [ {"role": p.role.value, "content": p.content} for p in prompt_messages ] + response = inference(predictor=self.predictor, messages=messages, params=model_parameters, stop=stop, stream=stream) - # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_text - ) + if stream: + if tools and len(tools) > 0: + raise InvokeBadRequestError(f"{model}'s tool calls does not support stream mode") - usage = self._calc_response_usage(model, credentials, 0, 0) + return self._handle_chat_stream_response(model=model, credentials=credentials, + prompt_messages=prompt_messages, + tools=tools, resp=response) + return self._handle_chat_generate_response(model=model, credentials=credentials, + prompt_messages=prompt_messages, + tools=tools, resp=response) - response = LLMResult( - model=model, - prompt_messages=prompt_messages, - message=assistant_prompt_message, - usage=usage - ) + def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: + """ + Convert PromptMessage to dict for OpenAI Compatibility API + """ + if isinstance(message, UserPromptMessage): + message = cast(UserPromptMessage, message) + if isinstance(message.content, str): + message_dict = {"role": "user", "content": message.content} + else: + sub_messages = [] + for message_content in message.content: + if message_content.type == PromptMessageContentType.TEXT: + message_content = cast(PromptMessageContent, message_content) + sub_message_dict = { + "type": "text", + "text": message_content.data + } + sub_messages.append(sub_message_dict) + elif message_content.type == PromptMessageContentType.IMAGE: + message_content = cast(ImagePromptMessageContent, message_content) + sub_message_dict = { + "type": "image_url", + "image_url": { + "url": message_content.data, + "detail": message_content.detail.value + } + } + sub_messages.append(sub_message_dict) + message_dict = {"role": "user", "content": sub_messages} + elif isinstance(message, AssistantPromptMessage): + message = cast(AssistantPromptMessage, message) + message_dict = {"role": "assistant", "content": message.content} + if message.tool_calls and len(message.tool_calls) > 0: + message_dict["function_call"] = { + "name": message.tool_calls[0].function.name, + "arguments": message.tool_calls[0].function.arguments + } + elif isinstance(message, SystemPromptMessage): + message = cast(SystemPromptMessage, message) + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, ToolPromptMessage): + message = cast(ToolPromptMessage, message) + message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content} + else: + raise ValueError(f"Unknown message type {type(message)}") - return response + return message_dict + + def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool], + is_completion_model: bool = False) -> int: + def tokens(text: str): + return self._get_num_tokens_by_gpt2(text) + + if is_completion_model: + return sum(tokens(str(message.content)) for message in messages) + + tokens_per_message = 3 + tokens_per_name = 1 + + num_tokens = 0 + messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages] + for message in messages_dict: + num_tokens += tokens_per_message + for key, value in message.items(): + if isinstance(value, list): + text = '' + for item in value: + if isinstance(item, dict) and item['type'] == 'text': + text += item['text'] + + value = text + + if key == "tool_calls": + for tool_call in value: + for t_key, t_value in tool_call.items(): + num_tokens += tokens(t_key) + if t_key == "function": + for f_key, f_value in t_value.items(): + num_tokens += tokens(f_key) + num_tokens += tokens(f_value) + else: + num_tokens += tokens(t_key) + num_tokens += tokens(t_value) + if key == "function_call": + for t_key, t_value in value.items(): + num_tokens += tokens(t_key) + if t_key == "function": + for f_key, f_value in t_value.items(): + num_tokens += tokens(f_key) + num_tokens += tokens(f_value) + else: + num_tokens += tokens(t_key) + num_tokens += tokens(t_value) + else: + num_tokens += tokens(str(value)) + + if key == "name": + num_tokens += tokens_per_name + num_tokens += 3 + + if tools: + num_tokens += self._num_tokens_for_tools(tools) + + return num_tokens def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: @@ -112,10 +351,8 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): :return: """ # get model mode - model_mode = self.get_model_mode(model) - try: - return 0 + return self._num_tokens_from_messages(prompt_messages, tools) except Exception as e: raise self._transform_invoke_error(e) @@ -129,7 +366,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): """ try: # get model mode - model_mode = self.get_model_mode(model) + pass except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -200,13 +437,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): ) ] - completion_type = LLMMode.value_of(credentials["mode"]) - - if completion_type == LLMMode.CHAT: - print(f"completion_type : {LLMMode.CHAT.value}") - - if completion_type == LLMMode.COMPLETION: - print(f"completion_type : {LLMMode.COMPLETION.value}") + completion_type = LLMMode.value_of(credentials["mode"]).value features = [] diff --git a/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py index 0b06f54ef1..6b7cfc210b 100644 --- a/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) class SageMakerRerankModel(RerankModel): """ - Model class for Cohere rerank model. + Model class for SageMaker rerank model. """ sagemaker_client: Any = None diff --git a/api/core/model_runtime/model_providers/sagemaker/sagemaker.py b/api/core/model_runtime/model_providers/sagemaker/sagemaker.py index 02d05f406c..6f3e02489f 100644 --- a/api/core/model_runtime/model_providers/sagemaker/sagemaker.py +++ b/api/core/model_runtime/model_providers/sagemaker/sagemaker.py @@ -1,10 +1,11 @@ import logging +import uuid +from typing import IO, Any from core.model_runtime.model_providers.__base.model_provider import ModelProvider logger = logging.getLogger(__name__) - class SageMakerProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ @@ -15,3 +16,28 @@ class SageMakerProvider(ModelProvider): :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. """ pass + +def buffer_to_s3(s3_client:Any, file: IO[bytes], bucket:str, s3_prefix:str) -> str: + ''' + return s3_uri of this file + ''' + s3_key = f'{s3_prefix}{uuid.uuid4()}.mp3' + s3_client.put_object( + Body=file.read(), + Bucket=bucket, + Key=s3_key, + ContentType='audio/mp3' + ) + return s3_key + +def generate_presigned_url(s3_client:Any, file: IO[bytes], bucket_name:str, s3_prefix:str, expiration=600) -> str: + object_key = buffer_to_s3(s3_client, file, bucket_name, s3_prefix) + try: + response = s3_client.generate_presigned_url('get_object', + Params={'Bucket': bucket_name, 'Key': object_key}, + ExpiresIn=expiration) + except Exception as e: + print(f"Error generating presigned URL: {e}") + return None + + return response \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml b/api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml index 290cb0edab..87cd50f50c 100644 --- a/api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml +++ b/api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml @@ -21,6 +21,8 @@ supported_model_types: - llm - text-embedding - rerank + - speech2text + - tts configurate_methods: - customizable-model model_credential_schema: @@ -45,14 +47,10 @@ model_credential_schema: zh_Hans: 选择对话类型 en_US: Select completion mode options: - - value: completion - label: - en_US: Completion - zh_Hans: 补全 - value: chat label: en_US: Chat - zh_Hans: 对话 + zh_Hans: Chat - variable: sagemaker_endpoint label: en_US: sagemaker endpoint @@ -61,6 +59,76 @@ model_credential_schema: placeholder: zh_Hans: 请输出你的Sagemaker推理端点 en_US: Enter your Sagemaker Inference endpoint + - variable: audio_s3_cache_bucket + show_on: + - variable: __model_type + value: speech2text + label: + zh_Hans: 音频缓存桶(s3 bucket) + en_US: audio cache bucket(s3 bucket) + type: text-input + required: true + placeholder: + zh_Hans: sagemaker-us-east-1-******207838 + en_US: sagemaker-us-east-1-*******7838 + - variable: audio_model_type + show_on: + - variable: __model_type + value: tts + label: + en_US: Audio model type + type: select + required: true + placeholder: + zh_Hans: 语音模型类型 + en_US: Audio model type + options: + - value: PresetVoice + label: + en_US: preset voice + zh_Hans: 内置音色 + - value: CloneVoice + label: + en_US: clone voice + zh_Hans: 克隆音色 + - value: CloneVoice_CrossLingual + label: + en_US: crosslingual clone voice + zh_Hans: 跨语种克隆音色 + - value: InstructVoice + label: + en_US: Instruct voice + zh_Hans: 文字指令音色 + - variable: prompt_audio + show_on: + - variable: __model_type + value: tts + label: + en_US: Mock Audio Source + type: text-input + required: false + placeholder: + zh_Hans: 被模仿的音色音频 + en_US: source audio to be mocked + - variable: prompt_text + show_on: + - variable: __model_type + value: tts + label: + en_US: Prompt Audio Text + type: text-input + required: false + placeholder: + zh_Hans: 模仿音色的对应文本 + en_US: text for the mocked source audio + - variable: instruct_text + show_on: + - variable: __model_type + value: tts + label: + en_US: instruct text for speaker + type: text-input + required: false - variable: aws_access_key_id required: false label: diff --git a/api/core/model_runtime/model_providers/sagemaker/speech2text/__init__.py b/api/core/model_runtime/model_providers/sagemaker/speech2text/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py b/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py new file mode 100644 index 0000000000..8b57f182fe --- /dev/null +++ b/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py @@ -0,0 +1,142 @@ +import json +import logging +from typing import IO, Any, Optional + +import boto3 + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel +from core.model_runtime.model_providers.sagemaker.sagemaker import generate_presigned_url + +logger = logging.getLogger(__name__) + +class SageMakerSpeech2TextModel(Speech2TextModel): + """ + Model class for Xinference speech to text model. + """ + sagemaker_client: Any = None + s3_client : Any = None + + def _invoke(self, model: str, credentials: dict, + file: IO[bytes], user: Optional[str] = None) \ + -> str: + """ + Invoke speech2text model + + :param model: model name + :param credentials: model credentials + :param file: audio file + :param user: unique user id + :return: text for given audio file + """ + asr_text = None + + try: + if not self.sagemaker_client: + access_key = credentials.get('aws_access_key_id') + secret_key = credentials.get('aws_secret_access_key') + aws_region = credentials.get('aws_region') + if aws_region: + if access_key and secret_key: + self.sagemaker_client = boto3.client("sagemaker-runtime", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=aws_region) + self.s3_client = boto3.client("s3", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) + self.s3_client = boto3.client("s3", region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime") + self.s3_client = boto3.client("s3") + + s3_prefix='dify/speech2text/' + sagemaker_endpoint = credentials.get('sagemaker_endpoint') + bucket = credentials.get('audio_s3_cache_bucket') + + s3_presign_url = generate_presigned_url(self.s3_client, file, bucket, s3_prefix) + payload = { + "audio_s3_presign_uri" : s3_presign_url + } + + response_model = self.sagemaker_client.invoke_endpoint( + EndpointName=sagemaker_endpoint, + Body=json.dumps(payload), + ContentType="application/json" + ) + json_str = response_model['Body'].read().decode('utf8') + json_obj = json.loads(json_str) + asr_text = json_obj['text'] + except Exception as e: + logger.exception(f'Exception {e}, line : {line}') + + return asr_text + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + pass + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [ + InvokeConnectionError + ], + InvokeServerUnavailableError: [ + InvokeServerUnavailableError + ], + InvokeRateLimitError: [ + InvokeRateLimitError + ], + InvokeAuthorizationError: [ + InvokeAuthorizationError + ], + InvokeBadRequestError: [ + InvokeBadRequestError, + KeyError, + ValueError + ] + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + entity = AIModelEntity( + model=model, + label=I18nObject( + en_US=model + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.SPEECH2TEXT, + model_properties={ }, + parameter_rules=[] + ) + + return entity diff --git a/api/core/model_runtime/model_providers/sagemaker/tts/__init__.py b/api/core/model_runtime/model_providers/sagemaker/tts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/sagemaker/tts/tts.py b/api/core/model_runtime/model_providers/sagemaker/tts/tts.py new file mode 100644 index 0000000000..315b31fd85 --- /dev/null +++ b/api/core/model_runtime/model_providers/sagemaker/tts/tts.py @@ -0,0 +1,287 @@ +import concurrent.futures +import copy +import json +import logging +from enum import Enum +from typing import Any, Optional + +import boto3 +import requests + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.model_providers.__base.tts_model import TTSModel + +logger = logging.getLogger(__name__) + +class TTSModelType(Enum): + PresetVoice = "PresetVoice" + CloneVoice = "CloneVoice" + CloneVoice_CrossLingual = "CloneVoice_CrossLingual" + InstructVoice = "InstructVoice" + +class SageMakerText2SpeechModel(TTSModel): + + sagemaker_client: Any = None + s3_client : Any = None + comprehend_client : Any = None + + def __init__(self): + # preset voices, need support custom voice + self.model_voices = { + '__default': { + 'all': [ + {'name': 'Default', 'value': 'default'}, + ] + }, + 'CosyVoice': { + 'zh-Hans': [ + {'name': '中文男', 'value': '中文男'}, + {'name': '中文女', 'value': '中文女'}, + {'name': '粤语女', 'value': '粤语女'}, + ], + 'zh-Hant': [ + {'name': '中文男', 'value': '中文男'}, + {'name': '中文女', 'value': '中文女'}, + {'name': '粤语女', 'value': '粤语女'}, + ], + 'en-US': [ + {'name': '英文男', 'value': '英文男'}, + {'name': '英文女', 'value': '英文女'}, + ], + 'ja-JP': [ + {'name': '日语男', 'value': '日语男'}, + ], + 'ko-KR': [ + {'name': '韩语女', 'value': '韩语女'}, + ] + } + } + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + pass + + def _detect_lang_code(self, content:str, map_dict:dict=None): + map_dict = { + "zh" : "<|zh|>", + "en" : "<|en|>", + "ja" : "<|jp|>", + "zh-TW" : "<|yue|>", + "ko" : "<|ko|>" + } + + response = self.comprehend_client.detect_dominant_language(Text=content) + language_code = response['Languages'][0]['LanguageCode'] + + return map_dict.get(language_code, '<|zh|>') + + def _build_tts_payload(self, model_type:str, content_text:str, model_role:str, prompt_text:str, prompt_audio:str, instruct_text:str): + if model_type == TTSModelType.PresetVoice.value and model_role: + return { "tts_text" : content_text, "role" : model_role } + if model_type == TTSModelType.CloneVoice.value and prompt_text and prompt_audio: + return { "tts_text" : content_text, "prompt_text": prompt_text, "prompt_audio" : prompt_audio } + if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio: + lang_tag = self._detect_lang_code(content_text) + return { "tts_text" : f"{content_text}", "prompt_audio" : prompt_audio, "lang_tag" : lang_tag } + if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role: + return { "tts_text" : content_text, "role" : model_role, "instruct_text" : instruct_text } + + raise RuntimeError(f"Invalid params for {model_type}") + + def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, + user: Optional[str] = None): + """ + _invoke text2speech model + + :param model: model name + :param tenant_id: user tenant id + :param credentials: model credentials + :param voice: model timbre + :param content_text: text content to be translated + :param user: unique user id + :return: text translated to audio file + """ + if not self.sagemaker_client: + access_key = credentials.get('aws_access_key_id') + secret_key = credentials.get('aws_secret_access_key') + aws_region = credentials.get('aws_region') + if aws_region: + if access_key and secret_key: + self.sagemaker_client = boto3.client("sagemaker-runtime", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=aws_region) + self.s3_client = boto3.client("s3", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=aws_region) + self.comprehend_client = boto3.client('comprehend', + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) + self.s3_client = boto3.client("s3", region_name=aws_region) + self.comprehend_client = boto3.client('comprehend', region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime") + self.s3_client = boto3.client("s3") + self.comprehend_client = boto3.client('comprehend') + + model_type = credentials.get('audio_model_type', 'PresetVoice') + prompt_text = credentials.get('prompt_text') + prompt_audio = credentials.get('prompt_audio') + instruct_text = credentials.get('instruct_text') + sagemaker_endpoint = credentials.get('sagemaker_endpoint') + payload = self._build_tts_payload( + model_type, + content_text, + voice, + prompt_text, + prompt_audio, + instruct_text + ) + + return self._tts_invoke_streaming(model_type, payload, sagemaker_endpoint) + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + entity = AIModelEntity( + model=model, + label=I18nObject( + en_US=model + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.TTS, + model_properties={}, + parameter_rules=[] + ) + + return entity + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [ + InvokeConnectionError + ], + InvokeServerUnavailableError: [ + InvokeServerUnavailableError + ], + InvokeRateLimitError: [ + InvokeRateLimitError + ], + InvokeAuthorizationError: [ + InvokeAuthorizationError + ], + InvokeBadRequestError: [ + InvokeBadRequestError, + KeyError, + ValueError + ] + } + + def _get_model_default_voice(self, model: str, credentials: dict) -> any: + return "" + + def _get_model_word_limit(self, model: str, credentials: dict) -> int: + return 15 + + def _get_model_audio_type(self, model: str, credentials: dict) -> str: + return "mp3" + + def _get_model_workers_limit(self, model: str, credentials: dict) -> int: + return 5 + + def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list: + audio_model_name = 'CosyVoice' + for key, voices in self.model_voices.items(): + if key in audio_model_name: + if language and language in voices: + return voices[language] + elif 'all' in voices: + return voices['all'] + + return self.model_voices['__default']['all'] + + def _invoke_sagemaker(self, payload:dict, endpoint:str): + response_model = self.sagemaker_client.invoke_endpoint( + EndpointName=endpoint, + Body=json.dumps(payload), + ContentType="application/json", + ) + json_str = response_model['Body'].read().decode('utf8') + json_obj = json.loads(json_str) + return json_obj + + def _tts_invoke_streaming(self, model_type:str, payload:dict, sagemaker_endpoint:str) -> any: + """ + _tts_invoke_streaming text2speech model + + :param model: model name + :param credentials: model credentials + :param content_text: text content to be translated + :param voice: model timbre + :return: text translated to audio file + """ + try: + lang_tag = '' + if model_type == TTSModelType.CloneVoice_CrossLingual.value: + lang_tag = payload.pop('lang_tag') + + word_limit = self._get_model_word_limit(model='', credentials={}) + content_text = payload.get("tts_text") + if len(content_text) > word_limit: + split_sentences = self._split_text_into_sentences(content_text, max_length=word_limit) + sentences = [ f"{lang_tag}{s}" for s in split_sentences if len(s) ] + len_sent = len(sentences) + executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(4, len_sent)) + payloads = [ copy.deepcopy(payload) for i in range(len_sent) ] + for idx in range(len_sent): + payloads[idx]["tts_text"] = sentences[idx] + + futures = [ executor.submit( + self._invoke_sagemaker, + payload=payload, + endpoint=sagemaker_endpoint, + ) + for payload in payloads] + + for index, future in enumerate(futures): + resp = future.result() + audio_bytes = requests.get(resp.get('s3_presign_url')).content + for i in range(0, len(audio_bytes), 1024): + yield audio_bytes[i:i + 1024] + else: + resp = self._invoke_sagemaker(payload, sagemaker_endpoint) + audio_bytes = requests.get(resp.get('s3_presign_url')).content + + for i in range(0, len(audio_bytes), 1024): + yield audio_bytes[i:i + 1024] + except Exception as ex: + raise InvokeBadRequestError(str(ex)) diff --git a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py index 9c006733bd..06fcf8a453 100644 --- a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py +++ b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py @@ -3,6 +3,7 @@ import logging from typing import Any, Union import boto3 +from botocore.exceptions import BotoCoreError from pydantic import BaseModel, Field from core.tools.entities.tool_entities import ToolInvokeMessage @@ -16,7 +17,7 @@ class GuardrailParameters(BaseModel): guardrail_version: str = Field(..., description="The version of the guardrail") source: str = Field(..., description="The source of the content") text: str = Field(..., description="The text to apply the guardrail to") - aws_region: str = Field(default="us-east-1", description="AWS region for the Bedrock client") + aws_region: str = Field(..., description="AWS region for the Bedrock client") class ApplyGuardrailTool(BuiltinTool): def _invoke(self, @@ -40,6 +41,8 @@ class ApplyGuardrailTool(BuiltinTool): source=params.source, content=[{"text": {"text": params.text}}] ) + + logger.info(f"Raw response from AWS: {json.dumps(response, indent=2)}") # Check for empty response if not response: @@ -69,7 +72,7 @@ class ApplyGuardrailTool(BuiltinTool): return self.create_text_message(text=result) - except boto3.exceptions.BotoCoreError as e: + except BotoCoreError as e: error_message = f'AWS service error: {str(e)}' logger.error(error_message, exc_info=True) return self.create_text_message(text=error_message) @@ -80,4 +83,4 @@ class ApplyGuardrailTool(BuiltinTool): except Exception as e: error_message = f'An unexpected error occurred: {str(e)}' logger.error(error_message, exc_info=True) - return self.create_text_message(text=error_message) + return self.create_text_message(text=error_message) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.yaml b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.yaml index 2b7c8abb44..66044e4ea8 100644 --- a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.yaml +++ b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.yaml @@ -54,3 +54,14 @@ parameters: zh_Hans: 用于请求护栏审查的内容,可以是用户输入或 LLM 输出。 llm_description: The content used for requesting guardrail review, which can be either user input or LLM output. form: llm + - name: aws_region + type: string + required: true + label: + en_US: AWS Region + zh_Hans: AWS 区域 + human_description: + en_US: Please enter the AWS region for the Bedrock client, for example 'us-east-1'. + zh_Hans: 请输入 Bedrock 客户端的 AWS 区域,例如 'us-east-1'。 + llm_description: Please enter the AWS region for the Bedrock client, for example 'us-east-1'. + form: form diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py b/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py new file mode 100644 index 0000000000..bb7f6840b8 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py @@ -0,0 +1,71 @@ +import json +import logging +from typing import Any, Union + +import boto3 + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +console_handler = logging.StreamHandler() +logger.addHandler(console_handler) + + +class LambdaYamlToJsonTool(BuiltinTool): + lambda_client: Any = None + + def _invoke_lambda(self, lambda_name: str, yaml_content: str) -> str: + msg = { + "body": yaml_content + } + logger.info(json.dumps(msg)) + + invoke_response = self.lambda_client.invoke(FunctionName=lambda_name, + InvocationType='RequestResponse', + Payload=json.dumps(msg)) + response_body = invoke_response['Payload'] + + response_str = response_body.read().decode("utf-8") + resp_json = json.loads(response_str) + + logger.info(resp_json) + if resp_json['statusCode'] != 200: + raise Exception(f"Invalid status code: {response_str}") + + return resp_json['body'] + + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + try: + if not self.lambda_client: + aws_region = tool_parameters.get('aws_region') # todo: move aws_region out, and update client region + if aws_region: + self.lambda_client = boto3.client("lambda", region_name=aws_region) + else: + self.lambda_client = boto3.client("lambda") + + yaml_content = tool_parameters.get('yaml_content', '') + if not yaml_content: + return self.create_text_message('Please input yaml_content') + + lambda_name = tool_parameters.get('lambda_name', '') + if not lambda_name: + return self.create_text_message('Please input lambda_name') + logger.debug(f'{json.dumps(tool_parameters, indent=2, ensure_ascii=False)}') + + result = self._invoke_lambda(lambda_name, yaml_content) + logger.debug(result) + + return self.create_text_message(result) + except Exception as e: + return self.create_text_message(f'Exception: {str(e)}') + + console_handler.flush() \ No newline at end of file diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.yaml b/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.yaml new file mode 100644 index 0000000000..919c285348 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.yaml @@ -0,0 +1,53 @@ +identity: + name: lambda_yaml_to_json + author: AWS + label: + en_US: LambdaYamlToJson + zh_Hans: LambdaYamlToJson + pt_BR: LambdaYamlToJson + icon: icon.svg +description: + human: + en_US: A tool to convert yaml to json using AWS Lambda. + zh_Hans: 将 YAML 转为 JSON 的工具(通过AWS Lambda)。 + pt_BR: A tool to convert yaml to json using AWS Lambda. + llm: A tool to convert yaml to json. +parameters: + - name: yaml_content + type: string + required: true + label: + en_US: YAML content to convert for + zh_Hans: YAML 内容 + pt_BR: YAML content to convert for + human_description: + en_US: YAML content to convert for + zh_Hans: YAML 内容 + pt_BR: YAML content to convert for + llm_description: YAML content to convert for + form: llm + - name: aws_region + type: string + required: false + label: + en_US: region of lambda + zh_Hans: Lambda 所在的region + pt_BR: region of lambda + human_description: + en_US: region of lambda + zh_Hans: Lambda 所在的region + pt_BR: region of lambda + llm_description: region of lambda + form: form + - name: lambda_name + type: string + required: false + label: + en_US: name of lambda + zh_Hans: Lambda 名称 + pt_BR: name of lambda + human_description: + en_US: name of lambda + zh_Hans: Lambda 名称 + pt_BR: name of lambda + form: form diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py index d4bc446e5b..2b3a3eaad6 100644 --- a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py @@ -78,9 +78,7 @@ class SageMakerReRankTool(BuiltinTool): sorted_candidate_docs = sorted(candidate_docs, key=lambda x: x['score'], reverse=True) line = 9 - results_str = json.dumps(sorted_candidate_docs[:self.topk], ensure_ascii=False) - return self.create_text_message(text=results_str) + return [ self.create_json_message(res) for res in sorted_candidate_docs[:self.topk] ] except Exception as e: - return self.create_text_message(f'Exception {str(e)}, line : {line}') - \ No newline at end of file + return self.create_text_message(f'Exception {str(e)}, line : {line}') \ No newline at end of file diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py new file mode 100644 index 0000000000..a100e62230 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py @@ -0,0 +1,95 @@ +import json +from enum import Enum +from typing import Any, Union + +import boto3 + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class TTSModelType(Enum): + PresetVoice = "PresetVoice" + CloneVoice = "CloneVoice" + CloneVoice_CrossLingual = "CloneVoice_CrossLingual" + InstructVoice = "InstructVoice" + +class SageMakerTTSTool(BuiltinTool): + sagemaker_client: Any = None + sagemaker_endpoint:str = None + s3_client : Any = None + comprehend_client : Any = None + + def _detect_lang_code(self, content:str, map_dict:dict=None): + map_dict = { + "zh" : "<|zh|>", + "en" : "<|en|>", + "ja" : "<|jp|>", + "zh-TW" : "<|yue|>", + "ko" : "<|ko|>" + } + + response = self.comprehend_client.detect_dominant_language(Text=content) + language_code = response['Languages'][0]['LanguageCode'] + return map_dict.get(language_code, '<|zh|>') + + def _build_tts_payload(self, model_type:str, content_text:str, model_role:str, prompt_text:str, prompt_audio:str, instruct_text:str): + if model_type == TTSModelType.PresetVoice.value and model_role: + return { "tts_text" : content_text, "role" : model_role } + if model_type == TTSModelType.CloneVoice.value and prompt_text and prompt_audio: + return { "tts_text" : content_text, "prompt_text": prompt_text, "prompt_audio" : prompt_audio } + if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio: + lang_tag = self._detect_lang_code(content_text) + return { "tts_text" : f"{content_text}", "prompt_audio" : prompt_audio, "lang_tag" : lang_tag } + if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role: + return { "tts_text" : content_text, "role" : model_role, "instruct_text" : instruct_text } + + raise RuntimeError(f"Invalid params for {model_type}") + + def _invoke_sagemaker(self, payload:dict, endpoint:str): + response_model = self.sagemaker_client.invoke_endpoint( + EndpointName=endpoint, + Body=json.dumps(payload), + ContentType="application/json", + ) + json_str = response_model['Body'].read().decode('utf8') + json_obj = json.loads(json_str) + return json_obj + + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + try: + if not self.sagemaker_client: + aws_region = tool_parameters.get('aws_region') + if aws_region: + self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) + self.s3_client = boto3.client("s3", region_name=aws_region) + self.comprehend_client = boto3.client('comprehend', region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime") + self.s3_client = boto3.client("s3") + self.comprehend_client = boto3.client('comprehend') + + if not self.sagemaker_endpoint: + self.sagemaker_endpoint = tool_parameters.get('sagemaker_endpoint') + + tts_text = tool_parameters.get('tts_text') + tts_infer_type = tool_parameters.get('tts_infer_type') + + voice = tool_parameters.get('voice') + mock_voice_audio = tool_parameters.get('mock_voice_audio') + mock_voice_text = tool_parameters.get('mock_voice_text') + voice_instruct_prompt = tool_parameters.get('voice_instruct_prompt') + payload = self._build_tts_payload(tts_infer_type, tts_text, voice, mock_voice_text, mock_voice_audio, voice_instruct_prompt) + + result = self._invoke_sagemaker(payload, self.sagemaker_endpoint) + + return self.create_text_message(text=result['s3_presign_url']) + + except Exception as e: + return self.create_text_message(f'Exception {str(e)}') \ No newline at end of file diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.yaml b/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.yaml new file mode 100644 index 0000000000..a6a61dd4aa --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.yaml @@ -0,0 +1,149 @@ +identity: + name: sagemaker_tts + author: AWS + label: + en_US: SagemakerTTS + zh_Hans: Sagemaker语音合成 + pt_BR: SagemakerTTS + icon: icon.svg +description: + human: + en_US: A tool for Speech synthesis - https://github.com/aws-samples/dify-aws-tool + zh_Hans: Sagemaker语音合成工具, 请参考 Github Repo - https://github.com/aws-samples/dify-aws-tool上的部署脚本 + pt_BR: A tool for Speech synthesis. + llm: A tool for Speech synthesis. You can find deploy notebook on Github Repo - https://github.com/aws-samples/dify-aws-tool +parameters: + - name: sagemaker_endpoint + type: string + required: true + label: + en_US: sagemaker endpoint for tts + zh_Hans: 语音生成的SageMaker端点 + pt_BR: sagemaker endpoint for tts + human_description: + en_US: sagemaker endpoint for tts + zh_Hans: 语音生成的SageMaker端点 + pt_BR: sagemaker endpoint for tts + llm_description: sagemaker endpoint for tts + form: form + - name: tts_text + type: string + required: true + label: + en_US: tts text + zh_Hans: 语音合成原文 + pt_BR: tts text + human_description: + en_US: tts text + zh_Hans: 语音合成原文 + pt_BR: tts text + llm_description: tts text + form: llm + - name: tts_infer_type + type: select + required: false + label: + en_US: tts infer type + zh_Hans: 合成方式 + pt_BR: tts infer type + human_description: + en_US: tts infer type + zh_Hans: 合成方式 + pt_BR: tts infer type + llm_description: tts infer type + options: + - value: PresetVoice + label: + en_US: preset voice + zh_Hans: 预置音色 + - value: CloneVoice + label: + en_US: clone voice + zh_Hans: 克隆音色 + - value: CloneVoice_CrossLingual + label: + en_US: clone crossLingual voice + zh_Hans: 克隆音色(跨语言) + - value: InstructVoice + label: + en_US: instruct voice + zh_Hans: 指令音色 + form: form + - name: voice + type: select + required: false + label: + en_US: preset voice + zh_Hans: 预置音色 + pt_BR: preset voice + human_description: + en_US: preset voice + zh_Hans: 预置音色 + pt_BR: preset voice + llm_description: preset voice + options: + - value: 中文男 + label: + en_US: zh-cn male + zh_Hans: 中文男 + - value: 中文女 + label: + en_US: zh-cn female + zh_Hans: 中文女 + - value: 粤语女 + label: + en_US: zh-TW female + zh_Hans: 粤语女 + form: form + - name: mock_voice_audio + type: string + required: false + label: + en_US: clone voice link + zh_Hans: 克隆音频链接 + pt_BR: clone voice link + human_description: + en_US: clone voice link + zh_Hans: 克隆音频链接 + pt_BR: clone voice link + llm_description: clone voice link + form: llm + - name: mock_voice_text + type: string + required: false + label: + en_US: text of clone voice + zh_Hans: 克隆音频对应文本 + pt_BR: text of clone voice + human_description: + en_US: text of clone voice + zh_Hans: 克隆音频对应文本 + pt_BR: text of clone voice + llm_description: text of clone voice + form: llm + - name: voice_instruct_prompt + type: string + required: false + label: + en_US: instruct prompt for voice + zh_Hans: 音色指令文本 + pt_BR: instruct prompt for voice + human_description: + en_US: instruct prompt for voice + zh_Hans: 音色指令文本 + pt_BR: instruct prompt for voice + llm_description: instruct prompt for voice + form: llm + - name: aws_region + type: string + required: false + label: + en_US: region of sagemaker endpoint + zh_Hans: SageMaker 端点所在的region + pt_BR: region of sagemaker endpoint + human_description: + en_US: region of sagemaker endpoint + zh_Hans: SageMaker 端点所在的region + pt_BR: region of sagemaker endpoint + llm_description: region of sagemaker endpoint + form: form diff --git a/api/poetry.lock b/api/poetry.lock index 93dd4c71bb..21708795b1 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -520,22 +520,22 @@ files = [ [[package]] name = "attrs" -version = "24.2.0" +version = "23.2.0" description = "Classes Without Boilerplate" optional = false python-versions = ">=3.7" files = [ - {file = "attrs-24.2.0-py3-none-any.whl", hash = "sha256:81921eb96de3191c8258c199618104dd27ac608d9366f5e35d011eae1867ede2"}, - {file = "attrs-24.2.0.tar.gz", hash = "sha256:5cfb1b9148b5b086569baec03f20d7b6bf3bcacc9a42bebf87ffaaca362f6346"}, + {file = "attrs-23.2.0-py3-none-any.whl", hash = "sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1"}, + {file = "attrs-23.2.0.tar.gz", hash = "sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30"}, ] [package.extras] -benchmark = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -cov = ["cloudpickle", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -dev = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier (<24.7)"] -tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"] +cov = ["attrs[tests]", "coverage[toml] (>=5.3)"] +dev = ["attrs[tests]", "pre-commit"] +docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope-interface"] +tests = ["attrs[tests-no-zope]", "zope-interface"] +tests-mypy = ["mypy (>=1.6)", "pytest-mypy-plugins"] +tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"] [[package]] name = "authlib" @@ -1719,6 +1719,17 @@ lz4 = ["clickhouse-cityhash (>=1.0.2.1)", "lz4", "lz4 (<=3.0.1)"] numpy = ["numpy (>=1.12.0)", "pandas (>=0.24.0)"] zstd = ["clickhouse-cityhash (>=1.0.2.1)", "zstd"] +[[package]] +name = "cloudpickle" +version = "2.2.1" +description = "Extended pickling support for Python objects" +optional = false +python-versions = ">=3.6" +files = [ + {file = "cloudpickle-2.2.1-py3-none-any.whl", hash = "sha256:61f594d1f4c295fa5cd9014ceb3a1fc4a70b0de1164b94fbc2d854ccba056f9f"}, + {file = "cloudpickle-2.2.1.tar.gz", hash = "sha256:d89684b8de9e34a2a43b3460fbca07d09d6e25ce858df4d5a44240403b6178f5"}, +] + [[package]] name = "cloudscraper" version = "1.2.71" @@ -2151,6 +2162,21 @@ wrapt = ">=1.10,<2" [package.extras] dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] +[[package]] +name = "dill" +version = "0.3.8" +description = "serialize all of Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7"}, + {file = "dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca"}, +] + +[package.extras] +graph = ["objgraph (>=1.7.2)"] +profile = ["gprof2dot (>=2022.7.29)"] + [[package]] name = "distro" version = "1.9.0" @@ -2162,6 +2188,28 @@ files = [ {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, ] +[[package]] +name = "docker" +version = "7.1.0" +description = "A Python library for the Docker Engine API." +optional = false +python-versions = ">=3.8" +files = [ + {file = "docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0"}, + {file = "docker-7.1.0.tar.gz", hash = "sha256:ad8c70e6e3f8926cb8a92619b832b4ea5299e2831c14284663184e200546fa6c"}, +] + +[package.dependencies] +pywin32 = {version = ">=304", markers = "sys_platform == \"win32\""} +requests = ">=2.26.0" +urllib3 = ">=1.26.0" + +[package.extras] +dev = ["coverage (==7.2.7)", "pytest (==7.4.2)", "pytest-cov (==4.1.0)", "pytest-timeout (==2.1.0)", "ruff (==0.1.8)"] +docs = ["myst-parser (==0.18.0)", "sphinx (==5.1.1)"] +ssh = ["paramiko (>=2.4.3)"] +websockets = ["websocket-client (>=1.3.0)"] + [[package]] name = "docstring-parser" version = "0.16" @@ -3309,6 +3357,21 @@ typing-extensions = "*" [package.extras] dev = ["Pillow", "absl-py", "black", "ipython", "nose2", "pandas", "pytype", "pyyaml"] +[[package]] +name = "google-pasta" +version = "0.2.0" +description = "pasta is an AST-based Python refactoring library" +optional = false +python-versions = "*" +files = [ + {file = "google-pasta-0.2.0.tar.gz", hash = "sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e"}, + {file = "google_pasta-0.2.0-py2-none-any.whl", hash = "sha256:4612951da876b1a10fe3960d7226f0c7682cf901e16ac06e473b267a5afa8954"}, + {file = "google_pasta-0.2.0-py3-none-any.whl", hash = "sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed"}, +] + +[package.dependencies] +six = "*" + [[package]] name = "google-resumable-media" version = "2.7.2" @@ -3930,22 +3993,22 @@ files = [ [[package]] name = "importlib-metadata" -version = "8.4.0" +version = "6.11.0" description = "Read metadata from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_metadata-8.4.0-py3-none-any.whl", hash = "sha256:66f342cc6ac9818fc6ff340576acd24d65ba0b3efabb2b4ac08b598965a4a2f1"}, - {file = "importlib_metadata-8.4.0.tar.gz", hash = "sha256:9a547d3bc3608b025f93d403fdd1aae741c24fbb8314df4b155675742ce303c5"}, + {file = "importlib_metadata-6.11.0-py3-none-any.whl", hash = "sha256:f0afba6205ad8f8947c7d338b5342d5db2afbfd82f9cbef7879a9539cc12eb9b"}, + {file = "importlib_metadata-6.11.0.tar.gz", hash = "sha256:1231cf92d825c9e03cfc4da076a16de6422c863558229ea0b22b675657463443"}, ] [package.dependencies] zipp = ">=0.5" [package.extras] -doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] perf = ["ipython"] -test = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] +testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"] [[package]] name = "importlib-resources" @@ -4929,6 +4992,22 @@ files = [ [package.extras] test = ["mypy (>=1.0)", "pytest (>=7.0.0)"] +[[package]] +name = "mock" +version = "4.0.3" +description = "Rolling backport of unittest.mock for all Pythons" +optional = false +python-versions = ">=3.6" +files = [ + {file = "mock-4.0.3-py3-none-any.whl", hash = "sha256:122fcb64ee37cfad5b3f48d7a7d51875d7031aaf3d8be7c42e2bee25044eee62"}, + {file = "mock-4.0.3.tar.gz", hash = "sha256:7d3fbbde18228f4ff2f1f119a45cdffa458b4c0dee32eb4d2bb2f82554bac7bc"}, +] + +[package.extras] +build = ["blurb", "twine", "wheel"] +docs = ["sphinx"] +test = ["pytest (<5.4)", "pytest-cov"] + [[package]] name = "monotonic" version = "1.6" @@ -5128,6 +5207,30 @@ files = [ {file = "multidict-6.0.5.tar.gz", hash = "sha256:f7e301075edaf50500f0b341543c41194d8df3ae5caf4702f2095f3ca73dd8da"}, ] +[[package]] +name = "multiprocess" +version = "0.70.16" +description = "better multiprocessing and multithreading in Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "multiprocess-0.70.16-pp310-pypy310_pp73-macosx_10_13_x86_64.whl", hash = "sha256:476887be10e2f59ff183c006af746cb6f1fd0eadcfd4ef49e605cbe2659920ee"}, + {file = "multiprocess-0.70.16-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d951bed82c8f73929ac82c61f01a7b5ce8f3e5ef40f5b52553b4f547ce2b08ec"}, + {file = "multiprocess-0.70.16-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:37b55f71c07e2d741374998c043b9520b626a8dddc8b3129222ca4f1a06ef67a"}, + {file = "multiprocess-0.70.16-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:ba8c31889abf4511c7308a8c52bb4a30b9d590e7f58523302ba00237702ca054"}, + {file = "multiprocess-0.70.16-pp39-pypy39_pp73-macosx_10_13_x86_64.whl", hash = "sha256:0dfd078c306e08d46d7a8d06fb120313d87aa43af60d66da43ffff40b44d2f41"}, + {file = "multiprocess-0.70.16-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e7b9d0f307cd9bd50851afaac0dba2cb6c44449efff697df7c7645f7d3f2be3a"}, + {file = "multiprocess-0.70.16-py310-none-any.whl", hash = "sha256:c4a9944c67bd49f823687463660a2d6daae94c289adff97e0f9d696ba6371d02"}, + {file = "multiprocess-0.70.16-py311-none-any.whl", hash = "sha256:af4cabb0dac72abfb1e794fa7855c325fd2b55a10a44628a3c1ad3311c04127a"}, + {file = "multiprocess-0.70.16-py312-none-any.whl", hash = "sha256:fc0544c531920dde3b00c29863377f87e1632601092ea2daca74e4beb40faa2e"}, + {file = "multiprocess-0.70.16-py38-none-any.whl", hash = "sha256:a71d82033454891091a226dfc319d0cfa8019a4e888ef9ca910372a446de4435"}, + {file = "multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3"}, + {file = "multiprocess-0.70.16.tar.gz", hash = "sha256:161af703d4652a0e1410be6abccecde4a7ddffd19341be0a7011b94aeb171ac1"}, +] + +[package.dependencies] +dill = ">=0.3.8" + [[package]] name = "multitasking" version = "0.0.11" @@ -5955,6 +6058,23 @@ sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-d test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] xml = ["lxml (>=4.9.2)"] +[[package]] +name = "pathos" +version = "0.3.2" +description = "parallel graph management and execution in heterogeneous computing" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pathos-0.3.2-py3-none-any.whl", hash = "sha256:d669275e6eb4b3fbcd2846d7a6d1bba315fe23add0c614445ba1408d8b38bafe"}, + {file = "pathos-0.3.2.tar.gz", hash = "sha256:4f2a42bc1e10ccf0fe71961e7145fc1437018b6b21bd93b2446abc3983e49a7a"}, +] + +[package.dependencies] +dill = ">=0.3.8" +multiprocess = ">=0.70.16" +pox = ">=0.3.4" +ppft = ">=1.7.6.8" + [[package]] name = "peewee" version = "3.17.6" @@ -6196,6 +6316,31 @@ dev = ["black", "flake8", "flake8-print", "isort", "pre-commit"] sentry = ["django", "sentry-sdk"] test = ["coverage", "django", "flake8", "freezegun (==0.3.15)", "mock (>=2.0.0)", "pylint", "pytest", "pytest-timeout"] +[[package]] +name = "pox" +version = "0.3.4" +description = "utilities for filesystem exploration and automated builds" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pox-0.3.4-py3-none-any.whl", hash = "sha256:651b8ae8a7b341b7bfd267f67f63106daeb9805f1ac11f323d5280d2da93fdb6"}, + {file = "pox-0.3.4.tar.gz", hash = "sha256:16e6eca84f1bec3828210b06b052adf04cf2ab20c22fd6fbef5f78320c9a6fed"}, +] + +[[package]] +name = "ppft" +version = "1.7.6.8" +description = "distributed and parallel Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "ppft-1.7.6.8-py3-none-any.whl", hash = "sha256:de2dd4b1b080923dd9627fbdea52649fd741c752fce4f3cf37e26f785df23d9b"}, + {file = "ppft-1.7.6.8.tar.gz", hash = "sha256:76a429a7d7b74c4d743f6dba8351e58d62b6432ed65df9fe204790160dab996d"}, +] + +[package.extras] +dill = ["dill (>=0.3.8)"] + [[package]] name = "primp" version = "0.6.1" @@ -8004,6 +8149,84 @@ tensorflow = ["safetensors[numpy]", "tensorflow (>=2.11.0)"] testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "hypothesis (>=6.70.2)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "safetensors[numpy]", "setuptools-rust (>=1.5.2)"] torch = ["safetensors[numpy]", "torch (>=1.10)"] +[[package]] +name = "sagemaker" +version = "2.231.0" +description = "Open source library for training and deploying models on Amazon SageMaker." +optional = false +python-versions = ">=3.8" +files = [ + {file = "sagemaker-2.231.0-py3-none-any.whl", hash = "sha256:5b6d84484a58c6ac8b22af42c6c5e0ea3c5f42d719345fe6aafba42f93635000"}, + {file = "sagemaker-2.231.0.tar.gz", hash = "sha256:d49ee9c35725832dd9810708938af723201b831e82924a3a6ac1c4260a3d8239"}, +] + +[package.dependencies] +attrs = ">=23.1.0,<24" +boto3 = ">=1.34.142,<2.0" +cloudpickle = "2.2.1" +docker = "*" +google-pasta = "*" +importlib-metadata = ">=1.4.0,<7.0" +jsonschema = "*" +numpy = ">=1.9.0,<2.0" +packaging = ">=20.0" +pandas = "*" +pathos = "*" +platformdirs = "*" +protobuf = ">=3.12,<5.0" +psutil = "*" +pyyaml = ">=6.0,<7.0" +requests = "*" +sagemaker-core = ">=1.0.0,<2.0.0" +schema = "*" +smdebug-rulesconfig = "1.0.1" +tblib = ">=1.7.0,<4" +tqdm = "*" +urllib3 = ">=1.26.8,<3.0.0" + +[package.extras] +all = ["accelerate (>=0.24.1,<=0.27.0)", "docker (>=5.0.2,<8.0.0)", "fastapi (>=0.111.0)", "nest-asyncio", "pyspark (==3.3.1)", "pyyaml (>=5.4.1,<7)", "sagemaker-feature-store-pyspark-3-3", "sagemaker-schema-inference-artifacts (>=0.0.5)", "scipy (==1.10.1)", "urllib3 (>=1.26.8,<3.0.0)", "uvicorn (>=0.30.1)"] +feature-processor = ["pyspark (==3.3.1)", "sagemaker-feature-store-pyspark-3-3"] +huggingface = ["accelerate (>=0.24.1,<=0.27.0)", "fastapi (>=0.111.0)", "nest-asyncio", "sagemaker-schema-inference-artifacts (>=0.0.5)", "uvicorn (>=0.30.1)"] +local = ["docker (>=5.0.2,<8.0.0)", "pyyaml (>=5.4.1,<7)", "urllib3 (>=1.26.8,<3.0.0)"] +scipy = ["scipy (==1.10.1)"] +test = ["accelerate (>=0.24.1,<=0.27.0)", "apache-airflow (==2.9.3)", "apache-airflow-providers-amazon (==7.2.1)", "attrs (>=23.1.0,<24)", "awslogs (==0.14.0)", "black (==24.3.0)", "build[virtualenv] (==1.2.1)", "cloudpickle (==2.2.1)", "contextlib2 (==21.6.0)", "coverage (>=5.2,<6.2)", "docker (>=5.0.2,<8.0.0)", "fabric (==2.6.0)", "fastapi (>=0.111.0)", "flake8 (==4.0.1)", "huggingface-hub (>=0.23.4)", "jinja2 (==3.1.4)", "mlflow (>=2.12.2,<2.13)", "mock (==4.0.3)", "nbformat (>=5.9,<6)", "nest-asyncio", "numpy (>=1.24.0)", "onnx (>=1.15.0)", "pandas (>=1.3.5,<1.5)", "pillow (>=10.0.1,<=11)", "pyspark (==3.3.1)", "pytest (==6.2.5)", "pytest-cov (==3.0.0)", "pytest-rerunfailures (==10.2)", "pytest-timeout (==2.1.0)", "pytest-xdist (==2.4.0)", "pyvis (==0.2.1)", "pyyaml (==6.0)", "pyyaml (>=5.4.1,<7)", "requests (==2.32.2)", "sagemaker-experiments (==0.1.35)", "sagemaker-feature-store-pyspark-3-3", "sagemaker-schema-inference-artifacts (>=0.0.5)", "schema (==0.7.5)", "scikit-learn (==1.3.0)", "scipy (==1.10.1)", "stopit (==1.1.2)", "tensorflow (>=2.1,<=2.16)", "tox (==3.24.5)", "tritonclient[http] (<2.37.0)", "urllib3 (>=1.26.8,<3.0.0)", "uvicorn (>=0.30.1)", "xgboost (>=1.6.2,<=1.7.6)"] + +[[package]] +name = "sagemaker-core" +version = "1.0.2" +description = "An python package for sagemaker core functionalities" +optional = false +python-versions = ">=3.8" +files = [ + {file = "sagemaker_core-1.0.2-py3-none-any.whl", hash = "sha256:ce8d38a4a32efa83e4bc037a8befc7e29f87cd3eaf99acc4472b607f75a0f45a"}, + {file = "sagemaker_core-1.0.2.tar.gz", hash = "sha256:8fb942aac5e7ed928dab512ffe6facf8c6bdd4595df63c59c0bd0795ea434f8d"}, +] + +[package.dependencies] +boto3 = ">=1.34.0,<2.0.0" +importlib-metadata = ">=1.4.0,<7.0" +jsonschema = "<5.0.0" +mock = ">4.0,<5.0" +platformdirs = ">=4.0.0,<5.0.0" +pydantic = ">=1.7.0,<3.0.0" +PyYAML = ">=6.0,<7.0" +rich = ">=13.0.0,<14.0.0" + +[package.extras] +codegen = ["black (>=24.3.0,<25.0.0)", "pandas (>=2.0.0,<3.0.0)", "pylint (>=3.0.0,<4.0.0)", "pytest (>=8.0.0,<9.0.0)"] + +[[package]] +name = "schema" +version = "0.7.7" +description = "Simple data validation library" +optional = false +python-versions = "*" +files = [ + {file = "schema-0.7.7-py2.py3-none-any.whl", hash = "sha256:5d976a5b50f36e74e2157b47097b60002bd4d42e65425fcc9c9befadb4255dde"}, + {file = "schema-0.7.7.tar.gz", hash = "sha256:7da553abd2958a19dc2547c388cde53398b39196175a9be59ea1caf5ab0a1807"}, +] + [[package]] name = "scikit-learn" version = "1.5.1" @@ -8276,6 +8499,17 @@ files = [ {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, ] +[[package]] +name = "smdebug-rulesconfig" +version = "1.0.1" +description = "SMDebug RulesConfig" +optional = false +python-versions = ">=2.7" +files = [ + {file = "smdebug_rulesconfig-1.0.1-py2.py3-none-any.whl", hash = "sha256:104da3e6931ecf879dfc687ca4bbb3bee5ea2bc27f4478e9dbb3ee3655f1ae61"}, + {file = "smdebug_rulesconfig-1.0.1.tar.gz", hash = "sha256:7a19e6eb2e6bcfefbc07e4a86ef7a88f32495001a038bf28c7d8e77ab793fcd6"}, +] + [[package]] name = "sniffio" version = "1.3.1" @@ -8473,6 +8707,17 @@ files = [ [package.extras] widechars = ["wcwidth"] +[[package]] +name = "tblib" +version = "3.0.0" +description = "Traceback serialization library." +optional = false +python-versions = ">=3.8" +files = [ + {file = "tblib-3.0.0-py3-none-any.whl", hash = "sha256:80a6c77e59b55e83911e1e607c649836a69c103963c5f28a46cbeef44acf8129"}, + {file = "tblib-3.0.0.tar.gz", hash = "sha256:93622790a0a29e04f0346458face1e144dc4d32f493714c6c3dff82a4adb77e6"}, +] + [[package]] name = "tcvectordb" version = "1.3.2" @@ -10126,4 +10371,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "78c7db0bf525a72f4c8309e3363304d1a0a23cf0a6836bfb974a38a1fcde9158" +content-hash = "c3c637d643f4dcb3e35d0e7f2a3a4fbaf2a730512a4ca31adce5884c94f07f57" diff --git a/api/pyproject.toml b/api/pyproject.toml index 52e6147345..826b7debc4 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -113,6 +113,7 @@ azure-identity = "1.16.1" azure-storage-blob = "12.13.0" beautifulsoup4 = "4.12.2" boto3 = "1.34.148" +sagemaker = "2.231.0" bs4 = "~0.0.1" cachetools = "~5.3.0" celery = "~5.3.6"