feat: claude3 tool call (#3111)

This commit is contained in:
Yeuoly 2024-04-05 15:35:59 +08:00 committed by GitHub
parent 718ac3f83b
commit 25b9ac3df4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 149 additions and 74 deletions

View File

@ -5,6 +5,7 @@ model_type: llm
features:
- agent-thought
- vision
- tool-call
model_properties:
mode: chat
context_size: 200000

View File

@ -5,6 +5,7 @@ model_type: llm
features:
- agent-thought
- vision
- tool-call
model_properties:
mode: chat
context_size: 200000

View File

@ -5,6 +5,7 @@ model_type: llm
features:
- agent-thought
- vision
- tool-call
model_properties:
mode: chat
context_size: 200000

View File

@ -1,4 +1,5 @@
import base64
import json
import mimetypes
from collections.abc import Generator
from typing import Optional, Union, cast
@ -15,6 +16,7 @@ from anthropic.types import (
MessageStreamEvent,
completion_create_params,
)
from anthropic.types.beta.tools import ToolsBetaMessage
from httpx import Timeout
from core.model_runtime.callbacks.base_callback import Callback
@ -27,6 +29,7 @@ from core.model_runtime.entities.message_entities import (
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.errors.invoke import (
@ -70,10 +73,11 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
:return: full response or stream response chunk generator result
"""
# invoke model
return self._chat_generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
return self._chat_generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def _chat_generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
"""
Invoke llm chat model
@ -109,14 +113,26 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
if system:
extra_model_kwargs['system'] = system
# chat model
response = client.messages.create(
model=model,
messages=prompt_message_dicts,
stream=stream,
**model_parameters,
**extra_model_kwargs
)
if tools:
extra_model_kwargs['tools'] = [
self._transform_tool_prompt(tool) for tool in tools
]
response = client.beta.tools.messages.create(
model=model,
messages=prompt_message_dicts,
stream=stream,
**model_parameters,
**extra_model_kwargs
)
else:
# chat model
response = client.messages.create(
model=model,
messages=prompt_message_dicts,
stream=stream,
**model_parameters,
**extra_model_kwargs
)
if stream:
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
@ -148,6 +164,13 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def _transform_tool_prompt(self, tool: PromptMessageTool) -> dict:
return {
'name': tool.name,
'description': tool.description,
'input_schema': tool.parameters
}
def _transform_chat_json_prompts(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
@ -193,7 +216,18 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
prompt = self._convert_messages_to_prompt_anthropic(prompt_messages)
client = Anthropic(api_key="")
return client.count_tokens(prompt)
tokens = client.count_tokens(prompt)
tool_call_inner_prompts_tokens_map = {
'claude-3-opus-20240229': 395,
'claude-3-haiku-20240307': 264,
'claude-3-sonnet-20240229': 159
}
if model in tool_call_inner_prompts_tokens_map and tools:
tokens += tool_call_inner_prompts_tokens_map[model]
return tokens
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
@ -219,7 +253,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _handle_chat_generate_response(self, model: str, credentials: dict, response: Message,
def _handle_chat_generate_response(self, model: str, credentials: dict, response: Union[Message, ToolsBetaMessage],
prompt_messages: list[PromptMessage]) -> LLMResult:
"""
Handle llm chat response
@ -232,9 +266,24 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
"""
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=response.content[0].text
content='',
tool_calls=[]
)
for content in response.content:
if content.type == 'text':
assistant_prompt_message.content += content.text
elif content.type == 'tool_use':
tool_call = AssistantPromptMessage.ToolCall(
id=content.id,
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=content.name,
arguments=json.dumps(content.input)
)
)
assistant_prompt_message.tool_calls.append(tool_call)
# calculate num tokens
if response.usage:
# transform usage
@ -356,69 +405,90 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
prompt_message_dicts = []
for message in prompt_messages:
if not isinstance(message, SystemPromptMessage):
prompt_message_dicts.append(self._convert_prompt_message_to_dict(message))
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
prompt_message_dicts.append(message_dict)
else:
sub_messages = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content)
sub_message_dict = {
"type": "text",
"text": message_content.data
}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
if not message_content.data.startswith("data:"):
# fetch image data from url
try:
image_content = requests.get(message_content.data).content
mime_type, _ = mimetypes.guess_type(message_content.data)
base64_data = base64.b64encode(image_content).decode('utf-8')
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
else:
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
raise ValueError(f"Unsupported image type {mime_type}, "
f"only support image/jpeg, image/png, image/gif, and image/webp")
sub_message_dict = {
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": base64_data
}
}
sub_messages.append(sub_message_dict)
prompt_message_dicts.append({"role": "user", "content": sub_messages})
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
content = []
if message.tool_calls:
for tool_call in message.tool_calls:
content.append({
"type": "tool_use",
"id": tool_call.id,
"name": tool_call.function.name,
"input": json.loads(tool_call.function.arguments)
})
if message.content:
content.append({
"type": "text",
"text": message.content
})
if prompt_message_dicts[-1]["role"] == "assistant":
prompt_message_dicts[-1]["content"].extend(content)
else:
prompt_message_dicts.append({
"role": "assistant",
"content": content
})
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {
"role": "user",
"content": [{
"type": "tool_result",
"tool_use_id": message.tool_call_id,
"content": message.content
}]
}
prompt_message_dicts.append(message_dict)
else:
raise ValueError(f"Got unknown type {message}")
return system, prompt_message_dicts
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
"""
Convert PromptMessage to dict
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
sub_messages = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content)
sub_message_dict = {
"type": "text",
"text": message_content.data
}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
if not message_content.data.startswith("data:"):
# fetch image data from url
try:
image_content = requests.get(message_content.data).content
mime_type, _ = mimetypes.guess_type(message_content.data)
base64_data = base64.b64encode(image_content).decode('utf-8')
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
else:
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
raise ValueError(f"Unsupported image type {mime_type}, "
f"only support image/jpeg, image/png, image/gif, and image/webp")
sub_message_dict = {
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": base64_data
}
}
sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
return message_dict
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
"""
Convert a single message to a string.
@ -453,6 +523,8 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
message_text += f"{ai_prompt} [IMAGE]"
elif isinstance(message, SystemPromptMessage):
message_text = content
elif isinstance(message, ToolPromptMessage):
message_text = f"{human_prompt} {message.content}"
else:
raise ValueError(f"Got unknown type {message}")

View File

@ -36,7 +36,7 @@ python-docx~=1.1.0
pypdfium2==4.16.0
resend~=0.7.0
pyjwt~=2.8.0
anthropic~=0.20.0
anthropic~=0.23.1
newspaper3k==0.2.8
google-api-python-client==2.90.0
wikipedia==1.4.0