Support debug components. (#3994)

### What problem does this PR solve?

#3993

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Kevin Hu 2024-12-11 19:23:59 +08:00 committed by GitHub
parent f61c276f74
commit 6d19294ddc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 63 additions and 15 deletions

View File

@ -211,6 +211,7 @@ class Canvas(ABC):
except Exception as e:
logging.exception(f"Canvas.run got exception: {e}")
self.path[-1].append(c)
ran += 1
raise e
self.path[-1].append(c)
ran += 1
@ -330,4 +331,4 @@ class Canvas(ABC):
return self.components["begin"]["obj"]._param.query
def get_component_input_elements(self, cpnnm):
return self.components["begin"]["obj"].get_input_elements()
return self.components[cpnnm]["obj"].get_input_elements()

View File

@ -37,6 +37,7 @@ class ComponentParamBase(ABC):
self.message_history_window_size = 22
self.query = []
self.inputs = []
self.debug_inputs = []
def set_name(self, name: str):
self._name = name
@ -410,6 +411,7 @@ class ComponentBase(ABC):
def run(self, history, **kwargs):
logging.debug("{}, history: {}, kwargs: {}".format(self, json.dumps(history, ensure_ascii=False),
json.dumps(kwargs, ensure_ascii=False)))
self._param.debug_inputs = []
try:
res = self._run(history, **kwargs)
self.set_output(res)
@ -446,10 +448,13 @@ class ComponentBase(ABC):
setattr(self._param, self._param.output_var_name, None)
self._param.inputs = []
def set_output(self, v: partial | pd.DataFrame):
def set_output(self, v):
setattr(self._param, self._param.output_var_name, v)
def get_input(self):
if self._param.debug_inputs:
return pd.DataFrame([{"content": v["value"]} for v in self._param.debug_inputs])
reversed_cpnts = []
if len(self._canvas.path) > 1:
reversed_cpnts.extend(self._canvas.path[-2])
@ -531,14 +536,15 @@ class ComponentBase(ABC):
eles = []
for q in self._param.query:
if q.get("component_id"):
if q["component_id"].split("@")[0].lower().find("begin") >= 0:
cpn_id, key = q["component_id"].split("@")
cpn_id = q["component_id"]
if cpn_id.split("@")[0].lower().find("begin") >= 0:
cpn_id, key = cpn_id.split("@")
eles.extend(self._canvas.get_component(cpn_id)["obj"]._param.query)
continue
eles.append({"key": q["key"], "component_id": q["component_id"]})
eles.append({"name": self._canvas.get_compnent_name(cpn_id), "key": cpn_id})
else:
eles.append({"key": q["key"]})
eles.append({"key": q["value"], "name": q["value"], "value": q["value"]})
return eles
def get_stream_input(self):
@ -558,3 +564,6 @@ class ComponentBase(ABC):
def get_component_name(self, cpn_id):
return self._canvas.get_component(cpn_id)["obj"].component_name.lower()
def debug(self, **kwargs):
return self._run([], **kwargs)

View File

@ -43,7 +43,7 @@ class Begin(ComponentBase):
def stream_output(self):
res = {"content": self._param.prologue}
yield res
self.set_output(res)
self.set_output(self.be_output(res))

View File

@ -111,9 +111,9 @@ class Generate(ComponentBase):
def get_input_elements(self):
if self._param.parameters:
return self._param.parameters
return [{"key": "user"}, *self._param.parameters]
return [{"key": "input"}]
return [{"key": "user"}]
def _run(self, history, **kwargs):
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
@ -218,4 +218,16 @@ class Generate(ComponentBase):
res = self.set_cite(retrieval_res, answer)
yield res
self.set_output(res)
self.set_output(Generate.be_output(res))
def debug(self, history, **kwargs):
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
prompt = self._param.prompt
for para in self._param.debug_inputs:
kwargs[para["key"]] = para["value"]
for n, v in kwargs.items():
prompt = re.sub(r"\{%s\}" % re.escape(n), str(v).replace("\\", " "), prompt)
return chat_mdl.chat(prompt, [{"role": "user", "content": kwargs.get("user", "")}], self._param.gen_conf())

View File

@ -187,10 +187,32 @@ def reset():
@manager.route('/input_elements', methods=['GET']) # noqa: F821
@validate_request("id", "component_id")
@login_required
def input_elements():
cvs_id = request.args.get("id")
cpn_id = request.args.get("component_id")
try:
e, user_canvas = UserCanvasService.get_by_id(cvs_id)
if not e:
return get_data_error_result(message="canvas not found.")
if not UserCanvasService.query(user_id=current_user.id, id=cvs_id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR)
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
return get_json_result(data=canvas.get_component_input_elements(cpn_id))
except Exception as e:
return server_error_response(e)
@manager.route('/debug', methods=['POST']) # noqa: F821
@validate_request("id", "component_id", "params")
@login_required
def debug():
req = request.json
for p in req["params"]:
assert p.get("key")
try:
e, user_canvas = UserCanvasService.get_by_id(req["id"])
if not e:
@ -201,7 +223,9 @@ def input_elements():
code=RetCode.OPERATING_ERROR)
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
return get_json_result(data=canvas.get_component_input_elements(req["component_id"]))
canvas.get_component(req["component_id"])["obj"]._param.debug_inputs = req["params"]
df = canvas.get_component(req["component_id"])["obj"].debug()
return get_json_result(data=df.to_dict(orient="records"))
except Exception as e:
return server_error_response(e)

View File

@ -95,6 +95,8 @@ def get():
return d.get(k1, d.get(k2))
for ref in conv.reference:
if isinstance(ref, list):
continue
ref["chunks"] = [{
"id": get_value(ck, "chunk_id", "id"),
"content": get_value(ck, "content", "content_with_weight"),

View File

@ -552,7 +552,7 @@ def parse():
})
driver = Chrome(options=options)
driver.get(url)
res_headers = [r.response.headers for r in driver.requests]
res_headers = [r.response.headers for r in driver.requests if r and r.response]
if len(res_headers) > 1:
sections = RAGFlowHtmlParser().parser_txt(driver.page_source)
driver.quit()

View File

@ -54,7 +54,7 @@ from rag.app import laws, paper, presentation, manual, qa, table, book, resume,
from rag.nlp import search, rag_tokenizer
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings
from rag.utils import rmSpace, num_tokens_from_string
from rag.utils import num_tokens_from_string
from rag.utils.redis_conn import REDIS_CONN, Payload
from rag.utils.storage_factory import STORAGE_IMPL
@ -269,7 +269,7 @@ def embedding(docs, mdl, parser_config=None, callback=None):
batch_size = 16
tts, cnts = [], []
for d in docs:
tts.append(rmSpace(d.get("docnm_kwd", "Title")))
tts.append(d.get("docnm_kwd", "Title"))
c = "\n".join(d.get("question_kwd", []))
if not c:
c = d["content_with_weight"]