Feat: Support tool calling in Generate component (#7572)

### What problem does this PR solve?

Hello, our use case requires LLM agent to invoke some tools, so I made a
simple implementation here.

This PR does two things:

1. A simple plugin mechanism based on `pluginlib`:

This mechanism lives in the `plugin` directory. It will only load
plugins from `plugin/embedded_plugins` for now.

A sample plugin `bad_calculator.py` is placed in
`plugin/embedded_plugins/llm_tools`, it accepts two numbers `a` and `b`,
then give a wrong result `a + b + 100`.

In the future, it can load plugins from external location with little
code change.

Plugins are divided into different types. The only plugin type supported
in this PR is `llm_tools`, which must implement the `LLMToolPlugin`
class in the `plugin/llm_tool_plugin.py`.
More plugin types can be added in the future.

2. A tool selector in the `Generate` component:

Added a tool selector to select one or more tools for LLM:


![image](https://github.com/user-attachments/assets/74a21fdf-9333-4175-991b-43df6524c5dc)

And with the `bad_calculator` tool, it results this with the `qwen-max`
model:


![image](https://github.com/user-attachments/assets/93aff9c4-8550-414a-90a2-1a15a5249d94)


### Type of change

- [ ] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
- [ ] Documentation Update
- [ ] Refactoring
- [ ] Performance Improvement
- [ ] Other (please describe):

Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com>
This commit is contained in:
Song Fuchang 2025-05-16 16:32:19 +08:00 committed by GitHub
parent cb26564d50
commit a1f06a4fdc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 625 additions and 61 deletions

View File

@ -199,6 +199,7 @@ COPY graphrag graphrag
COPY agentic_reasoning agentic_reasoning
COPY pyproject.toml uv.lock ./
COPY mcp mcp
COPY plugin plugin
COPY docker/service_conf.yaml.template ./conf/service_conf.yaml.template
COPY docker/entrypoint.sh ./

View File

@ -33,6 +33,7 @@ ADD ./rag ./rag
ADD ./requirements.txt ./requirements.txt
ADD ./agent ./agent
ADD ./graphrag ./graphrag
ADD ./plugin ./plugin
RUN dnf install -y openmpi openmpi-devel python3-openmpi
ENV C_INCLUDE_PATH /usr/include/openmpi-x86_64:$C_INCLUDE_PATH

View File

@ -16,15 +16,29 @@
import json
import re
from functools import partial
from typing import Any
import pandas as pd
from api.db import LLMType
from api.db.services.conversation_service import structure_answer
from api.db.services.llm_service import LLMBundle
from api import settings
from agent.component.base import ComponentBase, ComponentParamBase
from plugin import GlobalPluginManager
from plugin.llm_tool_plugin import llm_tool_metadata_to_openai_tool
from rag.llm.chat_model import ToolCallSession
from rag.prompts import message_fit_in
class LLMToolPluginCallSession(ToolCallSession):
def tool_call(self, name: str, arguments: dict[str, Any]) -> str:
tool = GlobalPluginManager.get_llm_tool_by_name(name)
if tool is None:
raise ValueError(f"LLM tool {name} does not exist")
return tool().invoke(**arguments)
class GenerateParam(ComponentParamBase):
"""
Define the Generate component parameters.
@ -41,6 +55,7 @@ class GenerateParam(ComponentParamBase):
self.frequency_penalty = 0
self.cite = True
self.parameters = []
self.llm_enabled_tools = []
def check(self):
self.check_decimal_float(self.temperature, "[Generate] Temperature")
@ -133,6 +148,15 @@ class Generate(ComponentBase):
def _run(self, history, **kwargs):
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
if len(self._param.llm_enabled_tools) > 0:
tools = GlobalPluginManager.get_llm_tools_by_names(self._param.llm_enabled_tools)
chat_mdl.bind_tools(
LLMToolPluginCallSession(),
[llm_tool_metadata_to_openai_tool(t.get_metadata()) for t in tools]
)
prompt = self._param.prompt
retrieval_res = []

12
api/apps/plugin_app.py Normal file
View File

@ -0,0 +1,12 @@
from flask import Response
from flask_login import login_required
from api.utils.api_utils import get_json_result
from plugin import GlobalPluginManager
@manager.route('/llm_tools', methods=['GET']) # noqa: F821
@login_required
def llm_tools() -> Response:
tools = GlobalPluginManager.get_llm_tools()
tools_metadata = [t.get_metadata() for t in tools]
return get_json_result(data=tools_metadata)

View File

@ -226,6 +226,7 @@ class LLMBundle:
def bind_tools(self, toolcall_session, tools):
if not self.is_tools:
logging.warning(f"Model {self.llm_name} does not support tool call, but you have assigned one or more tools to it!")
return
self.mdl.bind_tools(toolcall_session, tools)

View File

@ -19,6 +19,7 @@
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
from api.utils.log_utils import initRootLogger
from plugin import GlobalPluginManager
initRootLogger("ragflow_server")
import logging
@ -119,6 +120,8 @@ if __name__ == '__main__':
RuntimeConfig.init_env()
RuntimeConfig.init_config(JOB_SERVER_HOST=settings.HOST_IP, HTTP_PORT=settings.HOST_PORT)
GlobalPluginManager.load_plugins()
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)

97
plugin/README.md Normal file
View File

@ -0,0 +1,97 @@
# Plugins
This directory contains the plugin mechanism for RAGFlow.
RAGFlow will load plugins from `embedded_plugins` subdirectory recursively.
## Supported plugin types
Currently, the only supported plugin type is `llm_tools`.
- `llm_tools`: A tool for LLM to call.
## How to add a plugin
Add a LLM tool plugin is simple: create a plugin file, put a class inherits the `LLMToolPlugin` class in it, then implement the `get_metadata` and the `invoke` methods.
- `get_metadata` method: This method returns a `LLMToolMetadata` object, which contains the description of this tool.
The description will be provided to LLM, and the RAGFlow web frontend for displaying.
- `invoke` method: This method accepts parameters generated by LLM, and return a `str` containing the tool execution result.
All the execution logic of this tool should go into this method.
When you start RAGFlow, you can see your plugin was loaded in the log:
```
2025-05-15 19:29:08,959 INFO 34670 Recursively importing plugins from path `/some-path/ragflow/plugin/embedded_plugins`
2025-05-15 19:29:08,960 INFO 34670 Loaded llm_tools plugin BadCalculatorPlugin version 1.0.0
```
Or it may contain some errors for you to fix your plugin.
### Demo
We will demonstrate how to add a plugin with a calculator tool which will give wrong answers.
First, create a plugin file `bad_calculator.py` under the `embedded_plugins/llm_tools` directory.
Then, we create a `BadCalculatorPlugin` class, extending the `LLMToolPlugin` base class:
```python
class BadCalculatorPlugin(LLMToolPlugin):
_version_ = "1.0.0"
```
The `_version_` field is required, which specifies the version of the plugin.
Our calculator has two numbers `a` and `b` as inputs, so we add a `invoke` method to our `BadCalculatorPlugin` class:
```python
def invoke(self, a: int, b: int) -> str:
return str(a + b + 100)
```
The `invoke` method will be called by LLM. It can have many parameters, but the return type must be a `str`.
Finally, we have to add a `get_metadata` method, to tell LLM how to use our `bad_calculator`:
```python
@classmethod
def get_metadata(cls) -> LLMToolMetadata:
return {
# Name of this tool, providing to LLM
"name": "bad_calculator",
# Display name of this tool, providing to RAGFlow frontend
"displayName": "$t:bad_calculator.name",
# Description of the usage of this tool, providing to LLM
"description": "A tool to calculate the sum of two numbers (will give wrong answer)",
# Description of this tool, providing to RAGFlow frontend
"displayDescription": "$t:bad_calculator.description",
# Parameters of this tool
"parameters": {
# The first parameter - a
"a": {
# Parameter type, options are: number, string, or whatever the LLM can recognise
"type": "number",
# Description of this parameter, providing to LLM
"description": "The first number",
# Description of this parameter, provding to RAGFlow frontend
"displayDescription": "$t:bad_calculator.params.a",
# Whether this parameter is required
"required": True
},
# The second parameter - b
"b": {
"type": "number",
"description": "The second number",
"displayDescription": "$t:bad_calculator.params.b",
"required": True
}
}
```
The `get_metadata` method is a `classmethod`. It will provide the description of this tool to LLM.
The fields starts with `display` can use a special notation: `$t:xxx`, which will use the i18n mechanism in the RAGFlow frontend, getting text from the `llmTools` category. The frontend will display what you put here if you don't use this notation.
Now our tool is ready. You can select it in the `Generate` component and try it out.

98
plugin/README_zh.md Normal file
View File

@ -0,0 +1,98 @@
# 插件
这个文件夹包含了RAGFlow的插件机制。
RAGFlow将会从`embedded_plugins`子文件夹中递归加载所有的插件。
## 支持的插件类型
目前,唯一支持的插件类型是`llm_tools`
- `llm_tools`用于供LLM进行调用的工具。
## 如何添加一个插件
添加一个LLM工具插件是很简单的创建一个插件文件向其中放一个继承自`LLMToolPlugin`的类,再实现它的`get_metadata``invoke`方法即可。
- `get_metadata`方法:这个方法返回一个`LLMToolMetadata`对象,其中包含了对这个工具的描述。
这些描述信息将被提供给LLM进行调用和RAGFlow的Web前端用作展示。
- `invoke`方法这个方法接受LLM生成的参数并且返回一个`str`对象,其中包含了这个工具的执行结果。
这个工具的所有执行逻辑都应当放到这个方法里。
当你启动RAGFlow时你会在日志中看见你的插件被加载了
```
2025-05-15 19:29:08,959 INFO 34670 Recursively importing plugins from path `/some-path/ragflow/plugin/embedded_plugins`
2025-05-15 19:29:08,960 INFO 34670 Loaded llm_tools plugin BadCalculatorPlugin version 1.0.0
```
也可能会报错,这时就需要根据报错对你的插件进行修复。
### 示例
我们将会添加一个会给出错误答案的计算器工具,来演示添加插件的过程。
首先,在`embedded_plugins/llm_tools`文件夹下创建一个插件文件`bad_calculator.py`
接下来,我们创建一个`BadCalculatorPlugin`类,继承基类`LLMToolPlugin`
```python
class BadCalculatorPlugin(LLMToolPlugin):
_version_ = "1.0.0"
```
`_version_`字段是必填的,用于指定这个插件的版本号。
我们的计算器拥有两个输入字段`a``b`,所以我们添加如下的`invoke`方法到`BadCalculatorPlugin`类中:
```python
def invoke(self, a: int, b: int) -> str:
return str(a + b + 100)
```
`invoke`方法将会被LLM所调用。这个方法可以有许多参数但它必须返回一个`str`
最后,我们需要添加一个`get_metadata`方法来告诉LLM怎样使用我们的`bad_calculator`工具:
```python
@classmethod
def get_metadata(cls) -> LLMToolMetadata:
return {
# 这个工具的名称会提供给LLM
"name": "bad_calculator",
# 这个工具的展示名称会提供给RAGFlow的Web前端
"displayName": "$t:bad_calculator.name",
# 这个工具的用法描述会提供给LLM
"description": "A tool to calculate the sum of two numbers (will give wrong answer)",
# 这个工具的描述会提供给RAGFlow的Web前端
"displayDescription": "$t:bad_calculator.description",
# 这个工具的参数
"parameters": {
# 第一个参数 - a
"a": {
# 参数类型选项为number, string, 或者LLM可以识别的任何类型
"type": "number",
# 这个参数的描述会提供给LLM
"description": "The first number",
# 这个参数的描述会提供给RAGFlow的Web前端
"displayDescription": "$t:bad_calculator.params.a",
# 这个参数是否是必填的
"required": True
},
# 第二个参数 - b
"b": {
"type": "number",
"description": "The second number",
"displayDescription": "$t:bad_calculator.params.b",
"required": True
}
}
```
`get_metadata`方法是一个`classmethod`。它会把这个工具的描述提供给LLM。
`display`开头的字段可以使用一种特殊写法`$t:xxx`这种写法将使用RAGFlow的国际化机制`llmTools`这个分类中获取文字。如果你不使用这种写法,那么前端将会显示此处的原始内容。
现在,我们的工具已经做好了,你可以在`生成回答`组件中选择这个工具来尝试一下。

3
plugin/__init__.py Normal file
View File

@ -0,0 +1,3 @@
from .plugin_manager import PluginManager
GlobalPluginManager = PluginManager()

1
plugin/common.py Normal file
View File

@ -0,0 +1 @@
PLUGIN_TYPE_LLM_TOOLS = "llm_tools"

View File

@ -0,0 +1,37 @@
import logging
from plugin.llm_tool_plugin import LLMToolMetadata, LLMToolPlugin
class BadCalculatorPlugin(LLMToolPlugin):
"""
A sample LLM tool plugin, will add two numbers with 100.
It only present for demo purpose. Do not use it in production.
"""
_version_ = "1.0.0"
@classmethod
def get_metadata(cls) -> LLMToolMetadata:
return {
"name": "bad_calculator",
"displayName": "$t:bad_calculator.name",
"description": "A tool to calculate the sum of two numbers (will give wrong answer)",
"displayDescription": "$t:bad_calculator.description",
"parameters": {
"a": {
"type": "number",
"description": "The first number",
"displayDescription": "$t:bad_calculator.params.a",
"required": True
},
"b": {
"type": "number",
"description": "The second number",
"displayDescription": "$t:bad_calculator.params.b",
"required": True
}
}
}
def invoke(self, a: int, b: int) -> str:
logging.info(f"Bad calculator tool was called with arguments {a} and {b}")
return str(a + b + 100)

51
plugin/llm_tool_plugin.py Normal file
View File

@ -0,0 +1,51 @@
from typing import Any, TypedDict
import pluginlib
from .common import PLUGIN_TYPE_LLM_TOOLS
class LLMToolParameter(TypedDict):
type: str
description: str
displayDescription: str
required: bool
class LLMToolMetadata(TypedDict):
name: str
displayName: str
description: str
displayDescription: str
parameters: dict[str, LLMToolParameter]
@pluginlib.Parent(PLUGIN_TYPE_LLM_TOOLS)
class LLMToolPlugin:
@classmethod
@pluginlib.abstractmethod
def get_metadata(cls) -> LLMToolMetadata:
pass
def invoke(self, **kwargs) -> str:
raise NotImplementedError
def llm_tool_metadata_to_openai_tool(llm_tool_metadata: LLMToolMetadata) -> dict[str, Any]:
return {
"type": "function",
"function": {
"name": llm_tool_metadata["name"],
"description": llm_tool_metadata["description"],
"parameters": {
"type": "object",
"properties": {
k: {
"type": p["type"],
"description": p["description"]
}
for k, p in llm_tool_metadata["parameters"].items()
},
"required": [k for k, p in llm_tool_metadata["parameters"].items() if p["required"]]
}
}
}

45
plugin/plugin_manager.py Normal file
View File

@ -0,0 +1,45 @@
import logging
import os
from pathlib import Path
import pluginlib
from .common import PLUGIN_TYPE_LLM_TOOLS
from .llm_tool_plugin import LLMToolPlugin
class PluginManager:
_llm_tool_plugins: dict[str, LLMToolPlugin]
def __init__(self) -> None:
self._llm_tool_plugins = {}
def load_plugins(self) -> None:
loader = pluginlib.PluginLoader(
paths=[str(Path(os.path.dirname(__file__), "embedded_plugins"))]
)
for type, plugins in loader.plugins.items():
for name, plugin in plugins.items():
logging.info(f"Loaded {type} plugin {name} version {plugin.version}")
if type == PLUGIN_TYPE_LLM_TOOLS:
metadata = plugin.get_metadata()
self._llm_tool_plugins[metadata["name"]] = plugin
def get_llm_tools(self) -> list[LLMToolPlugin]:
return list(self._llm_tool_plugins.values())
def get_llm_tool_by_name(self, name: str) -> LLMToolPlugin | None:
return self._llm_tool_plugins.get(name)
def get_llm_tools_by_names(self, tool_names: list[str]) -> list[LLMToolPlugin]:
results = []
for name in tool_names:
plugin = self._llm_tool_plugins.get(name)
if plugin is not None:
results.append(plugin)
return results

View File

@ -125,7 +125,8 @@ dependencies = [
"langfuse>=2.60.0",
"debugpy>=1.8.13",
"mcp>=1.6.0",
"opensearch-py==2.7.1"
"opensearch-py==2.7.1",
"pluginlib==0.9.4",
]
[project.optional-dependencies]

View File

@ -21,6 +21,7 @@ import random
import re
import time
from abc import ABC
from typing import Any, Protocol
import openai
import requests
@ -51,6 +52,10 @@ LENGTH_NOTIFICATION_CN = "······\n由于大模型的上下文窗口大小
LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to its limitation on context length."
class ToolCallSession(Protocol):
def tool_call(self, name: str, arguments: dict[str, Any]) -> str: ...
class Base(ABC):
def __init__(self, key, model_name, base_url):
timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))
@ -251,10 +256,8 @@ class Base(ABC):
if index not in final_tool_calls:
final_tool_calls[index] = tool_call
final_tool_calls[index].function.arguments += tool_call.function.arguments
if resp.choices[0].finish_reason != "stop":
continue
else:
final_tool_calls[index].function.arguments += tool_call.function.arguments
else:
if not resp.choices:
continue
@ -276,58 +279,57 @@ class Base(ABC):
else:
total_tokens += tol
finish_reason = resp.choices[0].finish_reason
if finish_reason == "tool_calls" and final_tool_calls:
for tool_call in final_tool_calls.values():
name = tool_call.function.name
try:
if name == "get_current_weather":
args = json.loads('{"location":"Shanghai"}')
else:
args = json.loads(tool_call.function.arguments)
except Exception:
continue
# args = json.loads(tool_call.function.arguments)
tool_response = self.toolcall_session.tool_call(name, args)
history.append(
{
"role": "assistant",
"refusal": "",
"content": "",
"audio": "",
"function_call": "",
"tool_calls": [
{
"index": tool_call.index,
"id": tool_call.id,
"function": tool_call.function,
"type": "function",
finish_reason = resp.choices[0].finish_reason
if finish_reason == "tool_calls" and final_tool_calls:
for tool_call in final_tool_calls.values():
name = tool_call.function.name
try:
args = json.loads(tool_call.function.arguments)
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
yield ans + "\n**ERROR**: " + str(e)
finish_completion = True
break
tool_response = self.toolcall_session.tool_call(name, args)
history.append(
{
"role": "assistant",
"tool_calls": [
{
"index": tool_call.index,
"id": tool_call.id,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
},
],
}
)
# if tool_response.choices[0].finish_reason == "length":
# if is_chinese(ans):
# ans += LENGTH_NOTIFICATION_CN
# else:
# ans += LENGTH_NOTIFICATION_EN
# return ans, total_tokens + self.total_token_count(tool_response)
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
final_tool_calls = {}
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
continue
if finish_reason == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, total_tokens + self.total_token_count(resp)
if finish_reason == "stop":
finish_completion = True
yield ans
break
yield ans
"type": "function",
},
],
}
)
# if tool_response.choices[0].finish_reason == "length":
# if is_chinese(ans):
# ans += LENGTH_NOTIFICATION_CN
# else:
# ans += LENGTH_NOTIFICATION_EN
# return ans, total_tokens + self.total_token_count(tool_response)
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
final_tool_calls = {}
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
continue
if finish_reason == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, total_tokens
if finish_reason == "stop":
finish_completion = True
yield ans
break
yield ans
continue
except openai.APIError as e:
yield ans + "\n**ERROR**: " + str(e)
@ -854,6 +856,14 @@ class ZhipuChat(Base):
except Exception as e:
return "**ERROR**: " + str(e), 0
def chat_with_tools(self, system: str, history: list, gen_conf: dict):
if "presence_penalty" in gen_conf:
del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
del gen_conf["frequency_penalty"]
return super().chat_with_tools(system, history, gen_conf)
def chat_streamly(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
@ -886,6 +896,14 @@ class ZhipuChat(Base):
yield tk_count
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict):
if "presence_penalty" in gen_conf:
del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
del gen_conf["frequency_penalty"]
return super().chat_streamly_with_tools(system, history, gen_conf)
class OllamaChat(Base):
def __init__(self, key, model_name, **kwargs):

14
uv.lock generated
View File

@ -3952,6 +3952,18 @@ wheels = [
{ url = "https://mirrors.aliyun.com/pypi/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669" },
]
[[package]]
name = "pluginlib"
version = "0.9.4"
source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
dependencies = [
{ name = "setuptools" },
]
sdist = { url = "https://mirrors.aliyun.com/pypi/packages/58/38/ca974ba2d8ccc7954d8ccb0394cce184ac6269bd1fbfe06f70a0da3c8946/pluginlib-0.9.4.tar.gz", hash = "sha256:88727037138f759a3952f6391ae3751536f04ad8be6023607620ea49695a3a83" }
wheels = [
{ url = "https://mirrors.aliyun.com/pypi/packages/b0/b5/c869b3d2ed1613afeb02c635be11f5d35fa5b2b665f4d059cfe5b8e82941/pluginlib-0.9.4-py2.py3-none-any.whl", hash = "sha256:d4cfb7d74a6d2454e256b6512fbc4bc2dd7620cb7764feb67331ef56ce4b33f2" },
]
[[package]]
name = "polars-lts-cpu"
version = "1.9.0"
@ -4872,6 +4884,7 @@ dependencies = [
{ name = "pdfplumber" },
{ name = "peewee" },
{ name = "pillow" },
{ name = "pluginlib" },
{ name = "protobuf" },
{ name = "psycopg2-binary" },
{ name = "pyclipper" },
@ -5009,6 +5022,7 @@ requires-dist = [
{ name = "pdfplumber", specifier = "==0.10.4" },
{ name = "peewee", specifier = "==3.17.1" },
{ name = "pillow", specifier = "==10.4.0" },
{ name = "pluginlib", specifier = "==0.9.4" },
{ name = "protobuf", specifier = "==5.27.2" },
{ name = "psycopg2-binary", specifier = "==2.9.9" },
{ name = "pyclipper", specifier = "==1.3.0.post5" },

View File

@ -11,19 +11,31 @@ import { Select, SelectTrigger, SelectValue } from '../ui/select';
interface IProps {
id?: string;
value?: string;
onChange?: (value: string) => void;
onInitialValue?: (value: string, option: any) => void;
onChange?: (value: string, option: any) => void;
disabled?: boolean;
}
const LLMSelect = ({ id, value, onChange, disabled }: IProps) => {
const LLMSelect = ({ id, value, onInitialValue, onChange, disabled }: IProps) => {
const modelOptions = useComposeLlmOptionsByModelTypes([
LlmModelType.Chat,
LlmModelType.Image2text,
]);
if (onInitialValue && value) {
for (const modelOption of modelOptions) {
for (const option of modelOption.options) {
if (option.value === value) {
onInitialValue(value, option);
break;
}
}
}
}
const content = (
<div style={{ width: 400 }}>
<LlmSettingItems
<LlmSettingItems onChange={onChange}
formItemLayout={{ labelCol: { span: 10 }, wrapperCol: { span: 14 } }}
></LlmSettingItems>
</div>

View File

@ -16,9 +16,10 @@ interface IProps {
prefix?: string;
formItemLayout?: any;
handleParametersChange?(value: ModelVariableType): void;
onChange?(value: string, option: any): void;
}
const LlmSettingItems = ({ prefix, formItemLayout = {} }: IProps) => {
const LlmSettingItems = ({ prefix, formItemLayout = {}, onChange }: IProps) => {
const form = Form.useFormInstance();
const { t } = useTranslate('chat');
const parameterOptions = Object.values(ModelVariableType).map((x) => ({
@ -58,6 +59,7 @@ const LlmSettingItems = ({ prefix, formItemLayout = {} }: IProps) => {
options={modelOptions}
showSearch
popupMatchSelectWidth={false}
onChange={onChange}
/>
</Form.Item>
<div className="border rounded-md">

View File

@ -0,0 +1,51 @@
import { useTranslate } from '@/hooks/common-hooks';
import { useLlmToolsList } from '@/hooks/plugin-hooks';
import { Select, Space } from 'antd';
interface IProps {
value?: string;
onChange?: (value: string) => void;
disabled?: boolean;
}
const LLMToolsSelect = ({ value, onChange, disabled }: IProps) => {
const { t } = useTranslate("llmTools");
const tools = useLlmToolsList();
function wrapTranslation(text: string): string {
if (!text) {
return text;
}
if (text.startsWith("$t:")) {
return t(text.substring(3));
}
return text;
}
const toolOptions = tools.map(t => ({
label: wrapTranslation(t.displayName),
description: wrapTranslation(t.displayDescription),
value: t.name,
title: wrapTranslation(t.displayDescription),
}));
return (
<Select
mode="multiple"
options={toolOptions}
optionRender={option => (
<Space size="large">
{option.label}
{option.data.description}
</Space>
)}
onChange={onChange}
value={value}
disabled={disabled}
></Select>
);
};
export default LLMToolsSelect;

View File

@ -71,6 +71,7 @@ function buildLlmOptionsWithIcon(x: IThirdOAIModel) {
),
value: `${x.llm_name}@${x.fid}`,
disabled: !x.available,
is_tools: x.is_tools,
};
}
@ -142,7 +143,7 @@ export const useComposeLlmOptionsByModelTypes = (
return modelTypes.reduce<
(DefaultOptionType & {
options: { label: JSX.Element; value: string; disabled: boolean }[];
options: { label: JSX.Element; value: string; disabled: boolean; is_tools: boolean }[];
})[]
>((pre, cur) => {
const options = allOptions[cur];

View File

@ -0,0 +1,17 @@
import { ILLMTools } from '@/interfaces/database/plugin';
import pluginService from '@/services/plugin-service';
import { useQuery } from '@tanstack/react-query';
export const useLlmToolsList = (): ILLMTools => {
const { data } = useQuery({
queryKey: ['llmTools'],
initialData: [],
queryFn: async () => {
const { data } = await pluginService.getLlmTools();
return data?.data ?? [];
},
});
return data;
};

View File

@ -13,6 +13,7 @@ export interface IThirdOAIModel {
update_time: number;
tenant_id?: string;
tenant_name?: string;
is_tools: boolean;
}
export type IThirdOAIModelCollection = Record<string, IThirdOAIModel[]>;

View File

@ -0,0 +1,13 @@
export type ILLMTools = ILLMToolMetadata[];
export interface ILLMToolMetadata {
name: string;
displayName: string;
displayDescription: string;
parameters: Map<string, ILLMToolParameter>;
}
export interface ILLMToolParameter {
type: string;
displayDescription: string;
}

View File

@ -454,6 +454,8 @@ This auto-tagging feature enhances retrieval by adding another layer of domain-s
model: 'Model',
modelTip: 'Large language chat model',
modelMessage: 'Please select!',
modelEnabledTools: 'Enabled tools',
modelEnabledToolsTip: 'Please select one or more tools for the chat model to use. It takes no effect for models not supporting tool call.',
freedom: 'Freedom',
improvise: 'Improvise',
precise: 'Precise',
@ -1267,5 +1269,15 @@ This delimiter is used to split the input text into several text pieces echo of
inputVariables: 'Input variables',
runningHintText: 'is running...🕞',
},
llmTools: {
bad_calculator: {
name: "Calculator",
description: "A tool to calculate the sum of two numbers (will give wrong answer)",
params: {
a: "The first number",
b: "The second number",
},
},
},
},
};

View File

@ -461,6 +461,8 @@ General实体和关系提取提示来自 GitHub - microsoft/graphrag基于
model: '模型',
modelTip: '大语言聊天模型',
modelMessage: '请选择',
modelEnabledTools: '可用的工具',
modelEnabledToolsTip: '请选择一个或多个可供该模型所使用的工具。仅对支持工具调用的模型生效。',
freedom: '自由度',
improvise: '即兴创作',
precise: '精确',
@ -1231,5 +1233,15 @@ General实体和关系提取提示来自 GitHub - microsoft/graphrag基于
knowledge: 'knowledge',
chat: 'chat',
},
llmTools: {
bad_calculator: {
name: "计算器",
description: "用于计算两个数的和的工具(会给出错误答案)",
params: {
a: "第一个数",
b: "第二个数",
},
},
},
},
};

View File

@ -4,10 +4,18 @@ import { PromptEditor } from '@/components/prompt-editor';
import { useTranslate } from '@/hooks/common-hooks';
import { Form, Switch } from 'antd';
import { IOperatorForm } from '../../interface';
import LLMToolsSelect from '@/components/llm-tools-select';
import { useState } from 'react';
const GenerateForm = ({ onValuesChange, form }: IOperatorForm) => {
const { t } = useTranslate('flow');
const [isCurrentLlmSupportTools, setCurrentLlmSupportTools] = useState(false);
const onLlmSelectChanged = (_: string, option: any) => {
setCurrentLlmSupportTools(option.is_tools);
};
return (
<Form
name="basic"
@ -21,7 +29,7 @@ const GenerateForm = ({ onValuesChange, form }: IOperatorForm) => {
label={t('model', { keyPrefix: 'chat' })}
tooltip={t('modelTip', { keyPrefix: 'chat' })}
>
<LLMSelect></LLMSelect>
<LLMSelect onInitialValue={onLlmSelectChanged} onChange={onLlmSelectChanged}></LLMSelect>
</Form.Item>
<Form.Item
name={['prompt']}
@ -38,6 +46,13 @@ const GenerateForm = ({ onValuesChange, form }: IOperatorForm) => {
{/* <Input.TextArea rows={8}></Input.TextArea> */}
<PromptEditor></PromptEditor>
</Form.Item>
<Form.Item
name={'llm_enabled_tools'}
label={t('modelEnabledTools', { keyPrefix: 'chat' })}
tooltip={t('modelEnabledToolsTip', { keyPrefix: 'chat' })}
>
<LLMToolsSelect disabled={!isCurrentLlmSupportTools}></LLMToolsSelect>
</Form.Item>
<Form.Item
name={['cite']}
label={t('cite')}

View File

@ -0,0 +1,18 @@
import api from '@/utils/api';
import registerServer from '@/utils/register-server';
import request from '@/utils/request';
const {
llm_tools
} = api;
const methods = {
getLlmTools: {
url: llm_tools,
method: 'get',
},
} as const;
const pluginService = registerServer<keyof typeof methods>(methods, request);
export default pluginService;

View File

@ -32,6 +32,9 @@ export default {
delete_llm: `${api_host}/llm/delete_llm`,
deleteFactory: `${api_host}/llm/delete_factory`,
// plugin
llm_tools: `${api_host}/plugin/llm_tools`,
// knowledge base
kb_list: `${api_host}/kb/list`,
create_kb: `${api_host}/kb/create`,