diff --git a/agent/component/retrieval.py b/agent/component/retrieval.py index b6013506a..ec45ba1ae 100644 --- a/agent/component/retrieval.py +++ b/agent/component/retrieval.py @@ -50,12 +50,15 @@ class Retrieval(ComponentBase, ABC): component_name = "Retrieval" def _run(self, history, **kwargs): - query = [] - for role, cnt in history[::-1][:self._param.message_history_window_size]: - if role != "user":continue - query.append(cnt) - # query = "\n".join(query) - query = query[0] + # query = [] + # for role, cnt in history[::-1][:self._param.message_history_window_size]: + # if role != "user":continue + # query.append(cnt) + # # query = "\n".join(query) + # query = query[0] + query = self.get_input() + query = str(query["content"][0]) if "content" in query else "" + kbs = KnowledgebaseService.get_by_ids(self._param.kb_ids) if not kbs: raise ValueError("Can't find knowledgebases by {}".format(self._param.kb_ids)) diff --git a/agent/component/rewrite.py b/agent/component/rewrite.py index 31dc0880e..63e473d8a 100644 --- a/agent/component/rewrite.py +++ b/agent/component/rewrite.py @@ -91,7 +91,11 @@ class RewriteQuestion(Generate, ABC): raise Exception("Sorry! Nothing relevant found.") self._loop += 1 - conv = self._canvas.get_history(4) + hist = self._canvas.get_history(4) + conv = [] + for m in hist: + if m["role"] not in ["user", "assistant"]: continue + conv.append("{}: {}".format(m["role"].upper(), m["content"])) conv = "\n".join(conv) chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)