diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-haiku-20240307.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-haiku-20240307.yaml index 073d0c3a7d..0dedb2ef38 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-haiku-20240307.yaml +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-haiku-20240307.yaml @@ -5,6 +5,7 @@ model_type: llm features: - agent-thought - vision + - tool-call model_properties: mode: chat context_size: 200000 diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-opus-20240229.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-opus-20240229.yaml index ab3e92a059..60e56452eb 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-opus-20240229.yaml +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-opus-20240229.yaml @@ -5,6 +5,7 @@ model_type: llm features: - agent-thought - vision + - tool-call model_properties: mode: chat context_size: 200000 diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-sonnet-20240229.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-sonnet-20240229.yaml index 65cdab9bc6..08c8375d45 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-sonnet-20240229.yaml +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-sonnet-20240229.yaml @@ -5,6 +5,7 @@ model_type: llm features: - agent-thought - vision + - tool-call model_properties: mode: chat context_size: 200000 diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index 724a0401b7..0f87455f4f 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -1,4 +1,5 @@ import base64 +import json import mimetypes from collections.abc import Generator from typing import Optional, Union, cast @@ -15,6 +16,7 @@ from anthropic.types import ( MessageStreamEvent, completion_create_params, ) +from anthropic.types.beta.tools import ToolsBetaMessage from httpx import Timeout from core.model_runtime.callbacks.base_callback import Callback @@ -27,6 +29,7 @@ from core.model_runtime.entities.message_entities import ( PromptMessageTool, SystemPromptMessage, TextPromptMessageContent, + ToolPromptMessage, UserPromptMessage, ) from core.model_runtime.errors.invoke import ( @@ -70,10 +73,11 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ # invoke model - return self._chat_generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) + return self._chat_generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke llm chat model @@ -109,14 +113,26 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): if system: extra_model_kwargs['system'] = system - # chat model - response = client.messages.create( - model=model, - messages=prompt_message_dicts, - stream=stream, - **model_parameters, - **extra_model_kwargs - ) + if tools: + extra_model_kwargs['tools'] = [ + self._transform_tool_prompt(tool) for tool in tools + ] + response = client.beta.tools.messages.create( + model=model, + messages=prompt_message_dicts, + stream=stream, + **model_parameters, + **extra_model_kwargs + ) + else: + # chat model + response = client.messages.create( + model=model, + messages=prompt_message_dicts, + stream=stream, + **model_parameters, + **extra_model_kwargs + ) if stream: return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages) @@ -148,6 +164,13 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + def _transform_tool_prompt(self, tool: PromptMessageTool) -> dict: + return { + 'name': tool.name, + 'description': tool.description, + 'input_schema': tool.parameters + } + def _transform_chat_json_prompts(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, @@ -193,7 +216,18 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): prompt = self._convert_messages_to_prompt_anthropic(prompt_messages) client = Anthropic(api_key="") - return client.count_tokens(prompt) + tokens = client.count_tokens(prompt) + + tool_call_inner_prompts_tokens_map = { + 'claude-3-opus-20240229': 395, + 'claude-3-haiku-20240307': 264, + 'claude-3-sonnet-20240229': 159 + } + + if model in tool_call_inner_prompts_tokens_map and tools: + tokens += tool_call_inner_prompts_tokens_map[model] + + return tokens def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -219,7 +253,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _handle_chat_generate_response(self, model: str, credentials: dict, response: Message, + def _handle_chat_generate_response(self, model: str, credentials: dict, response: Union[Message, ToolsBetaMessage], prompt_messages: list[PromptMessage]) -> LLMResult: """ Handle llm chat response @@ -232,9 +266,24 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): """ # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=response.content[0].text + content='', + tool_calls=[] ) + for content in response.content: + if content.type == 'text': + assistant_prompt_message.content += content.text + elif content.type == 'tool_use': + tool_call = AssistantPromptMessage.ToolCall( + id=content.id, + type='function', + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=content.name, + arguments=json.dumps(content.input) + ) + ) + assistant_prompt_message.tool_calls.append(tool_call) + # calculate num tokens if response.usage: # transform usage @@ -356,69 +405,90 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): prompt_message_dicts = [] for message in prompt_messages: if not isinstance(message, SystemPromptMessage): - prompt_message_dicts.append(self._convert_prompt_message_to_dict(message)) + if isinstance(message, UserPromptMessage): + message = cast(UserPromptMessage, message) + if isinstance(message.content, str): + message_dict = {"role": "user", "content": message.content} + prompt_message_dicts.append(message_dict) + else: + sub_messages = [] + for message_content in message.content: + if message_content.type == PromptMessageContentType.TEXT: + message_content = cast(TextPromptMessageContent, message_content) + sub_message_dict = { + "type": "text", + "text": message_content.data + } + sub_messages.append(sub_message_dict) + elif message_content.type == PromptMessageContentType.IMAGE: + message_content = cast(ImagePromptMessageContent, message_content) + if not message_content.data.startswith("data:"): + # fetch image data from url + try: + image_content = requests.get(message_content.data).content + mime_type, _ = mimetypes.guess_type(message_content.data) + base64_data = base64.b64encode(image_content).decode('utf-8') + except Exception as ex: + raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") + else: + data_split = message_content.data.split(";base64,") + mime_type = data_split[0].replace("data:", "") + base64_data = data_split[1] + + if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: + raise ValueError(f"Unsupported image type {mime_type}, " + f"only support image/jpeg, image/png, image/gif, and image/webp") + + sub_message_dict = { + "type": "image", + "source": { + "type": "base64", + "media_type": mime_type, + "data": base64_data + } + } + sub_messages.append(sub_message_dict) + prompt_message_dicts.append({"role": "user", "content": sub_messages}) + elif isinstance(message, AssistantPromptMessage): + message = cast(AssistantPromptMessage, message) + content = [] + if message.tool_calls: + for tool_call in message.tool_calls: + content.append({ + "type": "tool_use", + "id": tool_call.id, + "name": tool_call.function.name, + "input": json.loads(tool_call.function.arguments) + }) + if message.content: + content.append({ + "type": "text", + "text": message.content + }) + + if prompt_message_dicts[-1]["role"] == "assistant": + prompt_message_dicts[-1]["content"].extend(content) + else: + prompt_message_dicts.append({ + "role": "assistant", + "content": content + }) + elif isinstance(message, ToolPromptMessage): + message = cast(ToolPromptMessage, message) + message_dict = { + "role": "user", + "content": [{ + "type": "tool_result", + "tool_use_id": message.tool_call_id, + "content": message.content + }] + } + prompt_message_dicts.append(message_dict) + else: + raise ValueError(f"Got unknown type {message}") return system, prompt_message_dicts - def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: - """ - Convert PromptMessage to dict - """ - if isinstance(message, UserPromptMessage): - message = cast(UserPromptMessage, message) - if isinstance(message.content, str): - message_dict = {"role": "user", "content": message.content} - else: - sub_messages = [] - for message_content in message.content: - if message_content.type == PromptMessageContentType.TEXT: - message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } - sub_messages.append(sub_message_dict) - elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast(ImagePromptMessageContent, message_content) - if not message_content.data.startswith("data:"): - # fetch image data from url - try: - image_content = requests.get(message_content.data).content - mime_type, _ = mimetypes.guess_type(message_content.data) - base64_data = base64.b64encode(image_content).decode('utf-8') - except Exception as ex: - raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") - else: - data_split = message_content.data.split(";base64,") - mime_type = data_split[0].replace("data:", "") - base64_data = data_split[1] - - if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: - raise ValueError(f"Unsupported image type {mime_type}, " - f"only support image/jpeg, image/png, image/gif, and image/webp") - - sub_message_dict = { - "type": "image", - "source": { - "type": "base64", - "media_type": mime_type, - "data": base64_data - } - } - sub_messages.append(sub_message_dict) - - message_dict = {"role": "user", "content": sub_messages} - elif isinstance(message, AssistantPromptMessage): - message = cast(AssistantPromptMessage, message) - message_dict = {"role": "assistant", "content": message.content} - elif isinstance(message, SystemPromptMessage): - message = cast(SystemPromptMessage, message) - message_dict = {"role": "system", "content": message.content} - else: - raise ValueError(f"Got unknown type {message}") - - return message_dict - def _convert_one_message_to_text(self, message: PromptMessage) -> str: """ Convert a single message to a string. @@ -453,6 +523,8 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): message_text += f"{ai_prompt} [IMAGE]" elif isinstance(message, SystemPromptMessage): message_text = content + elif isinstance(message, ToolPromptMessage): + message_text = f"{human_prompt} {message.content}" else: raise ValueError(f"Got unknown type {message}") diff --git a/api/requirements.txt b/api/requirements.txt index cbfde69125..dcf679adcf 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -36,7 +36,7 @@ python-docx~=1.1.0 pypdfium2==4.16.0 resend~=0.7.0 pyjwt~=2.8.0 -anthropic~=0.20.0 +anthropic~=0.23.1 newspaper3k==0.2.8 google-api-python-client==2.90.0 wikipedia==1.4.0