[seanguo] modify bedrock Claude3 invoke method to converse API (#5768)

Co-authored-by: Chenhe Gu <guchenhe@gmail.com>
This commit is contained in:
longzhihun 2024-07-01 04:36:13 +08:00 committed by GitHub
parent a27462d58b
commit fdfbbde10d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 101 additions and 151 deletions

View File

@ -1,22 +1,14 @@
# standard import
import base64 import base64
import json import json
import logging import logging
import mimetypes import mimetypes
import time
from collections.abc import Generator from collections.abc import Generator
from typing import Optional, Union, cast from typing import Optional, Union, cast
# 3rd import
import boto3 import boto3
import requests import requests
from anthropic import AnthropicBedrock, Stream
from anthropic.types import (
ContentBlockDeltaEvent,
Message,
MessageDeltaEvent,
MessageStartEvent,
MessageStopEvent,
MessageStreamEvent,
)
from botocore.config import Config from botocore.config import Config
from botocore.exceptions import ( from botocore.exceptions import (
ClientError, ClientError,
@ -27,7 +19,8 @@ from botocore.exceptions import (
) )
from cohere import ChatMessage from cohere import ChatMessage
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage # local import
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
ImagePromptMessageContent, ImagePromptMessageContent,
@ -38,7 +31,6 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent, TextPromptMessageContent,
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.errors.invoke import ( from core.model_runtime.errors.invoke import (
InvokeAuthorizationError, InvokeAuthorizationError,
InvokeBadRequestError, InvokeBadRequestError,
@ -73,8 +65,8 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:param user: unique user id :param user: unique user id
:return: full response or stream response chunk generator result :return: full response or stream response chunk generator result
""" """
# TODO: consolidate different invocation methods for models based on base model capabilities
# invoke anthropic models via anthropic official SDK # invoke anthropic models via boto3 client
if "anthropic" in model: if "anthropic" in model:
return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user) return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user)
# invoke Cohere models via boto3 client # invoke Cohere models via boto3 client
@ -171,48 +163,34 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:param stream: is stream response :param stream: is stream response
:return: full response or stream response chunk generator result :return: full response or stream response chunk generator result
""" """
# use Anthropic official SDK references bedrock_client = boto3.client(service_name='bedrock-runtime',
# - https://docs.anthropic.com/claude/reference/claude-on-amazon-bedrock aws_access_key_id=credentials.get("aws_access_key_id"),
# - https://github.com/anthropics/anthropic-sdk-python aws_secret_access_key=credentials.get("aws_secret_access_key"),
client = AnthropicBedrock( region_name=credentials["aws_region"])
aws_access_key=credentials.get("aws_access_key_id"),
aws_secret_key=credentials.get("aws_secret_access_key"),
aws_region=credentials["aws_region"],
)
extra_model_kwargs = {} system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages)
if stop: inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop)
extra_model_kwargs['stop_sequences'] = stop
# Notice: If you request the current version of the SDK to the bedrock server,
# you will get the following error message and you need to wait for the service or SDK to be updated.
# Response: Error code: 400
# {'message': 'Malformed input request: #: subject must not be valid against schema
# {"required":["messages"]}#: extraneous key [metadata] is not permitted, please reformat your input and try again.'}
# TODO: Open in the future when the interface is properly supported
# if user:
# ref: https://github.com/anthropics/anthropic-sdk-python/blob/e84645b07ca5267066700a104b4d8d6a8da1383d/src/anthropic/resources/messages.py#L465
# extra_model_kwargs['metadata'] = message_create_params.Metadata(user_id=user)
system, prompt_message_dicts = self._convert_claude_prompt_messages(prompt_messages)
if system:
extra_model_kwargs['system'] = system
response = client.messages.create(
model=model,
messages=prompt_message_dicts,
stream=stream,
**model_parameters,
**extra_model_kwargs
)
if stream: if stream:
return self._handle_claude_stream_response(model, credentials, response, prompt_messages) response = bedrock_client.converse_stream(
modelId=model,
messages=prompt_message_dicts,
system=system,
inferenceConfig=inference_config,
additionalModelRequestFields=additional_model_fields
)
return self._handle_converse_stream_response(model, credentials, response, prompt_messages)
else:
response = bedrock_client.converse(
modelId=model,
messages=prompt_message_dicts,
system=system,
inferenceConfig=inference_config,
additionalModelRequestFields=additional_model_fields
)
return self._handle_converse_response(model, credentials, response, prompt_messages)
return self._handle_claude_response(model, credentials, response, prompt_messages) def _handle_converse_response(self, model: str, credentials: dict, response: dict,
def _handle_claude_response(self, model: str, credentials: dict, response: Message,
prompt_messages: list[PromptMessage]) -> LLMResult: prompt_messages: list[PromptMessage]) -> LLMResult:
""" """
Handle llm chat response Handle llm chat response
@ -223,17 +201,16 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:param prompt_messages: prompt messages :param prompt_messages: prompt messages
:return: full response chunk generator result :return: full response chunk generator result
""" """
# transform assistant message to prompt message # transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage( assistant_prompt_message = AssistantPromptMessage(
content=response.content[0].text content=response['output']['message']['content'][0]['text']
) )
# calculate num tokens # calculate num tokens
if response.usage: if response['usage']:
# transform usage # transform usage
prompt_tokens = response.usage.input_tokens prompt_tokens = response['usage']['inputTokens']
completion_tokens = response.usage.output_tokens completion_tokens = response['usage']['outputTokens']
else: else:
# calculate num tokens # calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
@ -242,17 +219,15 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
# transform usage # transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
# transform response result = LLMResult(
response = LLMResult( model=model,
model=response.model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
message=assistant_prompt_message, message=assistant_prompt_message,
usage=usage usage=usage,
) )
return result
return response def _handle_converse_stream_response(self, model: str, credentials: dict, response: dict,
def _handle_claude_stream_response(self, model: str, credentials: dict, response: Stream[MessageStreamEvent],
prompt_messages: list[PromptMessage], ) -> Generator: prompt_messages: list[PromptMessage], ) -> Generator:
""" """
Handle llm chat stream response Handle llm chat stream response
@ -272,14 +247,14 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
finish_reason = None finish_reason = None
index = 0 index = 0
for chunk in response: for chunk in response['stream']:
if isinstance(chunk, MessageStartEvent): if 'messageStart' in chunk:
return_model = chunk.message.model return_model = model
input_tokens = chunk.message.usage.input_tokens elif 'messageStop' in chunk:
elif isinstance(chunk, MessageDeltaEvent): finish_reason = chunk['messageStop']['stopReason']
output_tokens = chunk.usage.output_tokens elif 'metadata' in chunk:
finish_reason = chunk.delta.stop_reason input_tokens = chunk['metadata']['usage']['inputTokens']
elif isinstance(chunk, MessageStopEvent): output_tokens = chunk['metadata']['usage']['outputTokens']
usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens) usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens)
yield LLMResultChunk( yield LLMResultChunk(
model=return_model, model=return_model,
@ -293,13 +268,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
usage=usage usage=usage
) )
) )
elif isinstance(chunk, ContentBlockDeltaEvent): elif 'contentBlockDelta' in chunk:
chunk_text = chunk.delta.text if chunk.delta.text else '' chunk_text = chunk['contentBlockDelta']['delta']['text'] if chunk['contentBlockDelta']['delta']['text'] else ''
full_assistant_content += chunk_text full_assistant_content += chunk_text
assistant_prompt_message = AssistantPromptMessage( assistant_prompt_message = AssistantPromptMessage(
content=chunk_text if chunk_text else '', content=chunk_text if chunk_text else '',
) )
index = chunk.index index = chunk['contentBlockDelta']['contentBlockIndex']
yield LLMResultChunk( yield LLMResultChunk(
model=model, model=model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
@ -311,56 +286,32 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
except Exception as ex: except Exception as ex:
raise InvokeError(str(ex)) raise InvokeError(str(ex))
def _calc_claude_response_usage(self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int) -> LLMUsage: def _convert_converse_api_model_parameters(self, model_parameters: dict, stop: Optional[list[str]] = None) -> tuple[dict, dict]:
""" inference_config = {}
Calculate response usage additional_model_fields = {}
if 'max_tokens' in model_parameters:
inference_config['maxTokens'] = model_parameters['max_tokens']
:param model: model name if 'temperature' in model_parameters:
:param credentials: model credentials inference_config['temperature'] = model_parameters['temperature']
:param prompt_tokens: prompt tokens
:param completion_tokens: completion tokens
:return: usage
"""
# get prompt price info
prompt_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=prompt_tokens,
)
# get completion price info if 'top_p' in model_parameters:
completion_price_info = self.get_price( inference_config['topP'] = model_parameters['temperature']
model=model,
credentials=credentials,
price_type=PriceType.OUTPUT,
tokens=completion_tokens
)
# transform usage if stop:
usage = LLMUsage( inference_config['stopSequences'] = stop
prompt_tokens=prompt_tokens,
prompt_unit_price=prompt_price_info.unit_price,
prompt_price_unit=prompt_price_info.unit,
prompt_price=prompt_price_info.total_amount,
completion_tokens=completion_tokens,
completion_unit_price=completion_price_info.unit_price,
completion_price_unit=completion_price_info.unit,
completion_price=completion_price_info.total_amount,
total_tokens=prompt_tokens + completion_tokens,
total_price=prompt_price_info.total_amount + completion_price_info.total_amount,
currency=prompt_price_info.currency,
latency=time.perf_counter() - self.started_at
)
return usage if 'top_k' in model_parameters:
additional_model_fields['top_k'] = model_parameters['top_k']
def _convert_claude_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]: return inference_config, additional_model_fields
def _convert_converse_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]:
""" """
Convert prompt messages to dict list and system Convert prompt messages to dict list and system
""" """
system = "" system = []
first_loop = True first_loop = True
for message in prompt_messages: for message in prompt_messages:
if isinstance(message, SystemPromptMessage): if isinstance(message, SystemPromptMessage):
@ -375,25 +326,24 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
prompt_message_dicts = [] prompt_message_dicts = []
for message in prompt_messages: for message in prompt_messages:
if not isinstance(message, SystemPromptMessage): if not isinstance(message, SystemPromptMessage):
prompt_message_dicts.append(self._convert_claude_prompt_message_to_dict(message)) prompt_message_dicts.append(self._convert_prompt_message_to_dict(message))
return system, prompt_message_dicts return system, prompt_message_dicts
def _convert_claude_prompt_message_to_dict(self, message: PromptMessage) -> dict: def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
""" """
Convert PromptMessage to dict Convert PromptMessage to dict
""" """
if isinstance(message, UserPromptMessage): if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message) message = cast(UserPromptMessage, message)
if isinstance(message.content, str): if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content} message_dict = {"role": "user", "content": [{'text': message.content}]}
else: else:
sub_messages = [] sub_messages = []
for message_content in message.content: for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT: if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content) message_content = cast(TextPromptMessageContent, message_content)
sub_message_dict = { sub_message_dict = {
"type": "text",
"text": message_content.data "text": message_content.data
} }
sub_messages.append(sub_message_dict) sub_messages.append(sub_message_dict)
@ -404,24 +354,24 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
try: try:
image_content = requests.get(message_content.data).content image_content = requests.get(message_content.data).content
mime_type, _ = mimetypes.guess_type(message_content.data) mime_type, _ = mimetypes.guess_type(message_content.data)
base64_data = base64.b64encode(image_content).decode('utf-8')
except Exception as ex: except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
else: else:
data_split = message_content.data.split(";base64,") data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "") mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1] base64_data = data_split[1]
image_content = base64.b64decode(base64_data)
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
raise ValueError(f"Unsupported image type {mime_type}, " raise ValueError(f"Unsupported image type {mime_type}, "
f"only support image/jpeg, image/png, image/gif, and image/webp") f"only support image/jpeg, image/png, image/gif, and image/webp")
sub_message_dict = { sub_message_dict = {
"type": "image", "image": {
"source": { "format": mime_type.replace('image/', ''),
"type": "base64", "source": {
"media_type": mime_type, "bytes": image_content
"data": base64_data }
} }
} }
sub_messages.append(sub_message_dict) sub_messages.append(sub_message_dict)
@ -429,10 +379,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
message_dict = {"role": "user", "content": sub_messages} message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage): elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message) message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content} message_dict = {"role": "assistant", "content": [{'text': message.content}]}
elif isinstance(message, SystemPromptMessage): elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message) message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content} message_dict = [{'text': message.content}]
else: else:
raise ValueError(f"Got unknown type {message}") raise ValueError(f"Got unknown type {message}")

38
api/poetry.lock generated
View File

@ -534,41 +534,41 @@ files = [
[[package]] [[package]]
name = "boto3" name = "boto3"
version = "1.28.17" version = "1.34.136"
description = "The AWS SDK for Python" description = "The AWS SDK for Python"
optional = false optional = false
python-versions = ">= 3.7" python-versions = ">=3.8"
files = [ files = [
{file = "boto3-1.28.17-py3-none-any.whl", hash = "sha256:bca0526f819e0f19c0f1e6eba3e2d1d6b6a92a45129f98c0d716e5aab6d9444b"}, {file = "boto3-1.34.136-py3-none-any.whl", hash = "sha256:d41037e2c680ab8d6c61a0a4ee6bf1fdd9e857f43996672830a95d62d6f6fa79"},
{file = "boto3-1.28.17.tar.gz", hash = "sha256:90f7cfb5e1821af95b1fc084bc50e6c47fa3edc99f32de1a2591faa0c546bea7"}, {file = "boto3-1.34.136.tar.gz", hash = "sha256:0314e6598f59ee0f34eb4e6d1a0f69fa65c146d2b88a6e837a527a9956ec2731"},
] ]
[package.dependencies] [package.dependencies]
botocore = ">=1.31.17,<1.32.0" botocore = ">=1.34.136,<1.35.0"
jmespath = ">=0.7.1,<2.0.0" jmespath = ">=0.7.1,<2.0.0"
s3transfer = ">=0.6.0,<0.7.0" s3transfer = ">=0.10.0,<0.11.0"
[package.extras] [package.extras]
crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
[[package]] [[package]]
name = "botocore" name = "botocore"
version = "1.31.85" version = "1.34.136"
description = "Low-level, data-driven core of boto 3." description = "Low-level, data-driven core of boto 3."
optional = false optional = false
python-versions = ">= 3.7" python-versions = ">=3.8"
files = [ files = [
{file = "botocore-1.31.85-py3-none-any.whl", hash = "sha256:b8f35d65f2b45af50c36fc25cc1844d6bd61d38d2148b2ef133b8f10e198555d"}, {file = "botocore-1.34.136-py3-none-any.whl", hash = "sha256:c63fe9032091fb9e9477706a3ebfa4d0c109b807907051d892ed574f9b573e61"},
{file = "botocore-1.31.85.tar.gz", hash = "sha256:ce58e688222df73ec5691f934be1a2122a52c9d11d3037b586b3fff16ed6d25f"}, {file = "botocore-1.34.136.tar.gz", hash = "sha256:7f7135178692b39143c8f152a618d2a3b71065a317569a7102d2306d4946f42f"},
] ]
[package.dependencies] [package.dependencies]
jmespath = ">=0.7.1,<2.0.0" jmespath = ">=0.7.1,<2.0.0"
python-dateutil = ">=2.1,<3.0.0" python-dateutil = ">=2.1,<3.0.0"
urllib3 = {version = ">=1.25.4,<2.1", markers = "python_version >= \"3.10\""} urllib3 = {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""}
[package.extras] [package.extras]
crt = ["awscrt (==0.19.12)"] crt = ["awscrt (==0.20.11)"]
[[package]] [[package]]
name = "bottleneck" name = "bottleneck"
@ -7032,20 +7032,20 @@ files = [
[[package]] [[package]]
name = "s3transfer" name = "s3transfer"
version = "0.6.2" version = "0.10.2"
description = "An Amazon S3 Transfer Manager" description = "An Amazon S3 Transfer Manager"
optional = false optional = false
python-versions = ">= 3.7" python-versions = ">=3.8"
files = [ files = [
{file = "s3transfer-0.6.2-py3-none-any.whl", hash = "sha256:b014be3a8a2aab98cfe1abc7229cc5a9a0cf05eb9c1f2b86b230fd8df3f78084"}, {file = "s3transfer-0.10.2-py3-none-any.whl", hash = "sha256:eca1c20de70a39daee580aef4986996620f365c4e0fda6a86100231d62f1bf69"},
{file = "s3transfer-0.6.2.tar.gz", hash = "sha256:cab66d3380cca3e70939ef2255d01cd8aece6a4907a9528740f668c4b0611861"}, {file = "s3transfer-0.10.2.tar.gz", hash = "sha256:0711534e9356d3cc692fdde846b4a1e4b0cb6519971860796e6bc4c7aea00ef6"},
] ]
[package.dependencies] [package.dependencies]
botocore = ">=1.12.36,<2.0a.0" botocore = ">=1.33.2,<2.0a.0"
[package.extras] [package.extras]
crt = ["botocore[crt] (>=1.20.29,<2.0a.0)"] crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"]
[[package]] [[package]]
name = "safetensors" name = "safetensors"
@ -9095,4 +9095,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "d40bed69caecf3a2bcd5ec054288d7cb36a9a231fff210d4f1a42745dd3bf604" content-hash = "90f0e77567fbe5100d15bf2bc9472007aafc53c2fd594b6a90dd8455dea58582"

View File

@ -107,7 +107,7 @@ authlib = "1.3.1"
azure-identity = "1.16.1" azure-identity = "1.16.1"
azure-storage-blob = "12.13.0" azure-storage-blob = "12.13.0"
beautifulsoup4 = "4.12.2" beautifulsoup4 = "4.12.2"
boto3 = "1.28.17" boto3 = "1.34.136"
bs4 = "~0.0.1" bs4 = "~0.0.1"
cachetools = "~5.3.0" cachetools = "~5.3.0"
celery = "~5.3.6" celery = "~5.3.6"