add input variables to begin component (#3498)

### What problem does this PR solve?

#3355 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu 2024-11-19 18:41:48 +08:00 committed by GitHub
parent 0cd5b64c3b
commit 361cff34fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 128 additions and 78 deletions

View File

@ -156,8 +156,12 @@ class Canvas(ABC):
self.components[k]["obj"].reset()
self._embed_id = ""
def get_compnent_name(self, cid):
for n in self.dsl["graph"]["nodes"]:
if cid == n["id"]: return n["data"]["name"]
return ""
def run(self, **kwargs):
ans = ""
if self.answer:
cpn_id = self.answer[0]
self.answer.pop(0)
@ -167,10 +171,10 @@ class Canvas(ABC):
ans = ComponentBase.be_output(str(e))
self.path[-1].append(cpn_id)
if kwargs.get("stream"):
assert isinstance(ans, partial)
return ans
self.history.append(("assistant", ans.to_dict("records")))
return ans
for an in ans():
yield an
else: yield ans
return
if not self.path:
self.components["begin"]["obj"].run(self.history, **kwargs)
@ -178,6 +182,8 @@ class Canvas(ABC):
self.path.append([])
ran = -1
waiting = []
without_dependent_checking = []
def prepare2run(cpns):
nonlocal ran, ans
@ -188,14 +194,19 @@ class Canvas(ABC):
self.answer.append(c)
else:
logging.debug(f"Canvas.prepare2run: {c}")
cpids = cpn.get_dependent_components()
if any([c not in self.path[-1] for c in cpids]):
continue
if c not in without_dependent_checking:
cpids = cpn.get_dependent_components()
if any([cc not in self.path[-1] for cc in cpids]):
if c not in waiting: waiting.append(c)
continue
yield "'{}' is running...".format(self.get_compnent_name(c))
ans = cpn.run(self.history, **kwargs)
self.path[-1].append(c)
ran += 1
prepare2run(self.components[self.path[-2][-1]]["downstream"])
for m in prepare2run(self.components[self.path[-2][-1]]["downstream"]):
yield {"content": m, "running_status": True}
while 0 <= ran < len(self.path[-1]):
logging.debug(f"Canvas.run: {ran} {self.path}")
cpn_id = self.path[-1][ran]
@ -210,28 +221,39 @@ class Canvas(ABC):
assert switch_out in self.components, \
"{}'s output: {} not valid.".format(cpn_id, switch_out)
try:
prepare2run([switch_out])
for m in prepare2run([switch_out]):
yield {"content": m, "running_status": True}
except Exception as e:
for p in [c for p in self.path for c in p][::-1]:
if p.lower().find("answer") >= 0:
self.get_component(p)["obj"].set_exception(e)
prepare2run([p])
for m in prepare2run([p]):
yield {"content": m, "running_status": True}
break
logging.exception("Canvas.run got exception")
break
continue
try:
prepare2run(cpn["downstream"])
for m in prepare2run(cpn["downstream"]):
yield {"content": m, "running_status": True}
except Exception as e:
for p in [c for p in self.path for c in p][::-1]:
if p.lower().find("answer") >= 0:
self.get_component(p)["obj"].set_exception(e)
prepare2run([p])
for m in prepare2run([p]):
yield {"content": m, "running_status": True}
break
logging.exception("Canvas.run got exception")
break
if ran >= len(self.path[-1]) and waiting:
without_dependent_checking = waiting
waiting = []
for m in prepare2run(without_dependent_checking):
yield {"content": m, "running_status": True}
ran -= 1
if self.answer:
cpn_id = self.answer[0]
self.answer.pop(0)
@ -239,11 +261,13 @@ class Canvas(ABC):
self.path[-1].append(cpn_id)
if kwargs.get("stream"):
assert isinstance(ans, partial)
return ans
for an in ans():
yield an
else:
yield ans
self.history.append(("assistant", ans.to_dict("records")))
return ans
else:
raise Exception("The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow.")
def get_component(self, cpn_id):
return self.components[cpn_id]

View File

@ -13,17 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from abc import ABC
import builtins
import json
import os
from functools import partial
from typing import Tuple, Union
import pandas as pd
from agent import settings
from agent.settings import flow_logger, DEBUG
_FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
_DEPRECATED_PARAMS = "_deprecated_params"
@ -82,7 +82,6 @@ class ComponentParamBase(ABC):
return {name: True for name in self.get_feeded_deprecated_params()}
def __str__(self):
return json.dumps(self.as_dict(), ensure_ascii=False)
def as_dict(self):
@ -398,8 +397,11 @@ class ComponentBase(ABC):
self._param.check()
def get_dependent_components(self):
cpnts = [para["component_id"] for para in self._param.query if para.get("component_id") and para["component_id"].lower().find("answer") < 0]
return cpnts
cpnts = set([para["component_id"].split("@")[0] for para in self._param.query \
if para.get("component_id") \
and para["component_id"].lower().find("answer") < 0 \
and para["component_id"].lower().find("begin") < 0])
return list(cpnts)
def run(self, history, **kwargs):
logging.debug("{}, history: {}, kwargs: {}".format(self, json.dumps(history, ensure_ascii=False),
@ -416,7 +418,7 @@ class ComponentBase(ABC):
def _run(self, history, **kwargs):
raise NotImplementedError()
def output(self, allow_partial=True) -> tuple[str, pd.DataFrame | partial]:
def output(self, allow_partial=True) -> Tuple[str, Union[pd.DataFrame, partial]]:
o = getattr(self._param, self._param.output_var_name)
if not isinstance(o, partial) and not isinstance(o, pd.DataFrame):
if not isinstance(o, list): o = [o]
@ -436,12 +438,19 @@ class ComponentBase(ABC):
def reset(self):
setattr(self._param, self._param.output_var_name, None)
self._param.inputs = []
def set_output(self, v: pd.DataFrame):
setattr(self._param, self._param.output_var_name, v)
def get_input(self):
reversed_cpnts = []
if len(self._canvas.path) > 1:
reversed_cpnts.extend(self._canvas.path[-2])
reversed_cpnts.extend(self._canvas.path[-1])
if self._param.query:
self._param.inputs = []
outs = []
for q in self._param.query:
if q["component_id"]:
@ -449,9 +458,9 @@ class ComponentBase(ABC):
cpn_id, key = q["component_id"].split("@")
for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
if p["key"] == key:
outs.append(pd.DataFrame([{"content": p["value"]}]))
outs.append(pd.DataFrame([{"content": p.get("value", "")}]))
self._param.inputs.append({"component_id": q["component_id"],
"content": p["value"]})
"content": p.get("value", "")})
break
else:
assert False, f"Can't find parameter '{key}' for {cpn_id}"
@ -470,12 +479,8 @@ class ComponentBase(ABC):
return df
upstream_outs = []
reversed_cpnts = []
if len(self._canvas.path) > 1:
reversed_cpnts.extend(self._canvas.path[-2])
reversed_cpnts.extend(self._canvas.path[-1])
logging.debug(f"{self.component_name} {reversed_cpnts[::-1]}")
if DEBUG: print(self.component_name, reversed_cpnts[::-1])
for u in reversed_cpnts[::-1]:
if self.get_component_name(u) in ["switch", "concentrator"]: continue
if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
@ -484,7 +489,7 @@ class ComponentBase(ABC):
o["component_id"] = u
upstream_outs.append(o)
continue
if self.component_name.lower()!="answer" and u not in self._canvas.get_component(self._id)["upstream"]: continue
#if self.component_name.lower()!="answer" and u not in self._canvas.get_component(self._id)["upstream"]: continue
if self.component_name.lower().find("switch") < 0 \
and self.get_component_name(u) in ["relevant", "categorize"]:
continue
@ -502,14 +507,14 @@ class ComponentBase(ABC):
upstream_outs.append(o)
break
assert upstream_outs, "Can't inference the where the component input is."
assert upstream_outs, "Can't inference the where the component input is. Please identify whose output is this component's input."
df = pd.concat(upstream_outs, ignore_index=True)
if "content" in df:
df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
self._param.inputs = []
for _,r in df.iterrows():
for _, r in df.iterrows():
self._param.inputs.append({"component_id": r["component_id"], "content": r["content"]})
return df

View File

@ -63,9 +63,11 @@ class Generate(ComponentBase):
component_name = "Generate"
def get_dependent_components(self):
cpnts = [para["component_id"] for para in self._param.parameters if
para.get("component_id") and para["component_id"].lower().find("answer") < 0]
return cpnts
cpnts = set([para["component_id"].split("@")[0] for para in self._param.parameters \
if para.get("component_id") \
and para["component_id"].lower().find("answer") < 0 \
and para["component_id"].lower().find("begin") < 0])
return list(cpnts)
def set_cite(self, retrieval_res, answer):
retrieval_res = retrieval_res.dropna(subset=["vector", "content_ltks"]).reset_index(drop=True)
@ -107,11 +109,12 @@ class Generate(ComponentBase):
self._param.inputs = []
for para in self._param.parameters:
if not para.get("component_id"): continue
if para["component_id"].split("@")[0].lower().find("begin") > 0:
component_id = para["component_id"].split("@")[0]
if para["component_id"].lower().find("@") >= 0:
cpn_id, key = para["component_id"].split("@")
for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
if p["key"] == key:
kwargs[para["key"]] = p["value"]
kwargs[para["key"]] = p.get("value", "")
self._param.inputs.append(
{"component_id": para["component_id"], "content": kwargs[para["key"]]})
break
@ -119,7 +122,7 @@ class Generate(ComponentBase):
assert False, f"Can't find parameter '{key}' for {cpn_id}"
continue
cpn = self._canvas.get_component(para["component_id"])["obj"]
cpn = self._canvas.get_component(component_id)["obj"]
if cpn.component_name.lower() == "answer":
kwargs[para["key"]] = self._canvas.get_history(1)[0]["content"]
continue
@ -129,14 +132,12 @@ class Generate(ComponentBase):
else:
if cpn.component_name.lower() == "retrieval":
retrieval_res.append(out)
kwargs[para["key"]] = " - " + "\n - ".join(
[o if isinstance(o, str) else str(o) for o in out["content"]])
kwargs[para["key"]] = " - "+"\n - ".join([o if isinstance(o, str) else str(o) for o in out["content"]])
self._param.inputs.append({"component_id": para["component_id"], "content": kwargs[para["key"]]})
if retrieval_res:
retrieval_res = pd.concat(retrieval_res, ignore_index=True)
else:
retrieval_res = pd.DataFrame([])
else: retrieval_res = pd.DataFrame([])
for n, v in kwargs.items():
prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt)
@ -158,6 +159,7 @@ class Generate(ComponentBase):
return pd.DataFrame([res])
msg = self._canvas.get_history(self._param.message_history_window_size)
if len(msg) < 1: msg.append({"role": "user", "content": ""})
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
if len(msg) < 2: msg.append({"role": "user", "content": ""})
ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf())
@ -178,6 +180,7 @@ class Generate(ComponentBase):
return
msg = self._canvas.get_history(self._param.message_history_window_size)
if len(msg) < 1: msg.append({"role": "user", "content": ""})
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
if len(msg) < 2: msg.append({"role": "user", "content": ""})
answer = ""

View File

@ -47,13 +47,35 @@ class SwitchParam(ComponentParamBase):
class Switch(ComponentBase, ABC):
component_name = "Switch"
def get_dependent_components(self):
res = []
for cond in self._param.conditions:
for item in cond["items"]:
if not item["cpn_id"]: continue
if item["cpn_id"].find("begin") >= 0:
continue
cid = item["cpn_id"].split("@")[0]
res.append(cid)
return list(set(res))
def _run(self, history, **kwargs):
for cond in self._param.conditions:
res = []
for item in cond["items"]:
out = self._canvas.get_component(item["cpn_id"])["obj"].output()[1]
cpn_input = "" if "content" not in out.columns else " ".join([str(s) for s in out["content"]])
res.append(self.process_operator(cpn_input, item["operator"], item["value"]))
if not item["cpn_id"]:continue
cid = item["cpn_id"].split("@")[0]
if item["cpn_id"].find("@") > 0:
cpn_id, key = item["cpn_id"].split("@")
for p in self._canvas.get_component(cid)["obj"]._param.query:
if p["key"] == key:
res.append(self.process_operator(p.get("value",""), item["operator"], item.get("value", "")))
break
else:
out = self._canvas.get_component(cid)["obj"].output()[1]
cpn_input = "" if "content" not in out.columns else " ".join([str(s) for s in out["content"]])
res.append(self.process_operator(cpn_input, item["operator"], item.get("value", "")))
if cond["logical_operator"] != "and" and any(res):
return Switch.be_output(cond["to"])

View File

@ -15,11 +15,12 @@
#
import logging
import json
import traceback
from functools import partial
from flask import request, Response
from flask_login import login_required, current_user
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
from api import settings
from api.settings import RetCode
from api.utils import get_uuid
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
from agent.canvas import Canvas
@ -36,8 +37,7 @@ def templates():
@login_required
def canvas_list():
return get_json_result(data=sorted([c.to_dict() for c in \
UserCanvasService.query(user_id=current_user.id)],
key=lambda x: x["update_time"] * -1)
UserCanvasService.query(user_id=current_user.id)], key=lambda x: x["update_time"]*-1)
)
@ -46,10 +46,10 @@ def canvas_list():
@login_required
def rm():
for i in request.json["canvas_ids"]:
if not UserCanvasService.query(user_id=current_user.id, id=i):
if not UserCanvasService.query(user_id=current_user.id,id=i):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
code=settings.RetCode.OPERATING_ERROR)
code=RetCode.OPERATING_ERROR)
UserCanvasService.delete_by_id(i)
return get_json_result(data=True)
@ -73,7 +73,7 @@ def save():
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
code=settings.RetCode.OPERATING_ERROR)
code=RetCode.OPERATING_ERROR)
UserCanvasService.update_by_id(req["id"], req)
return get_json_result(data=req)
@ -99,7 +99,7 @@ def run():
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
code=settings.RetCode.OPERATING_ERROR)
code=RetCode.OPERATING_ERROR)
if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
@ -110,26 +110,18 @@ def run():
canvas = Canvas(cvs.dsl, current_user.id)
if "message" in req:
canvas.messages.append({"role": "user", "content": req["message"], "id": message_id})
if len([m for m in canvas.messages if m["role"] == "user"]) > 1:
# ten = TenantService.get_info_by(current_user.id)[0]
# req["message"] = full_question(ten["tenant_id"], ten["llm_id"], canvas.messages)
pass
canvas.add_user_input(req["message"])
answer = canvas.run(stream=stream)
logging.debug(canvas)
except Exception as e:
return server_error_response(e)
assert answer is not None, "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow."
if stream:
assert isinstance(answer,
partial), "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow."
def sse():
nonlocal answer, cvs
try:
for ans in answer():
for ans in canvas.run(stream=True):
if ans.get("running_status"):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
continue
for k in ans.keys():
final_ans[k] = ans[k]
ans = {"answer": ans["content"], "reference": ans.get("reference", [])}
@ -142,6 +134,7 @@ def run():
cvs.dsl = json.loads(str(canvas))
UserCanvasService.update_by_id(req["id"], cvs.to_dict())
except Exception as e:
traceback.print_exc()
yield "data:" + json.dumps({"code": 500, "message": str(e),
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
@ -154,13 +147,15 @@ def run():
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
if final_ans.get("reference"):
canvas.reference.append(final_ans["reference"])
cvs.dsl = json.loads(str(canvas))
UserCanvasService.update_by_id(req["id"], cvs.to_dict())
return get_json_result(data={"answer": final_ans["content"], "reference": final_ans.get("reference", [])})
for answer in canvas.run(stream=False):
if answer.get("running_status"): continue
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
if final_ans.get("reference"):
canvas.reference.append(final_ans["reference"])
cvs.dsl = json.loads(str(canvas))
UserCanvasService.update_by_id(req["id"], cvs.to_dict())
return get_json_result(data={"answer": final_ans["content"], "reference": final_ans.get("reference", [])})
@manager.route('/reset', methods=['POST'])
@ -175,7 +170,7 @@ def reset():
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
code=settings.RetCode.OPERATING_ERROR)
code=RetCode.OPERATING_ERROR)
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
canvas.reset()

View File

@ -563,13 +563,13 @@ def parse():
self.filepath = filepath
def read(self):
with open(self.filepath, "r") as f:
with open(self.filepath, "rb") as f:
return f.read()
r = re.search(r"filename=\"([^\"])\"", json.dumps(res_headers))
if not r or r.group(1):
r = re.search(r"filename=\"([^\"]+)\"", str(res_headers))
if not r or not r.group(1):
return get_json_result(
data=False, message="Can't not identify downloaded file", code=RetCode.ARGUMENT_ERROR)
data=False, message="Can't not identify downloaded file", code=settings.RetCode.ARGUMENT_ERROR)
f = File(r.group(1), os.path.join(download_path, r.group(1)))
txt = FileService.parse_docs([f], current_user.id)
return get_json_result(data=txt)

View File

@ -98,7 +98,8 @@ def message_fit_in(msg, max_length=4000):
return c, msg
msg_ = [m for m in msg[:-1] if m["role"] == "system"]
msg_.append(msg[-1])
if len(msg) > 1:
msg_.append(msg[-1])
msg = msg_
c = count()
if c < max_length: