mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 20:19:12 +08:00
feat: claude3 tool call (#3111)
This commit is contained in:
parent
718ac3f83b
commit
25b9ac3df4
@ -5,6 +5,7 @@ model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
|
@ -5,6 +5,7 @@ model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
|
@ -5,6 +5,7 @@ model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
|
@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import json
|
||||
import mimetypes
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
@ -15,6 +16,7 @@ from anthropic.types import (
|
||||
MessageStreamEvent,
|
||||
completion_create_params,
|
||||
)
|
||||
from anthropic.types.beta.tools import ToolsBetaMessage
|
||||
from httpx import Timeout
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
@ -27,6 +29,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import (
|
||||
@ -70,10 +73,11 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# invoke model
|
||||
return self._chat_generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
||||
return self._chat_generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def _chat_generate(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke llm chat model
|
||||
@ -109,14 +113,26 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
if system:
|
||||
extra_model_kwargs['system'] = system
|
||||
|
||||
# chat model
|
||||
response = client.messages.create(
|
||||
model=model,
|
||||
messages=prompt_message_dicts,
|
||||
stream=stream,
|
||||
**model_parameters,
|
||||
**extra_model_kwargs
|
||||
)
|
||||
if tools:
|
||||
extra_model_kwargs['tools'] = [
|
||||
self._transform_tool_prompt(tool) for tool in tools
|
||||
]
|
||||
response = client.beta.tools.messages.create(
|
||||
model=model,
|
||||
messages=prompt_message_dicts,
|
||||
stream=stream,
|
||||
**model_parameters,
|
||||
**extra_model_kwargs
|
||||
)
|
||||
else:
|
||||
# chat model
|
||||
response = client.messages.create(
|
||||
model=model,
|
||||
messages=prompt_message_dicts,
|
||||
stream=stream,
|
||||
**model_parameters,
|
||||
**extra_model_kwargs
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
|
||||
@ -148,6 +164,13 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def _transform_tool_prompt(self, tool: PromptMessageTool) -> dict:
|
||||
return {
|
||||
'name': tool.name,
|
||||
'description': tool.description,
|
||||
'input_schema': tool.parameters
|
||||
}
|
||||
|
||||
def _transform_chat_json_prompts(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||
@ -193,7 +216,18 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
prompt = self._convert_messages_to_prompt_anthropic(prompt_messages)
|
||||
|
||||
client = Anthropic(api_key="")
|
||||
return client.count_tokens(prompt)
|
||||
tokens = client.count_tokens(prompt)
|
||||
|
||||
tool_call_inner_prompts_tokens_map = {
|
||||
'claude-3-opus-20240229': 395,
|
||||
'claude-3-haiku-20240307': 264,
|
||||
'claude-3-sonnet-20240229': 159
|
||||
}
|
||||
|
||||
if model in tool_call_inner_prompts_tokens_map and tools:
|
||||
tokens += tool_call_inner_prompts_tokens_map[model]
|
||||
|
||||
return tokens
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
@ -219,7 +253,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _handle_chat_generate_response(self, model: str, credentials: dict, response: Message,
|
||||
def _handle_chat_generate_response(self, model: str, credentials: dict, response: Union[Message, ToolsBetaMessage],
|
||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||
"""
|
||||
Handle llm chat response
|
||||
@ -232,9 +266,24 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=response.content[0].text
|
||||
content='',
|
||||
tool_calls=[]
|
||||
)
|
||||
|
||||
for content in response.content:
|
||||
if content.type == 'text':
|
||||
assistant_prompt_message.content += content.text
|
||||
elif content.type == 'tool_use':
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=content.id,
|
||||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=content.name,
|
||||
arguments=json.dumps(content.input)
|
||||
)
|
||||
)
|
||||
assistant_prompt_message.tool_calls.append(tool_call)
|
||||
|
||||
# calculate num tokens
|
||||
if response.usage:
|
||||
# transform usage
|
||||
@ -356,69 +405,90 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
prompt_message_dicts = []
|
||||
for message in prompt_messages:
|
||||
if not isinstance(message, SystemPromptMessage):
|
||||
prompt_message_dicts.append(self._convert_prompt_message_to_dict(message))
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
prompt_message_dicts.append(message_dict)
|
||||
else:
|
||||
sub_messages = []
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"type": "text",
|
||||
"text": message_content.data
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
if not message_content.data.startswith("data:"):
|
||||
# fetch image data from url
|
||||
try:
|
||||
image_content = requests.get(message_content.data).content
|
||||
mime_type, _ = mimetypes.guess_type(message_content.data)
|
||||
base64_data = base64.b64encode(image_content).decode('utf-8')
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
||||
else:
|
||||
data_split = message_content.data.split(";base64,")
|
||||
mime_type = data_split[0].replace("data:", "")
|
||||
base64_data = data_split[1]
|
||||
|
||||
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
|
||||
raise ValueError(f"Unsupported image type {mime_type}, "
|
||||
f"only support image/jpeg, image/png, image/gif, and image/webp")
|
||||
|
||||
sub_message_dict = {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mime_type,
|
||||
"data": base64_data
|
||||
}
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
prompt_message_dicts.append({"role": "user", "content": sub_messages})
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
content = []
|
||||
if message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
content.append({
|
||||
"type": "tool_use",
|
||||
"id": tool_call.id,
|
||||
"name": tool_call.function.name,
|
||||
"input": json.loads(tool_call.function.arguments)
|
||||
})
|
||||
if message.content:
|
||||
content.append({
|
||||
"type": "text",
|
||||
"text": message.content
|
||||
})
|
||||
|
||||
if prompt_message_dicts[-1]["role"] == "assistant":
|
||||
prompt_message_dicts[-1]["content"].extend(content)
|
||||
else:
|
||||
prompt_message_dicts.append({
|
||||
"role": "assistant",
|
||||
"content": content
|
||||
})
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": message.tool_call_id,
|
||||
"content": message.content
|
||||
}]
|
||||
}
|
||||
prompt_message_dicts.append(message_dict)
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
return system, prompt_message_dicts
|
||||
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
||||
"""
|
||||
Convert PromptMessage to dict
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
sub_messages = []
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"type": "text",
|
||||
"text": message_content.data
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
if not message_content.data.startswith("data:"):
|
||||
# fetch image data from url
|
||||
try:
|
||||
image_content = requests.get(message_content.data).content
|
||||
mime_type, _ = mimetypes.guess_type(message_content.data)
|
||||
base64_data = base64.b64encode(image_content).decode('utf-8')
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
||||
else:
|
||||
data_split = message_content.data.split(";base64,")
|
||||
mime_type = data_split[0].replace("data:", "")
|
||||
base64_data = data_split[1]
|
||||
|
||||
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
|
||||
raise ValueError(f"Unsupported image type {mime_type}, "
|
||||
f"only support image/jpeg, image/png, image/gif, and image/webp")
|
||||
|
||||
sub_message_dict = {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mime_type,
|
||||
"data": base64_data
|
||||
}
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
|
||||
message_dict = {"role": "user", "content": sub_messages}
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
return message_dict
|
||||
|
||||
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
|
||||
"""
|
||||
Convert a single message to a string.
|
||||
@ -453,6 +523,8 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
message_text += f"{ai_prompt} [IMAGE]"
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message_text = content
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message_text = f"{human_prompt} {message.content}"
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
|
@ -36,7 +36,7 @@ python-docx~=1.1.0
|
||||
pypdfium2==4.16.0
|
||||
resend~=0.7.0
|
||||
pyjwt~=2.8.0
|
||||
anthropic~=0.20.0
|
||||
anthropic~=0.23.1
|
||||
newspaper3k==0.2.8
|
||||
google-api-python-client==2.90.0
|
||||
wikipedia==1.4.0
|
||||
|
Loading…
x
Reference in New Issue
Block a user