add input variables to begin component (#3498)

### What problem does this PR solve?

#3355 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu 2024-11-19 18:41:48 +08:00 committed by GitHub
parent 0cd5b64c3b
commit 361cff34fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 128 additions and 78 deletions

View File

@ -156,8 +156,12 @@ class Canvas(ABC):
self.components[k]["obj"].reset() self.components[k]["obj"].reset()
self._embed_id = "" self._embed_id = ""
def get_compnent_name(self, cid):
for n in self.dsl["graph"]["nodes"]:
if cid == n["id"]: return n["data"]["name"]
return ""
def run(self, **kwargs): def run(self, **kwargs):
ans = ""
if self.answer: if self.answer:
cpn_id = self.answer[0] cpn_id = self.answer[0]
self.answer.pop(0) self.answer.pop(0)
@ -167,10 +171,10 @@ class Canvas(ABC):
ans = ComponentBase.be_output(str(e)) ans = ComponentBase.be_output(str(e))
self.path[-1].append(cpn_id) self.path[-1].append(cpn_id)
if kwargs.get("stream"): if kwargs.get("stream"):
assert isinstance(ans, partial) for an in ans():
return ans yield an
self.history.append(("assistant", ans.to_dict("records"))) else: yield ans
return ans return
if not self.path: if not self.path:
self.components["begin"]["obj"].run(self.history, **kwargs) self.components["begin"]["obj"].run(self.history, **kwargs)
@ -178,6 +182,8 @@ class Canvas(ABC):
self.path.append([]) self.path.append([])
ran = -1 ran = -1
waiting = []
without_dependent_checking = []
def prepare2run(cpns): def prepare2run(cpns):
nonlocal ran, ans nonlocal ran, ans
@ -188,14 +194,19 @@ class Canvas(ABC):
self.answer.append(c) self.answer.append(c)
else: else:
logging.debug(f"Canvas.prepare2run: {c}") logging.debug(f"Canvas.prepare2run: {c}")
cpids = cpn.get_dependent_components() if c not in without_dependent_checking:
if any([c not in self.path[-1] for c in cpids]): cpids = cpn.get_dependent_components()
continue if any([cc not in self.path[-1] for cc in cpids]):
if c not in waiting: waiting.append(c)
continue
yield "'{}' is running...".format(self.get_compnent_name(c))
ans = cpn.run(self.history, **kwargs) ans = cpn.run(self.history, **kwargs)
self.path[-1].append(c) self.path[-1].append(c)
ran += 1 ran += 1
prepare2run(self.components[self.path[-2][-1]]["downstream"]) for m in prepare2run(self.components[self.path[-2][-1]]["downstream"]):
yield {"content": m, "running_status": True}
while 0 <= ran < len(self.path[-1]): while 0 <= ran < len(self.path[-1]):
logging.debug(f"Canvas.run: {ran} {self.path}") logging.debug(f"Canvas.run: {ran} {self.path}")
cpn_id = self.path[-1][ran] cpn_id = self.path[-1][ran]
@ -210,28 +221,39 @@ class Canvas(ABC):
assert switch_out in self.components, \ assert switch_out in self.components, \
"{}'s output: {} not valid.".format(cpn_id, switch_out) "{}'s output: {} not valid.".format(cpn_id, switch_out)
try: try:
prepare2run([switch_out]) for m in prepare2run([switch_out]):
yield {"content": m, "running_status": True}
except Exception as e: except Exception as e:
for p in [c for p in self.path for c in p][::-1]: for p in [c for p in self.path for c in p][::-1]:
if p.lower().find("answer") >= 0: if p.lower().find("answer") >= 0:
self.get_component(p)["obj"].set_exception(e) self.get_component(p)["obj"].set_exception(e)
prepare2run([p]) for m in prepare2run([p]):
yield {"content": m, "running_status": True}
break break
logging.exception("Canvas.run got exception") logging.exception("Canvas.run got exception")
break break
continue continue
try: try:
prepare2run(cpn["downstream"]) for m in prepare2run(cpn["downstream"]):
yield {"content": m, "running_status": True}
except Exception as e: except Exception as e:
for p in [c for p in self.path for c in p][::-1]: for p in [c for p in self.path for c in p][::-1]:
if p.lower().find("answer") >= 0: if p.lower().find("answer") >= 0:
self.get_component(p)["obj"].set_exception(e) self.get_component(p)["obj"].set_exception(e)
prepare2run([p]) for m in prepare2run([p]):
yield {"content": m, "running_status": True}
break break
logging.exception("Canvas.run got exception") logging.exception("Canvas.run got exception")
break break
if ran >= len(self.path[-1]) and waiting:
without_dependent_checking = waiting
waiting = []
for m in prepare2run(without_dependent_checking):
yield {"content": m, "running_status": True}
ran -= 1
if self.answer: if self.answer:
cpn_id = self.answer[0] cpn_id = self.answer[0]
self.answer.pop(0) self.answer.pop(0)
@ -239,11 +261,13 @@ class Canvas(ABC):
self.path[-1].append(cpn_id) self.path[-1].append(cpn_id)
if kwargs.get("stream"): if kwargs.get("stream"):
assert isinstance(ans, partial) assert isinstance(ans, partial)
return ans for an in ans():
yield an
else:
yield ans
self.history.append(("assistant", ans.to_dict("records"))) else:
raise Exception("The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow.")
return ans
def get_component(self, cpn_id): def get_component(self, cpn_id):
return self.components[cpn_id] return self.components[cpn_id]

View File

@ -13,17 +13,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import logging
from abc import ABC from abc import ABC
import builtins import builtins
import json import json
import os import os
from functools import partial from functools import partial
from typing import Tuple, Union
import pandas as pd import pandas as pd
from agent import settings from agent import settings
from agent.settings import flow_logger, DEBUG
_FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params" _FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
_DEPRECATED_PARAMS = "_deprecated_params" _DEPRECATED_PARAMS = "_deprecated_params"
@ -82,7 +82,6 @@ class ComponentParamBase(ABC):
return {name: True for name in self.get_feeded_deprecated_params()} return {name: True for name in self.get_feeded_deprecated_params()}
def __str__(self): def __str__(self):
return json.dumps(self.as_dict(), ensure_ascii=False) return json.dumps(self.as_dict(), ensure_ascii=False)
def as_dict(self): def as_dict(self):
@ -398,8 +397,11 @@ class ComponentBase(ABC):
self._param.check() self._param.check()
def get_dependent_components(self): def get_dependent_components(self):
cpnts = [para["component_id"] for para in self._param.query if para.get("component_id") and para["component_id"].lower().find("answer") < 0] cpnts = set([para["component_id"].split("@")[0] for para in self._param.query \
return cpnts if para.get("component_id") \
and para["component_id"].lower().find("answer") < 0 \
and para["component_id"].lower().find("begin") < 0])
return list(cpnts)
def run(self, history, **kwargs): def run(self, history, **kwargs):
logging.debug("{}, history: {}, kwargs: {}".format(self, json.dumps(history, ensure_ascii=False), logging.debug("{}, history: {}, kwargs: {}".format(self, json.dumps(history, ensure_ascii=False),
@ -416,7 +418,7 @@ class ComponentBase(ABC):
def _run(self, history, **kwargs): def _run(self, history, **kwargs):
raise NotImplementedError() raise NotImplementedError()
def output(self, allow_partial=True) -> tuple[str, pd.DataFrame | partial]: def output(self, allow_partial=True) -> Tuple[str, Union[pd.DataFrame, partial]]:
o = getattr(self._param, self._param.output_var_name) o = getattr(self._param, self._param.output_var_name)
if not isinstance(o, partial) and not isinstance(o, pd.DataFrame): if not isinstance(o, partial) and not isinstance(o, pd.DataFrame):
if not isinstance(o, list): o = [o] if not isinstance(o, list): o = [o]
@ -436,12 +438,19 @@ class ComponentBase(ABC):
def reset(self): def reset(self):
setattr(self._param, self._param.output_var_name, None) setattr(self._param, self._param.output_var_name, None)
self._param.inputs = []
def set_output(self, v: pd.DataFrame): def set_output(self, v: pd.DataFrame):
setattr(self._param, self._param.output_var_name, v) setattr(self._param, self._param.output_var_name, v)
def get_input(self): def get_input(self):
reversed_cpnts = []
if len(self._canvas.path) > 1:
reversed_cpnts.extend(self._canvas.path[-2])
reversed_cpnts.extend(self._canvas.path[-1])
if self._param.query: if self._param.query:
self._param.inputs = []
outs = [] outs = []
for q in self._param.query: for q in self._param.query:
if q["component_id"]: if q["component_id"]:
@ -449,9 +458,9 @@ class ComponentBase(ABC):
cpn_id, key = q["component_id"].split("@") cpn_id, key = q["component_id"].split("@")
for p in self._canvas.get_component(cpn_id)["obj"]._param.query: for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
if p["key"] == key: if p["key"] == key:
outs.append(pd.DataFrame([{"content": p["value"]}])) outs.append(pd.DataFrame([{"content": p.get("value", "")}]))
self._param.inputs.append({"component_id": q["component_id"], self._param.inputs.append({"component_id": q["component_id"],
"content": p["value"]}) "content": p.get("value", "")})
break break
else: else:
assert False, f"Can't find parameter '{key}' for {cpn_id}" assert False, f"Can't find parameter '{key}' for {cpn_id}"
@ -470,12 +479,8 @@ class ComponentBase(ABC):
return df return df
upstream_outs = [] upstream_outs = []
reversed_cpnts = []
if len(self._canvas.path) > 1:
reversed_cpnts.extend(self._canvas.path[-2])
reversed_cpnts.extend(self._canvas.path[-1])
logging.debug(f"{self.component_name} {reversed_cpnts[::-1]}") if DEBUG: print(self.component_name, reversed_cpnts[::-1])
for u in reversed_cpnts[::-1]: for u in reversed_cpnts[::-1]:
if self.get_component_name(u) in ["switch", "concentrator"]: continue if self.get_component_name(u) in ["switch", "concentrator"]: continue
if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval": if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
@ -484,7 +489,7 @@ class ComponentBase(ABC):
o["component_id"] = u o["component_id"] = u
upstream_outs.append(o) upstream_outs.append(o)
continue continue
if self.component_name.lower()!="answer" and u not in self._canvas.get_component(self._id)["upstream"]: continue #if self.component_name.lower()!="answer" and u not in self._canvas.get_component(self._id)["upstream"]: continue
if self.component_name.lower().find("switch") < 0 \ if self.component_name.lower().find("switch") < 0 \
and self.get_component_name(u) in ["relevant", "categorize"]: and self.get_component_name(u) in ["relevant", "categorize"]:
continue continue
@ -502,14 +507,14 @@ class ComponentBase(ABC):
upstream_outs.append(o) upstream_outs.append(o)
break break
assert upstream_outs, "Can't inference the where the component input is." assert upstream_outs, "Can't inference the where the component input is. Please identify whose output is this component's input."
df = pd.concat(upstream_outs, ignore_index=True) df = pd.concat(upstream_outs, ignore_index=True)
if "content" in df: if "content" in df:
df = df.drop_duplicates(subset=['content']).reset_index(drop=True) df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
self._param.inputs = [] self._param.inputs = []
for _,r in df.iterrows(): for _, r in df.iterrows():
self._param.inputs.append({"component_id": r["component_id"], "content": r["content"]}) self._param.inputs.append({"component_id": r["component_id"], "content": r["content"]})
return df return df

View File

@ -63,9 +63,11 @@ class Generate(ComponentBase):
component_name = "Generate" component_name = "Generate"
def get_dependent_components(self): def get_dependent_components(self):
cpnts = [para["component_id"] for para in self._param.parameters if cpnts = set([para["component_id"].split("@")[0] for para in self._param.parameters \
para.get("component_id") and para["component_id"].lower().find("answer") < 0] if para.get("component_id") \
return cpnts and para["component_id"].lower().find("answer") < 0 \
and para["component_id"].lower().find("begin") < 0])
return list(cpnts)
def set_cite(self, retrieval_res, answer): def set_cite(self, retrieval_res, answer):
retrieval_res = retrieval_res.dropna(subset=["vector", "content_ltks"]).reset_index(drop=True) retrieval_res = retrieval_res.dropna(subset=["vector", "content_ltks"]).reset_index(drop=True)
@ -107,11 +109,12 @@ class Generate(ComponentBase):
self._param.inputs = [] self._param.inputs = []
for para in self._param.parameters: for para in self._param.parameters:
if not para.get("component_id"): continue if not para.get("component_id"): continue
if para["component_id"].split("@")[0].lower().find("begin") > 0: component_id = para["component_id"].split("@")[0]
if para["component_id"].lower().find("@") >= 0:
cpn_id, key = para["component_id"].split("@") cpn_id, key = para["component_id"].split("@")
for p in self._canvas.get_component(cpn_id)["obj"]._param.query: for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
if p["key"] == key: if p["key"] == key:
kwargs[para["key"]] = p["value"] kwargs[para["key"]] = p.get("value", "")
self._param.inputs.append( self._param.inputs.append(
{"component_id": para["component_id"], "content": kwargs[para["key"]]}) {"component_id": para["component_id"], "content": kwargs[para["key"]]})
break break
@ -119,7 +122,7 @@ class Generate(ComponentBase):
assert False, f"Can't find parameter '{key}' for {cpn_id}" assert False, f"Can't find parameter '{key}' for {cpn_id}"
continue continue
cpn = self._canvas.get_component(para["component_id"])["obj"] cpn = self._canvas.get_component(component_id)["obj"]
if cpn.component_name.lower() == "answer": if cpn.component_name.lower() == "answer":
kwargs[para["key"]] = self._canvas.get_history(1)[0]["content"] kwargs[para["key"]] = self._canvas.get_history(1)[0]["content"]
continue continue
@ -129,14 +132,12 @@ class Generate(ComponentBase):
else: else:
if cpn.component_name.lower() == "retrieval": if cpn.component_name.lower() == "retrieval":
retrieval_res.append(out) retrieval_res.append(out)
kwargs[para["key"]] = " - " + "\n - ".join( kwargs[para["key"]] = " - "+"\n - ".join([o if isinstance(o, str) else str(o) for o in out["content"]])
[o if isinstance(o, str) else str(o) for o in out["content"]])
self._param.inputs.append({"component_id": para["component_id"], "content": kwargs[para["key"]]}) self._param.inputs.append({"component_id": para["component_id"], "content": kwargs[para["key"]]})
if retrieval_res: if retrieval_res:
retrieval_res = pd.concat(retrieval_res, ignore_index=True) retrieval_res = pd.concat(retrieval_res, ignore_index=True)
else: else: retrieval_res = pd.DataFrame([])
retrieval_res = pd.DataFrame([])
for n, v in kwargs.items(): for n, v in kwargs.items():
prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt) prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt)
@ -158,6 +159,7 @@ class Generate(ComponentBase):
return pd.DataFrame([res]) return pd.DataFrame([res])
msg = self._canvas.get_history(self._param.message_history_window_size) msg = self._canvas.get_history(self._param.message_history_window_size)
if len(msg) < 1: msg.append({"role": "user", "content": ""})
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97)) _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
if len(msg) < 2: msg.append({"role": "user", "content": ""}) if len(msg) < 2: msg.append({"role": "user", "content": ""})
ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf()) ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf())
@ -178,6 +180,7 @@ class Generate(ComponentBase):
return return
msg = self._canvas.get_history(self._param.message_history_window_size) msg = self._canvas.get_history(self._param.message_history_window_size)
if len(msg) < 1: msg.append({"role": "user", "content": ""})
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97)) _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
if len(msg) < 2: msg.append({"role": "user", "content": ""}) if len(msg) < 2: msg.append({"role": "user", "content": ""})
answer = "" answer = ""

View File

@ -47,13 +47,35 @@ class SwitchParam(ComponentParamBase):
class Switch(ComponentBase, ABC): class Switch(ComponentBase, ABC):
component_name = "Switch" component_name = "Switch"
def get_dependent_components(self):
res = []
for cond in self._param.conditions:
for item in cond["items"]:
if not item["cpn_id"]: continue
if item["cpn_id"].find("begin") >= 0:
continue
cid = item["cpn_id"].split("@")[0]
res.append(cid)
return list(set(res))
def _run(self, history, **kwargs): def _run(self, history, **kwargs):
for cond in self._param.conditions: for cond in self._param.conditions:
res = [] res = []
for item in cond["items"]: for item in cond["items"]:
out = self._canvas.get_component(item["cpn_id"])["obj"].output()[1] if not item["cpn_id"]:continue
cpn_input = "" if "content" not in out.columns else " ".join([str(s) for s in out["content"]]) cid = item["cpn_id"].split("@")[0]
res.append(self.process_operator(cpn_input, item["operator"], item["value"])) if item["cpn_id"].find("@") > 0:
cpn_id, key = item["cpn_id"].split("@")
for p in self._canvas.get_component(cid)["obj"]._param.query:
if p["key"] == key:
res.append(self.process_operator(p.get("value",""), item["operator"], item.get("value", "")))
break
else:
out = self._canvas.get_component(cid)["obj"].output()[1]
cpn_input = "" if "content" not in out.columns else " ".join([str(s) for s in out["content"]])
res.append(self.process_operator(cpn_input, item["operator"], item.get("value", "")))
if cond["logical_operator"] != "and" and any(res): if cond["logical_operator"] != "and" and any(res):
return Switch.be_output(cond["to"]) return Switch.be_output(cond["to"])

View File

@ -15,11 +15,12 @@
# #
import logging import logging
import json import json
import traceback
from functools import partial from functools import partial
from flask import request, Response from flask import request, Response
from flask_login import login_required, current_user from flask_login import login_required, current_user
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
from api import settings from api.settings import RetCode
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
from agent.canvas import Canvas from agent.canvas import Canvas
@ -36,8 +37,7 @@ def templates():
@login_required @login_required
def canvas_list(): def canvas_list():
return get_json_result(data=sorted([c.to_dict() for c in \ return get_json_result(data=sorted([c.to_dict() for c in \
UserCanvasService.query(user_id=current_user.id)], UserCanvasService.query(user_id=current_user.id)], key=lambda x: x["update_time"]*-1)
key=lambda x: x["update_time"] * -1)
) )
@ -46,10 +46,10 @@ def canvas_list():
@login_required @login_required
def rm(): def rm():
for i in request.json["canvas_ids"]: for i in request.json["canvas_ids"]:
if not UserCanvasService.query(user_id=current_user.id, id=i): if not UserCanvasService.query(user_id=current_user.id,id=i):
return get_json_result( return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.', data=False, message='Only owner of canvas authorized for this operation.',
code=settings.RetCode.OPERATING_ERROR) code=RetCode.OPERATING_ERROR)
UserCanvasService.delete_by_id(i) UserCanvasService.delete_by_id(i)
return get_json_result(data=True) return get_json_result(data=True)
@ -73,7 +73,7 @@ def save():
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]): if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
return get_json_result( return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.', data=False, message='Only owner of canvas authorized for this operation.',
code=settings.RetCode.OPERATING_ERROR) code=RetCode.OPERATING_ERROR)
UserCanvasService.update_by_id(req["id"], req) UserCanvasService.update_by_id(req["id"], req)
return get_json_result(data=req) return get_json_result(data=req)
@ -99,7 +99,7 @@ def run():
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]): if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
return get_json_result( return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.', data=False, message='Only owner of canvas authorized for this operation.',
code=settings.RetCode.OPERATING_ERROR) code=RetCode.OPERATING_ERROR)
if not isinstance(cvs.dsl, str): if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
@ -110,26 +110,18 @@ def run():
canvas = Canvas(cvs.dsl, current_user.id) canvas = Canvas(cvs.dsl, current_user.id)
if "message" in req: if "message" in req:
canvas.messages.append({"role": "user", "content": req["message"], "id": message_id}) canvas.messages.append({"role": "user", "content": req["message"], "id": message_id})
if len([m for m in canvas.messages if m["role"] == "user"]) > 1:
# ten = TenantService.get_info_by(current_user.id)[0]
# req["message"] = full_question(ten["tenant_id"], ten["llm_id"], canvas.messages)
pass
canvas.add_user_input(req["message"]) canvas.add_user_input(req["message"])
answer = canvas.run(stream=stream)
logging.debug(canvas)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
assert answer is not None, "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow."
if stream: if stream:
assert isinstance(answer,
partial), "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow."
def sse(): def sse():
nonlocal answer, cvs nonlocal answer, cvs
try: try:
for ans in answer(): for ans in canvas.run(stream=True):
if ans.get("running_status"):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
continue
for k in ans.keys(): for k in ans.keys():
final_ans[k] = ans[k] final_ans[k] = ans[k]
ans = {"answer": ans["content"], "reference": ans.get("reference", [])} ans = {"answer": ans["content"], "reference": ans.get("reference", [])}
@ -142,6 +134,7 @@ def run():
cvs.dsl = json.loads(str(canvas)) cvs.dsl = json.loads(str(canvas))
UserCanvasService.update_by_id(req["id"], cvs.to_dict()) UserCanvasService.update_by_id(req["id"], cvs.to_dict())
except Exception as e: except Exception as e:
traceback.print_exc()
yield "data:" + json.dumps({"code": 500, "message": str(e), yield "data:" + json.dumps({"code": 500, "message": str(e),
"data": {"answer": "**ERROR**: " + str(e), "reference": []}}, "data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n" ensure_ascii=False) + "\n\n"
@ -154,13 +147,15 @@ def run():
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp return resp
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else "" for answer in canvas.run(stream=False):
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id}) if answer.get("running_status"): continue
if final_ans.get("reference"): final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
canvas.reference.append(final_ans["reference"]) canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
cvs.dsl = json.loads(str(canvas)) if final_ans.get("reference"):
UserCanvasService.update_by_id(req["id"], cvs.to_dict()) canvas.reference.append(final_ans["reference"])
return get_json_result(data={"answer": final_ans["content"], "reference": final_ans.get("reference", [])}) cvs.dsl = json.loads(str(canvas))
UserCanvasService.update_by_id(req["id"], cvs.to_dict())
return get_json_result(data={"answer": final_ans["content"], "reference": final_ans.get("reference", [])})
@manager.route('/reset', methods=['POST']) @manager.route('/reset', methods=['POST'])
@ -175,7 +170,7 @@ def reset():
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]): if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
return get_json_result( return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.', data=False, message='Only owner of canvas authorized for this operation.',
code=settings.RetCode.OPERATING_ERROR) code=RetCode.OPERATING_ERROR)
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id) canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
canvas.reset() canvas.reset()

View File

@ -563,13 +563,13 @@ def parse():
self.filepath = filepath self.filepath = filepath
def read(self): def read(self):
with open(self.filepath, "r") as f: with open(self.filepath, "rb") as f:
return f.read() return f.read()
r = re.search(r"filename=\"([^\"])\"", json.dumps(res_headers)) r = re.search(r"filename=\"([^\"]+)\"", str(res_headers))
if not r or r.group(1): if not r or not r.group(1):
return get_json_result( return get_json_result(
data=False, message="Can't not identify downloaded file", code=RetCode.ARGUMENT_ERROR) data=False, message="Can't not identify downloaded file", code=settings.RetCode.ARGUMENT_ERROR)
f = File(r.group(1), os.path.join(download_path, r.group(1))) f = File(r.group(1), os.path.join(download_path, r.group(1)))
txt = FileService.parse_docs([f], current_user.id) txt = FileService.parse_docs([f], current_user.id)
return get_json_result(data=txt) return get_json_result(data=txt)

View File

@ -98,7 +98,8 @@ def message_fit_in(msg, max_length=4000):
return c, msg return c, msg
msg_ = [m for m in msg[:-1] if m["role"] == "system"] msg_ = [m for m in msg[:-1] if m["role"] == "system"]
msg_.append(msg[-1]) if len(msg) > 1:
msg_.append(msg[-1])
msg = msg_ msg = msg_
c = count() c = count()
if c < max_length: if c < max_length: