mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 04:56:02 +08:00
feat: claude3 tool call (#3111)
This commit is contained in:
parent
718ac3f83b
commit
25b9ac3df4
@ -5,6 +5,7 @@ model_type: llm
|
|||||||
features:
|
features:
|
||||||
- agent-thought
|
- agent-thought
|
||||||
- vision
|
- vision
|
||||||
|
- tool-call
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 200000
|
context_size: 200000
|
||||||
|
@ -5,6 +5,7 @@ model_type: llm
|
|||||||
features:
|
features:
|
||||||
- agent-thought
|
- agent-thought
|
||||||
- vision
|
- vision
|
||||||
|
- tool-call
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 200000
|
context_size: 200000
|
||||||
|
@ -5,6 +5,7 @@ model_type: llm
|
|||||||
features:
|
features:
|
||||||
- agent-thought
|
- agent-thought
|
||||||
- vision
|
- vision
|
||||||
|
- tool-call
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 200000
|
context_size: 200000
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import base64
|
import base64
|
||||||
|
import json
|
||||||
import mimetypes
|
import mimetypes
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Optional, Union, cast
|
from typing import Optional, Union, cast
|
||||||
@ -15,6 +16,7 @@ from anthropic.types import (
|
|||||||
MessageStreamEvent,
|
MessageStreamEvent,
|
||||||
completion_create_params,
|
completion_create_params,
|
||||||
)
|
)
|
||||||
|
from anthropic.types.beta.tools import ToolsBetaMessage
|
||||||
from httpx import Timeout
|
from httpx import Timeout
|
||||||
|
|
||||||
from core.model_runtime.callbacks.base_callback import Callback
|
from core.model_runtime.callbacks.base_callback import Callback
|
||||||
@ -27,6 +29,7 @@ from core.model_runtime.entities.message_entities import (
|
|||||||
PromptMessageTool,
|
PromptMessageTool,
|
||||||
SystemPromptMessage,
|
SystemPromptMessage,
|
||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
|
ToolPromptMessage,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
from core.model_runtime.errors.invoke import (
|
from core.model_runtime.errors.invoke import (
|
||||||
@ -70,10 +73,11 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
:return: full response or stream response chunk generator result
|
:return: full response or stream response chunk generator result
|
||||||
"""
|
"""
|
||||||
# invoke model
|
# invoke model
|
||||||
return self._chat_generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
return self._chat_generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||||
|
|
||||||
def _chat_generate(self, model: str, credentials: dict,
|
def _chat_generate(self, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Invoke llm chat model
|
Invoke llm chat model
|
||||||
@ -109,14 +113,26 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
if system:
|
if system:
|
||||||
extra_model_kwargs['system'] = system
|
extra_model_kwargs['system'] = system
|
||||||
|
|
||||||
# chat model
|
if tools:
|
||||||
response = client.messages.create(
|
extra_model_kwargs['tools'] = [
|
||||||
model=model,
|
self._transform_tool_prompt(tool) for tool in tools
|
||||||
messages=prompt_message_dicts,
|
]
|
||||||
stream=stream,
|
response = client.beta.tools.messages.create(
|
||||||
**model_parameters,
|
model=model,
|
||||||
**extra_model_kwargs
|
messages=prompt_message_dicts,
|
||||||
)
|
stream=stream,
|
||||||
|
**model_parameters,
|
||||||
|
**extra_model_kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# chat model
|
||||||
|
response = client.messages.create(
|
||||||
|
model=model,
|
||||||
|
messages=prompt_message_dicts,
|
||||||
|
stream=stream,
|
||||||
|
**model_parameters,
|
||||||
|
**extra_model_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
|
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
|
||||||
@ -148,6 +164,13 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||||
|
|
||||||
|
def _transform_tool_prompt(self, tool: PromptMessageTool) -> dict:
|
||||||
|
return {
|
||||||
|
'name': tool.name,
|
||||||
|
'description': tool.description,
|
||||||
|
'input_schema': tool.parameters
|
||||||
|
}
|
||||||
|
|
||||||
def _transform_chat_json_prompts(self, model: str, credentials: dict,
|
def _transform_chat_json_prompts(self, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||||
@ -193,7 +216,18 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
prompt = self._convert_messages_to_prompt_anthropic(prompt_messages)
|
prompt = self._convert_messages_to_prompt_anthropic(prompt_messages)
|
||||||
|
|
||||||
client = Anthropic(api_key="")
|
client = Anthropic(api_key="")
|
||||||
return client.count_tokens(prompt)
|
tokens = client.count_tokens(prompt)
|
||||||
|
|
||||||
|
tool_call_inner_prompts_tokens_map = {
|
||||||
|
'claude-3-opus-20240229': 395,
|
||||||
|
'claude-3-haiku-20240307': 264,
|
||||||
|
'claude-3-sonnet-20240229': 159
|
||||||
|
}
|
||||||
|
|
||||||
|
if model in tool_call_inner_prompts_tokens_map and tools:
|
||||||
|
tokens += tool_call_inner_prompts_tokens_map[model]
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
"""
|
"""
|
||||||
@ -219,7 +253,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise CredentialsValidateFailedError(str(ex))
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
def _handle_chat_generate_response(self, model: str, credentials: dict, response: Message,
|
def _handle_chat_generate_response(self, model: str, credentials: dict, response: Union[Message, ToolsBetaMessage],
|
||||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||||
"""
|
"""
|
||||||
Handle llm chat response
|
Handle llm chat response
|
||||||
@ -232,9 +266,24 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
"""
|
"""
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
content=response.content[0].text
|
content='',
|
||||||
|
tool_calls=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for content in response.content:
|
||||||
|
if content.type == 'text':
|
||||||
|
assistant_prompt_message.content += content.text
|
||||||
|
elif content.type == 'tool_use':
|
||||||
|
tool_call = AssistantPromptMessage.ToolCall(
|
||||||
|
id=content.id,
|
||||||
|
type='function',
|
||||||
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
|
name=content.name,
|
||||||
|
arguments=json.dumps(content.input)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assistant_prompt_message.tool_calls.append(tool_call)
|
||||||
|
|
||||||
# calculate num tokens
|
# calculate num tokens
|
||||||
if response.usage:
|
if response.usage:
|
||||||
# transform usage
|
# transform usage
|
||||||
@ -356,69 +405,90 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
prompt_message_dicts = []
|
prompt_message_dicts = []
|
||||||
for message in prompt_messages:
|
for message in prompt_messages:
|
||||||
if not isinstance(message, SystemPromptMessage):
|
if not isinstance(message, SystemPromptMessage):
|
||||||
prompt_message_dicts.append(self._convert_prompt_message_to_dict(message))
|
if isinstance(message, UserPromptMessage):
|
||||||
|
message = cast(UserPromptMessage, message)
|
||||||
|
if isinstance(message.content, str):
|
||||||
|
message_dict = {"role": "user", "content": message.content}
|
||||||
|
prompt_message_dicts.append(message_dict)
|
||||||
|
else:
|
||||||
|
sub_messages = []
|
||||||
|
for message_content in message.content:
|
||||||
|
if message_content.type == PromptMessageContentType.TEXT:
|
||||||
|
message_content = cast(TextPromptMessageContent, message_content)
|
||||||
|
sub_message_dict = {
|
||||||
|
"type": "text",
|
||||||
|
"text": message_content.data
|
||||||
|
}
|
||||||
|
sub_messages.append(sub_message_dict)
|
||||||
|
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||||
|
message_content = cast(ImagePromptMessageContent, message_content)
|
||||||
|
if not message_content.data.startswith("data:"):
|
||||||
|
# fetch image data from url
|
||||||
|
try:
|
||||||
|
image_content = requests.get(message_content.data).content
|
||||||
|
mime_type, _ = mimetypes.guess_type(message_content.data)
|
||||||
|
base64_data = base64.b64encode(image_content).decode('utf-8')
|
||||||
|
except Exception as ex:
|
||||||
|
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
||||||
|
else:
|
||||||
|
data_split = message_content.data.split(";base64,")
|
||||||
|
mime_type = data_split[0].replace("data:", "")
|
||||||
|
base64_data = data_split[1]
|
||||||
|
|
||||||
|
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
|
||||||
|
raise ValueError(f"Unsupported image type {mime_type}, "
|
||||||
|
f"only support image/jpeg, image/png, image/gif, and image/webp")
|
||||||
|
|
||||||
|
sub_message_dict = {
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": mime_type,
|
||||||
|
"data": base64_data
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sub_messages.append(sub_message_dict)
|
||||||
|
prompt_message_dicts.append({"role": "user", "content": sub_messages})
|
||||||
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
|
message = cast(AssistantPromptMessage, message)
|
||||||
|
content = []
|
||||||
|
if message.tool_calls:
|
||||||
|
for tool_call in message.tool_calls:
|
||||||
|
content.append({
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": tool_call.id,
|
||||||
|
"name": tool_call.function.name,
|
||||||
|
"input": json.loads(tool_call.function.arguments)
|
||||||
|
})
|
||||||
|
if message.content:
|
||||||
|
content.append({
|
||||||
|
"type": "text",
|
||||||
|
"text": message.content
|
||||||
|
})
|
||||||
|
|
||||||
|
if prompt_message_dicts[-1]["role"] == "assistant":
|
||||||
|
prompt_message_dicts[-1]["content"].extend(content)
|
||||||
|
else:
|
||||||
|
prompt_message_dicts.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": content
|
||||||
|
})
|
||||||
|
elif isinstance(message, ToolPromptMessage):
|
||||||
|
message = cast(ToolPromptMessage, message)
|
||||||
|
message_dict = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": message.tool_call_id,
|
||||||
|
"content": message.content
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
prompt_message_dicts.append(message_dict)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
|
||||||
return system, prompt_message_dicts
|
return system, prompt_message_dicts
|
||||||
|
|
||||||
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
|
||||||
"""
|
|
||||||
Convert PromptMessage to dict
|
|
||||||
"""
|
|
||||||
if isinstance(message, UserPromptMessage):
|
|
||||||
message = cast(UserPromptMessage, message)
|
|
||||||
if isinstance(message.content, str):
|
|
||||||
message_dict = {"role": "user", "content": message.content}
|
|
||||||
else:
|
|
||||||
sub_messages = []
|
|
||||||
for message_content in message.content:
|
|
||||||
if message_content.type == PromptMessageContentType.TEXT:
|
|
||||||
message_content = cast(TextPromptMessageContent, message_content)
|
|
||||||
sub_message_dict = {
|
|
||||||
"type": "text",
|
|
||||||
"text": message_content.data
|
|
||||||
}
|
|
||||||
sub_messages.append(sub_message_dict)
|
|
||||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
|
||||||
message_content = cast(ImagePromptMessageContent, message_content)
|
|
||||||
if not message_content.data.startswith("data:"):
|
|
||||||
# fetch image data from url
|
|
||||||
try:
|
|
||||||
image_content = requests.get(message_content.data).content
|
|
||||||
mime_type, _ = mimetypes.guess_type(message_content.data)
|
|
||||||
base64_data = base64.b64encode(image_content).decode('utf-8')
|
|
||||||
except Exception as ex:
|
|
||||||
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
|
||||||
else:
|
|
||||||
data_split = message_content.data.split(";base64,")
|
|
||||||
mime_type = data_split[0].replace("data:", "")
|
|
||||||
base64_data = data_split[1]
|
|
||||||
|
|
||||||
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
|
|
||||||
raise ValueError(f"Unsupported image type {mime_type}, "
|
|
||||||
f"only support image/jpeg, image/png, image/gif, and image/webp")
|
|
||||||
|
|
||||||
sub_message_dict = {
|
|
||||||
"type": "image",
|
|
||||||
"source": {
|
|
||||||
"type": "base64",
|
|
||||||
"media_type": mime_type,
|
|
||||||
"data": base64_data
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sub_messages.append(sub_message_dict)
|
|
||||||
|
|
||||||
message_dict = {"role": "user", "content": sub_messages}
|
|
||||||
elif isinstance(message, AssistantPromptMessage):
|
|
||||||
message = cast(AssistantPromptMessage, message)
|
|
||||||
message_dict = {"role": "assistant", "content": message.content}
|
|
||||||
elif isinstance(message, SystemPromptMessage):
|
|
||||||
message = cast(SystemPromptMessage, message)
|
|
||||||
message_dict = {"role": "system", "content": message.content}
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Got unknown type {message}")
|
|
||||||
|
|
||||||
return message_dict
|
|
||||||
|
|
||||||
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
|
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
|
||||||
"""
|
"""
|
||||||
Convert a single message to a string.
|
Convert a single message to a string.
|
||||||
@ -453,6 +523,8 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
message_text += f"{ai_prompt} [IMAGE]"
|
message_text += f"{ai_prompt} [IMAGE]"
|
||||||
elif isinstance(message, SystemPromptMessage):
|
elif isinstance(message, SystemPromptMessage):
|
||||||
message_text = content
|
message_text = content
|
||||||
|
elif isinstance(message, ToolPromptMessage):
|
||||||
|
message_text = f"{human_prompt} {message.content}"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown type {message}")
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
|
||||||
|
@ -36,7 +36,7 @@ python-docx~=1.1.0
|
|||||||
pypdfium2==4.16.0
|
pypdfium2==4.16.0
|
||||||
resend~=0.7.0
|
resend~=0.7.0
|
||||||
pyjwt~=2.8.0
|
pyjwt~=2.8.0
|
||||||
anthropic~=0.20.0
|
anthropic~=0.23.1
|
||||||
newspaper3k==0.2.8
|
newspaper3k==0.2.8
|
||||||
google-api-python-client==2.90.0
|
google-api-python-client==2.90.0
|
||||||
wikipedia==1.4.0
|
wikipedia==1.4.0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user