From 6e7dd54a5065c05df2e3f2f869720bc3fbab90cd Mon Sep 17 00:00:00 2001 From: Song Fuchang Date: Wed, 30 Apr 2025 15:32:14 +0800 Subject: [PATCH] Feat: Support passing knowledge base id as variable in retrieval component (#7088) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What problem does this PR solve? Fix #6600 Hello, I have the same business requirement as #6600. My use case is: We have many departments (> 20 now and increasing), and each department has its own knowledge base. Because the agent workflow is the same, so I want to change the knowledge base on the fly, instead of creating agents for every department. It now looks like this: ![屏幕截图_20250416_212622](https://github.com/user-attachments/assets/5cb3dade-d4fb-4591-ade3-4b9c54387911) Knowledge bases can be selected from the dropdown, and passed through the variables in the table. All selected knowledge bases are used for retrieval. ### Type of change - [ ] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality) - [ ] Documentation Update - [ ] Refactoring - [ ] Performance Improvement - [ ] Other (please describe): --- agent/component/base.py | 74 +++++++++++-------- agent/component/retrieval.py | 22 +++++- web/src/components/knowledge-base-item.tsx | 4 +- web/src/locales/en.ts | 3 + web/src/locales/zh.ts | 2 + .../components/dynamic-input-variable.tsx | 19 +++-- .../pages/flow/form/retrieval-form/index.tsx | 9 ++- 7 files changed, 90 insertions(+), 43 deletions(-) diff --git a/agent/component/base.py b/agent/component/base.py index a09e6c0a2..54ccb6c3c 100644 --- a/agent/component/base.py +++ b/agent/component/base.py @@ -19,7 +19,7 @@ import json import os import logging from functools import partial -from typing import Tuple, Union +from typing import Any, Tuple, Union import pandas as pd @@ -462,6 +462,33 @@ class ComponentBase(ABC): def set_output(self, v): setattr(self._param, self._param.output_var_name, v) + def _fetch_outputs_from(self, sources: list[dict[str, Any]]) -> list[pd.DataFrame]: + outs = [] + for q in sources: + if q.get("component_id"): + if q["component_id"].split("@")[0].lower().find("begin") >= 0: + cpn_id, key = q["component_id"].split("@") + for p in self._canvas.get_component(cpn_id)["obj"]._param.query: + if p["key"] == key: + outs.append(pd.DataFrame([{"content": p.get("value", "")}])) + break + else: + assert False, f"Can't find parameter '{key}' for {cpn_id}" + continue + + if q["component_id"].lower().find("answer") == 0: + txt = [] + for r, c in self._canvas.history[::-1][:self._param.message_history_window_size][::-1]: + txt.append(f"{r.upper()}:{c}") + txt = "\n".join(txt) + outs.append(pd.DataFrame([{"content": txt}])) + continue + + outs.append(self._canvas.get_component(q["component_id"])["obj"].output(allow_partial=False)[1]) + elif q.get("value"): + outs.append(pd.DataFrame([{"content": q["value"]}])) + return outs + def get_input(self): if self._param.debug_inputs: return pd.DataFrame([{"content": v["value"]} for v in self._param.debug_inputs if v.get("value")]) @@ -475,37 +502,24 @@ class ComponentBase(ABC): if self._param.query: self._param.inputs = [] - outs = [] - for q in self._param.query: - if q.get("component_id"): - if q["component_id"].split("@")[0].lower().find("begin") >= 0: - cpn_id, key = q["component_id"].split("@") - for p in self._canvas.get_component(cpn_id)["obj"]._param.query: - if p["key"] == key: - outs.append(pd.DataFrame([{"content": p.get("value", "")}])) - self._param.inputs.append({"component_id": q["component_id"], - "content": p.get("value", "")}) - break - else: - assert False, f"Can't find parameter '{key}' for {cpn_id}" - continue + outs = self._fetch_outputs_from(self._param.query) - if q["component_id"].lower().find("answer") == 0: - txt = [] - for r, c in self._canvas.history[::-1][:self._param.message_history_window_size][::-1]: - txt.append(f"{r.upper()}:{c}") - txt = "\n".join(txt) - self._param.inputs.append({"content": txt, "component_id": q["component_id"]}) - outs.append(pd.DataFrame([{"content": txt}])) - continue + for out in outs: + records = out.to_dict("records") + content: str + + if len(records) > 1: + content = "\n".join( + [str(d["content"]) for d in records] + ) + else: + content = records[0]["content"] + + self._param.inputs.append({ + "component_id": records[0].get("component_id"), + "content": content + }) - outs.append(self._canvas.get_component(q["component_id"])["obj"].output(allow_partial=False)[1]) - self._param.inputs.append({"component_id": q["component_id"], - "content": "\n".join( - [str(d["content"]) for d in outs[-1].to_dict('records')])}) - elif q.get("value"): - self._param.inputs.append({"component_id": None, "content": q["value"]}) - outs.append(pd.DataFrame([{"content": q["value"]}])) if outs: df = pd.concat(outs, ignore_index=True) if "content" in df: diff --git a/agent/component/retrieval.py b/agent/component/retrieval.py index 8a2568c13..5364342b2 100644 --- a/agent/component/retrieval.py +++ b/agent/component/retrieval.py @@ -41,6 +41,7 @@ class RetrievalParam(ComponentParamBase): self.top_n = 8 self.top_k = 1024 self.kb_ids = [] + self.kb_vars = [] self.rerank_id = "" self.empty_response = "" self.tavily_api_key = "" @@ -58,7 +59,22 @@ class Retrieval(ComponentBase, ABC): def _run(self, history, **kwargs): query = self.get_input() query = str(query["content"][0]) if "content" in query else "" - kbs = KnowledgebaseService.get_by_ids(self._param.kb_ids) + + kb_ids: list[str] = self._param.kb_ids or [] + + kb_vars = self._fetch_outputs_from(self._param.kb_vars) + + if len(kb_vars) > 0: + for kb_var in kb_vars: + if len(kb_var) == 1: + kb_ids.append(str(kb_var["content"][0])) + else: + for v in kb_var.to_dict("records"): + kb_ids.append(v["content"]) + + filtered_kb_ids: list[str] = [kb_id for kb_id in kb_ids if kb_id] + + kbs = KnowledgebaseService.get_by_ids(filtered_kb_ids) if not kbs: return Retrieval.be_output("") @@ -75,7 +91,7 @@ class Retrieval(ComponentBase, ABC): rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id) if kbs: - kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids, + kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, filtered_kb_ids, 1, self._param.top_n, self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight, aggs=False, rerank_mdl=rerank_mdl, @@ -86,7 +102,7 @@ class Retrieval(ComponentBase, ABC): if self._param.use_kg and kbs: ck = settings.kg_retrievaler.retrieval(query, [kbs[0].tenant_id], - self._param.kb_ids, + filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, LLMType.CHAT)) if ck["content_with_weight"]: diff --git a/web/src/components/knowledge-base-item.tsx b/web/src/components/knowledge-base-item.tsx index 9e723df87..65b218c76 100644 --- a/web/src/components/knowledge-base-item.tsx +++ b/web/src/components/knowledge-base-item.tsx @@ -10,11 +10,13 @@ import { FormControl, FormField, FormItem, FormLabel } from './ui/form'; import { MultiSelect } from './ui/multi-select'; interface KnowledgeBaseItemProps { + tooltipText?: string; required?: boolean; onChange?(): void; } const KnowledgeBaseItem = ({ + tooltipText, required = true, onChange, }: KnowledgeBaseItemProps) => { @@ -40,7 +42,7 @@ const KnowledgeBaseItem = ({ type === VariableType.Reference ? 'component_id' : 'value'; -const DynamicVariableForm = ({ node }: IProps) => { +const DynamicVariableForm = ({ name: formName, node }: IProps) => { + formName = formName || 'query'; const { t } = useTranslation(); const valueOptions = useBuildComponentIdSelectOptions( node?.id, @@ -35,15 +38,15 @@ const DynamicVariableForm = ({ node }: IProps) => { const handleTypeChange = useCallback( (name: number) => () => { setTimeout(() => { - form.setFieldValue(['query', name, 'component_id'], undefined); - form.setFieldValue(['query', name, 'value'], undefined); + form.setFieldValue([formName, name, 'component_id'], undefined); + form.setFieldValue([formName, name, 'value'], undefined); }, 0); }, [form], ); return ( - + {(fields, { add, remove }) => ( <> {fields.map(({ key, name, ...restField }) => ( @@ -60,7 +63,7 @@ const DynamicVariableForm = ({ node }: IProps) => { {({ getFieldValue }) => { - const type = getFieldValue(['query', name, 'type']); + const type = getFieldValue([formName, name, 'type']); return ( { +const DynamicInputVariable = ({ name, node, title }: IProps) => { const { t } = useTranslation(); return ( - - + + ); }; diff --git a/web/src/pages/flow/form/retrieval-form/index.tsx b/web/src/pages/flow/form/retrieval-form/index.tsx index 3d7f29e64..fa3da7131 100644 --- a/web/src/pages/flow/form/retrieval-form/index.tsx +++ b/web/src/pages/flow/form/retrieval-form/index.tsx @@ -43,7 +43,14 @@ const RetrievalForm = ({ onValuesChange, form, node }: IOperatorForm) => { - + +