Inner prompt parameter setting. (#4806)

### What problem does this PR solve?

#4764

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu 2025-02-08 18:09:02 +08:00 committed by GitHub
parent 5a51bdd824
commit f64ae9dc33
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 10744 additions and 9935 deletions

View File

@ -69,10 +69,8 @@ class Generate(ComponentBase):
component_name = "Generate" component_name = "Generate"
def get_dependent_components(self): def get_dependent_components(self):
cpnts = set([para["component_id"].split("@")[0] for para in self._param.parameters \ inputs = self.get_input_elements()
if para.get("component_id") \ cpnts = set([i["key"] for i in inputs[1:] if i["key"].lower().find("answer") < 0 and i["key"].lower().find("begin") < 0])
and para["component_id"].lower().find("answer") < 0 \
and para["component_id"].lower().find("begin") < 0])
return list(cpnts) return list(cpnts)
def set_cite(self, retrieval_res, answer): def set_cite(self, retrieval_res, answer):
@ -110,10 +108,26 @@ class Generate(ComponentBase):
return res return res
def get_input_elements(self): def get_input_elements(self):
if self._param.parameters: key_set = set([])
return [{"key": "user", "name": "Input your question here:"}, *self._param.parameters] res = [{"key": "user", "name": "Input your question here:"}]
for r in re.finditer(r"\{([a-z]+[:@][a-z0-9_-]+)\}", self._param.prompt, flags=re.IGNORECASE):
return [{"key": "user", "name": "Input your question here:"}] cpn_id = r.group(1)
if cpn_id in key_set:
continue
if cpn_id.lower().find("begin@") == 0:
cpn_id, key = cpn_id.split("@")
for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
if p["key"] != key:
continue
res.append({"key": r.group(1), "name": p["name"]})
key_set.add(r.group(1))
continue
cpn_nm = self._canvas.get_compnent_name(cpn_id)
if not cpn_nm:
continue
res.append({"key": cpn_id, "name": cpn_nm})
key_set.add(cpn_id)
return res
def _run(self, history, **kwargs): def _run(self, history, **kwargs):
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id) chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
@ -121,22 +135,20 @@ class Generate(ComponentBase):
retrieval_res = [] retrieval_res = []
self._param.inputs = [] self._param.inputs = []
for para in self._param.parameters: for para in self.get_input_elements()[1:]:
if not para.get("component_id"): if para["key"].lower().find("begin@") == 0:
continue cpn_id, key = para["key"].split("@")
component_id = para["component_id"].split("@")[0]
if para["component_id"].lower().find("@") >= 0:
cpn_id, key = para["component_id"].split("@")
for p in self._canvas.get_component(cpn_id)["obj"]._param.query: for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
if p["key"] == key: if p["key"] == key:
kwargs[para["key"]] = p.get("value", "") kwargs[para["key"]] = p.get("value", "")
self._param.inputs.append( self._param.inputs.append(
{"component_id": para["component_id"], "content": kwargs[para["key"]]}) {"component_id": para["key"], "content": kwargs[para["key"]]})
break break
else: else:
assert False, f"Can't find parameter '{key}' for {cpn_id}" assert False, f"Can't find parameter '{key}' for {cpn_id}"
continue continue
component_id = para["key"]
cpn = self._canvas.get_component(component_id)["obj"] cpn = self._canvas.get_component(component_id)["obj"]
if cpn.component_name.lower() == "answer": if cpn.component_name.lower() == "answer":
hist = self._canvas.get_history(1) hist = self._canvas.get_history(1)
@ -152,8 +164,8 @@ class Generate(ComponentBase):
else: else:
if cpn.component_name.lower() == "retrieval": if cpn.component_name.lower() == "retrieval":
retrieval_res.append(out) retrieval_res.append(out)
kwargs[para["key"]] = " - "+"\n - ".join([o if isinstance(o, str) else str(o) for o in out["content"]]) kwargs[para["key"]] = " - " + "\n - ".join([o if isinstance(o, str) else str(o) for o in out["content"]])
self._param.inputs.append({"component_id": para["component_id"], "content": kwargs[para["key"]]}) self._param.inputs.append({"component_id": para["key"], "content": kwargs[para["key"]]})
if retrieval_res: if retrieval_res:
retrieval_res = pd.concat(retrieval_res, ignore_index=True) retrieval_res = pd.concat(retrieval_res, ignore_index=True)
@ -175,16 +187,16 @@ class Generate(ComponentBase):
return partial(self.stream_output, chat_mdl, prompt, retrieval_res) return partial(self.stream_output, chat_mdl, prompt, retrieval_res)
if "empty_response" in retrieval_res.columns and not "".join(retrieval_res["content"]): if "empty_response" in retrieval_res.columns and not "".join(retrieval_res["content"]):
res = {"content": "\n- ".join(retrieval_res["empty_response"]) if "\n- ".join( empty_res = "\n- ".join([str(t) for t in retrieval_res["empty_response"] if str(t)])
retrieval_res["empty_response"]) else "Nothing found in knowledgebase!", "reference": []} res = {"content": empty_res if empty_res else "Nothing found in knowledgebase!", "reference": []}
return pd.DataFrame([res]) return pd.DataFrame([res])
msg = self._canvas.get_history(self._param.message_history_window_size) msg = self._canvas.get_history(self._param.message_history_window_size)
if len(msg) < 1: if len(msg) < 1:
msg.append({"role": "user", "content": ""}) msg.append({"role": "user", "content": "Output: "})
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97)) _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
if len(msg) < 2: if len(msg) < 2:
msg.append({"role": "user", "content": ""}) msg.append({"role": "user", "content": "Output: "})
ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf()) ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf())
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns: if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
@ -196,18 +208,18 @@ class Generate(ComponentBase):
def stream_output(self, chat_mdl, prompt, retrieval_res): def stream_output(self, chat_mdl, prompt, retrieval_res):
res = None res = None
if "empty_response" in retrieval_res.columns and not "".join(retrieval_res["content"]): if "empty_response" in retrieval_res.columns and not "".join(retrieval_res["content"]):
res = {"content": "\n- ".join(retrieval_res["empty_response"]) if "\n- ".join( empty_res = "\n- ".join([str(t) for t in retrieval_res["empty_response"] if str(t)])
retrieval_res["empty_response"]) else "Nothing found in knowledgebase!", "reference": []} res = {"content": empty_res if empty_res else "Nothing found in knowledgebase!", "reference": []}
yield res yield res
self.set_output(res) self.set_output(res)
return return
msg = self._canvas.get_history(self._param.message_history_window_size) msg = self._canvas.get_history(self._param.message_history_window_size)
if len(msg) < 1: if len(msg) < 1:
msg.append({"role": "user", "content": ""}) msg.append({"role": "user", "content": "Output: "})
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97)) _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
if len(msg) < 2: if len(msg) < 2:
msg.append({"role": "user", "content": ""}) msg.append({"role": "user", "content": "Output: "})
answer = "" answer = ""
for ans in chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf()): for ans in chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf()):
res = {"content": ans, "reference": []} res = {"content": ans, "reference": []}
@ -230,5 +242,6 @@ class Generate(ComponentBase):
for n, v in kwargs.items(): for n, v in kwargs.items():
prompt = re.sub(r"\{%s\}" % re.escape(n), str(v).replace("\\", " "), prompt) prompt = re.sub(r"\{%s\}" % re.escape(n), str(v).replace("\\", " "), prompt)
ans = chat_mdl.chat(prompt, [{"role": "user", "content": kwargs.get("user", "")}], self._param.gen_conf()) u = kwargs.get("user")
ans = chat_mdl.chat(prompt, [{"role": "user", "content": u if u else "Output: "}], self._param.gen_conf())
return pd.DataFrame([ans]) return pd.DataFrame([ans])

View File

@ -38,27 +38,39 @@ class Template(ComponentBase):
component_name = "Template" component_name = "Template"
def get_dependent_components(self): def get_dependent_components(self):
cpnts = set( inputs = self.get_input_elements()
[ cpnts = set([i["key"] for i in inputs if i["key"].lower().find("answer") < 0 and i["key"].lower().find("begin") < 0])
para["component_id"].split("@")[0]
for para in self._param.parameters
if para.get("component_id")
and para["component_id"].lower().find("answer") < 0
and para["component_id"].lower().find("begin") < 0
]
)
return list(cpnts) return list(cpnts)
def get_input_elements(self):
key_set = set([])
res = []
for r in re.finditer(r"\{([a-z]+[:@][a-z0-9_-]+)\}", self._param.content, flags=re.IGNORECASE):
cpn_id = r.group(1)
if cpn_id in key_set:
continue
if cpn_id.lower().find("begin@") == 0:
cpn_id, key = cpn_id.split("@")
for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
if p["key"] != key:
continue
res.append({"key": r.group(1), "name": p["name"]})
key_set.add(r.group(1))
continue
cpn_nm = self._canvas.get_compnent_name(cpn_id)
if not cpn_nm:
continue
res.append({"key": cpn_id, "name": cpn_nm})
key_set.add(cpn_id)
return res
def _run(self, history, **kwargs): def _run(self, history, **kwargs):
content = self._param.content content = self._param.content
self._param.inputs = [] self._param.inputs = []
for para in self._param.parameters: for para in self.get_input_elements():
if not para.get("component_id"): if para["key"].lower().find("begin@") == 0:
continue cpn_id, key = para["key"].split("@")
component_id = para["component_id"].split("@")[0]
if para["component_id"].lower().find("@") >= 0:
cpn_id, key = para["component_id"].split("@")
for p in self._canvas.get_component(cpn_id)["obj"]._param.query: for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
if p["key"] == key: if p["key"] == key:
value = p.get("value", "") value = p.get("value", "")
@ -68,6 +80,7 @@ class Template(ComponentBase):
assert False, f"Can't find parameter '{key}' for {cpn_id}" assert False, f"Can't find parameter '{key}' for {cpn_id}"
continue continue
component_id = para["key"]
cpn = self._canvas.get_component(component_id)["obj"] cpn = self._canvas.get_component(component_id)["obj"]
if cpn.component_name.lower() == "answer": if cpn.component_name.lower() == "answer":
hist = self._canvas.get_history(1) hist = self._canvas.get_history(1)
@ -114,7 +127,7 @@ class Template(ComponentBase):
def make_kwargs(self, para, kwargs, value): def make_kwargs(self, para, kwargs, value):
self._param.inputs.append( self._param.inputs.append(
{"component_id": para["component_id"], "content": value} {"component_id": para["key"], "content": value}
) )
try: try:
value = json.loads(value) value = json.loads(value)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long