diff --git a/api/core/model_runtime/model_providers/bedrock/__init__.py b/api/core/model_runtime/model_providers/bedrock/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/model_runtime/model_providers/bedrock/_assets/icon_l_en.svg b/api/core/model_runtime/model_providers/bedrock/_assets/icon_l_en.svg
new file mode 100644
index 0000000000..667db50800
--- /dev/null
+++ b/api/core/model_runtime/model_providers/bedrock/_assets/icon_l_en.svg
@@ -0,0 +1,14 @@
+
diff --git a/api/core/model_runtime/model_providers/bedrock/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/bedrock/_assets/icon_s_en.svg
new file mode 100644
index 0000000000..6a0235af92
--- /dev/null
+++ b/api/core/model_runtime/model_providers/bedrock/_assets/icon_s_en.svg
@@ -0,0 +1,15 @@
+
diff --git a/api/core/model_runtime/model_providers/bedrock/bedrock.py b/api/core/model_runtime/model_providers/bedrock/bedrock.py
new file mode 100644
index 0000000000..aa322fc664
--- /dev/null
+++ b/api/core/model_runtime/model_providers/bedrock/bedrock.py
@@ -0,0 +1,30 @@
+import logging
+
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+class BedrockProvider(ModelProvider):
+ def validate_provider_credentials(self, credentials: dict) -> None:
+ """
+ Validate provider credentials
+
+ if validate failed, raise exception
+
+ :param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
+ """
+ try:
+ model_instance = self.get_model_instance(ModelType.LLM)
+
+ # Use `gemini-pro` model for validate,
+ model_instance.validate_credentials(
+ model='amazon.titan-text-lite-v1',
+ credentials=credentials
+ )
+ except CredentialsValidateFailedError as ex:
+ raise ex
+ except Exception as ex:
+ logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
+ raise ex
diff --git a/api/core/model_runtime/model_providers/bedrock/bedrock.yaml b/api/core/model_runtime/model_providers/bedrock/bedrock.yaml
new file mode 100644
index 0000000000..1458b830cd
--- /dev/null
+++ b/api/core/model_runtime/model_providers/bedrock/bedrock.yaml
@@ -0,0 +1,71 @@
+provider: bedrock
+label:
+ en_US: AWS
+description:
+ en_US: AWS Bedrock's models.
+icon_small:
+ en_US: icon_s_en.svg
+icon_large:
+ en_US: icon_l_en.svg
+background: "#FCFDFF"
+help:
+ title:
+ en_US: Get your Access Key and Secret Access Key from AWS Console
+ url:
+ en_US: https://console.aws.amazon.com/
+supported_model_types:
+ - llm
+configurate_methods:
+ - predefined-model
+provider_credential_schema:
+ credential_form_schemas:
+ - variable: aws_access_key_id
+ required: true
+ label:
+ en_US: Access Key
+ zh_Hans: Access Key
+ type: secret-input
+ placeholder:
+ en_US: Enter your Access Key
+ zh_Hans: 在此输入您的 Access Key
+ - variable: aws_secret_access_key
+ required: true
+ label:
+ en_US: Secret Access Key
+ zh_Hans: Secret Access Key
+ type: secret-input
+ placeholder:
+ en_US: Enter your Secret Access Key
+ zh_Hans: 在此输入您的 Secret Access Key
+ - variable: aws_region
+ required: true
+ label:
+ en_US: AWS Region
+ zh_Hans: AWS 地区
+ type: select
+ default: us-east-1
+ options:
+ - value: us-east-1
+ label:
+ en_US: US East (N. Virginia)
+ zh_Hans: US East (N. Virginia)
+ - value: us-west-2
+ label:
+ en_US: US West (Oregon)
+ zh_Hans: US West (Oregon)
+ - value: ap-southeast-1
+ label:
+ en_US: Asia Pacific (Singapore)
+ zh_Hans: Asia Pacific (Singapore)
+ - value: ap-northeast-1
+ label:
+ en_US: Asia Pacific (Tokyo)
+ zh_Hans: Asia Pacific (Tokyo)
+ - value: eu-central-1
+ label:
+ en_US: Europe (Frankfurt)
+ zh_Hans: Europe (Frankfurt)
+ - value: us-gov-west-1
+ label:
+ en_US: AWS GovCloud (US-West)
+ zh_Hans: AWS GovCloud (US-West)
diff --git a/api/core/model_runtime/model_providers/bedrock/llm/__init__.py b/api/core/model_runtime/model_providers/bedrock/llm/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml b/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml
new file mode 100644
index 0000000000..c4be732f2e
--- /dev/null
+++ b/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml
@@ -0,0 +1,10 @@
+- amazon.titan-text-express-v1
+- amazon.titan-text-lite-v1
+- anthropic.claude-instant-v1
+- anthropic.claude-v1
+- anthropic.claude-v2
+- anthropic.claude-v2:1
+- cohere.command-light-text-v14
+- cohere.command-text-v14
+- meta.llama2-13b-chat-v1
+- meta.llama2-70b-chat-v1
diff --git a/api/core/model_runtime/model_providers/bedrock/llm/ai21.j2-mid-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/ai21.j2-mid-v1.yaml
new file mode 100644
index 0000000000..65dad02969
--- /dev/null
+++ b/api/core/model_runtime/model_providers/bedrock/llm/ai21.j2-mid-v1.yaml
@@ -0,0 +1,47 @@
+model: ai21.j2-mid-v1
+label:
+ en_US: J2 Mid V1
+model_type: llm
+model_properties:
+ mode: completion
+ context_size: 8191
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: topP
+ use_template: top_p
+ - name: maxTokens
+ use_template: max_tokens
+ required: true
+ default: 2048
+ min: 1
+ max: 2048
+ - name: count_penalty
+ label:
+ en_US: Count Penalty
+ required: false
+ type: float
+ default: 0
+ min: 0
+ max: 1
+ - name: presence_penalty
+ label:
+ en_US: Presence Penalty
+ required: false
+ type: float
+ default: 0
+ min: 0
+ max: 5
+ - name: frequency_penalty
+ label:
+ en_US: Frequency Penalty
+ required: false
+ type: float
+ default: 0
+ min: 0
+ max: 500
+pricing:
+ input: '0.00'
+ output: '0.00'
+ unit: '0.000001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/bedrock/llm/ai21.j2-ultra-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/ai21.j2-ultra-v1.yaml
new file mode 100644
index 0000000000..b72f8064bd
--- /dev/null
+++ b/api/core/model_runtime/model_providers/bedrock/llm/ai21.j2-ultra-v1.yaml
@@ -0,0 +1,47 @@
+model: ai21.j2-ultra-v1
+label:
+ en_US: J2 Ultra V1
+model_type: llm
+model_properties:
+ mode: completion
+ context_size: 8191
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: topP
+ use_template: top_p
+ - name: maxTokens
+ use_template: max_tokens
+ required: true
+ default: 2048
+ min: 1
+ max: 2048
+ - name: count_penalty
+ label:
+ en_US: Count Penalty
+ required: false
+ type: float
+ default: 0
+ min: 0
+ max: 1
+ - name: presence_penalty
+ label:
+ en_US: Presence Penalty
+ required: false
+ type: float
+ default: 0
+ min: 0
+ max: 5
+ - name: frequency_penalty
+ label:
+ en_US: Frequency Penalty
+ required: false
+ type: float
+ default: 0
+ min: 0
+ max: 500
+pricing:
+ input: '0.00'
+ output: '0.00'
+ unit: '0.000001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/bedrock/llm/amazon.titan-text-express-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/amazon.titan-text-express-v1.yaml
new file mode 100644
index 0000000000..64f992b913
--- /dev/null
+++ b/api/core/model_runtime/model_providers/bedrock/llm/amazon.titan-text-express-v1.yaml
@@ -0,0 +1,25 @@
+model: amazon.titan-text-express-v1
+label:
+ en_US: Titan Text G1 - Express
+model_type: llm
+features:
+ - agent-thought
+model_properties:
+ mode: chat
+ context_size: 8192
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: topP
+ use_template: top_p
+ - name: maxTokenCount
+ use_template: max_tokens
+ required: true
+ default: 2048
+ min: 1
+ max: 8000
+pricing:
+ input: '0.0008'
+ output: '0.0016'
+ unit: '0.001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/bedrock/llm/amazon.titan-text-lite-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/amazon.titan-text-lite-v1.yaml
new file mode 100644
index 0000000000..69b298b571
--- /dev/null
+++ b/api/core/model_runtime/model_providers/bedrock/llm/amazon.titan-text-lite-v1.yaml
@@ -0,0 +1,25 @@
+model: amazon.titan-text-lite-v1
+label:
+ en_US: Titan Text G1 - Lite
+model_type: llm
+features:
+ - agent-thought
+model_properties:
+ mode: chat
+ context_size: 4096
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: topP
+ use_template: top_p
+ - name: maxTokenCount
+ use_template: max_tokens
+ required: true
+ default: 2048
+ min: 1
+ max: 2048
+pricing:
+ input: '0.0003'
+ output: '0.0004'
+ unit: '0.001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-instant-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-instant-v1.yaml
new file mode 100644
index 0000000000..94b741f50d
--- /dev/null
+++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-instant-v1.yaml
@@ -0,0 +1,35 @@
+model: anthropic.claude-instant-v1
+label:
+ en_US: Claude Instant V1
+model_type: llm
+model_properties:
+ mode: chat
+ context_size: 100000
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: topP
+ use_template: top_p
+ - name: topK
+ label:
+ zh_Hans: 取样数量
+ en_US: Top K
+ type: int
+ help:
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+ en_US: Only sample from the top K options for each subsequent token.
+ required: false
+ default: 250
+ min: 0
+ max: 500
+ - name: max_tokens_to_sample
+ use_template: max_tokens
+ required: true
+ default: 4096
+ min: 1
+ max: 4096
+pricing:
+ input: '0.0008'
+ output: '0.0024'
+ unit: '0.001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v1.yaml
new file mode 100644
index 0000000000..1c85923335
--- /dev/null
+++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v1.yaml
@@ -0,0 +1,35 @@
+model: anthropic.claude-v1
+label:
+ en_US: Claude V1
+model_type: llm
+model_properties:
+ mode: chat
+ context_size: 100000
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: topP
+ use_template: top_p
+ - name: topK
+ label:
+ zh_Hans: 取样数量
+ en_US: Top K
+ type: int
+ help:
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+ en_US: Only sample from the top K options for each subsequent token.
+ required: false
+ default: 250
+ min: 0
+ max: 500
+ - name: max_tokens_to_sample
+ use_template: max_tokens
+ required: true
+ default: 4096
+ min: 1
+ max: 4096
+pricing:
+ input: '0.008'
+ output: '0.024'
+ unit: '0.001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.yaml
new file mode 100644
index 0000000000..d12e7fce90
--- /dev/null
+++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.yaml
@@ -0,0 +1,35 @@
+model: anthropic.claude-v2
+label:
+ en_US: Claude V2
+model_type: llm
+model_properties:
+ mode: chat
+ context_size: 100000
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: topP
+ use_template: top_p
+ - name: topK
+ label:
+ zh_Hans: 取样数量
+ en_US: Top K
+ type: int
+ help:
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+ en_US: Only sample from the top K options for each subsequent token.
+ required: false
+ default: 250
+ min: 0
+ max: 500
+ - name: max_tokens_to_sample
+ use_template: max_tokens
+ required: true
+ default: 4096
+ min: 1
+ max: 4096
+pricing:
+ input: '0.008'
+ output: '0.024'
+ unit: '0.001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2:1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2:1.yaml
new file mode 100644
index 0000000000..c5502daf4a
--- /dev/null
+++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2:1.yaml
@@ -0,0 +1,35 @@
+model: anthropic.claude-v2:1
+label:
+ en_US: Claude V2.1
+model_type: llm
+model_properties:
+ mode: chat
+ context_size: 200000
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: topP
+ use_template: top_p
+ - name: topK
+ label:
+ zh_Hans: 取样数量
+ en_US: Top K
+ type: int
+ help:
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+ en_US: Only sample from the top K options for each subsequent token.
+ required: false
+ default: 250
+ min: 0
+ max: 500
+ - name: max_tokens_to_sample
+ use_template: max_tokens
+ required: true
+ default: 4096
+ min: 1
+ max: 4096
+pricing:
+ input: '0.008'
+ output: '0.024'
+ unit: '0.001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/bedrock/llm/cohere.command-light-text-v14.yaml b/api/core/model_runtime/model_providers/bedrock/llm/cohere.command-light-text-v14.yaml
new file mode 100644
index 0000000000..1fad910058
--- /dev/null
+++ b/api/core/model_runtime/model_providers/bedrock/llm/cohere.command-light-text-v14.yaml
@@ -0,0 +1,35 @@
+model: cohere.command-light-text-v14
+label:
+ en_US: Command Light Text V14
+model_type: llm
+model_properties:
+ mode: completion
+ context_size: 4096
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: p
+ use_template: top_p
+ - name: k
+ label:
+ zh_Hans: 取样数量
+ en_US: Top k
+ type: int
+ help:
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+ en_US: Only sample from the top K options for each subsequent token.
+ required: false
+ min: 0
+ max: 500
+ default: 0
+ - name: max_tokens_to_sample
+ use_template: max_tokens
+ required: true
+ default: 4096
+ min: 1
+ max: 4096
+pricing:
+ input: '0.0003'
+ output: '0.0006'
+ unit: '0.001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/bedrock/llm/cohere.command-text-v14.yaml b/api/core/model_runtime/model_providers/bedrock/llm/cohere.command-text-v14.yaml
new file mode 100644
index 0000000000..ed775afd7a
--- /dev/null
+++ b/api/core/model_runtime/model_providers/bedrock/llm/cohere.command-text-v14.yaml
@@ -0,0 +1,32 @@
+model: cohere.command-text-v14
+label:
+ en_US: Command Text V14
+model_type: llm
+model_properties:
+ mode: completion
+ context_size: 4096
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: top_k
+ label:
+ zh_Hans: 取样数量
+ en_US: Top k
+ type: int
+ help:
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+ en_US: Only sample from the top K options for each subsequent token.
+ required: false
+ - name: max_tokens_to_sample
+ use_template: max_tokens
+ required: true
+ default: 4096
+ min: 1
+ max: 4096
+pricing:
+ input: '0.0015'
+ output: '0.0020'
+ unit: '0.001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py
new file mode 100644
index 0000000000..269e814d99
--- /dev/null
+++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py
@@ -0,0 +1,486 @@
+import logging
+from typing import Generator, List, Optional, Union
+
+import boto3
+from botocore.exceptions import ClientError, EndpointConnectionError, NoRegionError, ServiceNotInRegionError, UnknownServiceError
+from botocore.config import Config
+import json
+
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
+
+from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage,
+ PromptMessageTool, SystemPromptMessage, UserPromptMessage)
+from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
+ InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+
+logger = logging.getLogger(__name__)
+
+class BedrockLargeLanguageModel(LargeLanguageModel):
+
+ def _invoke(self, model: str, credentials: dict,
+ prompt_messages: list[PromptMessage], model_parameters: dict,
+ tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
+ stream: bool = True, user: Optional[str] = None) \
+ -> Union[LLMResult, Generator]:
+ """
+ Invoke large language model
+
+ :param model: model name
+ :param credentials: model credentials
+ :param prompt_messages: prompt messages
+ :param model_parameters: model parameters
+ :param tools: tools for tool calling
+ :param stop: stop words
+ :param stream: is stream response
+ :param user: unique user id
+ :return: full response or stream response chunk generator result
+ """
+ # invoke model
+ return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
+
+ def get_num_tokens(self, model: str, credentials: dict, messages: list[PromptMessage] | str,
+ tools: Optional[list[PromptMessageTool]] = None) -> int:
+ """
+ Get number of tokens for given prompt messages
+
+ :param model: model name
+ :param credentials: model credentials
+ :param messages: prompt messages or message string
+ :param tools: tools for tool calling
+ :return:md = genai.GenerativeModel(model)
+ """
+ prefix = model.split('.')[0]
+
+ if isinstance(messages, str):
+ prompt = messages
+ else:
+ prompt = self._convert_messages_to_prompt(messages, prefix)
+
+ return self._get_num_tokens_by_gpt2(prompt)
+
+ def _convert_messages_to_prompt(self, model_prefix: str, messages: list[PromptMessage]) -> str:
+ """
+ Format a list of messages into a full prompt for the Google model
+
+ :param messages: List of PromptMessage to combine.
+ :return: Combined string with necessary human_prompt and ai_prompt tags.
+ """
+ messages = messages.copy() # don't mutate the original list
+
+ text = "".join(
+ self._convert_one_message_to_text(message, model_prefix)
+ for message in messages
+ )
+
+ return text.rstrip()
+
+ def validate_credentials(self, model: str, credentials: dict) -> None:
+ """
+ Validate model credentials
+
+ :param model: model name
+ :param credentials: model credentials
+ :return:
+ """
+
+ try:
+ ping_message = UserPromptMessage(content="ping")
+ self._generate(model=model,
+ credentials=credentials,
+ prompt_messages=[ping_message],
+ model_parameters={},
+ stream=False)
+
+ except ClientError as ex:
+ error_code = ex.response['Error']['Code']
+ full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
+
+ raise CredentialsValidateFailedError(str(self._map_client_to_invoke_error(error_code, full_error_msg)))
+
+ except Exception as ex:
+ raise CredentialsValidateFailedError(str(ex))
+
+ def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str) -> str:
+ """
+ Convert a single message to a string.
+
+ :param message: PromptMessage to convert.
+ :return: String representation of the message.
+ """
+
+ if model_prefix == "anthropic":
+ human_prompt_prefix = "\n\nHuman:"
+ human_prompt_postfix = ""
+ ai_prompt = "\n\nAssistant:"
+
+ elif model_prefix == "meta":
+ human_prompt_prefix = "\n[INST]"
+ human_prompt_postfix = "[\\INST]\n"
+ ai_prompt = ""
+
+ elif model_prefix == "amazon":
+ human_prompt_prefix = "\n\nUser:"
+ human_prompt_postfix = ""
+ ai_prompt = "\n\nBot:"
+
+ else:
+ human_prompt_prefix = ""
+ human_prompt_postfix = ""
+ ai_prompt = ""
+
+ content = message.content
+
+ if isinstance(message, UserPromptMessage):
+ message_text = f"{human_prompt_prefix} {content} {human_prompt_postfix}"
+ elif isinstance(message, AssistantPromptMessage):
+ message_text = f"{ai_prompt} {content}"
+ elif isinstance(message, SystemPromptMessage):
+ message_text = content
+ else:
+ raise ValueError(f"Got unknown type {message}")
+
+ return message_text
+
+ def _convert_messages_to_prompt(self, messages: List[PromptMessage], model_prefix: str) -> str:
+ """
+ Format a list of messages into a full prompt for the Anthropic, Amazon and Llama models
+
+ :param messages: List of PromptMessage to combine.
+ :return: Combined string with necessary human_prompt and ai_prompt tags.
+ """
+ if not messages:
+ return ''
+
+ messages = messages.copy() # don't mutate the original list
+ if not isinstance(messages[-1], AssistantPromptMessage):
+ messages.append(AssistantPromptMessage(content=""))
+
+ text = "".join(
+ self._convert_one_message_to_text(message, model_prefix)
+ for message in messages
+ )
+
+ # trim off the trailing ' ' that might come from the "Assistant: "
+ return text.rstrip()
+
+ def _create_payload(self, model_prefix: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, stream: bool = True):
+ """
+ Create payload for bedrock api call depending on model provider
+ """
+ payload = dict()
+
+ if model_prefix == "amazon":
+ payload["textGenerationConfig"] = { **model_parameters }
+ payload["textGenerationConfig"]["stopSequences"] = ["User:"] + (stop if stop else [])
+
+ payload["inputText"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
+
+ elif model_prefix == "ai21":
+ payload["temperature"] = model_parameters.get("temperature")
+ payload["topP"] = model_parameters.get("topP")
+ payload["maxTokens"] = model_parameters.get("maxTokens")
+ payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
+
+ # jurassic models only support a single stop sequence
+ if stop:
+ payload["stopSequences"] = stop[0]
+
+ if model_parameters.get("presencePenalty"):
+ payload["presencePenalty"] = {model_parameters.get("presencePenalty")}
+ if model_parameters.get("frequencyPenalty"):
+ payload["frequencyPenalty"] = {model_parameters.get("frequencyPenalty")}
+ if model_parameters.get("countPenalty"):
+ payload["countPenalty"] = {model_parameters.get("countPenalty")}
+
+ elif model_prefix == "anthropic":
+ payload = { **model_parameters }
+ payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
+ payload["stop_sequences"] = ["\n\nHuman:"] + (stop if stop else [])
+
+ elif model_prefix == "cohere":
+ payload = { **model_parameters }
+ payload["prompt"] = prompt_messages[0].content
+ payload["stream"] = stream
+
+ elif model_prefix == "meta":
+ payload = { **model_parameters }
+ payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
+
+ else:
+ raise ValueError(f"Got unknown model prefix {model_prefix}")
+
+ return payload
+
+ def _generate(self, model: str, credentials: dict,
+ prompt_messages: list[PromptMessage], model_parameters: dict,
+ stop: Optional[List[str]] = None, stream: bool = True,
+ user: Optional[str] = None) -> Union[LLMResult, Generator]:
+ """
+ Invoke large language model
+
+ :param model: model name
+ :param credentials: credentials kwargs
+ :param prompt_messages: prompt messages
+ :param model_parameters: model parameters
+ :param stop: stop words
+ :param stream: is stream response
+ :param user: unique user id
+ :return: full response or stream response chunk generator result
+ """
+ client_config = Config(
+ region_name=credentials["aws_region"]
+ )
+
+ runtime_client = boto3.client(
+ service_name='bedrock-runtime',
+ config=client_config,
+ aws_access_key_id=credentials["aws_access_key_id"],
+ aws_secret_access_key=credentials["aws_secret_access_key"]
+ )
+
+ model_prefix = model.split('.')[0]
+ payload = self._create_payload(model_prefix, prompt_messages, model_parameters, stop, stream)
+
+ # need workaround for ai21 models which doesn't support streaming
+ if stream and model_prefix != "ai21":
+ invoke = runtime_client.invoke_model_with_response_stream
+ else:
+ invoke = runtime_client.invoke_model
+
+ try:
+ response = invoke(
+ body=json.dumps(payload),
+ modelId=model,
+ )
+ except ClientError as ex:
+ error_code = ex.response['Error']['Code']
+ full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
+ raise self._map_client_to_invoke_error(error_code, full_error_msg)
+
+ except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex:
+ raise InvokeConnectionError(str(ex))
+
+ except UnknownServiceError as ex:
+ raise InvokeServerUnavailableError(str(ex))
+
+ except Exception as ex:
+ raise InvokeError(str(ex))
+
+
+ if stream:
+ return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
+
+ return self._handle_generate_response(model, credentials, response, prompt_messages)
+
+ def _handle_generate_response(self, model: str, credentials: dict, response: dict,
+ prompt_messages: list[PromptMessage]) -> LLMResult:
+ """
+ Handle llm response
+
+ :param model: model name
+ :param credentials: credentials
+ :param response: response
+ :param prompt_messages: prompt messages
+ :return: llm response
+ """
+ response_body = json.loads(response.get('body').read().decode('utf-8'))
+
+ finish_reason = response_body.get("error")
+
+ if finish_reason is not None:
+ raise InvokeError(finish_reason)
+
+ # get output text and calculate num tokens based on model / provider
+ model_prefix = model.split('.')[0]
+
+ if model_prefix == "amazon":
+ output = response_body.get("results")[0].get("outputText").strip('\n')
+ prompt_tokens = response_body.get("inputTextTokenCount")
+ completion_tokens = response_body.get("results")[0].get("tokenCount")
+
+ elif model_prefix == "ai21":
+ output = response_body.get('completions')[0].get('data').get('text')
+ prompt_tokens = len(response_body.get("prompt").get("tokens"))
+ completion_tokens = len(response_body.get('completions')[0].get('data').get('tokens'))
+
+ elif model_prefix == "anthropic":
+ output = response_body.get("completion")
+ prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
+ completion_tokens = self.get_num_tokens(model, credentials, output if output else '')
+
+ elif model_prefix == "cohere":
+ output = response_body.get("generations")[0].get("text")
+ prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
+ completion_tokens = self.get_num_tokens(model, credentials, output if output else '')
+
+ elif model_prefix == "meta":
+ output = response_body.get("generation").strip('\n')
+ prompt_tokens = response_body.get("prompt_token_count")
+ completion_tokens = response_body.get("generation_token_count")
+
+ else:
+ raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
+
+ # construct assistant message from output
+ assistant_prompt_message = AssistantPromptMessage(
+ content=output
+ )
+
+ # calculate usage
+ usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+ # construct response
+ result = LLMResult(
+ model=model,
+ prompt_messages=prompt_messages,
+ message=assistant_prompt_message,
+ usage=usage,
+ )
+
+ return result
+
+ def _handle_generate_stream_response(self, model: str, credentials: dict, response: dict,
+ prompt_messages: list[PromptMessage]) -> Generator:
+ """
+ Handle llm stream response
+
+ :param model: model name
+ :param credentials: credentials
+ :param response: response
+ :param prompt_messages: prompt messages
+ :return: llm response chunk generator result
+ """
+ model_prefix = model.split('.')[0]
+ if model_prefix == "ai21":
+ response_body = json.loads(response.get('body').read().decode('utf-8'))
+
+ content = response_body.get('completions')[0].get('data').get('text')
+ finish_reason = response_body.get('completions')[0].get('finish_reason')
+
+ prompt_tokens = len(response_body.get("prompt").get("tokens"))
+ completion_tokens = len(response_body.get('completions')[0].get('data').get('tokens'))
+ usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+ yield LLMResultChunk(
+ model=model,
+ prompt_messages=prompt_messages,
+ delta=LLMResultChunkDelta(
+ index=0,
+ message=AssistantPromptMessage(content=content),
+ finish_reason=finish_reason,
+ usage=usage
+ )
+ )
+ return
+
+ stream = response.get('body')
+ if not stream:
+ raise InvokeError('No response body')
+
+ index = -1
+ for event in stream:
+ chunk = event.get('chunk')
+
+ if not chunk:
+ exception_name = next(iter(event))
+ full_ex_msg = f"{exception_name}: {event[exception_name]['message']}"
+
+ raise self._map_client_to_invoke_error(exception_name, full_ex_msg)
+
+ payload = json.loads(chunk.get('bytes').decode())
+
+ model_prefix = model.split('.')[0]
+ if model_prefix == "amazon":
+ content_delta = payload.get("outputText").strip('\n')
+ finish_reason = payload.get("completion_reason")
+
+ elif model_prefix == "anthropic":
+ content_delta = payload
+ finish_reason = payload.get("stop_reason")
+
+ elif model_prefix == "cohere":
+ content_delta = payload.get("text")
+ finish_reason = payload.get("finish_reason")
+
+ elif model_prefix == "meta":
+ content_delta = payload.get("generation").strip('\n')
+ finish_reason = payload.get("stop_reason")
+
+ else:
+ raise ValueError(f"Got unknown model prefix {model_prefix} when handling stream response")
+
+ index += 1
+
+ assistant_prompt_message = AssistantPromptMessage(
+ content = content_delta if content_delta else '',
+ )
+
+ if not finish_reason:
+ yield LLMResultChunk(
+ model=model,
+ prompt_messages=prompt_messages,
+ delta=LLMResultChunkDelta(
+ index=index,
+ message=assistant_prompt_message
+ )
+ )
+
+ else:
+ # get num tokens from metrics in last chunk
+ prompt_tokens = payload["amazon-bedrock-invocationMetrics"]["inputTokenCount"]
+ completion_tokens = payload["amazon-bedrock-invocationMetrics"]["outputTokenCount"]
+
+ # transform usage
+ usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+ yield LLMResultChunk(
+ model=model,
+ prompt_messages=prompt_messages,
+ delta=LLMResultChunkDelta(
+ index=index,
+ message=assistant_prompt_message,
+ finish_reason=finish_reason,
+ usage=usage
+ )
+ )
+
+ @property
+ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ """
+ Map model invoke error to unified error
+ The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller
+ The value is the md = genai.GenerativeModel(model)error type thrown by the model,
+ which needs to be converted into a unified error type for the caller.
+
+ :return: Invoke emd = genai.GenerativeModel(model)rror mapping
+ """
+ return {
+ InvokeConnectionError: [],
+ InvokeServerUnavailableError: [],
+ InvokeRateLimitError: [],
+ InvokeAuthorizationError: [],
+ InvokeBadRequestError: []
+ }
+
+ def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]:
+ """
+ Map client error to invoke error
+
+ :param error_code: error code
+ :param error_msg: error message
+ :return: invoke error
+ """
+
+ if error_code == "AccessDeniedException":
+ return InvokeAuthorizationError(error_msg)
+ elif error_code in ["ResourceNotFoundException", "ValidationException"]:
+ return InvokeBadRequestError(error_msg)
+ elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]:
+ return InvokeRateLimitError(error_msg)
+ elif error_code in ["ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException"]:
+ return InvokeServerUnavailableError(error_msg)
+ elif error_code == "ModelStreamErrorException":
+ return InvokeConnectionError(error_msg)
+
+ return InvokeError(error_msg)
\ No newline at end of file
diff --git a/api/core/model_runtime/model_providers/bedrock/llm/meta.llama2-13b-chat-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/meta.llama2-13b-chat-v1.yaml
new file mode 100644
index 0000000000..a8d3704c15
--- /dev/null
+++ b/api/core/model_runtime/model_providers/bedrock/llm/meta.llama2-13b-chat-v1.yaml
@@ -0,0 +1,23 @@
+model: meta.llama2-13b-chat-v1
+label:
+ en_US: Llama 2 Chat 13B
+model_type: llm
+model_properties:
+ mode: chat
+ context_size: 4096
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: max_gen_len
+ use_template: max_tokens
+ required: true
+ default: 2048
+ min: 1
+ max: 2048
+pricing:
+ input: '0.00075'
+ output: '0.00100'
+ unit: '0.001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/bedrock/llm/meta.llama2-70b-chat-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/meta.llama2-70b-chat-v1.yaml
new file mode 100644
index 0000000000..77525e630b
--- /dev/null
+++ b/api/core/model_runtime/model_providers/bedrock/llm/meta.llama2-70b-chat-v1.yaml
@@ -0,0 +1,23 @@
+model: meta.llama2-70b-chat-v1
+label:
+ en_US: Llama 2 Chat 70B
+model_type: llm
+model_properties:
+ mode: chat
+ context_size: 4096
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: max_gen_len
+ use_template: max_tokens
+ required: true
+ default: 2048
+ min: 1
+ max: 2048
+pricing:
+ input: '0.00195'
+ output: '0.00256'
+ unit: '0.001'
+ currency: USD
diff --git a/api/tests/integration_tests/model_runtime/bedrock/__init__.py b/api/tests/integration_tests/model_runtime/bedrock/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/tests/integration_tests/model_runtime/bedrock/test_llm.py b/api/tests/integration_tests/model_runtime/bedrock/test_llm.py
new file mode 100644
index 0000000000..c395a08fe2
--- /dev/null
+++ b/api/tests/integration_tests/model_runtime/bedrock/test_llm.py
@@ -0,0 +1,117 @@
+import os
+from typing import Generator
+
+import pytest
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.bedrock.llm.llm import BedrockLargeLanguageModel
+
+def test_validate_credentials():
+ model = BedrockLargeLanguageModel()
+
+ with pytest.raises(CredentialsValidateFailedError):
+ model.validate_credentials(
+ model='meta.llama2-13b-chat-v1',
+ credentials={
+ 'anthropic_api_key': 'invalid_key'
+ }
+ )
+
+ model.validate_credentials(
+ model='meta.llama2-13b-chat-v1',
+ credentials={
+ "aws_region": os.getenv("AWS_REGION"),
+ "aws_access_key": os.getenv("AWS_ACCESS_KEY"),
+ "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
+ }
+ )
+
+def test_invoke_model():
+ model = BedrockLargeLanguageModel()
+
+ response = model.invoke(
+ model='meta.llama2-13b-chat-v1',
+ credentials={
+ "aws_region": os.getenv("AWS_REGION"),
+ "aws_access_key": os.getenv("AWS_ACCESS_KEY"),
+ "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
+ },
+ prompt_messages=[
+ SystemPromptMessage(
+ content='You are a helpful AI assistant.',
+ ),
+ UserPromptMessage(
+ content='Hello World!'
+ )
+ ],
+ model_parameters={
+ 'temperature': 0.0,
+ 'top_p': 1.0,
+ 'max_tokens_to_sample': 10
+ },
+ stop=['How'],
+ stream=False,
+ user="abc-123"
+ )
+
+ assert isinstance(response, LLMResult)
+ assert len(response.message.content) > 0
+
+def test_invoke_stream_model():
+ model = BedrockLargeLanguageModel()
+
+ response = model.invoke(
+ model='meta.llama2-13b-chat-v1',
+ credentials={
+ "aws_region": os.getenv("AWS_REGION"),
+ "aws_access_key": os.getenv("AWS_ACCESS_KEY"),
+ "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
+ },
+ prompt_messages=[
+ SystemPromptMessage(
+ content='You are a helpful AI assistant.',
+ ),
+ UserPromptMessage(
+ content='Hello World!'
+ )
+ ],
+ model_parameters={
+ 'temperature': 0.0,
+ 'max_tokens_to_sample': 100
+ },
+ stream=True,
+ user="abc-123"
+ )
+
+ assert isinstance(response, Generator)
+
+ for chunk in response:
+ print(chunk)
+ assert isinstance(chunk, LLMResultChunk)
+ assert isinstance(chunk.delta, LLMResultChunkDelta)
+ assert isinstance(chunk.delta.message, AssistantPromptMessage)
+ assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
+
+
+def test_get_num_tokens():
+ model = BedrockLargeLanguageModel()
+
+ num_tokens = model.get_num_tokens(
+ model='meta.llama2-13b-chat-v1',
+ credentials = {
+ "aws_region": os.getenv("AWS_REGION"),
+ "aws_access_key": os.getenv("AWS_ACCESS_KEY"),
+ "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
+ },
+ messages=[
+ SystemPromptMessage(
+ content='You are a helpful AI assistant.',
+ ),
+ UserPromptMessage(
+ content='Hello World!'
+ )
+ ]
+ )
+
+ assert num_tokens == 18
diff --git a/api/tests/integration_tests/model_runtime/bedrock/test_provider.py b/api/tests/integration_tests/model_runtime/bedrock/test_provider.py
new file mode 100644
index 0000000000..ce3d61c0f7
--- /dev/null
+++ b/api/tests/integration_tests/model_runtime/bedrock/test_provider.py
@@ -0,0 +1,21 @@
+import os
+
+import pytest
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.bedrock.bedrock import BedrockProvider
+
+def test_validate_provider_credentials():
+ provider = BedrockProvider()
+
+ with pytest.raises(CredentialsValidateFailedError):
+ provider.validate_provider_credentials(
+ credentials={}
+ )
+
+ provider.validate_provider_credentials(
+ credentials={
+ "aws_region": os.getenv("AWS_REGION"),
+ "aws_access_key": os.getenv("AWS_ACCESS_KEY"),
+ "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
+ }
+ )