mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-06-04 11:24:00 +08:00
Inner prompt parameter setting. (#4806)
### What problem does this PR solve? #4764 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
parent
5a51bdd824
commit
f64ae9dc33
@ -69,10 +69,8 @@ class Generate(ComponentBase):
|
|||||||
component_name = "Generate"
|
component_name = "Generate"
|
||||||
|
|
||||||
def get_dependent_components(self):
|
def get_dependent_components(self):
|
||||||
cpnts = set([para["component_id"].split("@")[0] for para in self._param.parameters \
|
inputs = self.get_input_elements()
|
||||||
if para.get("component_id") \
|
cpnts = set([i["key"] for i in inputs[1:] if i["key"].lower().find("answer") < 0 and i["key"].lower().find("begin") < 0])
|
||||||
and para["component_id"].lower().find("answer") < 0 \
|
|
||||||
and para["component_id"].lower().find("begin") < 0])
|
|
||||||
return list(cpnts)
|
return list(cpnts)
|
||||||
|
|
||||||
def set_cite(self, retrieval_res, answer):
|
def set_cite(self, retrieval_res, answer):
|
||||||
@ -110,10 +108,26 @@ class Generate(ComponentBase):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
def get_input_elements(self):
|
def get_input_elements(self):
|
||||||
if self._param.parameters:
|
key_set = set([])
|
||||||
return [{"key": "user", "name": "Input your question here:"}, *self._param.parameters]
|
res = [{"key": "user", "name": "Input your question here:"}]
|
||||||
|
for r in re.finditer(r"\{([a-z]+[:@][a-z0-9_-]+)\}", self._param.prompt, flags=re.IGNORECASE):
|
||||||
return [{"key": "user", "name": "Input your question here:"}]
|
cpn_id = r.group(1)
|
||||||
|
if cpn_id in key_set:
|
||||||
|
continue
|
||||||
|
if cpn_id.lower().find("begin@") == 0:
|
||||||
|
cpn_id, key = cpn_id.split("@")
|
||||||
|
for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
|
||||||
|
if p["key"] != key:
|
||||||
|
continue
|
||||||
|
res.append({"key": r.group(1), "name": p["name"]})
|
||||||
|
key_set.add(r.group(1))
|
||||||
|
continue
|
||||||
|
cpn_nm = self._canvas.get_compnent_name(cpn_id)
|
||||||
|
if not cpn_nm:
|
||||||
|
continue
|
||||||
|
res.append({"key": cpn_id, "name": cpn_nm})
|
||||||
|
key_set.add(cpn_id)
|
||||||
|
return res
|
||||||
|
|
||||||
def _run(self, history, **kwargs):
|
def _run(self, history, **kwargs):
|
||||||
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
|
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
|
||||||
@ -121,22 +135,20 @@ class Generate(ComponentBase):
|
|||||||
|
|
||||||
retrieval_res = []
|
retrieval_res = []
|
||||||
self._param.inputs = []
|
self._param.inputs = []
|
||||||
for para in self._param.parameters:
|
for para in self.get_input_elements()[1:]:
|
||||||
if not para.get("component_id"):
|
if para["key"].lower().find("begin@") == 0:
|
||||||
continue
|
cpn_id, key = para["key"].split("@")
|
||||||
component_id = para["component_id"].split("@")[0]
|
|
||||||
if para["component_id"].lower().find("@") >= 0:
|
|
||||||
cpn_id, key = para["component_id"].split("@")
|
|
||||||
for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
|
for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
|
||||||
if p["key"] == key:
|
if p["key"] == key:
|
||||||
kwargs[para["key"]] = p.get("value", "")
|
kwargs[para["key"]] = p.get("value", "")
|
||||||
self._param.inputs.append(
|
self._param.inputs.append(
|
||||||
{"component_id": para["component_id"], "content": kwargs[para["key"]]})
|
{"component_id": para["key"], "content": kwargs[para["key"]]})
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
assert False, f"Can't find parameter '{key}' for {cpn_id}"
|
assert False, f"Can't find parameter '{key}' for {cpn_id}"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
component_id = para["key"]
|
||||||
cpn = self._canvas.get_component(component_id)["obj"]
|
cpn = self._canvas.get_component(component_id)["obj"]
|
||||||
if cpn.component_name.lower() == "answer":
|
if cpn.component_name.lower() == "answer":
|
||||||
hist = self._canvas.get_history(1)
|
hist = self._canvas.get_history(1)
|
||||||
@ -152,8 +164,8 @@ class Generate(ComponentBase):
|
|||||||
else:
|
else:
|
||||||
if cpn.component_name.lower() == "retrieval":
|
if cpn.component_name.lower() == "retrieval":
|
||||||
retrieval_res.append(out)
|
retrieval_res.append(out)
|
||||||
kwargs[para["key"]] = " - "+"\n - ".join([o if isinstance(o, str) else str(o) for o in out["content"]])
|
kwargs[para["key"]] = " - " + "\n - ".join([o if isinstance(o, str) else str(o) for o in out["content"]])
|
||||||
self._param.inputs.append({"component_id": para["component_id"], "content": kwargs[para["key"]]})
|
self._param.inputs.append({"component_id": para["key"], "content": kwargs[para["key"]]})
|
||||||
|
|
||||||
if retrieval_res:
|
if retrieval_res:
|
||||||
retrieval_res = pd.concat(retrieval_res, ignore_index=True)
|
retrieval_res = pd.concat(retrieval_res, ignore_index=True)
|
||||||
@ -175,16 +187,16 @@ class Generate(ComponentBase):
|
|||||||
return partial(self.stream_output, chat_mdl, prompt, retrieval_res)
|
return partial(self.stream_output, chat_mdl, prompt, retrieval_res)
|
||||||
|
|
||||||
if "empty_response" in retrieval_res.columns and not "".join(retrieval_res["content"]):
|
if "empty_response" in retrieval_res.columns and not "".join(retrieval_res["content"]):
|
||||||
res = {"content": "\n- ".join(retrieval_res["empty_response"]) if "\n- ".join(
|
empty_res = "\n- ".join([str(t) for t in retrieval_res["empty_response"] if str(t)])
|
||||||
retrieval_res["empty_response"]) else "Nothing found in knowledgebase!", "reference": []}
|
res = {"content": empty_res if empty_res else "Nothing found in knowledgebase!", "reference": []}
|
||||||
return pd.DataFrame([res])
|
return pd.DataFrame([res])
|
||||||
|
|
||||||
msg = self._canvas.get_history(self._param.message_history_window_size)
|
msg = self._canvas.get_history(self._param.message_history_window_size)
|
||||||
if len(msg) < 1:
|
if len(msg) < 1:
|
||||||
msg.append({"role": "user", "content": ""})
|
msg.append({"role": "user", "content": "Output: "})
|
||||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
|
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
|
||||||
if len(msg) < 2:
|
if len(msg) < 2:
|
||||||
msg.append({"role": "user", "content": ""})
|
msg.append({"role": "user", "content": "Output: "})
|
||||||
ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf())
|
ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf())
|
||||||
|
|
||||||
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
|
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
|
||||||
@ -196,18 +208,18 @@ class Generate(ComponentBase):
|
|||||||
def stream_output(self, chat_mdl, prompt, retrieval_res):
|
def stream_output(self, chat_mdl, prompt, retrieval_res):
|
||||||
res = None
|
res = None
|
||||||
if "empty_response" in retrieval_res.columns and not "".join(retrieval_res["content"]):
|
if "empty_response" in retrieval_res.columns and not "".join(retrieval_res["content"]):
|
||||||
res = {"content": "\n- ".join(retrieval_res["empty_response"]) if "\n- ".join(
|
empty_res = "\n- ".join([str(t) for t in retrieval_res["empty_response"] if str(t)])
|
||||||
retrieval_res["empty_response"]) else "Nothing found in knowledgebase!", "reference": []}
|
res = {"content": empty_res if empty_res else "Nothing found in knowledgebase!", "reference": []}
|
||||||
yield res
|
yield res
|
||||||
self.set_output(res)
|
self.set_output(res)
|
||||||
return
|
return
|
||||||
|
|
||||||
msg = self._canvas.get_history(self._param.message_history_window_size)
|
msg = self._canvas.get_history(self._param.message_history_window_size)
|
||||||
if len(msg) < 1:
|
if len(msg) < 1:
|
||||||
msg.append({"role": "user", "content": ""})
|
msg.append({"role": "user", "content": "Output: "})
|
||||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
|
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
|
||||||
if len(msg) < 2:
|
if len(msg) < 2:
|
||||||
msg.append({"role": "user", "content": ""})
|
msg.append({"role": "user", "content": "Output: "})
|
||||||
answer = ""
|
answer = ""
|
||||||
for ans in chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf()):
|
for ans in chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf()):
|
||||||
res = {"content": ans, "reference": []}
|
res = {"content": ans, "reference": []}
|
||||||
@ -230,5 +242,6 @@ class Generate(ComponentBase):
|
|||||||
for n, v in kwargs.items():
|
for n, v in kwargs.items():
|
||||||
prompt = re.sub(r"\{%s\}" % re.escape(n), str(v).replace("\\", " "), prompt)
|
prompt = re.sub(r"\{%s\}" % re.escape(n), str(v).replace("\\", " "), prompt)
|
||||||
|
|
||||||
ans = chat_mdl.chat(prompt, [{"role": "user", "content": kwargs.get("user", "")}], self._param.gen_conf())
|
u = kwargs.get("user")
|
||||||
|
ans = chat_mdl.chat(prompt, [{"role": "user", "content": u if u else "Output: "}], self._param.gen_conf())
|
||||||
return pd.DataFrame([ans])
|
return pd.DataFrame([ans])
|
||||||
|
@ -38,27 +38,39 @@ class Template(ComponentBase):
|
|||||||
component_name = "Template"
|
component_name = "Template"
|
||||||
|
|
||||||
def get_dependent_components(self):
|
def get_dependent_components(self):
|
||||||
cpnts = set(
|
inputs = self.get_input_elements()
|
||||||
[
|
cpnts = set([i["key"] for i in inputs if i["key"].lower().find("answer") < 0 and i["key"].lower().find("begin") < 0])
|
||||||
para["component_id"].split("@")[0]
|
|
||||||
for para in self._param.parameters
|
|
||||||
if para.get("component_id")
|
|
||||||
and para["component_id"].lower().find("answer") < 0
|
|
||||||
and para["component_id"].lower().find("begin") < 0
|
|
||||||
]
|
|
||||||
)
|
|
||||||
return list(cpnts)
|
return list(cpnts)
|
||||||
|
|
||||||
|
def get_input_elements(self):
|
||||||
|
key_set = set([])
|
||||||
|
res = []
|
||||||
|
for r in re.finditer(r"\{([a-z]+[:@][a-z0-9_-]+)\}", self._param.content, flags=re.IGNORECASE):
|
||||||
|
cpn_id = r.group(1)
|
||||||
|
if cpn_id in key_set:
|
||||||
|
continue
|
||||||
|
if cpn_id.lower().find("begin@") == 0:
|
||||||
|
cpn_id, key = cpn_id.split("@")
|
||||||
|
for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
|
||||||
|
if p["key"] != key:
|
||||||
|
continue
|
||||||
|
res.append({"key": r.group(1), "name": p["name"]})
|
||||||
|
key_set.add(r.group(1))
|
||||||
|
continue
|
||||||
|
cpn_nm = self._canvas.get_compnent_name(cpn_id)
|
||||||
|
if not cpn_nm:
|
||||||
|
continue
|
||||||
|
res.append({"key": cpn_id, "name": cpn_nm})
|
||||||
|
key_set.add(cpn_id)
|
||||||
|
return res
|
||||||
|
|
||||||
def _run(self, history, **kwargs):
|
def _run(self, history, **kwargs):
|
||||||
content = self._param.content
|
content = self._param.content
|
||||||
|
|
||||||
self._param.inputs = []
|
self._param.inputs = []
|
||||||
for para in self._param.parameters:
|
for para in self.get_input_elements():
|
||||||
if not para.get("component_id"):
|
if para["key"].lower().find("begin@") == 0:
|
||||||
continue
|
cpn_id, key = para["key"].split("@")
|
||||||
component_id = para["component_id"].split("@")[0]
|
|
||||||
if para["component_id"].lower().find("@") >= 0:
|
|
||||||
cpn_id, key = para["component_id"].split("@")
|
|
||||||
for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
|
for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
|
||||||
if p["key"] == key:
|
if p["key"] == key:
|
||||||
value = p.get("value", "")
|
value = p.get("value", "")
|
||||||
@ -68,6 +80,7 @@ class Template(ComponentBase):
|
|||||||
assert False, f"Can't find parameter '{key}' for {cpn_id}"
|
assert False, f"Can't find parameter '{key}' for {cpn_id}"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
component_id = para["key"]
|
||||||
cpn = self._canvas.get_component(component_id)["obj"]
|
cpn = self._canvas.get_component(component_id)["obj"]
|
||||||
if cpn.component_name.lower() == "answer":
|
if cpn.component_name.lower() == "answer":
|
||||||
hist = self._canvas.get_history(1)
|
hist = self._canvas.get_history(1)
|
||||||
@ -114,7 +127,7 @@ class Template(ComponentBase):
|
|||||||
|
|
||||||
def make_kwargs(self, para, kwargs, value):
|
def make_kwargs(self, para, kwargs, value):
|
||||||
self._param.inputs.append(
|
self._param.inputs.append(
|
||||||
{"component_id": para["component_id"], "content": value}
|
{"component_id": para["key"], "content": value}
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
value = json.loads(value)
|
value = json.loads(value)
|
||||||
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Loading…
x
Reference in New Issue
Block a user