diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index 5745721ae8..b274cec35f 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -74,12 +74,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel): # invoke claude 3 models via anthropic official SDK if "anthropic.claude-3" in model: - return self._invoke_claude3(model, credentials, prompt_messages, model_parameters, stop, stream) + return self._invoke_claude3(model, credentials, prompt_messages, model_parameters, stop, stream, user) # invoke model return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) def _invoke_claude3(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[list[str]] = None, stream: bool = True) -> Union[LLMResult, Generator]: + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke Claude3 large language model @@ -100,22 +100,38 @@ class BedrockLargeLanguageModel(LargeLanguageModel): aws_region=credentials["aws_region"], ) + extra_model_kwargs = {} + if stop: + extra_model_kwargs['stop_sequences'] = stop + + # Notice: If you request the current version of the SDK to the bedrock server, + # you will get the following error message and you need to wait for the service or SDK to be updated. + # Response: Error code: 400 + # {'message': 'Malformed input request: #: subject must not be valid against schema + # {"required":["messages"]}#: extraneous key [metadata] is not permitted, please reformat your input and try again.'} + # TODO: Open in the future when the interface is properly supported + # if user: + # ref: https://github.com/anthropics/anthropic-sdk-python/blob/e84645b07ca5267066700a104b4d8d6a8da1383d/src/anthropic/resources/messages.py#L465 + # extra_model_kwargs['metadata'] = message_create_params.Metadata(user_id=user) + system, prompt_message_dicts = self._convert_claude3_prompt_messages(prompt_messages) + if system: + extra_model_kwargs['system'] = system + response = client.messages.create( model=model, messages=prompt_message_dicts, - stop_sequences=stop if stop else [], - system=system, stream=stream, **model_parameters, + **extra_model_kwargs ) - if stream is False: - return self._handle_claude3_response(model, credentials, response, prompt_messages) - else: + if stream: return self._handle_claude3_stream_response(model, credentials, response, prompt_messages) + return self._handle_claude3_response(model, credentials, response, prompt_messages) + def _handle_claude3_response(self, model: str, credentials: dict, response: Message, prompt_messages: list[PromptMessage]) -> LLMResult: """ @@ -263,13 +279,22 @@ class BedrockLargeLanguageModel(LargeLanguageModel): """ Convert prompt messages to dict list and system """ - system = "" - prompt_message_dicts = [] + system = "" + first_loop = True for message in prompt_messages: if isinstance(message, SystemPromptMessage): - system += message.content + ("\n" if not system else "") - else: + message.content=message.content.strip() + if first_loop: + system=message.content + first_loop=False + else: + system+="\n" + system+=message.content + + prompt_message_dicts = [] + for message in prompt_messages: + if not isinstance(message, SystemPromptMessage): prompt_message_dicts.append(self._convert_claude3_prompt_message_to_dict(message)) return system, prompt_message_dicts