Fix: cite disfunction for G component. (#7117)

### What problem does this PR solve?

#7097

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Kevin Hu 2025-04-18 18:05:26 +08:00 committed by GitHub
parent 8b8a2f2949
commit 487aed419e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 19 additions and 11 deletions

View File

@ -384,6 +384,11 @@ class ComponentBase(ABC):
"params": {} "params": {}
} }
""" """
out = getattr(self._param, self._param.output_var_name)
if isinstance(out, pd.DataFrame) and "chunks" in out:
del out["chunks"]
setattr(self._param, self._param.output_var_name, out)
return """{{ return """{{
"component_name": "{}", "component_name": "{}",
"params": {}, "params": {},

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import json
import re import re
from functools import partial from functools import partial
import pandas as pd import pandas as pd
@ -74,29 +75,30 @@ class Generate(ComponentBase):
return list(cpnts) return list(cpnts)
def set_cite(self, retrieval_res, answer): def set_cite(self, retrieval_res, answer):
retrieval_res = retrieval_res.dropna(subset=["vector", "content_ltks"]).reset_index(drop=True)
if "empty_response" in retrieval_res.columns: if "empty_response" in retrieval_res.columns:
retrieval_res["empty_response"].fillna("", inplace=True) retrieval_res["empty_response"].fillna("", inplace=True)
chunks = json.loads(retrieval_res["chunks"][0])
answer, idx = settings.retrievaler.insert_citations(answer, answer, idx = settings.retrievaler.insert_citations(answer,
[ck["content_ltks"] for _, ck in retrieval_res.iterrows()], [ck["content_ltks"] for ck in chunks],
[ck["vector"] for _, ck in retrieval_res.iterrows()], [ck["vector"] for ck in chunks],
LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING, LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
self._canvas.get_embedding_model()), tkweight=0.7, self._canvas.get_embedding_model()), tkweight=0.7,
vtweight=0.3) vtweight=0.3)
doc_ids = set([]) doc_ids = set([])
recall_docs = [] recall_docs = []
for i in idx: for i in idx:
did = retrieval_res.loc[int(i), "doc_id"] did = chunks[int(i)]["doc_id"]
if did in doc_ids: if did in doc_ids:
continue continue
doc_ids.add(did) doc_ids.add(did)
recall_docs.append({"doc_id": did, "doc_name": retrieval_res.loc[int(i), "docnm_kwd"]}) recall_docs.append({"doc_id": did, "doc_name": chunks[int(i)]["docnm_kwd"]})
del retrieval_res["vector"] for c in chunks:
del retrieval_res["content_ltks"] del c["vector"]
del c["content_ltks"]
reference = { reference = {
"chunks": [ck.to_dict() for _, ck in retrieval_res.iterrows()], "chunks": chunks,
"doc_aggs": recall_docs "doc_aggs": recall_docs
} }
@ -200,7 +202,7 @@ class Generate(ComponentBase):
ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf()) ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf())
ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL) ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL)
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns: if self._param.cite and "chunks" in retrieval_res.columns:
res = self.set_cite(retrieval_res, ans) res = self.set_cite(retrieval_res, ans)
return pd.DataFrame([res]) return pd.DataFrame([res])
@ -229,7 +231,7 @@ class Generate(ComponentBase):
answer = ans answer = ans
yield res yield res
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns: if self._param.cite and "chunks" in retrieval_res.columns:
res = self.set_cite(retrieval_res, answer) res = self.set_cite(retrieval_res, answer)
yield res yield res

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import json
import logging import logging
from abc import ABC from abc import ABC
@ -103,7 +104,7 @@ class Retrieval(ComponentBase, ABC):
df["empty_response"] = self._param.empty_response df["empty_response"] = self._param.empty_response
return df return df
df = pd.DataFrame({"content": kb_prompt(kbinfos, 200000)}) df = pd.DataFrame({"content": kb_prompt(kbinfos, 200000), "chunks": json.dumps(kbinfos["chunks"])})
logging.debug("{} {}".format(query, df)) logging.debug("{} {}".format(query, df))
return df.dropna() return df.dropna()