Feat: Support passing knowledge base id as variable in retrieval component (#7088)

### What problem does this PR solve?

Fix #6600

Hello, I have the same business requirement as #6600. My use case is: 

We have many departments (> 20 now and increasing), and each department
has its own knowledge base. Because the agent workflow is the same, so I
want to change the knowledge base on the fly, instead of creating agents
for every department.

It now looks like this:


![屏幕截图_20250416_212622](https://github.com/user-attachments/assets/5cb3dade-d4fb-4591-ade3-4b9c54387911)

Knowledge bases can be selected from the dropdown, and passed through
the variables in the table. All selected knowledge bases are used for
retrieval.

### Type of change

- [ ] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
- [ ] Documentation Update
- [ ] Refactoring
- [ ] Performance Improvement
- [ ] Other (please describe):
This commit is contained in:
Song Fuchang 2025-04-30 15:32:14 +08:00 committed by GitHub
parent f56b651acb
commit 6e7dd54a50
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 90 additions and 43 deletions

View File

@ -19,7 +19,7 @@ import json
import os
import logging
from functools import partial
from typing import Tuple, Union
from typing import Any, Tuple, Union
import pandas as pd
@ -462,6 +462,33 @@ class ComponentBase(ABC):
def set_output(self, v):
setattr(self._param, self._param.output_var_name, v)
def _fetch_outputs_from(self, sources: list[dict[str, Any]]) -> list[pd.DataFrame]:
outs = []
for q in sources:
if q.get("component_id"):
if q["component_id"].split("@")[0].lower().find("begin") >= 0:
cpn_id, key = q["component_id"].split("@")
for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
if p["key"] == key:
outs.append(pd.DataFrame([{"content": p.get("value", "")}]))
break
else:
assert False, f"Can't find parameter '{key}' for {cpn_id}"
continue
if q["component_id"].lower().find("answer") == 0:
txt = []
for r, c in self._canvas.history[::-1][:self._param.message_history_window_size][::-1]:
txt.append(f"{r.upper()}:{c}")
txt = "\n".join(txt)
outs.append(pd.DataFrame([{"content": txt}]))
continue
outs.append(self._canvas.get_component(q["component_id"])["obj"].output(allow_partial=False)[1])
elif q.get("value"):
outs.append(pd.DataFrame([{"content": q["value"]}]))
return outs
def get_input(self):
if self._param.debug_inputs:
return pd.DataFrame([{"content": v["value"]} for v in self._param.debug_inputs if v.get("value")])
@ -475,37 +502,24 @@ class ComponentBase(ABC):
if self._param.query:
self._param.inputs = []
outs = []
for q in self._param.query:
if q.get("component_id"):
if q["component_id"].split("@")[0].lower().find("begin") >= 0:
cpn_id, key = q["component_id"].split("@")
for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
if p["key"] == key:
outs.append(pd.DataFrame([{"content": p.get("value", "")}]))
self._param.inputs.append({"component_id": q["component_id"],
"content": p.get("value", "")})
break
else:
assert False, f"Can't find parameter '{key}' for {cpn_id}"
continue
outs = self._fetch_outputs_from(self._param.query)
if q["component_id"].lower().find("answer") == 0:
txt = []
for r, c in self._canvas.history[::-1][:self._param.message_history_window_size][::-1]:
txt.append(f"{r.upper()}:{c}")
txt = "\n".join(txt)
self._param.inputs.append({"content": txt, "component_id": q["component_id"]})
outs.append(pd.DataFrame([{"content": txt}]))
continue
for out in outs:
records = out.to_dict("records")
content: str
if len(records) > 1:
content = "\n".join(
[str(d["content"]) for d in records]
)
else:
content = records[0]["content"]
self._param.inputs.append({
"component_id": records[0].get("component_id"),
"content": content
})
outs.append(self._canvas.get_component(q["component_id"])["obj"].output(allow_partial=False)[1])
self._param.inputs.append({"component_id": q["component_id"],
"content": "\n".join(
[str(d["content"]) for d in outs[-1].to_dict('records')])})
elif q.get("value"):
self._param.inputs.append({"component_id": None, "content": q["value"]})
outs.append(pd.DataFrame([{"content": q["value"]}]))
if outs:
df = pd.concat(outs, ignore_index=True)
if "content" in df:

View File

@ -41,6 +41,7 @@ class RetrievalParam(ComponentParamBase):
self.top_n = 8
self.top_k = 1024
self.kb_ids = []
self.kb_vars = []
self.rerank_id = ""
self.empty_response = ""
self.tavily_api_key = ""
@ -58,7 +59,22 @@ class Retrieval(ComponentBase, ABC):
def _run(self, history, **kwargs):
query = self.get_input()
query = str(query["content"][0]) if "content" in query else ""
kbs = KnowledgebaseService.get_by_ids(self._param.kb_ids)
kb_ids: list[str] = self._param.kb_ids or []
kb_vars = self._fetch_outputs_from(self._param.kb_vars)
if len(kb_vars) > 0:
for kb_var in kb_vars:
if len(kb_var) == 1:
kb_ids.append(str(kb_var["content"][0]))
else:
for v in kb_var.to_dict("records"):
kb_ids.append(v["content"])
filtered_kb_ids: list[str] = [kb_id for kb_id in kb_ids if kb_id]
kbs = KnowledgebaseService.get_by_ids(filtered_kb_ids)
if not kbs:
return Retrieval.be_output("")
@ -75,7 +91,7 @@ class Retrieval(ComponentBase, ABC):
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
if kbs:
kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids,
kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, filtered_kb_ids,
1, self._param.top_n,
self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight,
aggs=False, rerank_mdl=rerank_mdl,
@ -86,7 +102,7 @@ class Retrieval(ComponentBase, ABC):
if self._param.use_kg and kbs:
ck = settings.kg_retrievaler.retrieval(query,
[kbs[0].tenant_id],
self._param.kb_ids,
filtered_kb_ids,
embd_mdl,
LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:

View File

@ -10,11 +10,13 @@ import { FormControl, FormField, FormItem, FormLabel } from './ui/form';
import { MultiSelect } from './ui/multi-select';
interface KnowledgeBaseItemProps {
tooltipText?: string;
required?: boolean;
onChange?(): void;
}
const KnowledgeBaseItem = ({
tooltipText,
required = true,
onChange,
}: KnowledgeBaseItemProps) => {
@ -40,7 +42,7 @@ const KnowledgeBaseItem = ({
<Form.Item
label={t('knowledgeBases')}
name="kb_ids"
tooltip={t('knowledgeBasesTip')}
tooltip={tooltipText || t('knowledgeBasesTip')}
rules={[
{
required,

View File

@ -1255,6 +1255,9 @@ This delimiter is used to split the input text into several text pieces echo of
promptTip:
'Use the system prompt to describe the task for the LLM, specify how it should respond, and outline other miscellaneous requirements. The system prompt is often used in conjunction with keys (variables), which serve as various data inputs for the LLM. Use a forward slash `/` or the (x) button to show the keys to use.',
promptMessage: 'Prompt is required',
knowledgeBasesTip:
'Select the knowledge bases to associate with this chat assistant, or choose variables containing knowledge base IDs below.',
knowledgeBaseVars: 'Knowledge base variables',
},
},
};

View File

@ -1210,6 +1210,8 @@ General实体和关系提取提示来自 GitHub - microsoft/graphrag基于
promptMessage: '提示词是必填项',
promptTip:
'系统提示为大模型提供任务描述、规定回复方式,以及设置其他各种要求。系统提示通常与 key (变量)合用,通过变量设置大模型的输入数据。你可以通过斜杠或者 (x) 按钮显示可用的 key。',
knowledgeBasesTip: '选择关联的知识库或者在下方选择包含知识库ID的变量。',
knowledgeBaseVars: '知识库变量',
},
footer: {
profile: 'All rights reserved @ React',

View File

@ -8,7 +8,9 @@ import { useBuildComponentIdSelectOptions } from '../../hooks/use-get-begin-quer
import styles from './index.less';
interface IProps {
name?: string;
node?: RAGFlowNodeType;
title?: string;
}
enum VariableType {
@ -19,7 +21,8 @@ enum VariableType {
const getVariableName = (type: string) =>
type === VariableType.Reference ? 'component_id' : 'value';
const DynamicVariableForm = ({ node }: IProps) => {
const DynamicVariableForm = ({ name: formName, node }: IProps) => {
formName = formName || 'query';
const { t } = useTranslation();
const valueOptions = useBuildComponentIdSelectOptions(
node?.id,
@ -35,15 +38,15 @@ const DynamicVariableForm = ({ node }: IProps) => {
const handleTypeChange = useCallback(
(name: number) => () => {
setTimeout(() => {
form.setFieldValue(['query', name, 'component_id'], undefined);
form.setFieldValue(['query', name, 'value'], undefined);
form.setFieldValue([formName, name, 'component_id'], undefined);
form.setFieldValue([formName, name, 'value'], undefined);
}, 0);
},
[form],
);
return (
<Form.List name="query">
<Form.List name={formName}>
{(fields, { add, remove }) => (
<>
{fields.map(({ key, name, ...restField }) => (
@ -60,7 +63,7 @@ const DynamicVariableForm = ({ node }: IProps) => {
</Form.Item>
<Form.Item noStyle dependencies={[name, 'type']}>
{({ getFieldValue }) => {
const type = getFieldValue(['query', name, 'type']);
const type = getFieldValue([formName, name, 'type']);
return (
<Form.Item
{...restField}
@ -118,11 +121,11 @@ export function FormCollapse({
);
}
const DynamicInputVariable = ({ node }: IProps) => {
const DynamicInputVariable = ({ name, node, title }: IProps) => {
const { t } = useTranslation();
return (
<FormCollapse title={t('flow.input')}>
<DynamicVariableForm node={node}></DynamicVariableForm>
<FormCollapse title={title || t('flow.input')}>
<DynamicVariableForm name={name} node={node}></DynamicVariableForm>
</FormCollapse>
);
};

View File

@ -43,7 +43,14 @@ const RetrievalForm = ({ onValuesChange, form, node }: IOperatorForm) => {
<Rerank></Rerank>
<TavilyItem name={'tavily_api_key'}></TavilyItem>
<UseKnowledgeGraphItem filedName={'use_kg'}></UseKnowledgeGraphItem>
<KnowledgeBaseItem></KnowledgeBaseItem>
<KnowledgeBaseItem
tooltipText={t('knowledgeBasesTip')}
></KnowledgeBaseItem>
<DynamicInputVariable
name={'kb_vars'}
node={node}
title={t('knowledgeBaseVars')}
></DynamicInputVariable>
<Form.Item
name={'empty_response'}
label={t('emptyResponse', { keyPrefix: 'chat' })}