diff --git a/agent/component/retrieval.py b/agent/component/retrieval.py index 5364342b2..22a4b6827 100644 --- a/agent/component/retrieval.py +++ b/agent/component/retrieval.py @@ -30,10 +30,10 @@ from rag.utils.tavily_conn import Tavily class RetrievalParam(ComponentParamBase): - """ Define the Retrieval component parameters. """ + def __init__(self): super().__init__() self.similarity_threshold = 0.2 @@ -67,7 +67,10 @@ class Retrieval(ComponentBase, ABC): if len(kb_vars) > 0: for kb_var in kb_vars: if len(kb_var) == 1: - kb_ids.append(str(kb_var["content"][0])) + kb_var_value = str(kb_var["content"][0]) + + for v in kb_var_value.split(","): + kb_ids.append(v) else: for v in kb_var.to_dict("records"): kb_ids.append(v["content"]) @@ -91,20 +94,24 @@ class Retrieval(ComponentBase, ABC): rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id) if kbs: - kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, filtered_kb_ids, - 1, self._param.top_n, - self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight, - aggs=False, rerank_mdl=rerank_mdl, - rank_feature=label_question(query, kbs)) + kbinfos = settings.retrievaler.retrieval( + query, + embd_mdl, + kbs[0].tenant_id, + filtered_kb_ids, + 1, + self._param.top_n, + self._param.similarity_threshold, + 1 - self._param.keywords_similarity_weight, + aggs=False, + rerank_mdl=rerank_mdl, + rank_feature=label_question(query, kbs), + ) else: kbinfos = {"chunks": [], "doc_aggs": []} if self._param.use_kg and kbs: - ck = settings.kg_retrievaler.retrieval(query, - [kbs[0].tenant_id], - filtered_kb_ids, - embd_mdl, - LLMBundle(kbs[0].tenant_id, LLMType.CHAT)) + ck = settings.kg_retrievaler.retrieval(query, [kbs[0].tenant_id], filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, LLMType.CHAT)) if ck["content_with_weight"]: kbinfos["chunks"].insert(0, ck) @@ -123,5 +130,3 @@ class Retrieval(ComponentBase, ABC): df = pd.DataFrame({"content": kb_prompt(kbinfos, 200000), "chunks": json.dumps(kbinfos["chunks"])}) logging.debug("{} {}".format(query, df)) return df.dropna() - - diff --git a/web/src/components/knowledge-base-item.tsx b/web/src/components/knowledge-base-item.tsx index 65b218c76..7ae98e4d5 100644 --- a/web/src/components/knowledge-base-item.tsx +++ b/web/src/components/knowledge-base-item.tsx @@ -10,13 +10,17 @@ import { FormControl, FormField, FormItem, FormLabel } from './ui/form'; import { MultiSelect } from './ui/multi-select'; interface KnowledgeBaseItemProps { + label?: string; tooltipText?: string; + name?: string; required?: boolean; onChange?(): void; } const KnowledgeBaseItem = ({ + label, tooltipText, + name, required = true, onChange, }: KnowledgeBaseItemProps) => { @@ -40,8 +44,8 @@ const KnowledgeBaseItem = ({ return ( ), + [BeginQueryType.KnowledgeBases]: ( + + ) }; return ( @@ -182,12 +190,16 @@ const DebugContent = ({ if (Array.isArray(value)) { nextValue = ``; - value.forEach((x) => { - nextValue += - x?.originFileObj instanceof File - ? `${x.name}\n${x.response?.data}\n----\n` - : `${x.url}\n${x.result}\n----\n`; - }); + if (item.type === 'kb') { + nextValue = value.join(',') + } else { + value.forEach((x) => { + nextValue += + x?.originFileObj instanceof File + ? `${x.name}\n${x.response?.data}\n----\n` + : `${x.url}\n${x.result}\n----\n`; + }); + } } return { ...item, value: nextValue }; });