Fix multiple generate (#1722)

### What problem does this PR solve?

#1625 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
H 2024-07-29 09:27:59 +08:00 committed by GitHub
parent 61096596bc
commit 013856b604
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 52 deletions

View File

@ -59,8 +59,10 @@ class Answer(ComponentBase, ABC):
stream = self.get_stream_input()
if isinstance(stream, pd.DataFrame):
res = stream
answer = ""
for ii, row in stream.iterrows():
yield row.to_dict()
answer += row.to_dict()["content"]
yield {"content": answer}
else:
for st in stream():
res = st

View File

@ -67,6 +67,34 @@ class Generate(ComponentBase):
cpnts = [para["component_id"] for para in self._param.parameters]
return cpnts
def set_cite(self, retrieval_res, answer):
answer, idx = retrievaler.insert_citations(answer, [ck["content_ltks"] for _, ck in retrieval_res.iterrows()],
[ck["vector"] for _, ck in retrieval_res.iterrows()],
LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
self._canvas.get_embedding_model()), tkweight=0.7,
vtweight=0.3)
doc_ids = set([])
recall_docs = []
for i in idx:
did = retrieval_res.loc[int(i), "doc_id"]
if did in doc_ids: continue
doc_ids.add(did)
recall_docs.append({"doc_id": did, "doc_name": retrieval_res.loc[int(i), "docnm_kwd"]})
del retrieval_res["vector"]
del retrieval_res["content_ltks"]
reference = {
"chunks": [ck.to_dict() for _, ck in retrieval_res.iterrows()],
"doc_aggs": recall_docs
}
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
res = {"content": answer, "reference": reference}
return res
def _run(self, history, **kwargs):
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
prompt = self._param.prompt
@ -87,9 +115,8 @@ class Generate(ComponentBase):
prompt = re.sub(r"\{%s\}" % n, str(v), prompt)
downstreams = self._canvas.get_component(self._id)["downstream"]
if kwargs.get("stream") \
and len(downstreams) == 1 \
and self._canvas.get_component(downstreams[0])["obj"].component_name.lower() == "answer":
if kwargs.get("stream") and len(downstreams) == 1 and self._canvas.get_component(downstreams[0])[
"obj"].component_name.lower() == "answer":
return partial(self.stream_output, chat_mdl, prompt, retrieval_res)
if "empty_response" in retrieval_res.columns:
@ -97,27 +124,8 @@ class Generate(ComponentBase):
ans = chat_mdl.chat(prompt, self._canvas.get_history(self._param.message_history_window_size),
self._param.gen_conf())
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
ans, idx = retrievaler.insert_citations(ans,
[ck["content_ltks"]
for _, ck in retrieval_res.iterrows()],
[ck["vector"]
for _, ck in retrieval_res.iterrows()],
LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
self._canvas.get_embedding_model()),
tkweight=0.7,
vtweight=0.3)
del retrieval_res["vector"]
retrieval_res = retrieval_res.to_dict("records")
df = []
for i in idx:
df.append(retrieval_res[int(i)])
r = re.search(r"^((.|[\r\n])*? ##%s\$\$)" % str(i), ans)
assert r, f"{i} => {ans}"
df[-1]["content"] = r.group(1)
ans = re.sub(r"^((.|[\r\n])*? ##%s\$\$)" % str(i), "", ans)
if ans: df.append({"content": ans})
df = self.set_cite(retrieval_res, ans)
return pd.DataFrame(df)
return Generate.be_output(ans)
@ -138,34 +146,7 @@ class Generate(ComponentBase):
yield res
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
answer, idx = retrievaler.insert_citations(answer,
[ck["content_ltks"]
for _, ck in retrieval_res.iterrows()],
[ck["vector"]
for _, ck in retrieval_res.iterrows()],
LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
self._canvas.get_embedding_model()),
tkweight=0.7,
vtweight=0.3)
doc_ids = set([])
recall_docs = []
for i in idx:
did = retrieval_res.loc[int(i), "doc_id"]
if did in doc_ids: continue
doc_ids.add(did)
recall_docs.append({"doc_id": did, "doc_name": retrieval_res.loc[int(i), "docnm_kwd"]})
del retrieval_res["vector"]
del retrieval_res["content_ltks"]
reference = {
"chunks": [ck.to_dict() for _, ck in retrieval_res.iterrows()],
"doc_aggs": recall_docs
}
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
res = {"content": answer, "reference": reference}
res = self.set_cite(retrieval_res, answer)
yield res
self.set_output(res)