fix(api/model_runtime/azure/llm): Switch to tool_call. (#5541)

This commit is contained in:
-LAN- 2024-06-24 15:35:21 +08:00 committed by GitHub
parent 41ceb6a4eb
commit ba67206bb9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 137 additions and 142 deletions

View File

@ -1,14 +1,13 @@
import copy import copy
import logging import logging
from collections.abc import Generator from collections.abc import Generator, Sequence
from typing import Optional, Union, cast from typing import Optional, Union, cast
import tiktoken import tiktoken
from openai import AzureOpenAI, Stream from openai import AzureOpenAI, Stream
from openai.types import Completion from openai.types import Completion
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
from openai.types.chat.chat_completion_message import FunctionCall
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
@ -16,6 +15,7 @@ from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent, ImagePromptMessageContent,
PromptMessage, PromptMessage,
PromptMessageContentType, PromptMessageContentType,
PromptMessageFunction,
PromptMessageTool, PromptMessageTool,
SystemPromptMessage, SystemPromptMessage,
TextPromptMessageContent, TextPromptMessageContent,
@ -26,7 +26,8 @@ from core.model_runtime.entities.model_entities import AIModelEntity, ModelPrope
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS, AzureBaseModel from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS
from core.model_runtime.utils import helper
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -39,9 +40,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
stream: bool = True, user: Optional[str] = None) \ stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]: -> Union[LLMResult, Generator]:
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model) base_model_name = credentials.get('base_model_name')
if not base_model_name:
raise ValueError('Base Model Name is required')
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
# chat model # chat model
return self._chat_generate( return self._chat_generate(
model=model, model=model,
@ -65,18 +69,29 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
user=user user=user
) )
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def get_num_tokens(
tools: Optional[list[PromptMessageTool]] = None) -> int: self,
model: str,
model_mode = self._get_ai_model_entity(credentials.get('base_model_name'), model).entity.model_properties.get( credentials: dict,
ModelPropertyKey.MODE) prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None
) -> int:
base_model_name = credentials.get('base_model_name')
if not base_model_name:
raise ValueError('Base Model Name is required')
model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
if not model_entity:
raise ValueError(f'Base Model Name {base_model_name} is invalid')
model_mode = model_entity.entity.model_properties.get(ModelPropertyKey.MODE)
if model_mode == LLMMode.CHAT.value: if model_mode == LLMMode.CHAT.value:
# chat model # chat model
return self._num_tokens_from_messages(credentials, prompt_messages, tools) return self._num_tokens_from_messages(credentials, prompt_messages, tools)
else: else:
# text completion model, do not support tool calling # text completion model, do not support tool calling
return self._num_tokens_from_string(credentials, prompt_messages[0].content) content = prompt_messages[0].content
assert isinstance(content, str)
return self._num_tokens_from_string(credentials,content)
def validate_credentials(self, model: str, credentials: dict) -> None: def validate_credentials(self, model: str, credentials: dict) -> None:
if 'openai_api_base' not in credentials: if 'openai_api_base' not in credentials:
@ -88,7 +103,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
if 'base_model_name' not in credentials: if 'base_model_name' not in credentials:
raise CredentialsValidateFailedError('Base Model Name is required') raise CredentialsValidateFailedError('Base Model Name is required')
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model) base_model_name = credentials.get('base_model_name')
if not base_model_name:
raise CredentialsValidateFailedError('Base Model Name is required')
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
if not ai_model_entity: if not ai_model_entity:
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid') raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
@ -118,7 +136,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
raise CredentialsValidateFailedError(str(ex)) raise CredentialsValidateFailedError(str(ex))
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model) base_model_name = credentials.get('base_model_name')
if not base_model_name:
raise ValueError('Base Model Name is required')
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
return ai_model_entity.entity if ai_model_entity else None return ai_model_entity.entity if ai_model_entity else None
def _generate(self, model: str, credentials: dict, def _generate(self, model: str, credentials: dict,
@ -149,8 +170,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
return self._handle_generate_response(model, credentials, response, prompt_messages) return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_response(self, model: str, credentials: dict, response: Completion, def _handle_generate_response(
prompt_messages: list[PromptMessage]) -> LLMResult: self, model: str, credentials: dict, response: Completion,
prompt_messages: list[PromptMessage]
):
assistant_text = response.choices[0].text assistant_text = response.choices[0].text
# transform assistant message to prompt message # transform assistant message to prompt message
@ -165,7 +188,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
completion_tokens = response.usage.completion_tokens completion_tokens = response.usage.completion_tokens
else: else:
# calculate num tokens # calculate num tokens
prompt_tokens = self._num_tokens_from_string(credentials, prompt_messages[0].content) content = prompt_messages[0].content
assert isinstance(content, str)
prompt_tokens = self._num_tokens_from_string(credentials, content)
completion_tokens = self._num_tokens_from_string(credentials, assistant_text) completion_tokens = self._num_tokens_from_string(credentials, assistant_text)
# transform usage # transform usage
@ -182,8 +207,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
return result return result
def _handle_generate_stream_response(self, model: str, credentials: dict, response: Stream[Completion], def _handle_generate_stream_response(
prompt_messages: list[PromptMessage]) -> Generator: self, model: str, credentials: dict, response: Stream[Completion],
prompt_messages: list[PromptMessage]
) -> Generator:
full_text = '' full_text = ''
for chunk in response: for chunk in response:
if len(chunk.choices) == 0: if len(chunk.choices) == 0:
@ -210,7 +237,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
completion_tokens = chunk.usage.completion_tokens completion_tokens = chunk.usage.completion_tokens
else: else:
# calculate num tokens # calculate num tokens
prompt_tokens = self._num_tokens_from_string(credentials, prompt_messages[0].content) content = prompt_messages[0].content
assert isinstance(content, str)
prompt_tokens = self._num_tokens_from_string(credentials, content)
completion_tokens = self._num_tokens_from_string(credentials, full_text) completion_tokens = self._num_tokens_from_string(credentials, full_text)
# transform usage # transform usage
@ -257,12 +286,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
extra_model_kwargs = {} extra_model_kwargs = {}
if tools: if tools:
# extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools] extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
extra_model_kwargs['functions'] = [{ # extra_model_kwargs['functions'] = [{
"name": tool.name, # "name": tool.name,
"description": tool.description, # "description": tool.description,
"parameters": tool.parameters # "parameters": tool.parameters
} for tool in tools] # } for tool in tools]
if stop: if stop:
extra_model_kwargs['stop'] = stop extra_model_kwargs['stop'] = stop
@ -271,8 +300,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
extra_model_kwargs['user'] = user extra_model_kwargs['user'] = user
# chat model # chat model
messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
response = client.chat.completions.create( response = client.chat.completions.create(
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], messages=messages,
model=model, model=model,
stream=stream, stream=stream,
**model_parameters, **model_parameters,
@ -284,18 +314,17 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion, def _handle_chat_generate_response(
self, model: str, credentials: dict, response: ChatCompletion,
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> LLMResult: tools: Optional[list[PromptMessageTool]] = None
):
assistant_message = response.choices[0].message assistant_message = response.choices[0].message
# assistant_message_tool_calls = assistant_message.tool_calls assistant_message_tool_calls = assistant_message.tool_calls
assistant_message_function_call = assistant_message.function_call
# extract tool calls from response # extract tool calls from response
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) tool_calls = []
function_call = self._extract_response_function_call(assistant_message_function_call) self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=assistant_message_tool_calls)
tool_calls = [function_call] if function_call else []
# transform assistant message to prompt message # transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage( assistant_prompt_message = AssistantPromptMessage(
@ -317,7 +346,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
# transform response # transform response
response = LLMResult( result = LLMResult(
model=response.model or model, model=response.model or model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
message=assistant_prompt_message, message=assistant_prompt_message,
@ -325,59 +354,35 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
system_fingerprint=response.system_fingerprint, system_fingerprint=response.system_fingerprint,
) )
return response return result
def _handle_chat_generate_stream_response(self, model: str, credentials: dict, def _handle_chat_generate_stream_response(
self,
model: str,
credentials: dict,
response: Stream[ChatCompletionChunk], response: Stream[ChatCompletionChunk],
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> Generator: tools: Optional[list[PromptMessageTool]] = None
):
index = 0 index = 0
full_assistant_content = '' full_assistant_content = ''
delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None
real_model = model real_model = model
system_fingerprint = None system_fingerprint = None
completion = '' completion = ''
tool_calls = []
for chunk in response: for chunk in response:
if len(chunk.choices) == 0: if len(chunk.choices) == 0:
continue continue
delta = chunk.choices[0] delta = chunk.choices[0]
# extract tool calls from response
self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=delta.delta.tool_calls)
# Handling exceptions when content filters' streaming mode is set to asynchronous modified filter # Handling exceptions when content filters' streaming mode is set to asynchronous modified filter
if delta.delta is None or ( if delta.finish_reason is None and not delta.delta.content:
delta.finish_reason is None
and (delta.delta.content is None or delta.delta.content == '')
and delta.delta.function_call is None
):
continue continue
# assistant_message_tool_calls = delta.delta.tool_calls
assistant_message_function_call = delta.delta.function_call
# extract tool calls from response
if delta_assistant_message_function_call_storage is not None:
# handle process of stream function call
if assistant_message_function_call:
# message has not ended ever
delta_assistant_message_function_call_storage.arguments += assistant_message_function_call.arguments
continue
else:
# message has ended
assistant_message_function_call = delta_assistant_message_function_call_storage
delta_assistant_message_function_call_storage = None
else:
if assistant_message_function_call:
# start of stream function call
delta_assistant_message_function_call_storage = assistant_message_function_call
if delta_assistant_message_function_call_storage.arguments is None:
delta_assistant_message_function_call_storage.arguments = ''
continue
# extract tool calls from response
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
function_call = self._extract_response_function_call(assistant_message_function_call)
tool_calls = [function_call] if function_call else []
# transform assistant message to prompt message # transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage( assistant_prompt_message = AssistantPromptMessage(
content=delta.delta.content if delta.delta.content else '', content=delta.delta.content if delta.delta.content else '',
@ -426,12 +431,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
) )
@staticmethod @staticmethod
def _extract_response_tool_calls(response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \ def _update_tool_calls(tool_calls: list[AssistantPromptMessage.ToolCall], tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]]) -> None:
-> list[AssistantPromptMessage.ToolCall]: if tool_calls_response:
for response_tool_call in tool_calls_response:
tool_calls = [] if isinstance(response_tool_call, ChatCompletionMessageToolCall):
if response_tool_calls:
for response_tool_call in response_tool_calls:
function = AssistantPromptMessage.ToolCall.ToolCallFunction( function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call.function.name, name=response_tool_call.function.name,
arguments=response_tool_call.function.arguments arguments=response_tool_call.function.arguments
@ -443,37 +446,41 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
function=function function=function
) )
tool_calls.append(tool_call) tool_calls.append(tool_call)
elif isinstance(response_tool_call, ChoiceDeltaToolCall):
index = response_tool_call.index
if index < len(tool_calls):
tool_calls[index].id = response_tool_call.id or tool_calls[index].id
tool_calls[index].type = response_tool_call.type or tool_calls[index].type
if response_tool_call.function:
tool_calls[index].function.name = response_tool_call.function.name or tool_calls[index].function.name
tool_calls[index].function.arguments += response_tool_call.function.arguments or ''
else:
assert response_tool_call.id is not None
assert response_tool_call.type is not None
assert response_tool_call.function is not None
assert response_tool_call.function.name is not None
assert response_tool_call.function.arguments is not None
return tool_calls
@staticmethod
def _extract_response_function_call(response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \
-> AssistantPromptMessage.ToolCall:
tool_call = None
if response_function_call:
function = AssistantPromptMessage.ToolCall.ToolCallFunction( function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_function_call.name, name=response_tool_call.function.name,
arguments=response_function_call.arguments arguments=response_tool_call.function.arguments
) )
tool_call = AssistantPromptMessage.ToolCall( tool_call = AssistantPromptMessage.ToolCall(
id=response_function_call.name, id=response_tool_call.id,
type="function", type=response_tool_call.type,
function=function function=function
) )
tool_calls.append(tool_call)
return tool_call
@staticmethod @staticmethod
def _convert_prompt_message_to_dict(message: PromptMessage) -> dict: def _convert_prompt_message_to_dict(message: PromptMessage):
if isinstance(message, UserPromptMessage): if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message) message = cast(UserPromptMessage, message)
if isinstance(message.content, str): if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content} message_dict = {"role": "user", "content": message.content}
else: else:
sub_messages = [] sub_messages = []
assert message.content is not None
for message_content in message.content: for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT: if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content) message_content = cast(TextPromptMessageContent, message_content)
@ -492,33 +499,22 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
} }
} }
sub_messages.append(sub_message_dict) sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages} message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage): elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message) message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content} message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls: if message.tool_calls:
# message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in message.tool_calls]
# message.tool_calls]
function_call = message.tool_calls[0]
message_dict["function_call"] = {
"name": function_call.function.name,
"arguments": function_call.function.arguments,
}
elif isinstance(message, SystemPromptMessage): elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message) message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content} message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage): elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message) message = cast(ToolPromptMessage, message)
# message_dict = {
# "role": "tool",
# "content": message.content,
# "tool_call_id": message.tool_call_id
# }
message_dict = { message_dict = {
"role": "function", "role": "tool",
"name": message.name,
"content": message.content, "content": message.content,
"name": message.tool_call_id "tool_call_id": message.tool_call_id
} }
else: else:
raise ValueError(f"Got unknown type {message}") raise ValueError(f"Got unknown type {message}")
@ -542,8 +538,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
return num_tokens return num_tokens
def _num_tokens_from_messages(self, credentials: dict, messages: list[PromptMessage], def _num_tokens_from_messages(
tools: Optional[list[PromptMessageTool]] = None) -> int: self, credentials: dict, messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None
) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/ Official documentation: https://github.com/openai/openai-cookbook/blob/
@ -591,6 +589,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
if key == "tool_calls": if key == "tool_calls":
for tool_call in value: for tool_call in value:
assert isinstance(tool_call, dict)
for t_key, t_value in tool_call.items(): for t_key, t_value in tool_call.items():
num_tokens += len(encoding.encode(t_key)) num_tokens += len(encoding.encode(t_key))
if t_key == "function": if t_key == "function":
@ -631,12 +630,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
num_tokens += len(encoding.encode('parameters')) num_tokens += len(encoding.encode('parameters'))
if 'title' in parameters: if 'title' in parameters:
num_tokens += len(encoding.encode('title')) num_tokens += len(encoding.encode('title'))
num_tokens += len(encoding.encode(parameters.get("title"))) num_tokens += len(encoding.encode(parameters['title']))
num_tokens += len(encoding.encode('type')) num_tokens += len(encoding.encode('type'))
num_tokens += len(encoding.encode(parameters.get("type"))) num_tokens += len(encoding.encode(parameters['type']))
if 'properties' in parameters: if 'properties' in parameters:
num_tokens += len(encoding.encode('properties')) num_tokens += len(encoding.encode('properties'))
for key, value in parameters.get('properties').items(): for key, value in parameters['properties'].items():
num_tokens += len(encoding.encode(key)) num_tokens += len(encoding.encode(key))
for field_key, field_value in value.items(): for field_key, field_value in value.items():
num_tokens += len(encoding.encode(field_key)) num_tokens += len(encoding.encode(field_key))
@ -656,7 +655,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
return num_tokens return num_tokens
@staticmethod @staticmethod
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel: def _get_ai_model_entity(base_model_name: str, model: str):
for ai_model_entity in LLM_BASE_MODELS: for ai_model_entity in LLM_BASE_MODELS:
if ai_model_entity.base_model_name == base_model_name: if ai_model_entity.base_model_name == base_model_name:
ai_model_entity_copy = copy.deepcopy(ai_model_entity) ai_model_entity_copy = copy.deepcopy(ai_model_entity)
@ -664,5 +663,3 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
ai_model_entity_copy.entity.label.en_US = model ai_model_entity_copy.entity.label.en_US = model
ai_model_entity_copy.entity.label.zh_Hans = model ai_model_entity_copy.entity.label.zh_Hans = model
return ai_model_entity_copy return ai_model_entity_copy
return None

View File

@ -73,15 +73,13 @@ class MockChatClass:
return FunctionCall(name=function_name, arguments=dumps(parameters)) return FunctionCall(name=function_name, arguments=dumps(parameters))
@staticmethod @staticmethod
def generate_tool_calls( def generate_tool_calls(tools = NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]:
tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
) -> Optional[list[ChatCompletionMessageToolCall]]:
list_tool_calls = [] list_tool_calls = []
if not tools or len(tools) == 0: if not tools or len(tools) == 0:
return None return None
tool: ChatCompletionToolParam = tools[0] tool = tools[0]
if tools['type'] != 'function': if 'type' in tools and tools['type'] != 'function':
return None return None
function = tool['function'] function = tool['function']