diff --git a/graph/component/answer.py b/graph/component/answer.py index 112817852..3973111d7 100644 --- a/graph/component/answer.py +++ b/graph/component/answer.py @@ -59,8 +59,10 @@ class Answer(ComponentBase, ABC): stream = self.get_stream_input() if isinstance(stream, pd.DataFrame): res = stream + answer = "" for ii, row in stream.iterrows(): - yield row.to_dict() + answer += row.to_dict()["content"] + yield {"content": answer} else: for st in stream(): res = st diff --git a/graph/component/generate.py b/graph/component/generate.py index d6d950e42..924af903f 100644 --- a/graph/component/generate.py +++ b/graph/component/generate.py @@ -67,6 +67,34 @@ class Generate(ComponentBase): cpnts = [para["component_id"] for para in self._param.parameters] return cpnts + def set_cite(self, retrieval_res, answer): + answer, idx = retrievaler.insert_citations(answer, [ck["content_ltks"] for _, ck in retrieval_res.iterrows()], + [ck["vector"] for _, ck in retrieval_res.iterrows()], + LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING, + self._canvas.get_embedding_model()), tkweight=0.7, + vtweight=0.3) + doc_ids = set([]) + recall_docs = [] + for i in idx: + did = retrieval_res.loc[int(i), "doc_id"] + if did in doc_ids: continue + doc_ids.add(did) + recall_docs.append({"doc_id": did, "doc_name": retrieval_res.loc[int(i), "docnm_kwd"]}) + + del retrieval_res["vector"] + del retrieval_res["content_ltks"] + + reference = { + "chunks": [ck.to_dict() for _, ck in retrieval_res.iterrows()], + "doc_aggs": recall_docs + } + + if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0: + answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'" + res = {"content": answer, "reference": reference} + + return res + def _run(self, history, **kwargs): chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id) prompt = self._param.prompt @@ -87,9 +115,8 @@ class Generate(ComponentBase): prompt = re.sub(r"\{%s\}" % n, str(v), prompt) downstreams = self._canvas.get_component(self._id)["downstream"] - if kwargs.get("stream") \ - and len(downstreams) == 1 \ - and self._canvas.get_component(downstreams[0])["obj"].component_name.lower() == "answer": + if kwargs.get("stream") and len(downstreams) == 1 and self._canvas.get_component(downstreams[0])[ + "obj"].component_name.lower() == "answer": return partial(self.stream_output, chat_mdl, prompt, retrieval_res) if "empty_response" in retrieval_res.columns: @@ -97,27 +124,8 @@ class Generate(ComponentBase): ans = chat_mdl.chat(prompt, self._canvas.get_history(self._param.message_history_window_size), self._param.gen_conf()) - if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns: - ans, idx = retrievaler.insert_citations(ans, - [ck["content_ltks"] - for _, ck in retrieval_res.iterrows()], - [ck["vector"] - for _, ck in retrieval_res.iterrows()], - LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING, - self._canvas.get_embedding_model()), - tkweight=0.7, - vtweight=0.3) - del retrieval_res["vector"] - retrieval_res = retrieval_res.to_dict("records") - df = [] - for i in idx: - df.append(retrieval_res[int(i)]) - r = re.search(r"^((.|[\r\n])*? ##%s\$\$)" % str(i), ans) - assert r, f"{i} => {ans}" - df[-1]["content"] = r.group(1) - ans = re.sub(r"^((.|[\r\n])*? ##%s\$\$)" % str(i), "", ans) - if ans: df.append({"content": ans}) + df = self.set_cite(retrieval_res, ans) return pd.DataFrame(df) return Generate.be_output(ans) @@ -138,34 +146,7 @@ class Generate(ComponentBase): yield res if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns: - answer, idx = retrievaler.insert_citations(answer, - [ck["content_ltks"] - for _, ck in retrieval_res.iterrows()], - [ck["vector"] - for _, ck in retrieval_res.iterrows()], - LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING, - self._canvas.get_embedding_model()), - tkweight=0.7, - vtweight=0.3) - doc_ids = set([]) - recall_docs = [] - for i in idx: - did = retrieval_res.loc[int(i), "doc_id"] - if did in doc_ids: continue - doc_ids.add(did) - recall_docs.append({"doc_id": did, "doc_name": retrieval_res.loc[int(i), "docnm_kwd"]}) - - del retrieval_res["vector"] - del retrieval_res["content_ltks"] - - reference = { - "chunks": [ck.to_dict() for _, ck in retrieval_res.iterrows()], - "doc_aggs": recall_docs - } - - if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0: - answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'" - res = {"content": answer, "reference": reference} + res = self.set_cite(retrieval_res, answer) yield res self.set_output(res)