feat: add claude3 function calling (#5889)

This commit is contained in:
longzhihun 2024-07-03 22:21:02 +08:00 committed by GitHub
parent cb8feb732f
commit aecdfa2d5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 114 additions and 39 deletions

View File

@ -5,6 +5,8 @@ model_type: llm
features: features:
- agent-thought - agent-thought
- vision - vision
- tool-call
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 200000 context_size: 200000

View File

@ -5,6 +5,8 @@ model_type: llm
features: features:
- agent-thought - agent-thought
- vision - vision
- tool-call
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 200000 context_size: 200000

View File

@ -5,6 +5,8 @@ model_type: llm
features: features:
- agent-thought - agent-thought
- vision - vision
- tool-call
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 200000 context_size: 200000

View File

@ -5,6 +5,8 @@ model_type: llm
features: features:
- agent-thought - agent-thought
- vision - vision
- tool-call
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 200000 context_size: 200000

View File

@ -29,6 +29,7 @@ from core.model_runtime.entities.message_entities import (
PromptMessageTool, PromptMessageTool,
SystemPromptMessage, SystemPromptMessage,
TextPromptMessageContent, TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.errors.invoke import ( from core.model_runtime.errors.invoke import (
@ -68,7 +69,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
# TODO: consolidate different invocation methods for models based on base model capabilities # TODO: consolidate different invocation methods for models based on base model capabilities
# invoke anthropic models via boto3 client # invoke anthropic models via boto3 client
if "anthropic" in model: if "anthropic" in model:
return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user) return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
# invoke Cohere models via boto3 client # invoke Cohere models via boto3 client
if "cohere.command-r" in model: if "cohere.command-r" in model:
return self._generate_cohere_chat(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools) return self._generate_cohere_chat(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
@ -151,7 +152,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]:
""" """
Invoke Anthropic large language model Invoke Anthropic large language model
@ -171,23 +172,24 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages) system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages)
inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop) inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop)
parameters = {
'modelId': model,
'messages': prompt_message_dicts,
'inferenceConfig': inference_config,
'additionalModelRequestFields': additional_model_fields,
}
if system and len(system) > 0:
parameters['system'] = system
if tools:
parameters['toolConfig'] = self._convert_converse_tool_config(tools=tools)
if stream: if stream:
response = bedrock_client.converse_stream( response = bedrock_client.converse_stream(**parameters)
modelId=model,
messages=prompt_message_dicts,
system=system,
inferenceConfig=inference_config,
additionalModelRequestFields=additional_model_fields
)
return self._handle_converse_stream_response(model, credentials, response, prompt_messages) return self._handle_converse_stream_response(model, credentials, response, prompt_messages)
else: else:
response = bedrock_client.converse( response = bedrock_client.converse(**parameters)
modelId=model,
messages=prompt_message_dicts,
system=system,
inferenceConfig=inference_config,
additionalModelRequestFields=additional_model_fields
)
return self._handle_converse_response(model, credentials, response, prompt_messages) return self._handle_converse_response(model, credentials, response, prompt_messages)
def _handle_converse_response(self, model: str, credentials: dict, response: dict, def _handle_converse_response(self, model: str, credentials: dict, response: dict,
@ -246,12 +248,18 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
output_tokens = 0 output_tokens = 0
finish_reason = None finish_reason = None
index = 0 index = 0
tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_use = {}
for chunk in response['stream']: for chunk in response['stream']:
if 'messageStart' in chunk: if 'messageStart' in chunk:
return_model = model return_model = model
elif 'messageStop' in chunk: elif 'messageStop' in chunk:
finish_reason = chunk['messageStop']['stopReason'] finish_reason = chunk['messageStop']['stopReason']
elif 'contentBlockStart' in chunk:
tool = chunk['contentBlockStart']['start']['toolUse']
tool_use['toolUseId'] = tool['toolUseId']
tool_use['name'] = tool['name']
elif 'metadata' in chunk: elif 'metadata' in chunk:
input_tokens = chunk['metadata']['usage']['inputTokens'] input_tokens = chunk['metadata']['usage']['inputTokens']
output_tokens = chunk['metadata']['usage']['outputTokens'] output_tokens = chunk['metadata']['usage']['outputTokens']
@ -260,29 +268,49 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
model=return_model, model=return_model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
delta=LLMResultChunkDelta( delta=LLMResultChunkDelta(
index=index + 1, index=index,
message=AssistantPromptMessage( message=AssistantPromptMessage(
content='' content='',
tool_calls=tool_calls
), ),
finish_reason=finish_reason, finish_reason=finish_reason,
usage=usage usage=usage
) )
) )
elif 'contentBlockDelta' in chunk: elif 'contentBlockDelta' in chunk:
chunk_text = chunk['contentBlockDelta']['delta']['text'] if chunk['contentBlockDelta']['delta']['text'] else '' delta = chunk['contentBlockDelta']['delta']
full_assistant_content += chunk_text if 'text' in delta:
assistant_prompt_message = AssistantPromptMessage( chunk_text = delta['text'] if delta['text'] else ''
content=chunk_text if chunk_text else '', full_assistant_content += chunk_text
) assistant_prompt_message = AssistantPromptMessage(
index = chunk['contentBlockDelta']['contentBlockIndex'] content=chunk_text if chunk_text else '',
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
) )
) index = chunk['contentBlockDelta']['contentBlockIndex']
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index+1,
message=assistant_prompt_message,
)
)
elif 'toolUse' in delta:
if 'input' not in tool_use:
tool_use['input'] = ''
tool_use['input'] += delta['toolUse']['input']
elif 'contentBlockStop' in chunk:
if 'input' in tool_use:
tool_call = AssistantPromptMessage.ToolCall(
id=tool_use['toolUseId'],
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_use['name'],
arguments=tool_use['input']
)
)
tool_calls.append(tool_call)
tool_use = {}
except Exception as ex: except Exception as ex:
raise InvokeError(str(ex)) raise InvokeError(str(ex))
@ -312,16 +340,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
""" """
system = [] system = []
first_loop = True
for message in prompt_messages: for message in prompt_messages:
if isinstance(message, SystemPromptMessage): if isinstance(message, SystemPromptMessage):
message.content=message.content.strip() message.content=message.content.strip()
if first_loop: system.append({"text": message.content})
system=message.content
first_loop=False
else:
system+="\n"
system+=message.content
prompt_message_dicts = [] prompt_message_dicts = []
for message in prompt_messages: for message in prompt_messages:
@ -330,6 +352,25 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
return system, prompt_message_dicts return system, prompt_message_dicts
def _convert_converse_tool_config(self, tools: Optional[list[PromptMessageTool]] = None) -> dict:
tool_config = {}
configs = []
if tools:
for tool in tools:
configs.append(
{
"toolSpec": {
"name": tool.name,
"description": tool.description,
"inputSchema": {
"json": tool.parameters
}
}
}
)
tool_config["tools"] = configs
return tool_config
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
""" """
Convert PromptMessage to dict Convert PromptMessage to dict
@ -379,10 +420,32 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
message_dict = {"role": "user", "content": sub_messages} message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage): elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message) message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": [{'text': message.content}]} if message.tool_calls:
message_dict = {
"role": "assistant", "content":[{
"toolUse": {
"toolUseId": message.tool_calls[0].id,
"name": message.tool_calls[0].function.name,
"input": json.loads(message.tool_calls[0].function.arguments)
}
}]
}
else:
message_dict = {"role": "assistant", "content": [{'text': message.content}]}
elif isinstance(message, SystemPromptMessage): elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message) message = cast(SystemPromptMessage, message)
message_dict = [{'text': message.content}] message_dict = [{'text': message.content}]
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {
"role": "user",
"content": [{
"toolResult": {
"toolUseId": message.tool_call_id,
"content": [{"json": {"text": message.content}}]
}
}]
}
else: else:
raise ValueError(f"Got unknown type {message}") raise ValueError(f"Got unknown type {message}")
@ -401,11 +464,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
""" """
prefix = model.split('.')[0] prefix = model.split('.')[0]
model_name = model.split('.')[1] model_name = model.split('.')[1]
if isinstance(prompt_messages, str): if isinstance(prompt_messages, str):
prompt = prompt_messages prompt = prompt_messages
else: else:
prompt = self._convert_messages_to_prompt(prompt_messages, prefix, model_name) prompt = self._convert_messages_to_prompt(prompt_messages, prefix, model_name)
return self._get_num_tokens_by_gpt2(prompt) return self._get_num_tokens_by_gpt2(prompt)
def validate_credentials(self, model: str, credentials: dict) -> None: def validate_credentials(self, model: str, credentials: dict) -> None:
@ -494,6 +559,8 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
message_text = f"{ai_prompt} {content}" message_text = f"{ai_prompt} {content}"
elif isinstance(message, SystemPromptMessage): elif isinstance(message, SystemPromptMessage):
message_text = content message_text = content
elif isinstance(message, ToolPromptMessage):
message_text = f"{human_prompt_prefix} {message.content}"
else: else:
raise ValueError(f"Got unknown type {message}") raise ValueError(f"Got unknown type {message}")