From 487aed419e393122558770b9df0c65e18412cc88 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Fri, 18 Apr 2025 18:05:26 +0800 Subject: [PATCH] Fix: cite disfunction for G component. (#7117) ### What problem does this PR solve? #7097 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- agent/component/base.py | 5 +++++ agent/component/generate.py | 22 ++++++++++++---------- agent/component/retrieval.py | 3 ++- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/agent/component/base.py b/agent/component/base.py index eee28b442..a09e6c0a2 100644 --- a/agent/component/base.py +++ b/agent/component/base.py @@ -384,6 +384,11 @@ class ComponentBase(ABC): "params": {} } """ + out = getattr(self._param, self._param.output_var_name) + if isinstance(out, pd.DataFrame) and "chunks" in out: + del out["chunks"] + setattr(self._param, self._param.output_var_name, out) + return """{{ "component_name": "{}", "params": {}, diff --git a/agent/component/generate.py b/agent/component/generate.py index 82963ea8e..fc17f6958 100644 --- a/agent/component/generate.py +++ b/agent/component/generate.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import json import re from functools import partial import pandas as pd @@ -74,29 +75,30 @@ class Generate(ComponentBase): return list(cpnts) def set_cite(self, retrieval_res, answer): - retrieval_res = retrieval_res.dropna(subset=["vector", "content_ltks"]).reset_index(drop=True) if "empty_response" in retrieval_res.columns: retrieval_res["empty_response"].fillna("", inplace=True) + chunks = json.loads(retrieval_res["chunks"][0]) answer, idx = settings.retrievaler.insert_citations(answer, - [ck["content_ltks"] for _, ck in retrieval_res.iterrows()], - [ck["vector"] for _, ck in retrieval_res.iterrows()], + [ck["content_ltks"] for ck in chunks], + [ck["vector"] for ck in chunks], LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING, self._canvas.get_embedding_model()), tkweight=0.7, vtweight=0.3) doc_ids = set([]) recall_docs = [] for i in idx: - did = retrieval_res.loc[int(i), "doc_id"] + did = chunks[int(i)]["doc_id"] if did in doc_ids: continue doc_ids.add(did) - recall_docs.append({"doc_id": did, "doc_name": retrieval_res.loc[int(i), "docnm_kwd"]}) + recall_docs.append({"doc_id": did, "doc_name": chunks[int(i)]["docnm_kwd"]}) - del retrieval_res["vector"] - del retrieval_res["content_ltks"] + for c in chunks: + del c["vector"] + del c["content_ltks"] reference = { - "chunks": [ck.to_dict() for _, ck in retrieval_res.iterrows()], + "chunks": chunks, "doc_aggs": recall_docs } @@ -200,7 +202,7 @@ class Generate(ComponentBase): ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf()) ans = re.sub(r".*", "", ans, flags=re.DOTALL) - if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns: + if self._param.cite and "chunks" in retrieval_res.columns: res = self.set_cite(retrieval_res, ans) return pd.DataFrame([res]) @@ -229,7 +231,7 @@ class Generate(ComponentBase): answer = ans yield res - if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns: + if self._param.cite and "chunks" in retrieval_res.columns: res = self.set_cite(retrieval_res, answer) yield res diff --git a/agent/component/retrieval.py b/agent/component/retrieval.py index 90cb8c9eb..8a2568c13 100644 --- a/agent/component/retrieval.py +++ b/agent/component/retrieval.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import json import logging from abc import ABC @@ -103,7 +104,7 @@ class Retrieval(ComponentBase, ABC): df["empty_response"] = self._param.empty_response return df - df = pd.DataFrame({"content": kb_prompt(kbinfos, 200000)}) + df = pd.DataFrame({"content": kb_prompt(kbinfos, 200000), "chunks": json.dumps(kbinfos["chunks"])}) logging.debug("{} {}".format(query, df)) return df.dropna()