mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-14 04:05:58 +08:00
refine agent (#2787)
### What problem does this PR solve? ### Type of change - [ ] Bug Fix (non-breaking change which fixes an issue) - [ ] New Feature (non-breaking change which adds functionality) - [ ] Documentation Update - [ ] Refactoring - [x] Performance Improvement - [ ] Other (please describe):
This commit is contained in:
parent
6af9d4e5f9
commit
7742f67481
@ -73,7 +73,7 @@ class Categorize(Generate, ABC):
|
||||
|
||||
def _run(self, history, **kwargs):
|
||||
input = self.get_input()
|
||||
input = "Question: " + ("; ".join(input["content"]) if "content" in input else "") + "Category: "
|
||||
input = "Question: " + (list(input["content"])[-1] if "content" in input else "") + "\tCategory: "
|
||||
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
|
||||
ans = chat_mdl.chat(self._param.get_prompt(), [{"role": "user", "content": input}],
|
||||
self._param.gen_conf())
|
||||
|
@ -101,7 +101,7 @@ class Generate(ComponentBase):
|
||||
prompt = self._param.prompt
|
||||
|
||||
retrieval_res = self.get_input()
|
||||
input = (" - " + "\n - ".join(retrieval_res["content"])) if "content" in retrieval_res else ""
|
||||
input = (" - "+"\n - ".join([c for c in retrieval_res["content"] if isinstance(c, str)])) if "content" in retrieval_res else ""
|
||||
for para in self._param.parameters:
|
||||
cpn = self._canvas.get_component(para["component_id"])["obj"]
|
||||
_, out = cpn.output(allow_partial=False)
|
||||
|
@ -33,7 +33,7 @@ class RewriteQuestionParam(GenerateParam):
|
||||
def check(self):
|
||||
super().check()
|
||||
|
||||
def get_prompt(self):
|
||||
def get_prompt(self, conv):
|
||||
self.prompt = """
|
||||
You are an expert at query expansion to generate a paraphrasing of a question.
|
||||
I can't retrieval relevant information from the knowledge base by using user's question directly.
|
||||
@ -43,6 +43,40 @@ class RewriteQuestionParam(GenerateParam):
|
||||
And return 5 versions of question and one is from translation.
|
||||
Just list the question. No other words are needed.
|
||||
"""
|
||||
return f"""
|
||||
Role: A helpful assistant
|
||||
Task: Generate a full user question that would follow the conversation.
|
||||
Requirements & Restrictions:
|
||||
- Text generated MUST be in the same language of the original user's question.
|
||||
- If the user's latest question is completely, don't do anything, just return the original question.
|
||||
- DON'T generate anything except a refined question.
|
||||
|
||||
######################
|
||||
-Examples-
|
||||
######################
|
||||
# Example 1
|
||||
## Conversation
|
||||
USER: What is the name of Donald Trump's father?
|
||||
ASSISTANT: Fred Trump.
|
||||
USER: And his mother?
|
||||
###############
|
||||
Output: What's the name of Donald Trump's mother?
|
||||
------------
|
||||
# Example 2
|
||||
## Conversation
|
||||
USER: What is the name of Donald Trump's father?
|
||||
ASSISTANT: Fred Trump.
|
||||
USER: And his mother?
|
||||
ASSISTANT: Mary Trump.
|
||||
User: What's her full name?
|
||||
###############
|
||||
Output: What's the full name of Donald Trump's mother Mary Trump?
|
||||
######################
|
||||
# Real Data
|
||||
## Conversation
|
||||
{conv}
|
||||
###############
|
||||
"""
|
||||
return self.prompt
|
||||
|
||||
|
||||
@ -56,14 +90,12 @@ class RewriteQuestion(Generate, ABC):
|
||||
self._loop = 0
|
||||
raise Exception("Sorry! Nothing relevant found.")
|
||||
self._loop += 1
|
||||
q = "Question: "
|
||||
for r, c in self._canvas.history[::-1]:
|
||||
if r == "user":
|
||||
q += c
|
||||
break
|
||||
|
||||
conv = self._canvas.get_history(4)
|
||||
conv = "\n".join(conv)
|
||||
|
||||
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
|
||||
ans = chat_mdl.chat(self._param.get_prompt(), [{"role": "user", "content": q}],
|
||||
ans = chat_mdl.chat(self._param.get_prompt(conv), [{"role": "user", "content": "Output: "}],
|
||||
self._param.gen_conf())
|
||||
self._canvas.history.pop()
|
||||
self._canvas.history.append(("user", ans))
|
||||
|
@ -112,7 +112,7 @@ def run():
|
||||
canvas.messages.append({"role": "user", "content": req["message"], "id": message_id})
|
||||
if len([m for m in canvas.messages if m["role"] == "user"]) > 1:
|
||||
ten = TenantService.get_by_user_id(current_user.id)[0]
|
||||
req["message"] = full_question(ten["tenant_id"], ten["llm_id"], canvas.messages)
|
||||
#req["message"] = full_question(ten["tenant_id"], ten["llm_id"], canvas.messages)
|
||||
canvas.add_user_input(req["message"])
|
||||
answer = canvas.run(stream=stream)
|
||||
print(canvas)
|
||||
|
Loading…
x
Reference in New Issue
Block a user