From a1f06a4fdc9ae93e66a706d3c7b783b587ac0607 Mon Sep 17 00:00:00 2001 From: Song Fuchang Date: Fri, 16 May 2025 16:32:19 +0800 Subject: [PATCH] Feat: Support tool calling in Generate component (#7572) ### What problem does this PR solve? Hello, our use case requires LLM agent to invoke some tools, so I made a simple implementation here. This PR does two things: 1. A simple plugin mechanism based on `pluginlib`: This mechanism lives in the `plugin` directory. It will only load plugins from `plugin/embedded_plugins` for now. A sample plugin `bad_calculator.py` is placed in `plugin/embedded_plugins/llm_tools`, it accepts two numbers `a` and `b`, then give a wrong result `a + b + 100`. In the future, it can load plugins from external location with little code change. Plugins are divided into different types. The only plugin type supported in this PR is `llm_tools`, which must implement the `LLMToolPlugin` class in the `plugin/llm_tool_plugin.py`. More plugin types can be added in the future. 2. A tool selector in the `Generate` component: Added a tool selector to select one or more tools for LLM: ![image](https://github.com/user-attachments/assets/74a21fdf-9333-4175-991b-43df6524c5dc) And with the `bad_calculator` tool, it results this with the `qwen-max` model: ![image](https://github.com/user-attachments/assets/93aff9c4-8550-414a-90a2-1a15a5249d94) ### Type of change - [ ] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality) - [ ] Documentation Update - [ ] Refactoring - [ ] Performance Improvement - [ ] Other (please describe): Co-authored-by: Yingfeng --- Dockerfile | 1 + Dockerfile.scratch.oc9 | 1 + agent/component/generate.py | 24 ++++ api/apps/plugin_app.py | 12 ++ api/db/services/llm_service.py | 1 + api/ragflow_server.py | 3 + plugin/README.md | 97 ++++++++++++++ plugin/README_zh.md | 98 ++++++++++++++ plugin/__init__.py | 3 + plugin/common.py | 1 + .../llm_tools/bad_calculator.py | 37 +++++ plugin/llm_tool_plugin.py | 51 +++++++ plugin/plugin_manager.py | 45 +++++++ pyproject.toml | 3 +- rag/llm/chat_model.py | 126 ++++++++++-------- uv.lock | 14 ++ web/src/components/llm-select/index.tsx | 18 ++- .../components/llm-setting-items/index.tsx | 4 +- web/src/components/llm-tools-select.tsx | 51 +++++++ web/src/hooks/llm-hooks.tsx | 3 +- web/src/hooks/plugin-hooks.tsx | 17 +++ web/src/interfaces/database/llm.ts | 1 + web/src/interfaces/database/plugin.ts | 13 ++ web/src/locales/en.ts | 12 ++ web/src/locales/zh.ts | 12 ++ .../pages/flow/form/generate-form/index.tsx | 17 ++- web/src/services/plugin-service.ts | 18 +++ web/src/utils/api.ts | 3 + 28 files changed, 625 insertions(+), 61 deletions(-) create mode 100644 api/apps/plugin_app.py create mode 100644 plugin/README.md create mode 100644 plugin/README_zh.md create mode 100644 plugin/__init__.py create mode 100644 plugin/common.py create mode 100644 plugin/embedded_plugins/llm_tools/bad_calculator.py create mode 100644 plugin/llm_tool_plugin.py create mode 100644 plugin/plugin_manager.py create mode 100644 web/src/components/llm-tools-select.tsx create mode 100644 web/src/hooks/plugin-hooks.tsx create mode 100644 web/src/interfaces/database/plugin.ts create mode 100644 web/src/services/plugin-service.ts diff --git a/Dockerfile b/Dockerfile index d01049bb6..47c533b9d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -199,6 +199,7 @@ COPY graphrag graphrag COPY agentic_reasoning agentic_reasoning COPY pyproject.toml uv.lock ./ COPY mcp mcp +COPY plugin plugin COPY docker/service_conf.yaml.template ./conf/service_conf.yaml.template COPY docker/entrypoint.sh ./ diff --git a/Dockerfile.scratch.oc9 b/Dockerfile.scratch.oc9 index 50a93a6a0..64424735e 100644 --- a/Dockerfile.scratch.oc9 +++ b/Dockerfile.scratch.oc9 @@ -33,6 +33,7 @@ ADD ./rag ./rag ADD ./requirements.txt ./requirements.txt ADD ./agent ./agent ADD ./graphrag ./graphrag +ADD ./plugin ./plugin RUN dnf install -y openmpi openmpi-devel python3-openmpi ENV C_INCLUDE_PATH /usr/include/openmpi-x86_64:$C_INCLUDE_PATH diff --git a/agent/component/generate.py b/agent/component/generate.py index ed0eeeeee..f0cdb1f15 100644 --- a/agent/component/generate.py +++ b/agent/component/generate.py @@ -16,15 +16,29 @@ import json import re from functools import partial +from typing import Any import pandas as pd from api.db import LLMType from api.db.services.conversation_service import structure_answer from api.db.services.llm_service import LLMBundle from api import settings from agent.component.base import ComponentBase, ComponentParamBase +from plugin import GlobalPluginManager +from plugin.llm_tool_plugin import llm_tool_metadata_to_openai_tool +from rag.llm.chat_model import ToolCallSession from rag.prompts import message_fit_in +class LLMToolPluginCallSession(ToolCallSession): + def tool_call(self, name: str, arguments: dict[str, Any]) -> str: + tool = GlobalPluginManager.get_llm_tool_by_name(name) + + if tool is None: + raise ValueError(f"LLM tool {name} does not exist") + + return tool().invoke(**arguments) + + class GenerateParam(ComponentParamBase): """ Define the Generate component parameters. @@ -41,6 +55,7 @@ class GenerateParam(ComponentParamBase): self.frequency_penalty = 0 self.cite = True self.parameters = [] + self.llm_enabled_tools = [] def check(self): self.check_decimal_float(self.temperature, "[Generate] Temperature") @@ -133,6 +148,15 @@ class Generate(ComponentBase): def _run(self, history, **kwargs): chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id) + + if len(self._param.llm_enabled_tools) > 0: + tools = GlobalPluginManager.get_llm_tools_by_names(self._param.llm_enabled_tools) + + chat_mdl.bind_tools( + LLMToolPluginCallSession(), + [llm_tool_metadata_to_openai_tool(t.get_metadata()) for t in tools] + ) + prompt = self._param.prompt retrieval_res = [] diff --git a/api/apps/plugin_app.py b/api/apps/plugin_app.py new file mode 100644 index 000000000..dcd209daa --- /dev/null +++ b/api/apps/plugin_app.py @@ -0,0 +1,12 @@ +from flask import Response +from flask_login import login_required +from api.utils.api_utils import get_json_result +from plugin import GlobalPluginManager + +@manager.route('/llm_tools', methods=['GET']) # noqa: F821 +@login_required +def llm_tools() -> Response: + tools = GlobalPluginManager.get_llm_tools() + tools_metadata = [t.get_metadata() for t in tools] + + return get_json_result(data=tools_metadata) diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 8c20812db..02e66944e 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -226,6 +226,7 @@ class LLMBundle: def bind_tools(self, toolcall_session, tools): if not self.is_tools: + logging.warning(f"Model {self.llm_name} does not support tool call, but you have assigned one or more tools to it!") return self.mdl.bind_tools(toolcall_session, tools) diff --git a/api/ragflow_server.py b/api/ragflow_server.py index 1b6775c06..024492cec 100644 --- a/api/ragflow_server.py +++ b/api/ragflow_server.py @@ -19,6 +19,7 @@ # beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code from api.utils.log_utils import initRootLogger +from plugin import GlobalPluginManager initRootLogger("ragflow_server") import logging @@ -119,6 +120,8 @@ if __name__ == '__main__': RuntimeConfig.init_env() RuntimeConfig.init_config(JOB_SERVER_HOST=settings.HOST_IP, HTTP_PORT=settings.HOST_PORT) + GlobalPluginManager.load_plugins() + signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) diff --git a/plugin/README.md b/plugin/README.md new file mode 100644 index 000000000..1de01e2c4 --- /dev/null +++ b/plugin/README.md @@ -0,0 +1,97 @@ +# Plugins + +This directory contains the plugin mechanism for RAGFlow. + +RAGFlow will load plugins from `embedded_plugins` subdirectory recursively. + +## Supported plugin types + +Currently, the only supported plugin type is `llm_tools`. + +- `llm_tools`: A tool for LLM to call. + +## How to add a plugin + +Add a LLM tool plugin is simple: create a plugin file, put a class inherits the `LLMToolPlugin` class in it, then implement the `get_metadata` and the `invoke` methods. + +- `get_metadata` method: This method returns a `LLMToolMetadata` object, which contains the description of this tool. +The description will be provided to LLM, and the RAGFlow web frontend for displaying. + +- `invoke` method: This method accepts parameters generated by LLM, and return a `str` containing the tool execution result. +All the execution logic of this tool should go into this method. + +When you start RAGFlow, you can see your plugin was loaded in the log: + +``` +2025-05-15 19:29:08,959 INFO 34670 Recursively importing plugins from path `/some-path/ragflow/plugin/embedded_plugins` +2025-05-15 19:29:08,960 INFO 34670 Loaded llm_tools plugin BadCalculatorPlugin version 1.0.0 +``` + +Or it may contain some errors for you to fix your plugin. + +### Demo + +We will demonstrate how to add a plugin with a calculator tool which will give wrong answers. + +First, create a plugin file `bad_calculator.py` under the `embedded_plugins/llm_tools` directory. + +Then, we create a `BadCalculatorPlugin` class, extending the `LLMToolPlugin` base class: + +```python +class BadCalculatorPlugin(LLMToolPlugin): + _version_ = "1.0.0" +``` + +The `_version_` field is required, which specifies the version of the plugin. + +Our calculator has two numbers `a` and `b` as inputs, so we add a `invoke` method to our `BadCalculatorPlugin` class: + +```python +def invoke(self, a: int, b: int) -> str: + return str(a + b + 100) +``` + +The `invoke` method will be called by LLM. It can have many parameters, but the return type must be a `str`. + +Finally, we have to add a `get_metadata` method, to tell LLM how to use our `bad_calculator`: + +```python +@classmethod +def get_metadata(cls) -> LLMToolMetadata: + return { + # Name of this tool, providing to LLM + "name": "bad_calculator", + # Display name of this tool, providing to RAGFlow frontend + "displayName": "$t:bad_calculator.name", + # Description of the usage of this tool, providing to LLM + "description": "A tool to calculate the sum of two numbers (will give wrong answer)", + # Description of this tool, providing to RAGFlow frontend + "displayDescription": "$t:bad_calculator.description", + # Parameters of this tool + "parameters": { + # The first parameter - a + "a": { + # Parameter type, options are: number, string, or whatever the LLM can recognise + "type": "number", + # Description of this parameter, providing to LLM + "description": "The first number", + # Description of this parameter, provding to RAGFlow frontend + "displayDescription": "$t:bad_calculator.params.a", + # Whether this parameter is required + "required": True + }, + # The second parameter - b + "b": { + "type": "number", + "description": "The second number", + "displayDescription": "$t:bad_calculator.params.b", + "required": True + } + } +``` + +The `get_metadata` method is a `classmethod`. It will provide the description of this tool to LLM. + +The fields starts with `display` can use a special notation: `$t:xxx`, which will use the i18n mechanism in the RAGFlow frontend, getting text from the `llmTools` category. The frontend will display what you put here if you don't use this notation. + +Now our tool is ready. You can select it in the `Generate` component and try it out. diff --git a/plugin/README_zh.md b/plugin/README_zh.md new file mode 100644 index 000000000..17b3dd703 --- /dev/null +++ b/plugin/README_zh.md @@ -0,0 +1,98 @@ +# 插件 + +这个文件夹包含了RAGFlow的插件机制。 + +RAGFlow将会从`embedded_plugins`子文件夹中递归加载所有的插件。 + +## 支持的插件类型 + +目前,唯一支持的插件类型是`llm_tools`。 + +- `llm_tools`:用于供LLM进行调用的工具。 + +## 如何添加一个插件 + +添加一个LLM工具插件是很简单的:创建一个插件文件,向其中放一个继承自`LLMToolPlugin`的类,再实现它的`get_metadata`和`invoke`方法即可。 + +- `get_metadata`方法:这个方法返回一个`LLMToolMetadata`对象,其中包含了对这个工具的描述。 +这些描述信息将被提供给LLM进行调用,和RAGFlow的Web前端用作展示。 + +- `invoke`方法:这个方法接受LLM生成的参数,并且返回一个`str`对象,其中包含了这个工具的执行结果。 +这个工具的所有执行逻辑都应当放到这个方法里。 + +当你启动RAGFlow时,你会在日志中看见你的插件被加载了: + +``` +2025-05-15 19:29:08,959 INFO 34670 Recursively importing plugins from path `/some-path/ragflow/plugin/embedded_plugins` +2025-05-15 19:29:08,960 INFO 34670 Loaded llm_tools plugin BadCalculatorPlugin version 1.0.0 +``` + +也可能会报错,这时就需要根据报错对你的插件进行修复。 + +### 示例 + +我们将会添加一个会给出错误答案的计算器工具,来演示添加插件的过程。 + +首先,在`embedded_plugins/llm_tools`文件夹下创建一个插件文件`bad_calculator.py`。 + +接下来,我们创建一个`BadCalculatorPlugin`类,继承基类`LLMToolPlugin`: + +```python +class BadCalculatorPlugin(LLMToolPlugin): + _version_ = "1.0.0" +``` + +`_version_`字段是必填的,用于指定这个插件的版本号。 + +我们的计算器拥有两个输入字段`a`和`b`,所以我们添加如下的`invoke`方法到`BadCalculatorPlugin`类中: + +```python +def invoke(self, a: int, b: int) -> str: + return str(a + b + 100) +``` + +`invoke`方法将会被LLM所调用。这个方法可以有许多参数,但它必须返回一个`str`。 + +最后,我们需要添加一个`get_metadata`方法,来告诉LLM怎样使用我们的`bad_calculator`工具: + +```python +@classmethod +def get_metadata(cls) -> LLMToolMetadata: + return { + # 这个工具的名称,会提供给LLM + "name": "bad_calculator", + # 这个工具的展示名称,会提供给RAGFlow的Web前端 + "displayName": "$t:bad_calculator.name", + # 这个工具的用法描述,会提供给LLM + "description": "A tool to calculate the sum of two numbers (will give wrong answer)", + # 这个工具的描述,会提供给RAGFlow的Web前端 + "displayDescription": "$t:bad_calculator.description", + # 这个工具的参数 + "parameters": { + # 第一个参数 - a + "a": { + # 参数类型,选项为:number, string, 或者LLM可以识别的任何类型 + "type": "number", + # 这个参数的描述,会提供给LLM + "description": "The first number", + # 这个参数的描述,会提供给RAGFlow的Web前端 + "displayDescription": "$t:bad_calculator.params.a", + # 这个参数是否是必填的 + "required": True + }, + # 第二个参数 - b + "b": { + "type": "number", + "description": "The second number", + "displayDescription": "$t:bad_calculator.params.b", + "required": True + } + } +``` + +`get_metadata`方法是一个`classmethod`。它会把这个工具的描述提供给LLM。 + +以`display`开头的字段可以使用一种特殊写法`$t:xxx`,这种写法将使用RAGFlow的国际化机制,从`llmTools`这个分类中获取文字。如果你不使用这种写法,那么前端将会显示此处的原始内容。 + +现在,我们的工具已经做好了,你可以在`生成回答`组件中选择这个工具来尝试一下。 + diff --git a/plugin/__init__.py b/plugin/__init__.py new file mode 100644 index 000000000..379f2f761 --- /dev/null +++ b/plugin/__init__.py @@ -0,0 +1,3 @@ +from .plugin_manager import PluginManager + +GlobalPluginManager = PluginManager() diff --git a/plugin/common.py b/plugin/common.py new file mode 100644 index 000000000..7e85e0a13 --- /dev/null +++ b/plugin/common.py @@ -0,0 +1 @@ +PLUGIN_TYPE_LLM_TOOLS = "llm_tools" \ No newline at end of file diff --git a/plugin/embedded_plugins/llm_tools/bad_calculator.py b/plugin/embedded_plugins/llm_tools/bad_calculator.py new file mode 100644 index 000000000..537875f0b --- /dev/null +++ b/plugin/embedded_plugins/llm_tools/bad_calculator.py @@ -0,0 +1,37 @@ +import logging +from plugin.llm_tool_plugin import LLMToolMetadata, LLMToolPlugin + + +class BadCalculatorPlugin(LLMToolPlugin): + """ + A sample LLM tool plugin, will add two numbers with 100. + It only present for demo purpose. Do not use it in production. + """ + _version_ = "1.0.0" + + @classmethod + def get_metadata(cls) -> LLMToolMetadata: + return { + "name": "bad_calculator", + "displayName": "$t:bad_calculator.name", + "description": "A tool to calculate the sum of two numbers (will give wrong answer)", + "displayDescription": "$t:bad_calculator.description", + "parameters": { + "a": { + "type": "number", + "description": "The first number", + "displayDescription": "$t:bad_calculator.params.a", + "required": True + }, + "b": { + "type": "number", + "description": "The second number", + "displayDescription": "$t:bad_calculator.params.b", + "required": True + } + } + } + + def invoke(self, a: int, b: int) -> str: + logging.info(f"Bad calculator tool was called with arguments {a} and {b}") + return str(a + b + 100) diff --git a/plugin/llm_tool_plugin.py b/plugin/llm_tool_plugin.py new file mode 100644 index 000000000..b0dc4c8e8 --- /dev/null +++ b/plugin/llm_tool_plugin.py @@ -0,0 +1,51 @@ +from typing import Any, TypedDict +import pluginlib + +from .common import PLUGIN_TYPE_LLM_TOOLS + + +class LLMToolParameter(TypedDict): + type: str + description: str + displayDescription: str + required: bool + + +class LLMToolMetadata(TypedDict): + name: str + displayName: str + description: str + displayDescription: str + parameters: dict[str, LLMToolParameter] + + +@pluginlib.Parent(PLUGIN_TYPE_LLM_TOOLS) +class LLMToolPlugin: + @classmethod + @pluginlib.abstractmethod + def get_metadata(cls) -> LLMToolMetadata: + pass + + def invoke(self, **kwargs) -> str: + raise NotImplementedError + + +def llm_tool_metadata_to_openai_tool(llm_tool_metadata: LLMToolMetadata) -> dict[str, Any]: + return { + "type": "function", + "function": { + "name": llm_tool_metadata["name"], + "description": llm_tool_metadata["description"], + "parameters": { + "type": "object", + "properties": { + k: { + "type": p["type"], + "description": p["description"] + } + for k, p in llm_tool_metadata["parameters"].items() + }, + "required": [k for k, p in llm_tool_metadata["parameters"].items() if p["required"]] + } + } + } diff --git a/plugin/plugin_manager.py b/plugin/plugin_manager.py new file mode 100644 index 000000000..1f1b81591 --- /dev/null +++ b/plugin/plugin_manager.py @@ -0,0 +1,45 @@ +import logging +import os +from pathlib import Path +import pluginlib + +from .common import PLUGIN_TYPE_LLM_TOOLS + +from .llm_tool_plugin import LLMToolPlugin + + +class PluginManager: + _llm_tool_plugins: dict[str, LLMToolPlugin] + + def __init__(self) -> None: + self._llm_tool_plugins = {} + + def load_plugins(self) -> None: + loader = pluginlib.PluginLoader( + paths=[str(Path(os.path.dirname(__file__), "embedded_plugins"))] + ) + + for type, plugins in loader.plugins.items(): + for name, plugin in plugins.items(): + logging.info(f"Loaded {type} plugin {name} version {plugin.version}") + + if type == PLUGIN_TYPE_LLM_TOOLS: + metadata = plugin.get_metadata() + self._llm_tool_plugins[metadata["name"]] = plugin + + def get_llm_tools(self) -> list[LLMToolPlugin]: + return list(self._llm_tool_plugins.values()) + + def get_llm_tool_by_name(self, name: str) -> LLMToolPlugin | None: + return self._llm_tool_plugins.get(name) + + def get_llm_tools_by_names(self, tool_names: list[str]) -> list[LLMToolPlugin]: + results = [] + + for name in tool_names: + plugin = self._llm_tool_plugins.get(name) + + if plugin is not None: + results.append(plugin) + + return results diff --git a/pyproject.toml b/pyproject.toml index c0c827277..0b063249d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,7 +125,8 @@ dependencies = [ "langfuse>=2.60.0", "debugpy>=1.8.13", "mcp>=1.6.0", - "opensearch-py==2.7.1" + "opensearch-py==2.7.1", + "pluginlib==0.9.4", ] [project.optional-dependencies] diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 3d92f378f..c9c8f8884 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -21,6 +21,7 @@ import random import re import time from abc import ABC +from typing import Any, Protocol import openai import requests @@ -51,6 +52,10 @@ LENGTH_NOTIFICATION_CN = "······\n由于大模型的上下文窗口大小 LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to its limitation on context length." +class ToolCallSession(Protocol): + def tool_call(self, name: str, arguments: dict[str, Any]) -> str: ... + + class Base(ABC): def __init__(self, key, model_name, base_url): timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600)) @@ -251,10 +256,8 @@ class Base(ABC): if index not in final_tool_calls: final_tool_calls[index] = tool_call - - final_tool_calls[index].function.arguments += tool_call.function.arguments - if resp.choices[0].finish_reason != "stop": - continue + else: + final_tool_calls[index].function.arguments += tool_call.function.arguments else: if not resp.choices: continue @@ -276,58 +279,57 @@ class Base(ABC): else: total_tokens += tol - finish_reason = resp.choices[0].finish_reason - if finish_reason == "tool_calls" and final_tool_calls: - for tool_call in final_tool_calls.values(): - name = tool_call.function.name - try: - if name == "get_current_weather": - args = json.loads('{"location":"Shanghai"}') - else: - args = json.loads(tool_call.function.arguments) - except Exception: - continue - # args = json.loads(tool_call.function.arguments) - tool_response = self.toolcall_session.tool_call(name, args) - history.append( - { - "role": "assistant", - "refusal": "", - "content": "", - "audio": "", - "function_call": "", - "tool_calls": [ - { - "index": tool_call.index, - "id": tool_call.id, - "function": tool_call.function, - "type": "function", + finish_reason = resp.choices[0].finish_reason + if finish_reason == "tool_calls" and final_tool_calls: + for tool_call in final_tool_calls.values(): + name = tool_call.function.name + try: + args = json.loads(tool_call.function.arguments) + except Exception as e: + logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") + yield ans + "\n**ERROR**: " + str(e) + finish_completion = True + break + + tool_response = self.toolcall_session.tool_call(name, args) + history.append( + { + "role": "assistant", + "tool_calls": [ + { + "index": tool_call.index, + "id": tool_call.id, + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, }, - ], - } - ) - # if tool_response.choices[0].finish_reason == "length": - # if is_chinese(ans): - # ans += LENGTH_NOTIFICATION_CN - # else: - # ans += LENGTH_NOTIFICATION_EN - # return ans, total_tokens + self.total_token_count(tool_response) - history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)}) - final_tool_calls = {} - response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf) - continue - if finish_reason == "length": - if is_chinese(ans): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - return ans, total_tokens + self.total_token_count(resp) - if finish_reason == "stop": - finish_completion = True - yield ans - break - yield ans + "type": "function", + }, + ], + } + ) + # if tool_response.choices[0].finish_reason == "length": + # if is_chinese(ans): + # ans += LENGTH_NOTIFICATION_CN + # else: + # ans += LENGTH_NOTIFICATION_EN + # return ans, total_tokens + self.total_token_count(tool_response) + history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)}) + final_tool_calls = {} + response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf) continue + if finish_reason == "length": + if is_chinese(ans): + ans += LENGTH_NOTIFICATION_CN + else: + ans += LENGTH_NOTIFICATION_EN + return ans, total_tokens + if finish_reason == "stop": + finish_completion = True + yield ans + break + yield ans + continue except openai.APIError as e: yield ans + "\n**ERROR**: " + str(e) @@ -854,6 +856,14 @@ class ZhipuChat(Base): except Exception as e: return "**ERROR**: " + str(e), 0 + def chat_with_tools(self, system: str, history: list, gen_conf: dict): + if "presence_penalty" in gen_conf: + del gen_conf["presence_penalty"] + if "frequency_penalty" in gen_conf: + del gen_conf["frequency_penalty"] + + return super().chat_with_tools(system, history, gen_conf) + def chat_streamly(self, system, history, gen_conf): if system: history.insert(0, {"role": "system", "content": system}) @@ -886,6 +896,14 @@ class ZhipuChat(Base): yield tk_count + def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict): + if "presence_penalty" in gen_conf: + del gen_conf["presence_penalty"] + if "frequency_penalty" in gen_conf: + del gen_conf["frequency_penalty"] + + return super().chat_streamly_with_tools(system, history, gen_conf) + class OllamaChat(Base): def __init__(self, key, model_name, **kwargs): diff --git a/uv.lock b/uv.lock index 426ae7863..5e568da3f 100644 --- a/uv.lock +++ b/uv.lock @@ -3952,6 +3952,18 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669" }, ] +[[package]] +name = "pluginlib" +version = "0.9.4" +source = { registry = "https://mirrors.aliyun.com/pypi/simple" } +dependencies = [ + { name = "setuptools" }, +] +sdist = { url = "https://mirrors.aliyun.com/pypi/packages/58/38/ca974ba2d8ccc7954d8ccb0394cce184ac6269bd1fbfe06f70a0da3c8946/pluginlib-0.9.4.tar.gz", hash = "sha256:88727037138f759a3952f6391ae3751536f04ad8be6023607620ea49695a3a83" } +wheels = [ + { url = "https://mirrors.aliyun.com/pypi/packages/b0/b5/c869b3d2ed1613afeb02c635be11f5d35fa5b2b665f4d059cfe5b8e82941/pluginlib-0.9.4-py2.py3-none-any.whl", hash = "sha256:d4cfb7d74a6d2454e256b6512fbc4bc2dd7620cb7764feb67331ef56ce4b33f2" }, +] + [[package]] name = "polars-lts-cpu" version = "1.9.0" @@ -4872,6 +4884,7 @@ dependencies = [ { name = "pdfplumber" }, { name = "peewee" }, { name = "pillow" }, + { name = "pluginlib" }, { name = "protobuf" }, { name = "psycopg2-binary" }, { name = "pyclipper" }, @@ -5009,6 +5022,7 @@ requires-dist = [ { name = "pdfplumber", specifier = "==0.10.4" }, { name = "peewee", specifier = "==3.17.1" }, { name = "pillow", specifier = "==10.4.0" }, + { name = "pluginlib", specifier = "==0.9.4" }, { name = "protobuf", specifier = "==5.27.2" }, { name = "psycopg2-binary", specifier = "==2.9.9" }, { name = "pyclipper", specifier = "==1.3.0.post5" }, diff --git a/web/src/components/llm-select/index.tsx b/web/src/components/llm-select/index.tsx index 03f5ad755..fc31f3a6c 100644 --- a/web/src/components/llm-select/index.tsx +++ b/web/src/components/llm-select/index.tsx @@ -11,19 +11,31 @@ import { Select, SelectTrigger, SelectValue } from '../ui/select'; interface IProps { id?: string; value?: string; - onChange?: (value: string) => void; + onInitialValue?: (value: string, option: any) => void; + onChange?: (value: string, option: any) => void; disabled?: boolean; } -const LLMSelect = ({ id, value, onChange, disabled }: IProps) => { +const LLMSelect = ({ id, value, onInitialValue, onChange, disabled }: IProps) => { const modelOptions = useComposeLlmOptionsByModelTypes([ LlmModelType.Chat, LlmModelType.Image2text, ]); + if (onInitialValue && value) { + for (const modelOption of modelOptions) { + for (const option of modelOption.options) { + if (option.value === value) { + onInitialValue(value, option); + break; + } + } + } + } + const content = (
-
diff --git a/web/src/components/llm-setting-items/index.tsx b/web/src/components/llm-setting-items/index.tsx index e91e467b9..81c4adf08 100644 --- a/web/src/components/llm-setting-items/index.tsx +++ b/web/src/components/llm-setting-items/index.tsx @@ -16,9 +16,10 @@ interface IProps { prefix?: string; formItemLayout?: any; handleParametersChange?(value: ModelVariableType): void; + onChange?(value: string, option: any): void; } -const LlmSettingItems = ({ prefix, formItemLayout = {} }: IProps) => { +const LlmSettingItems = ({ prefix, formItemLayout = {}, onChange }: IProps) => { const form = Form.useFormInstance(); const { t } = useTranslate('chat'); const parameterOptions = Object.values(ModelVariableType).map((x) => ({ @@ -58,6 +59,7 @@ const LlmSettingItems = ({ prefix, formItemLayout = {} }: IProps) => { options={modelOptions} showSearch popupMatchSelectWidth={false} + onChange={onChange} />
diff --git a/web/src/components/llm-tools-select.tsx b/web/src/components/llm-tools-select.tsx new file mode 100644 index 000000000..241dfa374 --- /dev/null +++ b/web/src/components/llm-tools-select.tsx @@ -0,0 +1,51 @@ +import { useTranslate } from '@/hooks/common-hooks'; +import { useLlmToolsList } from '@/hooks/plugin-hooks'; +import { Select, Space } from 'antd'; + +interface IProps { + value?: string; + onChange?: (value: string) => void; + disabled?: boolean; +} + +const LLMToolsSelect = ({ value, onChange, disabled }: IProps) => { + const { t } = useTranslate("llmTools"); + const tools = useLlmToolsList(); + + function wrapTranslation(text: string): string { + if (!text) { + return text; + } + + if (text.startsWith("$t:")) { + return t(text.substring(3)); + } + + return text; + } + + const toolOptions = tools.map(t => ({ + label: wrapTranslation(t.displayName), + description: wrapTranslation(t.displayDescription), + value: t.name, + title: wrapTranslation(t.displayDescription), + })); + + return ( + + ); +}; + +export default LLMToolsSelect; diff --git a/web/src/hooks/llm-hooks.tsx b/web/src/hooks/llm-hooks.tsx index 7583cbd42..6dedd09bd 100644 --- a/web/src/hooks/llm-hooks.tsx +++ b/web/src/hooks/llm-hooks.tsx @@ -71,6 +71,7 @@ function buildLlmOptionsWithIcon(x: IThirdOAIModel) { ), value: `${x.llm_name}@${x.fid}`, disabled: !x.available, + is_tools: x.is_tools, }; } @@ -142,7 +143,7 @@ export const useComposeLlmOptionsByModelTypes = ( return modelTypes.reduce< (DefaultOptionType & { - options: { label: JSX.Element; value: string; disabled: boolean }[]; + options: { label: JSX.Element; value: string; disabled: boolean; is_tools: boolean }[]; })[] >((pre, cur) => { const options = allOptions[cur]; diff --git a/web/src/hooks/plugin-hooks.tsx b/web/src/hooks/plugin-hooks.tsx new file mode 100644 index 000000000..9812b7a93 --- /dev/null +++ b/web/src/hooks/plugin-hooks.tsx @@ -0,0 +1,17 @@ +import { ILLMTools } from '@/interfaces/database/plugin'; +import pluginService from '@/services/plugin-service'; +import { useQuery } from '@tanstack/react-query'; + +export const useLlmToolsList = (): ILLMTools => { + const { data } = useQuery({ + queryKey: ['llmTools'], + initialData: [], + queryFn: async () => { + const { data } = await pluginService.getLlmTools(); + + return data?.data ?? []; + }, + }); + + return data; +}; diff --git a/web/src/interfaces/database/llm.ts b/web/src/interfaces/database/llm.ts index 802551d0f..2608e5ab1 100644 --- a/web/src/interfaces/database/llm.ts +++ b/web/src/interfaces/database/llm.ts @@ -13,6 +13,7 @@ export interface IThirdOAIModel { update_time: number; tenant_id?: string; tenant_name?: string; + is_tools: boolean; } export type IThirdOAIModelCollection = Record; diff --git a/web/src/interfaces/database/plugin.ts b/web/src/interfaces/database/plugin.ts new file mode 100644 index 000000000..0f2849438 --- /dev/null +++ b/web/src/interfaces/database/plugin.ts @@ -0,0 +1,13 @@ +export type ILLMTools = ILLMToolMetadata[]; + +export interface ILLMToolMetadata { + name: string; + displayName: string; + displayDescription: string; + parameters: Map; +} + +export interface ILLMToolParameter { + type: string; + displayDescription: string; +} diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index c7c4614c1..08b086e0f 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -454,6 +454,8 @@ This auto-tagging feature enhances retrieval by adding another layer of domain-s model: 'Model', modelTip: 'Large language chat model', modelMessage: 'Please select!', + modelEnabledTools: 'Enabled tools', + modelEnabledToolsTip: 'Please select one or more tools for the chat model to use. It takes no effect for models not supporting tool call.', freedom: 'Freedom', improvise: 'Improvise', precise: 'Precise', @@ -1267,5 +1269,15 @@ This delimiter is used to split the input text into several text pieces echo of inputVariables: 'Input variables', runningHintText: 'is running...🕞', }, + llmTools: { + bad_calculator: { + name: "Calculator", + description: "A tool to calculate the sum of two numbers (will give wrong answer)", + params: { + a: "The first number", + b: "The second number", + }, + }, + }, }, }; diff --git a/web/src/locales/zh.ts b/web/src/locales/zh.ts index 01e1e80bd..67c747213 100644 --- a/web/src/locales/zh.ts +++ b/web/src/locales/zh.ts @@ -461,6 +461,8 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于 model: '模型', modelTip: '大语言聊天模型', modelMessage: '请选择', + modelEnabledTools: '可用的工具', + modelEnabledToolsTip: '请选择一个或多个可供该模型所使用的工具。仅对支持工具调用的模型生效。', freedom: '自由度', improvise: '即兴创作', precise: '精确', @@ -1231,5 +1233,15 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于 knowledge: 'knowledge', chat: 'chat', }, + llmTools: { + bad_calculator: { + name: "计算器", + description: "用于计算两个数的和的工具(会给出错误答案)", + params: { + a: "第一个数", + b: "第二个数", + }, + }, + }, }, }; diff --git a/web/src/pages/flow/form/generate-form/index.tsx b/web/src/pages/flow/form/generate-form/index.tsx index 927e36063..e2e7ed5b7 100644 --- a/web/src/pages/flow/form/generate-form/index.tsx +++ b/web/src/pages/flow/form/generate-form/index.tsx @@ -4,10 +4,18 @@ import { PromptEditor } from '@/components/prompt-editor'; import { useTranslate } from '@/hooks/common-hooks'; import { Form, Switch } from 'antd'; import { IOperatorForm } from '../../interface'; +import LLMToolsSelect from '@/components/llm-tools-select'; +import { useState } from 'react'; const GenerateForm = ({ onValuesChange, form }: IOperatorForm) => { const { t } = useTranslate('flow'); + const [isCurrentLlmSupportTools, setCurrentLlmSupportTools] = useState(false); + + const onLlmSelectChanged = (_: string, option: any) => { + setCurrentLlmSupportTools(option.is_tools); + }; + return (
{ label={t('model', { keyPrefix: 'chat' })} tooltip={t('modelTip', { keyPrefix: 'chat' })} > - + { {/* */} + + + (methods, request); + +export default pluginService; diff --git a/web/src/utils/api.ts b/web/src/utils/api.ts index e83e5109d..b0c32b123 100644 --- a/web/src/utils/api.ts +++ b/web/src/utils/api.ts @@ -32,6 +32,9 @@ export default { delete_llm: `${api_host}/llm/delete_llm`, deleteFactory: `${api_host}/llm/delete_factory`, + // plugin + llm_tools: `${api_host}/plugin/llm_tools`, + // knowledge base kb_list: `${api_host}/kb/list`, create_kb: `${api_host}/kb/create`,