From a2c068d94925d040d842e5a72cb5df8ab334f0aa Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Tue, 9 Apr 2024 15:30:09 +0800 Subject: [PATCH] feat: moonshot function call (#3227) --- api/core/agent/cot_agent_runner.py | 2 +- api/core/agent/fc_agent_runner.py | 24 +- .../model_providers/moonshot/llm/llm.py | 320 +++++++++++++++++- .../model_providers/moonshot/moonshot.yaml | 49 +++ .../openai_api_compatible/llm/llm.py | 49 ++- 5 files changed, 423 insertions(+), 21 deletions(-) diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index d57f15638c..4ad5df5cfc 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -687,4 +687,4 @@ class CotAgentRunner(BaseAgentRunner): try: return json.dumps(tools, ensure_ascii=False) except json.JSONDecodeError: - return json.dumps(tools) + return json.dumps(tools) \ No newline at end of file diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index e66500d327..732a6ee750 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -207,19 +207,25 @@ class FunctionCallAgentRunner(BaseAgentRunner): ) ) + assistant_message = AssistantPromptMessage( + content='', + tool_calls=[] + ) if tool_calls: - prompt_messages.append(AssistantPromptMessage( - content='', - name='', - tool_calls=[AssistantPromptMessage.ToolCall( + assistant_message.tool_calls=[ + AssistantPromptMessage.ToolCall( id=tool_call[0], type='function', function=AssistantPromptMessage.ToolCall.ToolCallFunction( name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False) ) - ) for tool_call in tool_calls] - )) + ) for tool_call in tool_calls + ] + else: + assistant_message.content = response + + prompt_messages.append(assistant_message) # save thought self.save_agent_thought( @@ -239,12 +245,6 @@ class FunctionCallAgentRunner(BaseAgentRunner): final_answer += response + '\n' - # update prompt messages - if response.strip(): - prompt_messages.append(AssistantPromptMessage( - content=response, - )) - # call tools tool_responses = [] for tool_call_id, tool_call_name, tool_call_args in tool_calls: diff --git a/api/core/model_runtime/model_providers/moonshot/llm/llm.py b/api/core/model_runtime/model_providers/moonshot/llm/llm.py index 05feee877e..90045b210e 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/llm.py +++ b/api/core/model_runtime/model_providers/moonshot/llm/llm.py @@ -1,8 +1,31 @@ +import json from collections.abc import Generator -from typing import Optional, Union +from typing import Optional, Union, cast -from core.model_runtime.entities.llm_entities import LLMResult -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +import requests + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContent, + PromptMessageContentType, + PromptMessageTool, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, +) from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel @@ -13,6 +36,7 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) + self._add_function_call(model, credentials) user = user[:32] if user else None return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) @@ -20,7 +44,293 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): self._add_custom_parameters(credentials) super().validate_credentials(model, credentials) - @staticmethod - def _add_custom_parameters(credentials: dict) -> None: + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + return AIModelEntity( + model=model, + label=I18nObject(en_US=model, zh_Hans=model), + model_type=ModelType.LLM, + features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] + if credentials.get('function_calling_type') == 'tool_call' + else [], + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 4096)), + ModelPropertyKey.MODE: LLMMode.CHAT.value, + }, + parameter_rules=[ + ParameterRule( + name='temperature', + use_template='temperature', + label=I18nObject(en_US='Temperature', zh_Hans='温度'), + type=ParameterType.FLOAT, + ), + ParameterRule( + name='max_tokens', + use_template='max_tokens', + default=512, + min=1, + max=int(credentials.get('max_tokens', 4096)), + label=I18nObject(en_US='Max Tokens', zh_Hans='最大标记'), + type=ParameterType.INT, + ), + ParameterRule( + name='top_p', + use_template='top_p', + label=I18nObject(en_US='Top P', zh_Hans='Top P'), + type=ParameterType.FLOAT, + ), + ] + ) + + def _add_custom_parameters(self, credentials: dict) -> None: credentials['mode'] = 'chat' credentials['endpoint_url'] = 'https://api.moonshot.cn/v1' + + def _add_function_call(self, model: str, credentials: dict) -> None: + model_schema = self.get_model_schema(model, credentials) + if model_schema and set([ + ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL + ]).intersection(model_schema.features or []): + credentials['function_calling_type'] = 'tool_call' + + def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: + """ + Convert PromptMessage to dict for OpenAI API format + """ + if isinstance(message, UserPromptMessage): + message = cast(UserPromptMessage, message) + if isinstance(message.content, str): + message_dict = {"role": "user", "content": message.content} + else: + sub_messages = [] + for message_content in message.content: + if message_content.type == PromptMessageContentType.TEXT: + message_content = cast(PromptMessageContent, message_content) + sub_message_dict = { + "type": "text", + "text": message_content.data + } + sub_messages.append(sub_message_dict) + elif message_content.type == PromptMessageContentType.IMAGE: + message_content = cast(ImagePromptMessageContent, message_content) + sub_message_dict = { + "type": "image_url", + "image_url": { + "url": message_content.data, + "detail": message_content.detail.value + } + } + sub_messages.append(sub_message_dict) + message_dict = {"role": "user", "content": sub_messages} + elif isinstance(message, AssistantPromptMessage): + message = cast(AssistantPromptMessage, message) + message_dict = {"role": "assistant", "content": message.content} + if message.tool_calls: + message_dict["tool_calls"] = [] + for function_call in message.tool_calls: + message_dict["tool_calls"].append({ + "id": function_call.id, + "type": function_call.type, + "function": { + "name": f"functions.{function_call.function.name}", + "arguments": function_call.function.arguments + } + }) + elif isinstance(message, ToolPromptMessage): + message = cast(ToolPromptMessage, message) + message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id} + if not message.name.startswith("functions."): + message.name = f"functions.{message.name}" + elif isinstance(message, SystemPromptMessage): + message = cast(SystemPromptMessage, message) + message_dict = {"role": "system", "content": message.content} + else: + raise ValueError(f"Got unknown type {message}") + + if message.name: + message_dict["name"] = message.name + + return message_dict + + def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]: + """ + Extract tool calls from response + + :param response_tool_calls: response tool calls + :return: list of tool calls + """ + tool_calls = [] + if response_tool_calls: + for response_tool_call in response_tool_calls: + function = AssistantPromptMessage.ToolCall.ToolCallFunction( + name=response_tool_call["function"]["name"] if response_tool_call.get("function", {}).get("name") else "", + arguments=response_tool_call["function"]["arguments"] if response_tool_call.get("function", {}).get("arguments") else "" + ) + + tool_call = AssistantPromptMessage.ToolCall( + id=response_tool_call["id"] if response_tool_call.get("id") else "", + type=response_tool_call["type"] if response_tool_call.get("type") else "", + function=function + ) + tool_calls.append(tool_call) + + return tool_calls + + def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response, + prompt_messages: list[PromptMessage]) -> Generator: + """ + Handle llm stream response + + :param model: model name + :param credentials: model credentials + :param response: streamed response + :param prompt_messages: prompt messages + :return: llm response chunk generator + """ + full_assistant_content = '' + chunk_index = 0 + + def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \ + -> LLMResultChunk: + # calculate num tokens + prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) + completion_tokens = self._num_tokens_from_string(model, full_assistant_content) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + return LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=message, + finish_reason=finish_reason, + usage=usage + ) + ) + + tools_calls: list[AssistantPromptMessage.ToolCall] = [] + finish_reason = "Unknown" + + def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]): + def get_tool_call(tool_name: str): + if not tool_name: + return tools_calls[-1] + + tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None) + if tool_call is None: + tool_call = AssistantPromptMessage.ToolCall( + id='', + type='', + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments="") + ) + tools_calls.append(tool_call) + + return tool_call + + for new_tool_call in new_tool_calls: + # get tool call + tool_call = get_tool_call(new_tool_call.function.name) + # update tool call + if new_tool_call.id: + tool_call.id = new_tool_call.id + if new_tool_call.type: + tool_call.type = new_tool_call.type + if new_tool_call.function.name: + # remove the functions. prefix + if new_tool_call.function.name.startswith('functions.'): + parts = new_tool_call.function.name.split('functions.') + if len(parts) > 1: + new_tool_call.function.name = parts[1] + tool_call.function.name = new_tool_call.function.name + if new_tool_call.function.arguments: + tool_call.function.arguments += new_tool_call.function.arguments + + for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"): + if chunk: + # ignore sse comments + if chunk.startswith(':'): + continue + decoded_chunk = chunk.strip().lstrip('data: ').lstrip() + chunk_json = None + try: + chunk_json = json.loads(decoded_chunk) + # stream ended + except json.JSONDecodeError as e: + yield create_final_llm_result_chunk( + index=chunk_index + 1, + message=AssistantPromptMessage(content=""), + finish_reason="Non-JSON encountered." + ) + break + if not chunk_json or len(chunk_json['choices']) == 0: + continue + + choice = chunk_json['choices'][0] + finish_reason = chunk_json['choices'][0].get('finish_reason') + chunk_index += 1 + + if 'delta' in choice: + delta = choice['delta'] + delta_content = delta.get('content') + + assistant_message_tool_calls = delta.get('tool_calls', None) + # assistant_message_function_call = delta.delta.function_call + + # extract tool calls from response + if assistant_message_tool_calls: + tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) + increase_tool_call(tool_calls) + + if delta_content is None or delta_content == '': + continue + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage( + content=delta_content, + tool_calls=tool_calls if assistant_message_tool_calls else [] + ) + + full_assistant_content += delta_content + elif 'text' in choice: + choice_text = choice.get('text', '') + if choice_text == '': + continue + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage(content=choice_text) + full_assistant_content += choice_text + else: + continue + + # check payload indicator for completion + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=chunk_index, + message=assistant_prompt_message, + ) + ) + + chunk_index += 1 + + if tools_calls: + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=chunk_index, + message=AssistantPromptMessage( + tool_calls=tools_calls, + content="" + ), + ) + ) + + yield create_final_llm_result_chunk( + index=chunk_index, + message=AssistantPromptMessage(content=""), + finish_reason=finish_reason + ) \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/moonshot/moonshot.yaml b/api/core/model_runtime/model_providers/moonshot/moonshot.yaml index 1885ee9d94..34c802c2a7 100644 --- a/api/core/model_runtime/model_providers/moonshot/moonshot.yaml +++ b/api/core/model_runtime/model_providers/moonshot/moonshot.yaml @@ -20,6 +20,7 @@ supported_model_types: - llm configurate_methods: - predefined-model + - customizable-model provider_credential_schema: credential_form_schemas: - variable: api_key @@ -30,3 +31,51 @@ provider_credential_schema: placeholder: zh_Hans: 在此输入您的 API Key en_US: Enter your API Key +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your model name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key + - variable: context_size + label: + zh_Hans: 模型上下文长度 + en_US: Model context size + required: true + type: text-input + default: '4096' + placeholder: + zh_Hans: 在此输入您的模型上下文长度 + en_US: Enter your Model context size + - variable: max_tokens + label: + zh_Hans: 最大 token 上限 + en_US: Upper bound for max tokens + default: '4096' + type: text-input + - variable: function_calling_type + label: + en_US: Function calling + type: select + required: false + default: no_call + options: + - value: no_call + label: + en_US: Not supported + zh_Hans: 不支持 + - value: tool_call + label: + en_US: Tool Call + zh_Hans: Tool Call diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index 8cfec0e34b..04f5024207 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -378,6 +378,34 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): delimiter = credentials.get("stream_mode_delimiter", "\n\n") delimiter = codecs.decode(delimiter, "unicode_escape") + tools_calls: list[AssistantPromptMessage.ToolCall] = [] + + def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]): + def get_tool_call(tool_call_id: str): + tool_call = next( + (tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None + ) + if tool_call is None: + tool_call = AssistantPromptMessage.ToolCall( + id='', + type='function', + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name='', + arguments='' + ) + ) + tools_calls.append(tool_call) + return tool_call + + for new_tool_call in new_tool_calls: + # get tool call + tool_call = get_tool_call(new_tool_call.id) + # update tool call + tool_call.id = new_tool_call.id + tool_call.type = new_tool_call.type + tool_call.function.name = new_tool_call.function.name + tool_call.function.arguments += new_tool_call.function.arguments + for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter): if chunk: # ignore sse comments @@ -405,8 +433,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): if 'delta' in choice: delta = choice['delta'] delta_content = delta.get('content') - if delta_content is None or delta_content == '': - continue assistant_message_tool_calls = delta.get('tool_calls', None) # assistant_message_function_call = delta.delta.function_call @@ -414,6 +440,11 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): # extract tool calls from response if assistant_message_tool_calls: tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) + increase_tool_call(tool_calls) + + if delta_content is None or delta_content == '': + continue + # function_call = self._extract_response_function_call(assistant_message_function_call) # tool_calls = [function_call] if function_call else [] @@ -437,6 +468,18 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): # check payload indicator for completion if finish_reason is not None: + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=chunk_index, + message=AssistantPromptMessage( + tool_calls=tools_calls, + ), + finish_reason=finish_reason + ) + ) + yield create_final_llm_result_chunk( index=chunk_index, message=assistant_prompt_message, @@ -735,4 +778,4 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): function=function ) - return tool_call + return tool_call \ No newline at end of file