mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 03:09:01 +08:00
fix(api/model_runtime/azure/llm): Switch to tool_call. (#5541)
This commit is contained in:
parent
41ceb6a4eb
commit
ba67206bb9
@ -1,14 +1,13 @@
|
|||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator, Sequence
|
||||||
from typing import Optional, Union, cast
|
from typing import Optional, Union, cast
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from openai import AzureOpenAI, Stream
|
from openai import AzureOpenAI, Stream
|
||||||
from openai.types import Completion
|
from openai.types import Completion
|
||||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall
|
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall
|
||||||
from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
|
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
|
||||||
from openai.types.chat.chat_completion_message import FunctionCall
|
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
@ -16,6 +15,7 @@ from core.model_runtime.entities.message_entities import (
|
|||||||
ImagePromptMessageContent,
|
ImagePromptMessageContent,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
PromptMessageContentType,
|
PromptMessageContentType,
|
||||||
|
PromptMessageFunction,
|
||||||
PromptMessageTool,
|
PromptMessageTool,
|
||||||
SystemPromptMessage,
|
SystemPromptMessage,
|
||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
@ -26,7 +26,8 @@ from core.model_runtime.entities.model_entities import AIModelEntity, ModelPrope
|
|||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
|
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
|
||||||
from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS, AzureBaseModel
|
from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS
|
||||||
|
from core.model_runtime.utils import helper
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -39,9 +40,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
stream: bool = True, user: Optional[str] = None) \
|
stream: bool = True, user: Optional[str] = None) \
|
||||||
-> Union[LLMResult, Generator]:
|
-> Union[LLMResult, Generator]:
|
||||||
|
|
||||||
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
|
base_model_name = credentials.get('base_model_name')
|
||||||
|
if not base_model_name:
|
||||||
|
raise ValueError('Base Model Name is required')
|
||||||
|
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
||||||
|
|
||||||
if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
||||||
# chat model
|
# chat model
|
||||||
return self._chat_generate(
|
return self._chat_generate(
|
||||||
model=model,
|
model=model,
|
||||||
@ -65,18 +69,29 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
user=user
|
user=user
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def get_num_tokens(
|
||||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
self,
|
||||||
|
model: str,
|
||||||
model_mode = self._get_ai_model_entity(credentials.get('base_model_name'), model).entity.model_properties.get(
|
credentials: dict,
|
||||||
ModelPropertyKey.MODE)
|
prompt_messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None
|
||||||
|
) -> int:
|
||||||
|
base_model_name = credentials.get('base_model_name')
|
||||||
|
if not base_model_name:
|
||||||
|
raise ValueError('Base Model Name is required')
|
||||||
|
model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
||||||
|
if not model_entity:
|
||||||
|
raise ValueError(f'Base Model Name {base_model_name} is invalid')
|
||||||
|
model_mode = model_entity.entity.model_properties.get(ModelPropertyKey.MODE)
|
||||||
|
|
||||||
if model_mode == LLMMode.CHAT.value:
|
if model_mode == LLMMode.CHAT.value:
|
||||||
# chat model
|
# chat model
|
||||||
return self._num_tokens_from_messages(credentials, prompt_messages, tools)
|
return self._num_tokens_from_messages(credentials, prompt_messages, tools)
|
||||||
else:
|
else:
|
||||||
# text completion model, do not support tool calling
|
# text completion model, do not support tool calling
|
||||||
return self._num_tokens_from_string(credentials, prompt_messages[0].content)
|
content = prompt_messages[0].content
|
||||||
|
assert isinstance(content, str)
|
||||||
|
return self._num_tokens_from_string(credentials,content)
|
||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
if 'openai_api_base' not in credentials:
|
if 'openai_api_base' not in credentials:
|
||||||
@ -88,7 +103,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
if 'base_model_name' not in credentials:
|
if 'base_model_name' not in credentials:
|
||||||
raise CredentialsValidateFailedError('Base Model Name is required')
|
raise CredentialsValidateFailedError('Base Model Name is required')
|
||||||
|
|
||||||
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
|
base_model_name = credentials.get('base_model_name')
|
||||||
|
if not base_model_name:
|
||||||
|
raise CredentialsValidateFailedError('Base Model Name is required')
|
||||||
|
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
||||||
|
|
||||||
if not ai_model_entity:
|
if not ai_model_entity:
|
||||||
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
|
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
|
||||||
@ -118,7 +136,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
raise CredentialsValidateFailedError(str(ex))
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||||
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
|
base_model_name = credentials.get('base_model_name')
|
||||||
|
if not base_model_name:
|
||||||
|
raise ValueError('Base Model Name is required')
|
||||||
|
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
||||||
return ai_model_entity.entity if ai_model_entity else None
|
return ai_model_entity.entity if ai_model_entity else None
|
||||||
|
|
||||||
def _generate(self, model: str, credentials: dict,
|
def _generate(self, model: str, credentials: dict,
|
||||||
@ -149,8 +170,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
|
|
||||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||||
|
|
||||||
def _handle_generate_response(self, model: str, credentials: dict, response: Completion,
|
def _handle_generate_response(
|
||||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
self, model: str, credentials: dict, response: Completion,
|
||||||
|
prompt_messages: list[PromptMessage]
|
||||||
|
):
|
||||||
assistant_text = response.choices[0].text
|
assistant_text = response.choices[0].text
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
@ -165,7 +188,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
completion_tokens = response.usage.completion_tokens
|
completion_tokens = response.usage.completion_tokens
|
||||||
else:
|
else:
|
||||||
# calculate num tokens
|
# calculate num tokens
|
||||||
prompt_tokens = self._num_tokens_from_string(credentials, prompt_messages[0].content)
|
content = prompt_messages[0].content
|
||||||
|
assert isinstance(content, str)
|
||||||
|
prompt_tokens = self._num_tokens_from_string(credentials, content)
|
||||||
completion_tokens = self._num_tokens_from_string(credentials, assistant_text)
|
completion_tokens = self._num_tokens_from_string(credentials, assistant_text)
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
@ -182,8 +207,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _handle_generate_stream_response(self, model: str, credentials: dict, response: Stream[Completion],
|
def _handle_generate_stream_response(
|
||||||
prompt_messages: list[PromptMessage]) -> Generator:
|
self, model: str, credentials: dict, response: Stream[Completion],
|
||||||
|
prompt_messages: list[PromptMessage]
|
||||||
|
) -> Generator:
|
||||||
full_text = ''
|
full_text = ''
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
if len(chunk.choices) == 0:
|
if len(chunk.choices) == 0:
|
||||||
@ -210,7 +237,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
completion_tokens = chunk.usage.completion_tokens
|
completion_tokens = chunk.usage.completion_tokens
|
||||||
else:
|
else:
|
||||||
# calculate num tokens
|
# calculate num tokens
|
||||||
prompt_tokens = self._num_tokens_from_string(credentials, prompt_messages[0].content)
|
content = prompt_messages[0].content
|
||||||
|
assert isinstance(content, str)
|
||||||
|
prompt_tokens = self._num_tokens_from_string(credentials, content)
|
||||||
completion_tokens = self._num_tokens_from_string(credentials, full_text)
|
completion_tokens = self._num_tokens_from_string(credentials, full_text)
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
@ -257,12 +286,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
extra_model_kwargs = {}
|
extra_model_kwargs = {}
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
# extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
|
extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
|
||||||
extra_model_kwargs['functions'] = [{
|
# extra_model_kwargs['functions'] = [{
|
||||||
"name": tool.name,
|
# "name": tool.name,
|
||||||
"description": tool.description,
|
# "description": tool.description,
|
||||||
"parameters": tool.parameters
|
# "parameters": tool.parameters
|
||||||
} for tool in tools]
|
# } for tool in tools]
|
||||||
|
|
||||||
if stop:
|
if stop:
|
||||||
extra_model_kwargs['stop'] = stop
|
extra_model_kwargs['stop'] = stop
|
||||||
@ -271,8 +300,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
extra_model_kwargs['user'] = user
|
extra_model_kwargs['user'] = user
|
||||||
|
|
||||||
# chat model
|
# chat model
|
||||||
|
messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
|
messages=messages,
|
||||||
model=model,
|
model=model,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
**model_parameters,
|
**model_parameters,
|
||||||
@ -284,18 +314,17 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
|
|
||||||
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
||||||
|
|
||||||
def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion,
|
def _handle_chat_generate_response(
|
||||||
|
self, model: str, credentials: dict, response: ChatCompletion,
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
tools: Optional[list[PromptMessageTool]] = None) -> LLMResult:
|
tools: Optional[list[PromptMessageTool]] = None
|
||||||
|
):
|
||||||
assistant_message = response.choices[0].message
|
assistant_message = response.choices[0].message
|
||||||
# assistant_message_tool_calls = assistant_message.tool_calls
|
assistant_message_tool_calls = assistant_message.tool_calls
|
||||||
assistant_message_function_call = assistant_message.function_call
|
|
||||||
|
|
||||||
# extract tool calls from response
|
# extract tool calls from response
|
||||||
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
tool_calls = []
|
||||||
function_call = self._extract_response_function_call(assistant_message_function_call)
|
self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=assistant_message_tool_calls)
|
||||||
tool_calls = [function_call] if function_call else []
|
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
@ -317,7 +346,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||||
|
|
||||||
# transform response
|
# transform response
|
||||||
response = LLMResult(
|
result = LLMResult(
|
||||||
model=response.model or model,
|
model=response.model or model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
message=assistant_prompt_message,
|
message=assistant_prompt_message,
|
||||||
@ -325,59 +354,35 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
system_fingerprint=response.system_fingerprint,
|
system_fingerprint=response.system_fingerprint,
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return result
|
||||||
|
|
||||||
def _handle_chat_generate_stream_response(self, model: str, credentials: dict,
|
def _handle_chat_generate_stream_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
response: Stream[ChatCompletionChunk],
|
response: Stream[ChatCompletionChunk],
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
tools: Optional[list[PromptMessageTool]] = None) -> Generator:
|
tools: Optional[list[PromptMessageTool]] = None
|
||||||
|
):
|
||||||
index = 0
|
index = 0
|
||||||
full_assistant_content = ''
|
full_assistant_content = ''
|
||||||
delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None
|
|
||||||
real_model = model
|
real_model = model
|
||||||
system_fingerprint = None
|
system_fingerprint = None
|
||||||
completion = ''
|
completion = ''
|
||||||
|
tool_calls = []
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
if len(chunk.choices) == 0:
|
if len(chunk.choices) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
delta = chunk.choices[0]
|
delta = chunk.choices[0]
|
||||||
|
|
||||||
|
# extract tool calls from response
|
||||||
|
self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=delta.delta.tool_calls)
|
||||||
|
|
||||||
# Handling exceptions when content filters' streaming mode is set to asynchronous modified filter
|
# Handling exceptions when content filters' streaming mode is set to asynchronous modified filter
|
||||||
if delta.delta is None or (
|
if delta.finish_reason is None and not delta.delta.content:
|
||||||
delta.finish_reason is None
|
|
||||||
and (delta.delta.content is None or delta.delta.content == '')
|
|
||||||
and delta.delta.function_call is None
|
|
||||||
):
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# assistant_message_tool_calls = delta.delta.tool_calls
|
|
||||||
assistant_message_function_call = delta.delta.function_call
|
|
||||||
|
|
||||||
# extract tool calls from response
|
|
||||||
if delta_assistant_message_function_call_storage is not None:
|
|
||||||
# handle process of stream function call
|
|
||||||
if assistant_message_function_call:
|
|
||||||
# message has not ended ever
|
|
||||||
delta_assistant_message_function_call_storage.arguments += assistant_message_function_call.arguments
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
# message has ended
|
|
||||||
assistant_message_function_call = delta_assistant_message_function_call_storage
|
|
||||||
delta_assistant_message_function_call_storage = None
|
|
||||||
else:
|
|
||||||
if assistant_message_function_call:
|
|
||||||
# start of stream function call
|
|
||||||
delta_assistant_message_function_call_storage = assistant_message_function_call
|
|
||||||
if delta_assistant_message_function_call_storage.arguments is None:
|
|
||||||
delta_assistant_message_function_call_storage.arguments = ''
|
|
||||||
continue
|
|
||||||
|
|
||||||
# extract tool calls from response
|
|
||||||
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
|
||||||
function_call = self._extract_response_function_call(assistant_message_function_call)
|
|
||||||
tool_calls = [function_call] if function_call else []
|
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
content=delta.delta.content if delta.delta.content else '',
|
content=delta.delta.content if delta.delta.content else '',
|
||||||
@ -426,12 +431,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_response_tool_calls(response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \
|
def _update_tool_calls(tool_calls: list[AssistantPromptMessage.ToolCall], tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]]) -> None:
|
||||||
-> list[AssistantPromptMessage.ToolCall]:
|
if tool_calls_response:
|
||||||
|
for response_tool_call in tool_calls_response:
|
||||||
tool_calls = []
|
if isinstance(response_tool_call, ChatCompletionMessageToolCall):
|
||||||
if response_tool_calls:
|
|
||||||
for response_tool_call in response_tool_calls:
|
|
||||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
name=response_tool_call.function.name,
|
name=response_tool_call.function.name,
|
||||||
arguments=response_tool_call.function.arguments
|
arguments=response_tool_call.function.arguments
|
||||||
@ -443,37 +446,41 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
function=function
|
function=function
|
||||||
)
|
)
|
||||||
tool_calls.append(tool_call)
|
tool_calls.append(tool_call)
|
||||||
|
elif isinstance(response_tool_call, ChoiceDeltaToolCall):
|
||||||
|
index = response_tool_call.index
|
||||||
|
if index < len(tool_calls):
|
||||||
|
tool_calls[index].id = response_tool_call.id or tool_calls[index].id
|
||||||
|
tool_calls[index].type = response_tool_call.type or tool_calls[index].type
|
||||||
|
if response_tool_call.function:
|
||||||
|
tool_calls[index].function.name = response_tool_call.function.name or tool_calls[index].function.name
|
||||||
|
tool_calls[index].function.arguments += response_tool_call.function.arguments or ''
|
||||||
|
else:
|
||||||
|
assert response_tool_call.id is not None
|
||||||
|
assert response_tool_call.type is not None
|
||||||
|
assert response_tool_call.function is not None
|
||||||
|
assert response_tool_call.function.name is not None
|
||||||
|
assert response_tool_call.function.arguments is not None
|
||||||
|
|
||||||
return tool_calls
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _extract_response_function_call(response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \
|
|
||||||
-> AssistantPromptMessage.ToolCall:
|
|
||||||
|
|
||||||
tool_call = None
|
|
||||||
if response_function_call:
|
|
||||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
name=response_function_call.name,
|
name=response_tool_call.function.name,
|
||||||
arguments=response_function_call.arguments
|
arguments=response_tool_call.function.arguments
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_call = AssistantPromptMessage.ToolCall(
|
tool_call = AssistantPromptMessage.ToolCall(
|
||||||
id=response_function_call.name,
|
id=response_tool_call.id,
|
||||||
type="function",
|
type=response_tool_call.type,
|
||||||
function=function
|
function=function
|
||||||
)
|
)
|
||||||
|
tool_calls.append(tool_call)
|
||||||
return tool_call
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _convert_prompt_message_to_dict(message: PromptMessage) -> dict:
|
def _convert_prompt_message_to_dict(message: PromptMessage):
|
||||||
|
|
||||||
if isinstance(message, UserPromptMessage):
|
if isinstance(message, UserPromptMessage):
|
||||||
message = cast(UserPromptMessage, message)
|
message = cast(UserPromptMessage, message)
|
||||||
if isinstance(message.content, str):
|
if isinstance(message.content, str):
|
||||||
message_dict = {"role": "user", "content": message.content}
|
message_dict = {"role": "user", "content": message.content}
|
||||||
else:
|
else:
|
||||||
sub_messages = []
|
sub_messages = []
|
||||||
|
assert message.content is not None
|
||||||
for message_content in message.content:
|
for message_content in message.content:
|
||||||
if message_content.type == PromptMessageContentType.TEXT:
|
if message_content.type == PromptMessageContentType.TEXT:
|
||||||
message_content = cast(TextPromptMessageContent, message_content)
|
message_content = cast(TextPromptMessageContent, message_content)
|
||||||
@ -492,33 +499,22 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
sub_messages.append(sub_message_dict)
|
sub_messages.append(sub_message_dict)
|
||||||
|
|
||||||
message_dict = {"role": "user", "content": sub_messages}
|
message_dict = {"role": "user", "content": sub_messages}
|
||||||
elif isinstance(message, AssistantPromptMessage):
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
message = cast(AssistantPromptMessage, message)
|
message = cast(AssistantPromptMessage, message)
|
||||||
message_dict = {"role": "assistant", "content": message.content}
|
message_dict = {"role": "assistant", "content": message.content}
|
||||||
if message.tool_calls:
|
if message.tool_calls:
|
||||||
# message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in
|
message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in message.tool_calls]
|
||||||
# message.tool_calls]
|
|
||||||
function_call = message.tool_calls[0]
|
|
||||||
message_dict["function_call"] = {
|
|
||||||
"name": function_call.function.name,
|
|
||||||
"arguments": function_call.function.arguments,
|
|
||||||
}
|
|
||||||
elif isinstance(message, SystemPromptMessage):
|
elif isinstance(message, SystemPromptMessage):
|
||||||
message = cast(SystemPromptMessage, message)
|
message = cast(SystemPromptMessage, message)
|
||||||
message_dict = {"role": "system", "content": message.content}
|
message_dict = {"role": "system", "content": message.content}
|
||||||
elif isinstance(message, ToolPromptMessage):
|
elif isinstance(message, ToolPromptMessage):
|
||||||
message = cast(ToolPromptMessage, message)
|
message = cast(ToolPromptMessage, message)
|
||||||
# message_dict = {
|
|
||||||
# "role": "tool",
|
|
||||||
# "content": message.content,
|
|
||||||
# "tool_call_id": message.tool_call_id
|
|
||||||
# }
|
|
||||||
message_dict = {
|
message_dict = {
|
||||||
"role": "function",
|
"role": "tool",
|
||||||
|
"name": message.name,
|
||||||
"content": message.content,
|
"content": message.content,
|
||||||
"name": message.tool_call_id
|
"tool_call_id": message.tool_call_id
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown type {message}")
|
raise ValueError(f"Got unknown type {message}")
|
||||||
@ -542,8 +538,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
|
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
def _num_tokens_from_messages(self, credentials: dict, messages: list[PromptMessage],
|
def _num_tokens_from_messages(
|
||||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
self, credentials: dict, messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None
|
||||||
|
) -> int:
|
||||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||||
|
|
||||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||||
@ -591,6 +589,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
|
|
||||||
if key == "tool_calls":
|
if key == "tool_calls":
|
||||||
for tool_call in value:
|
for tool_call in value:
|
||||||
|
assert isinstance(tool_call, dict)
|
||||||
for t_key, t_value in tool_call.items():
|
for t_key, t_value in tool_call.items():
|
||||||
num_tokens += len(encoding.encode(t_key))
|
num_tokens += len(encoding.encode(t_key))
|
||||||
if t_key == "function":
|
if t_key == "function":
|
||||||
@ -631,12 +630,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
num_tokens += len(encoding.encode('parameters'))
|
num_tokens += len(encoding.encode('parameters'))
|
||||||
if 'title' in parameters:
|
if 'title' in parameters:
|
||||||
num_tokens += len(encoding.encode('title'))
|
num_tokens += len(encoding.encode('title'))
|
||||||
num_tokens += len(encoding.encode(parameters.get("title")))
|
num_tokens += len(encoding.encode(parameters['title']))
|
||||||
num_tokens += len(encoding.encode('type'))
|
num_tokens += len(encoding.encode('type'))
|
||||||
num_tokens += len(encoding.encode(parameters.get("type")))
|
num_tokens += len(encoding.encode(parameters['type']))
|
||||||
if 'properties' in parameters:
|
if 'properties' in parameters:
|
||||||
num_tokens += len(encoding.encode('properties'))
|
num_tokens += len(encoding.encode('properties'))
|
||||||
for key, value in parameters.get('properties').items():
|
for key, value in parameters['properties'].items():
|
||||||
num_tokens += len(encoding.encode(key))
|
num_tokens += len(encoding.encode(key))
|
||||||
for field_key, field_value in value.items():
|
for field_key, field_value in value.items():
|
||||||
num_tokens += len(encoding.encode(field_key))
|
num_tokens += len(encoding.encode(field_key))
|
||||||
@ -656,7 +655,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
|
def _get_ai_model_entity(base_model_name: str, model: str):
|
||||||
for ai_model_entity in LLM_BASE_MODELS:
|
for ai_model_entity in LLM_BASE_MODELS:
|
||||||
if ai_model_entity.base_model_name == base_model_name:
|
if ai_model_entity.base_model_name == base_model_name:
|
||||||
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
|
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
|
||||||
@ -664,5 +663,3 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
ai_model_entity_copy.entity.label.en_US = model
|
ai_model_entity_copy.entity.label.en_US = model
|
||||||
ai_model_entity_copy.entity.label.zh_Hans = model
|
ai_model_entity_copy.entity.label.zh_Hans = model
|
||||||
return ai_model_entity_copy
|
return ai_model_entity_copy
|
||||||
|
|
||||||
return None
|
|
||||||
|
@ -73,15 +73,13 @@ class MockChatClass:
|
|||||||
return FunctionCall(name=function_name, arguments=dumps(parameters))
|
return FunctionCall(name=function_name, arguments=dumps(parameters))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_tool_calls(
|
def generate_tool_calls(tools = NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]:
|
||||||
tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
|
|
||||||
) -> Optional[list[ChatCompletionMessageToolCall]]:
|
|
||||||
list_tool_calls = []
|
list_tool_calls = []
|
||||||
if not tools or len(tools) == 0:
|
if not tools or len(tools) == 0:
|
||||||
return None
|
return None
|
||||||
tool: ChatCompletionToolParam = tools[0]
|
tool = tools[0]
|
||||||
|
|
||||||
if tools['type'] != 'function':
|
if 'type' in tools and tools['type'] != 'function':
|
||||||
return None
|
return None
|
||||||
|
|
||||||
function = tool['function']
|
function = tool['function']
|
||||||
|
Loading…
x
Reference in New Issue
Block a user