[seanguo] modify bedrock Claude3 invoke method to converse API (#5768)

Co-authored-by: Chenhe Gu <guchenhe@gmail.com>
This commit is contained in:
longzhihun 2024-07-01 04:36:13 +08:00 committed by GitHub
parent a27462d58b
commit fdfbbde10d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 101 additions and 151 deletions

View File

@ -1,22 +1,14 @@
# standard import
import base64
import json
import logging
import mimetypes
import time
from collections.abc import Generator
from typing import Optional, Union, cast
# 3rd import
import boto3
import requests
from anthropic import AnthropicBedrock, Stream
from anthropic.types import (
ContentBlockDeltaEvent,
Message,
MessageDeltaEvent,
MessageStartEvent,
MessageStopEvent,
MessageStreamEvent,
)
from botocore.config import Config
from botocore.exceptions import (
ClientError,
@ -27,7 +19,8 @@ from botocore.exceptions import (
)
from cohere import ChatMessage
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
# local import
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
@ -38,7 +31,6 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
@ -73,8 +65,8 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:param user: unique user id
:return: full response or stream response chunk generator result
"""
# invoke anthropic models via anthropic official SDK
# TODO: consolidate different invocation methods for models based on base model capabilities
# invoke anthropic models via boto3 client
if "anthropic" in model:
return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user)
# invoke Cohere models via boto3 client
@ -171,48 +163,34 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:param stream: is stream response
:return: full response or stream response chunk generator result
"""
# use Anthropic official SDK references
# - https://docs.anthropic.com/claude/reference/claude-on-amazon-bedrock
# - https://github.com/anthropics/anthropic-sdk-python
client = AnthropicBedrock(
aws_access_key=credentials.get("aws_access_key_id"),
aws_secret_key=credentials.get("aws_secret_access_key"),
aws_region=credentials["aws_region"],
)
bedrock_client = boto3.client(service_name='bedrock-runtime',
aws_access_key_id=credentials.get("aws_access_key_id"),
aws_secret_access_key=credentials.get("aws_secret_access_key"),
region_name=credentials["aws_region"])
extra_model_kwargs = {}
if stop:
extra_model_kwargs['stop_sequences'] = stop
# Notice: If you request the current version of the SDK to the bedrock server,
# you will get the following error message and you need to wait for the service or SDK to be updated.
# Response: Error code: 400
# {'message': 'Malformed input request: #: subject must not be valid against schema
# {"required":["messages"]}#: extraneous key [metadata] is not permitted, please reformat your input and try again.'}
# TODO: Open in the future when the interface is properly supported
# if user:
# ref: https://github.com/anthropics/anthropic-sdk-python/blob/e84645b07ca5267066700a104b4d8d6a8da1383d/src/anthropic/resources/messages.py#L465
# extra_model_kwargs['metadata'] = message_create_params.Metadata(user_id=user)
system, prompt_message_dicts = self._convert_claude_prompt_messages(prompt_messages)
if system:
extra_model_kwargs['system'] = system
response = client.messages.create(
model=model,
messages=prompt_message_dicts,
stream=stream,
**model_parameters,
**extra_model_kwargs
)
system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages)
inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop)
if stream:
return self._handle_claude_stream_response(model, credentials, response, prompt_messages)
response = bedrock_client.converse_stream(
modelId=model,
messages=prompt_message_dicts,
system=system,
inferenceConfig=inference_config,
additionalModelRequestFields=additional_model_fields
)
return self._handle_converse_stream_response(model, credentials, response, prompt_messages)
else:
response = bedrock_client.converse(
modelId=model,
messages=prompt_message_dicts,
system=system,
inferenceConfig=inference_config,
additionalModelRequestFields=additional_model_fields
)
return self._handle_converse_response(model, credentials, response, prompt_messages)
return self._handle_claude_response(model, credentials, response, prompt_messages)
def _handle_claude_response(self, model: str, credentials: dict, response: Message,
def _handle_converse_response(self, model: str, credentials: dict, response: dict,
prompt_messages: list[PromptMessage]) -> LLMResult:
"""
Handle llm chat response
@ -223,17 +201,16 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:param prompt_messages: prompt messages
:return: full response chunk generator result
"""
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=response.content[0].text
content=response['output']['message']['content'][0]['text']
)
# calculate num tokens
if response.usage:
if response['usage']:
# transform usage
prompt_tokens = response.usage.input_tokens
completion_tokens = response.usage.output_tokens
prompt_tokens = response['usage']['inputTokens']
completion_tokens = response['usage']['outputTokens']
else:
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
@ -242,17 +219,15 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
# transform response
response = LLMResult(
model=response.model,
result = LLMResult(
model=model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage
usage=usage,
)
return result
return response
def _handle_claude_stream_response(self, model: str, credentials: dict, response: Stream[MessageStreamEvent],
def _handle_converse_stream_response(self, model: str, credentials: dict, response: dict,
prompt_messages: list[PromptMessage], ) -> Generator:
"""
Handle llm chat stream response
@ -272,14 +247,14 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
finish_reason = None
index = 0
for chunk in response:
if isinstance(chunk, MessageStartEvent):
return_model = chunk.message.model
input_tokens = chunk.message.usage.input_tokens
elif isinstance(chunk, MessageDeltaEvent):
output_tokens = chunk.usage.output_tokens
finish_reason = chunk.delta.stop_reason
elif isinstance(chunk, MessageStopEvent):
for chunk in response['stream']:
if 'messageStart' in chunk:
return_model = model
elif 'messageStop' in chunk:
finish_reason = chunk['messageStop']['stopReason']
elif 'metadata' in chunk:
input_tokens = chunk['metadata']['usage']['inputTokens']
output_tokens = chunk['metadata']['usage']['outputTokens']
usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens)
yield LLMResultChunk(
model=return_model,
@ -293,13 +268,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
usage=usage
)
)
elif isinstance(chunk, ContentBlockDeltaEvent):
chunk_text = chunk.delta.text if chunk.delta.text else ''
elif 'contentBlockDelta' in chunk:
chunk_text = chunk['contentBlockDelta']['delta']['text'] if chunk['contentBlockDelta']['delta']['text'] else ''
full_assistant_content += chunk_text
assistant_prompt_message = AssistantPromptMessage(
content=chunk_text if chunk_text else '',
)
index = chunk.index
index = chunk['contentBlockDelta']['contentBlockIndex']
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
@ -311,56 +286,32 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
except Exception as ex:
raise InvokeError(str(ex))
def _calc_claude_response_usage(self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int) -> LLMUsage:
"""
Calculate response usage
def _convert_converse_api_model_parameters(self, model_parameters: dict, stop: Optional[list[str]] = None) -> tuple[dict, dict]:
inference_config = {}
additional_model_fields = {}
if 'max_tokens' in model_parameters:
inference_config['maxTokens'] = model_parameters['max_tokens']
:param model: model name
:param credentials: model credentials
:param prompt_tokens: prompt tokens
:param completion_tokens: completion tokens
:return: usage
"""
# get prompt price info
prompt_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=prompt_tokens,
)
if 'temperature' in model_parameters:
inference_config['temperature'] = model_parameters['temperature']
# get completion price info
completion_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.OUTPUT,
tokens=completion_tokens
)
if 'top_p' in model_parameters:
inference_config['topP'] = model_parameters['temperature']
# transform usage
usage = LLMUsage(
prompt_tokens=prompt_tokens,
prompt_unit_price=prompt_price_info.unit_price,
prompt_price_unit=prompt_price_info.unit,
prompt_price=prompt_price_info.total_amount,
completion_tokens=completion_tokens,
completion_unit_price=completion_price_info.unit_price,
completion_price_unit=completion_price_info.unit,
completion_price=completion_price_info.total_amount,
total_tokens=prompt_tokens + completion_tokens,
total_price=prompt_price_info.total_amount + completion_price_info.total_amount,
currency=prompt_price_info.currency,
latency=time.perf_counter() - self.started_at
)
if stop:
inference_config['stopSequences'] = stop
return usage
if 'top_k' in model_parameters:
additional_model_fields['top_k'] = model_parameters['top_k']
def _convert_claude_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]:
return inference_config, additional_model_fields
def _convert_converse_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]:
"""
Convert prompt messages to dict list and system
"""
system = ""
system = []
first_loop = True
for message in prompt_messages:
if isinstance(message, SystemPromptMessage):
@ -375,25 +326,24 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
prompt_message_dicts = []
for message in prompt_messages:
if not isinstance(message, SystemPromptMessage):
prompt_message_dicts.append(self._convert_claude_prompt_message_to_dict(message))
prompt_message_dicts.append(self._convert_prompt_message_to_dict(message))
return system, prompt_message_dicts
def _convert_claude_prompt_message_to_dict(self, message: PromptMessage) -> dict:
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
"""
Convert PromptMessage to dict
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
message_dict = {"role": "user", "content": [{'text': message.content}]}
else:
sub_messages = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content)
sub_message_dict = {
"type": "text",
"text": message_content.data
}
sub_messages.append(sub_message_dict)
@ -404,24 +354,24 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
try:
image_content = requests.get(message_content.data).content
mime_type, _ = mimetypes.guess_type(message_content.data)
base64_data = base64.b64encode(image_content).decode('utf-8')
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
else:
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
image_content = base64.b64decode(base64_data)
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
raise ValueError(f"Unsupported image type {mime_type}, "
f"only support image/jpeg, image/png, image/gif, and image/webp")
sub_message_dict = {
"type": "image",
"image": {
"format": mime_type.replace('image/', ''),
"source": {
"type": "base64",
"media_type": mime_type,
"data": base64_data
"bytes": image_content
}
}
}
sub_messages.append(sub_message_dict)
@ -429,10 +379,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
message_dict = {"role": "assistant", "content": [{'text': message.content}]}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
message_dict = [{'text': message.content}]
else:
raise ValueError(f"Got unknown type {message}")

38
api/poetry.lock generated
View File

@ -534,41 +534,41 @@ files = [
[[package]]
name = "boto3"
version = "1.28.17"
version = "1.34.136"
description = "The AWS SDK for Python"
optional = false
python-versions = ">= 3.7"
python-versions = ">=3.8"
files = [
{file = "boto3-1.28.17-py3-none-any.whl", hash = "sha256:bca0526f819e0f19c0f1e6eba3e2d1d6b6a92a45129f98c0d716e5aab6d9444b"},
{file = "boto3-1.28.17.tar.gz", hash = "sha256:90f7cfb5e1821af95b1fc084bc50e6c47fa3edc99f32de1a2591faa0c546bea7"},
{file = "boto3-1.34.136-py3-none-any.whl", hash = "sha256:d41037e2c680ab8d6c61a0a4ee6bf1fdd9e857f43996672830a95d62d6f6fa79"},
{file = "boto3-1.34.136.tar.gz", hash = "sha256:0314e6598f59ee0f34eb4e6d1a0f69fa65c146d2b88a6e837a527a9956ec2731"},
]
[package.dependencies]
botocore = ">=1.31.17,<1.32.0"
botocore = ">=1.34.136,<1.35.0"
jmespath = ">=0.7.1,<2.0.0"
s3transfer = ">=0.6.0,<0.7.0"
s3transfer = ">=0.10.0,<0.11.0"
[package.extras]
crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
[[package]]
name = "botocore"
version = "1.31.85"
version = "1.34.136"
description = "Low-level, data-driven core of boto 3."
optional = false
python-versions = ">= 3.7"
python-versions = ">=3.8"
files = [
{file = "botocore-1.31.85-py3-none-any.whl", hash = "sha256:b8f35d65f2b45af50c36fc25cc1844d6bd61d38d2148b2ef133b8f10e198555d"},
{file = "botocore-1.31.85.tar.gz", hash = "sha256:ce58e688222df73ec5691f934be1a2122a52c9d11d3037b586b3fff16ed6d25f"},
{file = "botocore-1.34.136-py3-none-any.whl", hash = "sha256:c63fe9032091fb9e9477706a3ebfa4d0c109b807907051d892ed574f9b573e61"},
{file = "botocore-1.34.136.tar.gz", hash = "sha256:7f7135178692b39143c8f152a618d2a3b71065a317569a7102d2306d4946f42f"},
]
[package.dependencies]
jmespath = ">=0.7.1,<2.0.0"
python-dateutil = ">=2.1,<3.0.0"
urllib3 = {version = ">=1.25.4,<2.1", markers = "python_version >= \"3.10\""}
urllib3 = {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""}
[package.extras]
crt = ["awscrt (==0.19.12)"]
crt = ["awscrt (==0.20.11)"]
[[package]]
name = "bottleneck"
@ -7032,20 +7032,20 @@ files = [
[[package]]
name = "s3transfer"
version = "0.6.2"
version = "0.10.2"
description = "An Amazon S3 Transfer Manager"
optional = false
python-versions = ">= 3.7"
python-versions = ">=3.8"
files = [
{file = "s3transfer-0.6.2-py3-none-any.whl", hash = "sha256:b014be3a8a2aab98cfe1abc7229cc5a9a0cf05eb9c1f2b86b230fd8df3f78084"},
{file = "s3transfer-0.6.2.tar.gz", hash = "sha256:cab66d3380cca3e70939ef2255d01cd8aece6a4907a9528740f668c4b0611861"},
{file = "s3transfer-0.10.2-py3-none-any.whl", hash = "sha256:eca1c20de70a39daee580aef4986996620f365c4e0fda6a86100231d62f1bf69"},
{file = "s3transfer-0.10.2.tar.gz", hash = "sha256:0711534e9356d3cc692fdde846b4a1e4b0cb6519971860796e6bc4c7aea00ef6"},
]
[package.dependencies]
botocore = ">=1.12.36,<2.0a.0"
botocore = ">=1.33.2,<2.0a.0"
[package.extras]
crt = ["botocore[crt] (>=1.20.29,<2.0a.0)"]
crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"]
[[package]]
name = "safetensors"
@ -9095,4 +9095,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "d40bed69caecf3a2bcd5ec054288d7cb36a9a231fff210d4f1a42745dd3bf604"
content-hash = "90f0e77567fbe5100d15bf2bc9472007aafc53c2fd594b6a90dd8455dea58582"

View File

@ -107,7 +107,7 @@ authlib = "1.3.1"
azure-identity = "1.16.1"
azure-storage-blob = "12.13.0"
beautifulsoup4 = "4.12.2"
boto3 = "1.28.17"
boto3 = "1.34.136"
bs4 = "~0.0.1"
cachetools = "~5.3.0"
celery = "~5.3.6"