diff --git a/api/core/model_runtime/model_providers/bedrock/__init__.py b/api/core/model_runtime/model_providers/bedrock/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/bedrock/_assets/icon_l_en.svg b/api/core/model_runtime/model_providers/bedrock/_assets/icon_l_en.svg new file mode 100644 index 0000000000..667db50800 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/_assets/icon_l_en.svg @@ -0,0 +1,14 @@ + + + + + + + + + + + + + + diff --git a/api/core/model_runtime/model_providers/bedrock/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/bedrock/_assets/icon_s_en.svg new file mode 100644 index 0000000000..6a0235af92 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/_assets/icon_s_en.svg @@ -0,0 +1,15 @@ + + + + + + + + + + + + + + + diff --git a/api/core/model_runtime/model_providers/bedrock/bedrock.py b/api/core/model_runtime/model_providers/bedrock/bedrock.py new file mode 100644 index 0000000000..aa322fc664 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/bedrock.py @@ -0,0 +1,30 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + +class BedrockProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + model_instance = self.get_model_instance(ModelType.LLM) + + # Use `gemini-pro` model for validate, + model_instance.validate_credentials( + model='amazon.titan-text-lite-v1', + credentials=credentials + ) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + raise ex diff --git a/api/core/model_runtime/model_providers/bedrock/bedrock.yaml b/api/core/model_runtime/model_providers/bedrock/bedrock.yaml new file mode 100644 index 0000000000..1458b830cd --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/bedrock.yaml @@ -0,0 +1,71 @@ +provider: bedrock +label: + en_US: AWS +description: + en_US: AWS Bedrock's models. +icon_small: + en_US: icon_s_en.svg +icon_large: + en_US: icon_l_en.svg +background: "#FCFDFF" +help: + title: + en_US: Get your Access Key and Secret Access Key from AWS Console + url: + en_US: https://console.aws.amazon.com/ +supported_model_types: + - llm +configurate_methods: + - predefined-model +provider_credential_schema: + credential_form_schemas: + - variable: aws_access_key_id + required: true + label: + en_US: Access Key + zh_Hans: Access Key + type: secret-input + placeholder: + en_US: Enter your Access Key + zh_Hans: 在此输入您的 Access Key + - variable: aws_secret_access_key + required: true + label: + en_US: Secret Access Key + zh_Hans: Secret Access Key + type: secret-input + placeholder: + en_US: Enter your Secret Access Key + zh_Hans: 在此输入您的 Secret Access Key + - variable: aws_region + required: true + label: + en_US: AWS Region + zh_Hans: AWS 地区 + type: select + default: us-east-1 + options: + - value: us-east-1 + label: + en_US: US East (N. Virginia) + zh_Hans: US East (N. Virginia) + - value: us-west-2 + label: + en_US: US West (Oregon) + zh_Hans: US West (Oregon) + - value: ap-southeast-1 + label: + en_US: Asia Pacific (Singapore) + zh_Hans: Asia Pacific (Singapore) + - value: ap-northeast-1 + label: + en_US: Asia Pacific (Tokyo) + zh_Hans: Asia Pacific (Tokyo) + - value: eu-central-1 + label: + en_US: Europe (Frankfurt) + zh_Hans: Europe (Frankfurt) + - value: us-gov-west-1 + label: + en_US: AWS GovCloud (US-West) + zh_Hans: AWS GovCloud (US-West) diff --git a/api/core/model_runtime/model_providers/bedrock/llm/__init__.py b/api/core/model_runtime/model_providers/bedrock/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml b/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml new file mode 100644 index 0000000000..c4be732f2e --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml @@ -0,0 +1,10 @@ +- amazon.titan-text-express-v1 +- amazon.titan-text-lite-v1 +- anthropic.claude-instant-v1 +- anthropic.claude-v1 +- anthropic.claude-v2 +- anthropic.claude-v2:1 +- cohere.command-light-text-v14 +- cohere.command-text-v14 +- meta.llama2-13b-chat-v1 +- meta.llama2-70b-chat-v1 diff --git a/api/core/model_runtime/model_providers/bedrock/llm/ai21.j2-mid-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/ai21.j2-mid-v1.yaml new file mode 100644 index 0000000000..65dad02969 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/ai21.j2-mid-v1.yaml @@ -0,0 +1,47 @@ +model: ai21.j2-mid-v1 +label: + en_US: J2 Mid V1 +model_type: llm +model_properties: + mode: completion + context_size: 8191 +parameter_rules: + - name: temperature + use_template: temperature + - name: topP + use_template: top_p + - name: maxTokens + use_template: max_tokens + required: true + default: 2048 + min: 1 + max: 2048 + - name: count_penalty + label: + en_US: Count Penalty + required: false + type: float + default: 0 + min: 0 + max: 1 + - name: presence_penalty + label: + en_US: Presence Penalty + required: false + type: float + default: 0 + min: 0 + max: 5 + - name: frequency_penalty + label: + en_US: Frequency Penalty + required: false + type: float + default: 0 + min: 0 + max: 500 +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/ai21.j2-ultra-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/ai21.j2-ultra-v1.yaml new file mode 100644 index 0000000000..b72f8064bd --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/ai21.j2-ultra-v1.yaml @@ -0,0 +1,47 @@ +model: ai21.j2-ultra-v1 +label: + en_US: J2 Ultra V1 +model_type: llm +model_properties: + mode: completion + context_size: 8191 +parameter_rules: + - name: temperature + use_template: temperature + - name: topP + use_template: top_p + - name: maxTokens + use_template: max_tokens + required: true + default: 2048 + min: 1 + max: 2048 + - name: count_penalty + label: + en_US: Count Penalty + required: false + type: float + default: 0 + min: 0 + max: 1 + - name: presence_penalty + label: + en_US: Presence Penalty + required: false + type: float + default: 0 + min: 0 + max: 5 + - name: frequency_penalty + label: + en_US: Frequency Penalty + required: false + type: float + default: 0 + min: 0 + max: 500 +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/amazon.titan-text-express-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/amazon.titan-text-express-v1.yaml new file mode 100644 index 0000000000..64f992b913 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/amazon.titan-text-express-v1.yaml @@ -0,0 +1,25 @@ +model: amazon.titan-text-express-v1 +label: + en_US: Titan Text G1 - Express +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + - name: topP + use_template: top_p + - name: maxTokenCount + use_template: max_tokens + required: true + default: 2048 + min: 1 + max: 8000 +pricing: + input: '0.0008' + output: '0.0016' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/amazon.titan-text-lite-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/amazon.titan-text-lite-v1.yaml new file mode 100644 index 0000000000..69b298b571 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/amazon.titan-text-lite-v1.yaml @@ -0,0 +1,25 @@ +model: amazon.titan-text-lite-v1 +label: + en_US: Titan Text G1 - Lite +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + - name: topP + use_template: top_p + - name: maxTokenCount + use_template: max_tokens + required: true + default: 2048 + min: 1 + max: 2048 +pricing: + input: '0.0003' + output: '0.0004' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-instant-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-instant-v1.yaml new file mode 100644 index 0000000000..94b741f50d --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-instant-v1.yaml @@ -0,0 +1,35 @@ +model: anthropic.claude-instant-v1 +label: + en_US: Claude Instant V1 +model_type: llm +model_properties: + mode: chat + context_size: 100000 +parameter_rules: + - name: temperature + use_template: temperature + - name: topP + use_template: top_p + - name: topK + label: + zh_Hans: 取样数量 + en_US: Top K + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + default: 250 + min: 0 + max: 500 + - name: max_tokens_to_sample + use_template: max_tokens + required: true + default: 4096 + min: 1 + max: 4096 +pricing: + input: '0.0008' + output: '0.0024' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v1.yaml new file mode 100644 index 0000000000..1c85923335 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v1.yaml @@ -0,0 +1,35 @@ +model: anthropic.claude-v1 +label: + en_US: Claude V1 +model_type: llm +model_properties: + mode: chat + context_size: 100000 +parameter_rules: + - name: temperature + use_template: temperature + - name: topP + use_template: top_p + - name: topK + label: + zh_Hans: 取样数量 + en_US: Top K + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + default: 250 + min: 0 + max: 500 + - name: max_tokens_to_sample + use_template: max_tokens + required: true + default: 4096 + min: 1 + max: 4096 +pricing: + input: '0.008' + output: '0.024' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.yaml new file mode 100644 index 0000000000..d12e7fce90 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.yaml @@ -0,0 +1,35 @@ +model: anthropic.claude-v2 +label: + en_US: Claude V2 +model_type: llm +model_properties: + mode: chat + context_size: 100000 +parameter_rules: + - name: temperature + use_template: temperature + - name: topP + use_template: top_p + - name: topK + label: + zh_Hans: 取样数量 + en_US: Top K + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + default: 250 + min: 0 + max: 500 + - name: max_tokens_to_sample + use_template: max_tokens + required: true + default: 4096 + min: 1 + max: 4096 +pricing: + input: '0.008' + output: '0.024' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2:1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2:1.yaml new file mode 100644 index 0000000000..c5502daf4a --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2:1.yaml @@ -0,0 +1,35 @@ +model: anthropic.claude-v2:1 +label: + en_US: Claude V2.1 +model_type: llm +model_properties: + mode: chat + context_size: 200000 +parameter_rules: + - name: temperature + use_template: temperature + - name: topP + use_template: top_p + - name: topK + label: + zh_Hans: 取样数量 + en_US: Top K + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + default: 250 + min: 0 + max: 500 + - name: max_tokens_to_sample + use_template: max_tokens + required: true + default: 4096 + min: 1 + max: 4096 +pricing: + input: '0.008' + output: '0.024' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/cohere.command-light-text-v14.yaml b/api/core/model_runtime/model_providers/bedrock/llm/cohere.command-light-text-v14.yaml new file mode 100644 index 0000000000..1fad910058 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/cohere.command-light-text-v14.yaml @@ -0,0 +1,35 @@ +model: cohere.command-light-text-v14 +label: + en_US: Command Light Text V14 +model_type: llm +model_properties: + mode: completion + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + - name: p + use_template: top_p + - name: k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + min: 0 + max: 500 + default: 0 + - name: max_tokens_to_sample + use_template: max_tokens + required: true + default: 4096 + min: 1 + max: 4096 +pricing: + input: '0.0003' + output: '0.0006' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/cohere.command-text-v14.yaml b/api/core/model_runtime/model_providers/bedrock/llm/cohere.command-text-v14.yaml new file mode 100644 index 0000000000..ed775afd7a --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/cohere.command-text-v14.yaml @@ -0,0 +1,32 @@ +model: cohere.command-text-v14 +label: + en_US: Command Text V14 +model_type: llm +model_properties: + mode: completion + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_tokens_to_sample + use_template: max_tokens + required: true + default: 4096 + min: 1 + max: 4096 +pricing: + input: '0.0015' + output: '0.0020' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py new file mode 100644 index 0000000000..269e814d99 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -0,0 +1,486 @@ +import logging +from typing import Generator, List, Optional, Union + +import boto3 +from botocore.exceptions import ClientError, EndpointConnectionError, NoRegionError, ServiceNotInRegionError, UnknownServiceError +from botocore.config import Config +import json + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta + +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, + PromptMessageTool, SystemPromptMessage, UserPromptMessage) +from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, + InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + +logger = logging.getLogger(__name__) + +class BedrockLargeLanguageModel(LargeLanguageModel): + + def _invoke(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + stream: bool = True, user: Optional[str] = None) \ + -> Union[LLMResult, Generator]: + """ + Invoke large language model + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param tools: tools for tool calling + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ + # invoke model + return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) + + def get_num_tokens(self, model: str, credentials: dict, messages: list[PromptMessage] | str, + tools: Optional[list[PromptMessageTool]] = None) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param messages: prompt messages or message string + :param tools: tools for tool calling + :return:md = genai.GenerativeModel(model) + """ + prefix = model.split('.')[0] + + if isinstance(messages, str): + prompt = messages + else: + prompt = self._convert_messages_to_prompt(messages, prefix) + + return self._get_num_tokens_by_gpt2(prompt) + + def _convert_messages_to_prompt(self, model_prefix: str, messages: list[PromptMessage]) -> str: + """ + Format a list of messages into a full prompt for the Google model + + :param messages: List of PromptMessage to combine. + :return: Combined string with necessary human_prompt and ai_prompt tags. + """ + messages = messages.copy() # don't mutate the original list + + text = "".join( + self._convert_one_message_to_text(message, model_prefix) + for message in messages + ) + + return text.rstrip() + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + + try: + ping_message = UserPromptMessage(content="ping") + self._generate(model=model, + credentials=credentials, + prompt_messages=[ping_message], + model_parameters={}, + stream=False) + + except ClientError as ex: + error_code = ex.response['Error']['Code'] + full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" + + raise CredentialsValidateFailedError(str(self._map_client_to_invoke_error(error_code, full_error_msg))) + + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str) -> str: + """ + Convert a single message to a string. + + :param message: PromptMessage to convert. + :return: String representation of the message. + """ + + if model_prefix == "anthropic": + human_prompt_prefix = "\n\nHuman:" + human_prompt_postfix = "" + ai_prompt = "\n\nAssistant:" + + elif model_prefix == "meta": + human_prompt_prefix = "\n[INST]" + human_prompt_postfix = "[\\INST]\n" + ai_prompt = "" + + elif model_prefix == "amazon": + human_prompt_prefix = "\n\nUser:" + human_prompt_postfix = "" + ai_prompt = "\n\nBot:" + + else: + human_prompt_prefix = "" + human_prompt_postfix = "" + ai_prompt = "" + + content = message.content + + if isinstance(message, UserPromptMessage): + message_text = f"{human_prompt_prefix} {content} {human_prompt_postfix}" + elif isinstance(message, AssistantPromptMessage): + message_text = f"{ai_prompt} {content}" + elif isinstance(message, SystemPromptMessage): + message_text = content + else: + raise ValueError(f"Got unknown type {message}") + + return message_text + + def _convert_messages_to_prompt(self, messages: List[PromptMessage], model_prefix: str) -> str: + """ + Format a list of messages into a full prompt for the Anthropic, Amazon and Llama models + + :param messages: List of PromptMessage to combine. + :return: Combined string with necessary human_prompt and ai_prompt tags. + """ + if not messages: + return '' + + messages = messages.copy() # don't mutate the original list + if not isinstance(messages[-1], AssistantPromptMessage): + messages.append(AssistantPromptMessage(content="")) + + text = "".join( + self._convert_one_message_to_text(message, model_prefix) + for message in messages + ) + + # trim off the trailing ' ' that might come from the "Assistant: " + return text.rstrip() + + def _create_payload(self, model_prefix: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, stream: bool = True): + """ + Create payload for bedrock api call depending on model provider + """ + payload = dict() + + if model_prefix == "amazon": + payload["textGenerationConfig"] = { **model_parameters } + payload["textGenerationConfig"]["stopSequences"] = ["User:"] + (stop if stop else []) + + payload["inputText"] = self._convert_messages_to_prompt(prompt_messages, model_prefix) + + elif model_prefix == "ai21": + payload["temperature"] = model_parameters.get("temperature") + payload["topP"] = model_parameters.get("topP") + payload["maxTokens"] = model_parameters.get("maxTokens") + payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix) + + # jurassic models only support a single stop sequence + if stop: + payload["stopSequences"] = stop[0] + + if model_parameters.get("presencePenalty"): + payload["presencePenalty"] = {model_parameters.get("presencePenalty")} + if model_parameters.get("frequencyPenalty"): + payload["frequencyPenalty"] = {model_parameters.get("frequencyPenalty")} + if model_parameters.get("countPenalty"): + payload["countPenalty"] = {model_parameters.get("countPenalty")} + + elif model_prefix == "anthropic": + payload = { **model_parameters } + payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix) + payload["stop_sequences"] = ["\n\nHuman:"] + (stop if stop else []) + + elif model_prefix == "cohere": + payload = { **model_parameters } + payload["prompt"] = prompt_messages[0].content + payload["stream"] = stream + + elif model_prefix == "meta": + payload = { **model_parameters } + payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix) + + else: + raise ValueError(f"Got unknown model prefix {model_prefix}") + + return payload + + def _generate(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + stop: Optional[List[str]] = None, stream: bool = True, + user: Optional[str] = None) -> Union[LLMResult, Generator]: + """ + Invoke large language model + + :param model: model name + :param credentials: credentials kwargs + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ + client_config = Config( + region_name=credentials["aws_region"] + ) + + runtime_client = boto3.client( + service_name='bedrock-runtime', + config=client_config, + aws_access_key_id=credentials["aws_access_key_id"], + aws_secret_access_key=credentials["aws_secret_access_key"] + ) + + model_prefix = model.split('.')[0] + payload = self._create_payload(model_prefix, prompt_messages, model_parameters, stop, stream) + + # need workaround for ai21 models which doesn't support streaming + if stream and model_prefix != "ai21": + invoke = runtime_client.invoke_model_with_response_stream + else: + invoke = runtime_client.invoke_model + + try: + response = invoke( + body=json.dumps(payload), + modelId=model, + ) + except ClientError as ex: + error_code = ex.response['Error']['Code'] + full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" + raise self._map_client_to_invoke_error(error_code, full_error_msg) + + except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex: + raise InvokeConnectionError(str(ex)) + + except UnknownServiceError as ex: + raise InvokeServerUnavailableError(str(ex)) + + except Exception as ex: + raise InvokeError(str(ex)) + + + if stream: + return self._handle_generate_stream_response(model, credentials, response, prompt_messages) + + return self._handle_generate_response(model, credentials, response, prompt_messages) + + def _handle_generate_response(self, model: str, credentials: dict, response: dict, + prompt_messages: list[PromptMessage]) -> LLMResult: + """ + Handle llm response + + :param model: model name + :param credentials: credentials + :param response: response + :param prompt_messages: prompt messages + :return: llm response + """ + response_body = json.loads(response.get('body').read().decode('utf-8')) + + finish_reason = response_body.get("error") + + if finish_reason is not None: + raise InvokeError(finish_reason) + + # get output text and calculate num tokens based on model / provider + model_prefix = model.split('.')[0] + + if model_prefix == "amazon": + output = response_body.get("results")[0].get("outputText").strip('\n') + prompt_tokens = response_body.get("inputTextTokenCount") + completion_tokens = response_body.get("results")[0].get("tokenCount") + + elif model_prefix == "ai21": + output = response_body.get('completions')[0].get('data').get('text') + prompt_tokens = len(response_body.get("prompt").get("tokens")) + completion_tokens = len(response_body.get('completions')[0].get('data').get('tokens')) + + elif model_prefix == "anthropic": + output = response_body.get("completion") + prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) + completion_tokens = self.get_num_tokens(model, credentials, output if output else '') + + elif model_prefix == "cohere": + output = response_body.get("generations")[0].get("text") + prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) + completion_tokens = self.get_num_tokens(model, credentials, output if output else '') + + elif model_prefix == "meta": + output = response_body.get("generation").strip('\n') + prompt_tokens = response_body.get("prompt_token_count") + completion_tokens = response_body.get("generation_token_count") + + else: + raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response") + + # construct assistant message from output + assistant_prompt_message = AssistantPromptMessage( + content=output + ) + + # calculate usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + # construct response + result = LLMResult( + model=model, + prompt_messages=prompt_messages, + message=assistant_prompt_message, + usage=usage, + ) + + return result + + def _handle_generate_stream_response(self, model: str, credentials: dict, response: dict, + prompt_messages: list[PromptMessage]) -> Generator: + """ + Handle llm stream response + + :param model: model name + :param credentials: credentials + :param response: response + :param prompt_messages: prompt messages + :return: llm response chunk generator result + """ + model_prefix = model.split('.')[0] + if model_prefix == "ai21": + response_body = json.loads(response.get('body').read().decode('utf-8')) + + content = response_body.get('completions')[0].get('data').get('text') + finish_reason = response_body.get('completions')[0].get('finish_reason') + + prompt_tokens = len(response_body.get("prompt").get("tokens")) + completion_tokens = len(response_body.get('completions')[0].get('data').get('tokens')) + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=content), + finish_reason=finish_reason, + usage=usage + ) + ) + return + + stream = response.get('body') + if not stream: + raise InvokeError('No response body') + + index = -1 + for event in stream: + chunk = event.get('chunk') + + if not chunk: + exception_name = next(iter(event)) + full_ex_msg = f"{exception_name}: {event[exception_name]['message']}" + + raise self._map_client_to_invoke_error(exception_name, full_ex_msg) + + payload = json.loads(chunk.get('bytes').decode()) + + model_prefix = model.split('.')[0] + if model_prefix == "amazon": + content_delta = payload.get("outputText").strip('\n') + finish_reason = payload.get("completion_reason") + + elif model_prefix == "anthropic": + content_delta = payload + finish_reason = payload.get("stop_reason") + + elif model_prefix == "cohere": + content_delta = payload.get("text") + finish_reason = payload.get("finish_reason") + + elif model_prefix == "meta": + content_delta = payload.get("generation").strip('\n') + finish_reason = payload.get("stop_reason") + + else: + raise ValueError(f"Got unknown model prefix {model_prefix} when handling stream response") + + index += 1 + + assistant_prompt_message = AssistantPromptMessage( + content = content_delta if content_delta else '', + ) + + if not finish_reason: + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=assistant_prompt_message + ) + ) + + else: + # get num tokens from metrics in last chunk + prompt_tokens = payload["amazon-bedrock-invocationMetrics"]["inputTokenCount"] + completion_tokens = payload["amazon-bedrock-invocationMetrics"]["outputTokenCount"] + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=assistant_prompt_message, + finish_reason=finish_reason, + usage=usage + ) + ) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller + The value is the md = genai.GenerativeModel(model)error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke emd = genai.GenerativeModel(model)rror mapping + """ + return { + InvokeConnectionError: [], + InvokeServerUnavailableError: [], + InvokeRateLimitError: [], + InvokeAuthorizationError: [], + InvokeBadRequestError: [] + } + + def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]: + """ + Map client error to invoke error + + :param error_code: error code + :param error_msg: error message + :return: invoke error + """ + + if error_code == "AccessDeniedException": + return InvokeAuthorizationError(error_msg) + elif error_code in ["ResourceNotFoundException", "ValidationException"]: + return InvokeBadRequestError(error_msg) + elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]: + return InvokeRateLimitError(error_msg) + elif error_code in ["ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException"]: + return InvokeServerUnavailableError(error_msg) + elif error_code == "ModelStreamErrorException": + return InvokeConnectionError(error_msg) + + return InvokeError(error_msg) \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/bedrock/llm/meta.llama2-13b-chat-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/meta.llama2-13b-chat-v1.yaml new file mode 100644 index 0000000000..a8d3704c15 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/meta.llama2-13b-chat-v1.yaml @@ -0,0 +1,23 @@ +model: meta.llama2-13b-chat-v1 +label: + en_US: Llama 2 Chat 13B +model_type: llm +model_properties: + mode: chat + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_gen_len + use_template: max_tokens + required: true + default: 2048 + min: 1 + max: 2048 +pricing: + input: '0.00075' + output: '0.00100' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/meta.llama2-70b-chat-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/meta.llama2-70b-chat-v1.yaml new file mode 100644 index 0000000000..77525e630b --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/meta.llama2-70b-chat-v1.yaml @@ -0,0 +1,23 @@ +model: meta.llama2-70b-chat-v1 +label: + en_US: Llama 2 Chat 70B +model_type: llm +model_properties: + mode: chat + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_gen_len + use_template: max_tokens + required: true + default: 2048 + min: 1 + max: 2048 +pricing: + input: '0.00195' + output: '0.00256' + unit: '0.001' + currency: USD diff --git a/api/tests/integration_tests/model_runtime/bedrock/__init__.py b/api/tests/integration_tests/model_runtime/bedrock/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/model_runtime/bedrock/test_llm.py b/api/tests/integration_tests/model_runtime/bedrock/test_llm.py new file mode 100644 index 0000000000..c395a08fe2 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/bedrock/test_llm.py @@ -0,0 +1,117 @@ +import os +from typing import Generator + +import pytest +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.bedrock.llm.llm import BedrockLargeLanguageModel + +def test_validate_credentials(): + model = BedrockLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='meta.llama2-13b-chat-v1', + credentials={ + 'anthropic_api_key': 'invalid_key' + } + ) + + model.validate_credentials( + model='meta.llama2-13b-chat-v1', + credentials={ + "aws_region": os.getenv("AWS_REGION"), + "aws_access_key": os.getenv("AWS_ACCESS_KEY"), + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + } + ) + +def test_invoke_model(): + model = BedrockLargeLanguageModel() + + response = model.invoke( + model='meta.llama2-13b-chat-v1', + credentials={ + "aws_region": os.getenv("AWS_REGION"), + "aws_access_key": os.getenv("AWS_ACCESS_KEY"), + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + }, + prompt_messages=[ + SystemPromptMessage( + content='You are a helpful AI assistant.', + ), + UserPromptMessage( + content='Hello World!' + ) + ], + model_parameters={ + 'temperature': 0.0, + 'top_p': 1.0, + 'max_tokens_to_sample': 10 + }, + stop=['How'], + stream=False, + user="abc-123" + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + +def test_invoke_stream_model(): + model = BedrockLargeLanguageModel() + + response = model.invoke( + model='meta.llama2-13b-chat-v1', + credentials={ + "aws_region": os.getenv("AWS_REGION"), + "aws_access_key": os.getenv("AWS_ACCESS_KEY"), + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + }, + prompt_messages=[ + SystemPromptMessage( + content='You are a helpful AI assistant.', + ), + UserPromptMessage( + content='Hello World!' + ) + ], + model_parameters={ + 'temperature': 0.0, + 'max_tokens_to_sample': 100 + }, + stream=True, + user="abc-123" + ) + + assert isinstance(response, Generator) + + for chunk in response: + print(chunk) + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_get_num_tokens(): + model = BedrockLargeLanguageModel() + + num_tokens = model.get_num_tokens( + model='meta.llama2-13b-chat-v1', + credentials = { + "aws_region": os.getenv("AWS_REGION"), + "aws_access_key": os.getenv("AWS_ACCESS_KEY"), + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + }, + messages=[ + SystemPromptMessage( + content='You are a helpful AI assistant.', + ), + UserPromptMessage( + content='Hello World!' + ) + ] + ) + + assert num_tokens == 18 diff --git a/api/tests/integration_tests/model_runtime/bedrock/test_provider.py b/api/tests/integration_tests/model_runtime/bedrock/test_provider.py new file mode 100644 index 0000000000..ce3d61c0f7 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/bedrock/test_provider.py @@ -0,0 +1,21 @@ +import os + +import pytest +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.bedrock.bedrock import BedrockProvider + +def test_validate_provider_credentials(): + provider = BedrockProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials( + credentials={} + ) + + provider.validate_provider_credentials( + credentials={ + "aws_region": os.getenv("AWS_REGION"), + "aws_access_key": os.getenv("AWS_ACCESS_KEY"), + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + } + )