diff --git a/agent/component/base.py b/agent/component/base.py index bca2fc0be..d105d43b8 100644 --- a/agent/component/base.py +++ b/agent/component/base.py @@ -37,6 +37,7 @@ class ComponentParamBase(ABC): self.output_var_name = "output" self.message_history_window_size = 22 self.query = [] + self.inputs = [] def set_name(self, name: str): self._name = name @@ -444,8 +445,13 @@ class ComponentBase(ABC): if self._param.query: outs = [] for q in self._param.query: - if q["value"]: outs.append(pd.DataFrame([{"content": q["value"]}])) - if q["component_id"]: outs.append(self._canvas.get_component(q["component_id"])["obj"].output(allow_partial=False)[1]) + if q["component_id"]: + outs.append(self._canvas.get_component(q["component_id"])["obj"].output(allow_partial=False)[1]) + self._param.inputs.append({"component_id": q["component_id"], + "content": "\n".join([str(d["content"]) for d in outs[-1].to_dict('records')])}) + elif q["value"]: + self._param.inputs.append({"component_id": None, "content": q["value"]}) + outs.append(pd.DataFrame([{"content": q["value"]}])) if outs: df = pd.concat(outs, ignore_index=True) if "content" in df: df = df.drop_duplicates(subset=['content']).reset_index(drop=True) @@ -463,31 +469,38 @@ class ComponentBase(ABC): if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval": o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1] if o is not None: + o["component_id"] = u upstream_outs.append(o) continue - if u not in self._canvas.get_component(self._id)["upstream"]: continue + if self.component_name.lower()!="answer" and u not in self._canvas.get_component(self._id)["upstream"]: continue if self.component_name.lower().find("switch") < 0 \ and self.get_component_name(u) in ["relevant", "categorize"]: continue if u.lower().find("answer") >= 0: for r, c in self._canvas.history[::-1]: if r == "user": - upstream_outs.append(pd.DataFrame([{"content": c}])) + upstream_outs.append(pd.DataFrame([{"content": c, "component_id": u}])) break break if self.component_name.lower().find("answer") >= 0 and self.get_component_name(u) in ["relevant"]: continue o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1] if o is not None: + o["component_id"] = u upstream_outs.append(o) break - if upstream_outs: - df = pd.concat(upstream_outs, ignore_index=True) - if "content" in df: - df = df.drop_duplicates(subset=['content']).reset_index(drop=True) - return df - return pd.DataFrame(self._canvas.get_history(3)[-1:]) + assert upstream_outs, "Can't inference the where the component input is." + + df = pd.concat(upstream_outs, ignore_index=True) + if "content" in df: + df = df.drop_duplicates(subset=['content']).reset_index(drop=True) + + self._param.inputs = [] + for _,r in df.iterrows(): + self._param.inputs.append({"component_id": r["component_id"], "content": r["content"]}) + + return df def get_stream_input(self): reversed_cpnts = [] diff --git a/agent/component/cite.py b/agent/component/cite.py deleted file mode 100644 index f50bc4e81..000000000 --- a/agent/component/cite.py +++ /dev/null @@ -1,75 +0,0 @@ -# -# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from abc import ABC - -import pandas as pd - -from api.db import LLMType -from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db.services.llm_service import LLMBundle -from api.settings import retrievaler -from agent.component.base import ComponentBase, ComponentParamBase - - -class CiteParam(ComponentParamBase): - - """ - Define the Retrieval component parameters. - """ - def __init__(self): - super().__init__() - self.cite_sources = [] - - def check(self): - self.check_empty(self.cite_source, "Please specify where you want to cite from.") - - -class Cite(ComponentBase, ABC): - component_name = "Cite" - - def _run(self, history, **kwargs): - input = "\n- ".join(self.get_input()["content"]) - sources = [self._canvas.get_component(cpn_id).output()[1] for cpn_id in self._param.cite_source] - query = [] - for role, cnt in history[::-1][:self._param.message_history_window_size]: - if role != "user":continue - query.append(cnt) - query = "\n".join(query) - - kbs = KnowledgebaseService.get_by_ids(self._param.kb_ids) - if not kbs: - raise ValueError("Can't find knowledgebases by {}".format(self._param.kb_ids)) - embd_nms = list(set([kb.embd_id for kb in kbs])) - assert len(embd_nms) == 1, "Knowledge bases use different embedding models." - - embd_mdl = LLMBundle(kbs[0].tenant_id, LLMType.EMBEDDING, embd_nms[0]) - - rerank_mdl = None - if self._param.rerank_id: - rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id) - - kbinfos = retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids, - 1, self._param.top_n, - self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight, - aggs=False, rerank_mdl=rerank_mdl) - - if not kbinfos["chunks"]: return pd.DataFrame() - df = pd.DataFrame(kbinfos["chunks"]) - df["content"] = df["content_with_weight"] - del df["content_with_weight"] - return df - - diff --git a/agent/component/generate.py b/agent/component/generate.py index 19fd9159e..ab5b07ed7 100644 --- a/agent/component/generate.py +++ b/agent/component/generate.py @@ -101,8 +101,8 @@ class Generate(ComponentBase): chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id) prompt = self._param.prompt - retrieval_res = self.get_input() - input = (" - "+"\n - ".join([c for c in retrieval_res["content"] if isinstance(c, str)])) if "content" in retrieval_res else "" + retrieval_res = [] + self._param.inputs = [] for para in self._param.parameters: cpn = self._canvas.get_component(para["component_id"])["obj"] if cpn.component_name.lower() == "answer": @@ -112,12 +112,24 @@ class Generate(ComponentBase): if "content" not in out.columns: kwargs[para["key"]] = "Nothing" else: + if cpn.component_name.lower() == "retrieval": + retrieval_res.append(out) kwargs[para["key"]] = " - "+"\n - ".join([o if isinstance(o, str) else str(o) for o in out["content"]]) + self._param.inputs.append({"component_id": para["component_id"], "content": kwargs[para["key"]]}) + + if retrieval_res: + retrieval_res = pd.concat(retrieval_res, ignore_index=True) + else: retrieval_res = pd.DataFrame([]) - kwargs["input"] = input for n, v in kwargs.items(): prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt) + if not self._param.inputs and prompt.find("{input}") >= 0: + retrieval_res = self.get_input() + input = (" - " + "\n - ".join( + [c for c in retrieval_res["content"] if isinstance(c, str)])) if "content" in retrieval_res else "" + prompt = re.sub(r"\{input\}", re.escape(input), prompt) + downstreams = self._canvas.get_component(self._id)["downstream"] if kwargs.get("stream") and len(downstreams) == 1 and self._canvas.get_component(downstreams[0])[ "obj"].component_name.lower() == "answer": diff --git a/agent/component/keyword.py b/agent/component/keyword.py index c5083efae..805dddf3c 100644 --- a/agent/component/keyword.py +++ b/agent/component/keyword.py @@ -50,14 +50,11 @@ class KeywordExtract(Generate, ABC): component_name = "KeywordExtract" def _run(self, history, **kwargs): - q = "" - for r, c in self._canvas.history[::-1]: - if r == "user": - q += c - break + query = self.get_input() + query = str(query["content"][0]) if "content" in query else "" chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id) - ans = chat_mdl.chat(self._param.get_prompt(), [{"role": "user", "content": q}], + ans = chat_mdl.chat(self._param.get_prompt(), [{"role": "user", "content": query}], self._param.gen_conf()) ans = re.sub(r".*keyword:", "", ans).strip() diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index dbbfe13c3..96bd8c4e5 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -396,6 +396,7 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True): rows = ["|" + "|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + "|" for r in tbl["rows"]] + rows = [r for r in rows if re.sub(r"[ |]+", "", r)] if quota: rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) else: diff --git a/conf/llm_factories.json b/conf/llm_factories.json index d67e3bb8f..ca8fa3085 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -1299,7 +1299,7 @@ "llm": [] }, { - "name": "cohere", + "name": "Cohere", "logo": "", "tags": "LLM,TEXT EMBEDDING, TEXT RE-RANK", "status": "1", diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index b4a56f2ec..1a28a70e1 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -39,7 +39,7 @@ EmbeddingModel = { "NVIDIA": NvidiaEmbed, "LM-Studio": LmStudioEmbed, "OpenAI-API-Compatible": OpenAI_APIEmbed, - "cohere": CoHereEmbed, + "Cohere": CoHereEmbed, "TogetherAI": TogetherAIEmbed, "PerfXCloud": PerfXCloudEmbed, "Upstage": UpstageEmbed, @@ -92,7 +92,7 @@ ChatModel = { "NVIDIA": NvidiaChat, "LM-Studio": LmStudioChat, "OpenAI-API-Compatible": OpenAI_APIChat, - "cohere": CoHereChat, + "Cohere": CoHereChat, "LeptonAI": LeptonAIChat, "TogetherAI": TogetherAIChat, "PerfXCloud": PerfXCloudChat, @@ -117,7 +117,7 @@ RerankModel = { "NVIDIA": NvidiaRerank, "LM-Studio": LmStudioRerank, "OpenAI-API-Compatible": OpenAI_APIRerank, - "cohere": CoHereRerank, + "Cohere": CoHereRerank, "TogetherAI": TogetherAIRerank, "SILICONFLOW": SILICONFLOWRerank, "BaiduYiyan": BaiduYiyanRerank, diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index fc8f6fa03..c322f257d 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -394,6 +394,7 @@ class VoyageRerank(Base): rank[r.index] = r.relevance_score return rank, res.total_tokens + class QWenRerank(Base): def __init__(self, key, model_name='gte-rerank', base_url=None, **kwargs): import dashscope