refine agent (#2787)

### What problem does this PR solve?



### Type of change

- [ ] Bug Fix (non-breaking change which fixes an issue)
- [ ] New Feature (non-breaking change which adds functionality)
- [ ] Documentation Update
- [ ] Refactoring
- [x] Performance Improvement
- [ ] Other (please describe):
This commit is contained in:
Kevin Hu 2024-10-10 17:07:36 +08:00 committed by GitHub
parent 6af9d4e5f9
commit 7742f67481
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 42 additions and 10 deletions

View File

@ -73,7 +73,7 @@ class Categorize(Generate, ABC):
def _run(self, history, **kwargs): def _run(self, history, **kwargs):
input = self.get_input() input = self.get_input()
input = "Question: " + ("; ".join(input["content"]) if "content" in input else "") + "Category: " input = "Question: " + (list(input["content"])[-1] if "content" in input else "") + "\tCategory: "
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id) chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
ans = chat_mdl.chat(self._param.get_prompt(), [{"role": "user", "content": input}], ans = chat_mdl.chat(self._param.get_prompt(), [{"role": "user", "content": input}],
self._param.gen_conf()) self._param.gen_conf())

View File

@ -101,7 +101,7 @@ class Generate(ComponentBase):
prompt = self._param.prompt prompt = self._param.prompt
retrieval_res = self.get_input() retrieval_res = self.get_input()
input = (" - " + "\n - ".join(retrieval_res["content"])) if "content" in retrieval_res else "" input = (" - "+"\n - ".join([c for c in retrieval_res["content"] if isinstance(c, str)])) if "content" in retrieval_res else ""
for para in self._param.parameters: for para in self._param.parameters:
cpn = self._canvas.get_component(para["component_id"])["obj"] cpn = self._canvas.get_component(para["component_id"])["obj"]
_, out = cpn.output(allow_partial=False) _, out = cpn.output(allow_partial=False)

View File

@ -33,7 +33,7 @@ class RewriteQuestionParam(GenerateParam):
def check(self): def check(self):
super().check() super().check()
def get_prompt(self): def get_prompt(self, conv):
self.prompt = """ self.prompt = """
You are an expert at query expansion to generate a paraphrasing of a question. You are an expert at query expansion to generate a paraphrasing of a question.
I can't retrieval relevant information from the knowledge base by using user's question directly. I can't retrieval relevant information from the knowledge base by using user's question directly.
@ -43,6 +43,40 @@ class RewriteQuestionParam(GenerateParam):
And return 5 versions of question and one is from translation. And return 5 versions of question and one is from translation.
Just list the question. No other words are needed. Just list the question. No other words are needed.
""" """
return f"""
Role: A helpful assistant
Task: Generate a full user question that would follow the conversation.
Requirements & Restrictions:
- Text generated MUST be in the same language of the original user's question.
- If the user's latest question is completely, don't do anything, just return the original question.
- DON'T generate anything except a refined question.
######################
-Examples-
######################
# Example 1
## Conversation
USER: What is the name of Donald Trump's father?
ASSISTANT: Fred Trump.
USER: And his mother?
###############
Output: What's the name of Donald Trump's mother?
------------
# Example 2
## Conversation
USER: What is the name of Donald Trump's father?
ASSISTANT: Fred Trump.
USER: And his mother?
ASSISTANT: Mary Trump.
User: What's her full name?
###############
Output: What's the full name of Donald Trump's mother Mary Trump?
######################
# Real Data
## Conversation
{conv}
###############
"""
return self.prompt return self.prompt
@ -56,14 +90,12 @@ class RewriteQuestion(Generate, ABC):
self._loop = 0 self._loop = 0
raise Exception("Sorry! Nothing relevant found.") raise Exception("Sorry! Nothing relevant found.")
self._loop += 1 self._loop += 1
q = "Question: "
for r, c in self._canvas.history[::-1]: conv = self._canvas.get_history(4)
if r == "user": conv = "\n".join(conv)
q += c
break
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id) chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
ans = chat_mdl.chat(self._param.get_prompt(), [{"role": "user", "content": q}], ans = chat_mdl.chat(self._param.get_prompt(conv), [{"role": "user", "content": "Output: "}],
self._param.gen_conf()) self._param.gen_conf())
self._canvas.history.pop() self._canvas.history.pop()
self._canvas.history.append(("user", ans)) self._canvas.history.append(("user", ans))

View File

@ -112,7 +112,7 @@ def run():
canvas.messages.append({"role": "user", "content": req["message"], "id": message_id}) canvas.messages.append({"role": "user", "content": req["message"], "id": message_id})
if len([m for m in canvas.messages if m["role"] == "user"]) > 1: if len([m for m in canvas.messages if m["role"] == "user"]) > 1:
ten = TenantService.get_by_user_id(current_user.id)[0] ten = TenantService.get_by_user_id(current_user.id)[0]
req["message"] = full_question(ten["tenant_id"], ten["llm_id"], canvas.messages) #req["message"] = full_question(ten["tenant_id"], ten["llm_id"], canvas.messages)
canvas.add_user_input(req["message"]) canvas.add_user_input(req["message"])
answer = canvas.run(stream=stream) answer = canvas.run(stream=stream)
print(canvas) print(canvas)