mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-07-19 20:04:27 +08:00
Fix errors detected by Ruff (#3918)
### What problem does this PR solve? Fix errors detected by Ruff ### Type of change - [x] Refactoring
This commit is contained in:
parent
e267a026f3
commit
0d68a6cd1b
@ -133,7 +133,8 @@ class Canvas(ABC):
|
||||
"components": {}
|
||||
}
|
||||
for k in self.dsl.keys():
|
||||
if k in ["components"]:continue
|
||||
if k in ["components"]:
|
||||
continue
|
||||
dsl[k] = deepcopy(self.dsl[k])
|
||||
|
||||
for k, cpn in self.components.items():
|
||||
@ -158,7 +159,8 @@ class Canvas(ABC):
|
||||
|
||||
def get_compnent_name(self, cid):
|
||||
for n in self.dsl["graph"]["nodes"]:
|
||||
if cid == n["id"]: return n["data"]["name"]
|
||||
if cid == n["id"]:
|
||||
return n["data"]["name"]
|
||||
return ""
|
||||
|
||||
def run(self, **kwargs):
|
||||
@ -173,7 +175,8 @@ class Canvas(ABC):
|
||||
if kwargs.get("stream"):
|
||||
for an in ans():
|
||||
yield an
|
||||
else: yield ans
|
||||
else:
|
||||
yield ans
|
||||
return
|
||||
|
||||
if not self.path:
|
||||
@ -188,7 +191,8 @@ class Canvas(ABC):
|
||||
def prepare2run(cpns):
|
||||
nonlocal ran, ans
|
||||
for c in cpns:
|
||||
if self.path[-1] and c == self.path[-1][-1]: continue
|
||||
if self.path[-1] and c == self.path[-1][-1]:
|
||||
continue
|
||||
cpn = self.components[c]["obj"]
|
||||
if cpn.component_name == "Answer":
|
||||
self.answer.append(c)
|
||||
@ -197,7 +201,8 @@ class Canvas(ABC):
|
||||
if c not in without_dependent_checking:
|
||||
cpids = cpn.get_dependent_components()
|
||||
if any([cc not in self.path[-1] for cc in cpids]):
|
||||
if c not in waiting: waiting.append(c)
|
||||
if c not in waiting:
|
||||
waiting.append(c)
|
||||
continue
|
||||
yield "*'{}'* is running...🕞".format(self.get_compnent_name(c))
|
||||
ans = cpn.run(self.history, **kwargs)
|
||||
@ -211,10 +216,12 @@ class Canvas(ABC):
|
||||
logging.debug(f"Canvas.run: {ran} {self.path}")
|
||||
cpn_id = self.path[-1][ran]
|
||||
cpn = self.get_component(cpn_id)
|
||||
if not cpn["downstream"]: break
|
||||
if not cpn["downstream"]:
|
||||
break
|
||||
|
||||
loop = self._find_loop()
|
||||
if loop: raise OverflowError(f"Too much loops: {loop}")
|
||||
if loop:
|
||||
raise OverflowError(f"Too much loops: {loop}")
|
||||
|
||||
if cpn["obj"].component_name.lower() in ["switch", "categorize", "relevant"]:
|
||||
switch_out = cpn["obj"].output()[1].iloc[0, 0]
|
||||
@ -283,19 +290,22 @@ class Canvas(ABC):
|
||||
|
||||
def _find_loop(self, max_loops=6):
|
||||
path = self.path[-1][::-1]
|
||||
if len(path) < 2: return False
|
||||
if len(path) < 2:
|
||||
return False
|
||||
|
||||
for i in range(len(path)):
|
||||
if path[i].lower().find("answer") >= 0:
|
||||
path = path[:i]
|
||||
break
|
||||
|
||||
if len(path) < 2: return False
|
||||
if len(path) < 2:
|
||||
return False
|
||||
|
||||
for l in range(2, len(path) // 2):
|
||||
pat = ",".join(path[0:l])
|
||||
for loc in range(2, len(path) // 2):
|
||||
pat = ",".join(path[0:loc])
|
||||
path_str = ",".join(path)
|
||||
if len(pat) >= len(path_str): return False
|
||||
if len(pat) >= len(path_str):
|
||||
return False
|
||||
loop = max_loops
|
||||
while path_str.find(pat) == 0 and loop >= 0:
|
||||
loop -= 1
|
||||
@ -303,7 +313,7 @@ class Canvas(ABC):
|
||||
return False
|
||||
path_str = path_str[len(pat)+1:]
|
||||
if loop < 0:
|
||||
pat = " => ".join([p.split(":")[0] for p in path[0:l]])
|
||||
pat = " => ".join([p.split(":")[0] for p in path[0:loc]])
|
||||
return pat + " => " + pat
|
||||
|
||||
return False
|
||||
|
@ -39,3 +39,73 @@ def component_class(class_name):
|
||||
m = importlib.import_module("agent.component")
|
||||
c = getattr(m, class_name)
|
||||
return c
|
||||
|
||||
__all__ = [
|
||||
"Begin",
|
||||
"BeginParam",
|
||||
"Generate",
|
||||
"GenerateParam",
|
||||
"Retrieval",
|
||||
"RetrievalParam",
|
||||
"Answer",
|
||||
"AnswerParam",
|
||||
"Categorize",
|
||||
"CategorizeParam",
|
||||
"Switch",
|
||||
"SwitchParam",
|
||||
"Relevant",
|
||||
"RelevantParam",
|
||||
"Message",
|
||||
"MessageParam",
|
||||
"RewriteQuestion",
|
||||
"RewriteQuestionParam",
|
||||
"KeywordExtract",
|
||||
"KeywordExtractParam",
|
||||
"Concentrator",
|
||||
"ConcentratorParam",
|
||||
"Baidu",
|
||||
"BaiduParam",
|
||||
"DuckDuckGo",
|
||||
"DuckDuckGoParam",
|
||||
"Wikipedia",
|
||||
"WikipediaParam",
|
||||
"PubMed",
|
||||
"PubMedParam",
|
||||
"ArXiv",
|
||||
"ArXivParam",
|
||||
"Google",
|
||||
"GoogleParam",
|
||||
"Bing",
|
||||
"BingParam",
|
||||
"GoogleScholar",
|
||||
"GoogleScholarParam",
|
||||
"DeepL",
|
||||
"DeepLParam",
|
||||
"GitHub",
|
||||
"GitHubParam",
|
||||
"BaiduFanyi",
|
||||
"BaiduFanyiParam",
|
||||
"QWeather",
|
||||
"QWeatherParam",
|
||||
"ExeSQL",
|
||||
"ExeSQLParam",
|
||||
"YahooFinance",
|
||||
"YahooFinanceParam",
|
||||
"WenCai",
|
||||
"WenCaiParam",
|
||||
"Jin10",
|
||||
"Jin10Param",
|
||||
"TuShare",
|
||||
"TuShareParam",
|
||||
"AkShare",
|
||||
"AkShareParam",
|
||||
"Crawler",
|
||||
"CrawlerParam",
|
||||
"Invoke",
|
||||
"InvokeParam",
|
||||
"Template",
|
||||
"TemplateParam",
|
||||
"Email",
|
||||
"EmailParam",
|
||||
"component_class"
|
||||
]
|
||||
|
@ -428,7 +428,8 @@ class ComponentBase(ABC):
|
||||
def output(self, allow_partial=True) -> Tuple[str, Union[pd.DataFrame, partial]]:
|
||||
o = getattr(self._param, self._param.output_var_name)
|
||||
if not isinstance(o, partial) and not isinstance(o, pd.DataFrame):
|
||||
if not isinstance(o, list): o = [o]
|
||||
if not isinstance(o, list):
|
||||
o = [o]
|
||||
o = pd.DataFrame(o)
|
||||
|
||||
if allow_partial or not isinstance(o, partial):
|
||||
@ -440,7 +441,8 @@ class ComponentBase(ABC):
|
||||
for oo in o():
|
||||
if not isinstance(oo, pd.DataFrame):
|
||||
outs = pd.DataFrame(oo if isinstance(oo, list) else [oo])
|
||||
else: outs = oo
|
||||
else:
|
||||
outs = oo
|
||||
return self._param.output_var_name, outs
|
||||
|
||||
def reset(self):
|
||||
@ -482,13 +484,15 @@ class ComponentBase(ABC):
|
||||
outs.append(pd.DataFrame([{"content": q["value"]}]))
|
||||
if outs:
|
||||
df = pd.concat(outs, ignore_index=True)
|
||||
if "content" in df: df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
|
||||
if "content" in df:
|
||||
df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
|
||||
return df
|
||||
|
||||
upstream_outs = []
|
||||
|
||||
for u in reversed_cpnts[::-1]:
|
||||
if self.get_component_name(u) in ["switch", "concentrator"]: continue
|
||||
if self.get_component_name(u) in ["switch", "concentrator"]:
|
||||
continue
|
||||
if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
|
||||
o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
|
||||
if o is not None:
|
||||
@ -532,7 +536,8 @@ class ComponentBase(ABC):
|
||||
reversed_cpnts.extend(self._canvas.path[-1])
|
||||
|
||||
for u in reversed_cpnts[::-1]:
|
||||
if self.get_component_name(u) in ["switch", "answer"]: continue
|
||||
if self.get_component_name(u) in ["switch", "answer"]:
|
||||
continue
|
||||
return self._canvas.get_component(u)["obj"].output()[1]
|
||||
|
||||
@staticmethod
|
||||
|
@ -34,15 +34,18 @@ class CategorizeParam(GenerateParam):
|
||||
super().check()
|
||||
self.check_empty(self.category_description, "[Categorize] Category examples")
|
||||
for k, v in self.category_description.items():
|
||||
if not k: raise ValueError("[Categorize] Category name can not be empty!")
|
||||
if not v.get("to"): raise ValueError(f"[Categorize] 'To' of category {k} can not be empty!")
|
||||
if not k:
|
||||
raise ValueError("[Categorize] Category name can not be empty!")
|
||||
if not v.get("to"):
|
||||
raise ValueError(f"[Categorize] 'To' of category {k} can not be empty!")
|
||||
|
||||
def get_prompt(self):
|
||||
cate_lines = []
|
||||
for c, desc in self.category_description.items():
|
||||
for l in desc.get("examples", "").split("\n"):
|
||||
if not l: continue
|
||||
cate_lines.append("Question: {}\tCategory: {}".format(l, c))
|
||||
for line in desc.get("examples", "").split("\n"):
|
||||
if not line:
|
||||
continue
|
||||
cate_lines.append("Question: {}\tCategory: {}".format(line, c))
|
||||
descriptions = []
|
||||
for c, desc in self.category_description.items():
|
||||
if desc.get("description"):
|
||||
|
@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
import re
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
import deepl
|
||||
|
||||
|
@ -46,8 +46,10 @@ class ExeSQLParam(ComponentParamBase):
|
||||
self.check_empty(self.password, "Database password")
|
||||
self.check_positive_integer(self.top_n, "Number of records")
|
||||
if self.database == "rag_flow":
|
||||
if self.host == "ragflow-mysql": raise ValueError("The host is not accessible.")
|
||||
if self.password == "infini_rag_flow": raise ValueError("The host is not accessible.")
|
||||
if self.host == "ragflow-mysql":
|
||||
raise ValueError("The host is not accessible.")
|
||||
if self.password == "infini_rag_flow":
|
||||
raise ValueError("The host is not accessible.")
|
||||
|
||||
|
||||
class ExeSQL(ComponentBase, ABC):
|
||||
|
@ -51,11 +51,16 @@ class GenerateParam(ComponentParamBase):
|
||||
|
||||
def gen_conf(self):
|
||||
conf = {}
|
||||
if self.max_tokens > 0: conf["max_tokens"] = self.max_tokens
|
||||
if self.temperature > 0: conf["temperature"] = self.temperature
|
||||
if self.top_p > 0: conf["top_p"] = self.top_p
|
||||
if self.presence_penalty > 0: conf["presence_penalty"] = self.presence_penalty
|
||||
if self.frequency_penalty > 0: conf["frequency_penalty"] = self.frequency_penalty
|
||||
if self.max_tokens > 0:
|
||||
conf["max_tokens"] = self.max_tokens
|
||||
if self.temperature > 0:
|
||||
conf["temperature"] = self.temperature
|
||||
if self.top_p > 0:
|
||||
conf["top_p"] = self.top_p
|
||||
if self.presence_penalty > 0:
|
||||
conf["presence_penalty"] = self.presence_penalty
|
||||
if self.frequency_penalty > 0:
|
||||
conf["frequency_penalty"] = self.frequency_penalty
|
||||
return conf
|
||||
|
||||
|
||||
@ -83,7 +88,8 @@ class Generate(ComponentBase):
|
||||
recall_docs = []
|
||||
for i in idx:
|
||||
did = retrieval_res.loc[int(i), "doc_id"]
|
||||
if did in doc_ids: continue
|
||||
if did in doc_ids:
|
||||
continue
|
||||
doc_ids.add(did)
|
||||
recall_docs.append({"doc_id": did, "doc_name": retrieval_res.loc[int(i), "docnm_kwd"]})
|
||||
|
||||
@ -108,7 +114,8 @@ class Generate(ComponentBase):
|
||||
retrieval_res = []
|
||||
self._param.inputs = []
|
||||
for para in self._param.parameters:
|
||||
if not para.get("component_id"): continue
|
||||
if not para.get("component_id"):
|
||||
continue
|
||||
component_id = para["component_id"].split("@")[0]
|
||||
if para["component_id"].lower().find("@") >= 0:
|
||||
cpn_id, key = para["component_id"].split("@")
|
||||
@ -142,7 +149,8 @@ class Generate(ComponentBase):
|
||||
|
||||
if retrieval_res:
|
||||
retrieval_res = pd.concat(retrieval_res, ignore_index=True)
|
||||
else: retrieval_res = pd.DataFrame([])
|
||||
else:
|
||||
retrieval_res = pd.DataFrame([])
|
||||
|
||||
for n, v in kwargs.items():
|
||||
prompt = re.sub(r"\{%s\}" % re.escape(n), str(v).replace("\\", " "), prompt)
|
||||
@ -164,9 +172,11 @@ class Generate(ComponentBase):
|
||||
return pd.DataFrame([res])
|
||||
|
||||
msg = self._canvas.get_history(self._param.message_history_window_size)
|
||||
if len(msg) < 1: msg.append({"role": "user", "content": ""})
|
||||
if len(msg) < 1:
|
||||
msg.append({"role": "user", "content": ""})
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
|
||||
if len(msg) < 2: msg.append({"role": "user", "content": ""})
|
||||
if len(msg) < 2:
|
||||
msg.append({"role": "user", "content": ""})
|
||||
ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf())
|
||||
|
||||
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
|
||||
@ -185,9 +195,11 @@ class Generate(ComponentBase):
|
||||
return
|
||||
|
||||
msg = self._canvas.get_history(self._param.message_history_window_size)
|
||||
if len(msg) < 1: msg.append({"role": "user", "content": ""})
|
||||
if len(msg) < 1:
|
||||
msg.append({"role": "user", "content": ""})
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
|
||||
if len(msg) < 2: msg.append({"role": "user", "content": ""})
|
||||
if len(msg) < 2:
|
||||
msg.append({"role": "user", "content": ""})
|
||||
answer = ""
|
||||
for ans in chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf()):
|
||||
res = {"content": ans, "reference": []}
|
||||
|
@ -95,7 +95,8 @@ class RewriteQuestion(Generate, ABC):
|
||||
hist = self._canvas.get_history(4)
|
||||
conv = []
|
||||
for m in hist:
|
||||
if m["role"] not in ["user", "assistant"]: continue
|
||||
if m["role"] not in ["user", "assistant"]:
|
||||
continue
|
||||
conv.append("{}: {}".format(m["role"].upper(), m["content"]))
|
||||
conv = "\n".join(conv)
|
||||
|
||||
|
@ -41,7 +41,8 @@ class SwitchParam(ComponentParamBase):
|
||||
def check(self):
|
||||
self.check_empty(self.conditions, "[Switch] conditions")
|
||||
for cond in self.conditions:
|
||||
if not cond["to"]: raise ValueError(f"[Switch] 'To' can not be empty!")
|
||||
if not cond["to"]:
|
||||
raise ValueError("[Switch] 'To' can not be empty!")
|
||||
|
||||
|
||||
class Switch(ComponentBase, ABC):
|
||||
@ -51,7 +52,8 @@ class Switch(ComponentBase, ABC):
|
||||
res = []
|
||||
for cond in self._param.conditions:
|
||||
for item in cond["items"]:
|
||||
if not item["cpn_id"]: continue
|
||||
if not item["cpn_id"]:
|
||||
continue
|
||||
if item["cpn_id"].find("begin") >= 0:
|
||||
continue
|
||||
cid = item["cpn_id"].split("@")[0]
|
||||
@ -63,7 +65,8 @@ class Switch(ComponentBase, ABC):
|
||||
for cond in self._param.conditions:
|
||||
res = []
|
||||
for item in cond["items"]:
|
||||
if not item["cpn_id"]:continue
|
||||
if not item["cpn_id"]:
|
||||
continue
|
||||
cid = item["cpn_id"].split("@")[0]
|
||||
if item["cpn_id"].find("@") > 0:
|
||||
cpn_id, key = item["cpn_id"].split("@")
|
||||
@ -107,22 +110,22 @@ class Switch(ComponentBase, ABC):
|
||||
elif operator == ">":
|
||||
try:
|
||||
return True if float(input) > float(value) else False
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return True if input > value else False
|
||||
elif operator == "<":
|
||||
try:
|
||||
return True if float(input) < float(value) else False
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return True if input < value else False
|
||||
elif operator == "≥":
|
||||
try:
|
||||
return True if float(input) >= float(value) else False
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return True if input >= value else False
|
||||
elif operator == "≤":
|
||||
try:
|
||||
return True if float(input) <= float(value) else False
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return True if input <= value else False
|
||||
|
||||
raise ValueError('Not supported operator' + operator)
|
@ -47,7 +47,8 @@ class Template(ComponentBase):
|
||||
|
||||
self._param.inputs = []
|
||||
for para in self._param.parameters:
|
||||
if not para.get("component_id"): continue
|
||||
if not para.get("component_id"):
|
||||
continue
|
||||
component_id = para["component_id"].split("@")[0]
|
||||
if para["component_id"].lower().find("@") >= 0:
|
||||
cpn_id, key = para["component_id"].split("@")
|
||||
|
@ -43,6 +43,7 @@ if __name__ == '__main__':
|
||||
else:
|
||||
print(ans["content"])
|
||||
|
||||
if DEBUG: print(canvas.path)
|
||||
if DEBUG:
|
||||
print(canvas.path)
|
||||
question = input("\n==================== User =====================\n> ")
|
||||
canvas.add_user_input(question)
|
||||
|
@ -142,7 +142,6 @@ def set_conversation():
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
req = request.json
|
||||
try:
|
||||
if objs[0].source == "agent":
|
||||
e, cvs = UserCanvasService.get_by_id(objs[0].dialog_id)
|
||||
@ -188,7 +187,8 @@ def completion():
|
||||
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
if "quote" not in req: req["quote"] = False
|
||||
if "quote" not in req:
|
||||
req["quote"] = False
|
||||
|
||||
msg = []
|
||||
for m in req["messages"]:
|
||||
@ -197,7 +197,8 @@ def completion():
|
||||
if m["role"] == "assistant" and not msg:
|
||||
continue
|
||||
msg.append(m)
|
||||
if not msg[-1].get("id"): msg[-1]["id"] = get_uuid()
|
||||
if not msg[-1].get("id"):
|
||||
msg[-1]["id"] = get_uuid()
|
||||
message_id = msg[-1]["id"]
|
||||
|
||||
def fillin_conv(ans):
|
||||
@ -674,11 +675,13 @@ def completion_faq():
|
||||
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
if "quote" not in req: req["quote"] = True
|
||||
if "quote" not in req:
|
||||
req["quote"] = True
|
||||
|
||||
msg = []
|
||||
msg.append({"role": "user", "content": req["word"]})
|
||||
if not msg[-1].get("id"): msg[-1]["id"] = get_uuid()
|
||||
if not msg[-1].get("id"):
|
||||
msg[-1]["id"] = get_uuid()
|
||||
message_id = msg[-1]["id"]
|
||||
|
||||
def fillin_conv(ans):
|
||||
|
@ -13,10 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import json
|
||||
import traceback
|
||||
from functools import partial
|
||||
from flask import request, Response
|
||||
from flask_login import login_required, current_user
|
||||
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
|
||||
@ -60,7 +58,8 @@ def rm():
|
||||
def save():
|
||||
req = request.json
|
||||
req["user_id"] = current_user.id
|
||||
if not isinstance(req["dsl"], str): req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
||||
if not isinstance(req["dsl"], str):
|
||||
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
||||
|
||||
req["dsl"] = json.loads(req["dsl"])
|
||||
if "id" not in req:
|
||||
@ -153,7 +152,8 @@ def run():
|
||||
return resp
|
||||
|
||||
for answer in canvas.run(stream=False):
|
||||
if answer.get("running_status"): continue
|
||||
if answer.get("running_status"):
|
||||
continue
|
||||
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
|
||||
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
|
||||
if final_ans.get("reference"):
|
||||
|
@ -237,7 +237,8 @@ def create():
|
||||
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Knowledgebase not found!")
|
||||
if kb.pagerank: d["pagerank_fea"] = kb.pagerank
|
||||
if kb.pagerank:
|
||||
d["pagerank_fea"] = kb.pagerank
|
||||
|
||||
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
|
||||
|
@ -281,10 +281,12 @@ def thumbup():
|
||||
if req["message_id"] == msg.get("id", "") and msg.get("role", "") == "assistant":
|
||||
if up_down:
|
||||
msg["thumbup"] = True
|
||||
if "feedback" in msg: del msg["feedback"]
|
||||
if "feedback" in msg:
|
||||
del msg["feedback"]
|
||||
else:
|
||||
msg["thumbup"] = False
|
||||
if feedback: msg["feedback"] = feedback
|
||||
if feedback:
|
||||
msg["feedback"] = feedback
|
||||
break
|
||||
|
||||
ConversationService.update_by_id(conv["id"], conv)
|
||||
|
@ -37,10 +37,12 @@ def set_dialog():
|
||||
top_n = req.get("top_n", 6)
|
||||
top_k = req.get("top_k", 1024)
|
||||
rerank_id = req.get("rerank_id", "")
|
||||
if not rerank_id: req["rerank_id"] = ""
|
||||
if not rerank_id:
|
||||
req["rerank_id"] = ""
|
||||
similarity_threshold = req.get("similarity_threshold", 0.1)
|
||||
vector_similarity_weight = req.get("vector_similarity_weight", 0.3)
|
||||
if vector_similarity_weight is None: vector_similarity_weight = 0.3
|
||||
if vector_similarity_weight is None:
|
||||
vector_similarity_weight = 0.3
|
||||
llm_setting = req.get("llm_setting", {})
|
||||
default_prompt = {
|
||||
"system": """你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答。当所有知识库内容都与问题无关时,你的回答必须包括“知识库中未找到您要的答案!”这句话。回答需要考虑聊天历史。
|
||||
|
@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
#
|
||||
import json
|
||||
import os.path
|
||||
import pathlib
|
||||
import re
|
||||
@ -90,7 +89,8 @@ def web_crawl():
|
||||
raise LookupError("Can't find this knowledgebase!")
|
||||
|
||||
blob = html2pdf(url)
|
||||
if not blob: return server_error_response(ValueError("Download failure."))
|
||||
if not blob:
|
||||
return server_error_response(ValueError("Download failure."))
|
||||
|
||||
root_folder = FileService.get_root_folder(current_user.id)
|
||||
pf_id = root_folder["id"]
|
||||
@ -290,7 +290,8 @@ def change_status():
|
||||
def rm():
|
||||
req = request.json
|
||||
doc_ids = req["doc_id"]
|
||||
if isinstance(doc_ids, str): doc_ids = [doc_ids]
|
||||
if isinstance(doc_ids, str):
|
||||
doc_ids = [doc_ids]
|
||||
|
||||
for doc_id in doc_ids:
|
||||
if not DocumentService.accessible4deletion(doc_id, current_user.id):
|
||||
|
@ -351,8 +351,10 @@ def list_app():
|
||||
|
||||
llm_set = set([m["llm_name"] + "@" + m["fid"] for m in llms])
|
||||
for o in objs:
|
||||
if not o.api_key: continue
|
||||
if o.llm_name + "@" + o.llm_factory in llm_set: continue
|
||||
if not o.api_key:
|
||||
continue
|
||||
if o.llm_name + "@" + o.llm_factory in llm_set:
|
||||
continue
|
||||
llms.append({"llm_name": o.llm_name, "model_type": o.model_type, "fid": o.llm_factory, "available": True})
|
||||
|
||||
res = {}
|
||||
|
@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.utils.api_utils import get_error_data_result, token_required
|
||||
from api.utils.api_utils import get_result
|
||||
from flask import request
|
||||
|
@ -41,7 +41,6 @@ from api.utils.api_utils import construct_json_result, get_parser_config
|
||||
from rag.nlp import search
|
||||
from rag.utils import rmSpace
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
import os
|
||||
|
||||
MAXIMUM_OF_UPLOADING_FILES = 256
|
||||
|
||||
@ -976,12 +975,12 @@ def add_chunk(tenant_id, dataset_id, document_id):
|
||||
if not req.get("content"):
|
||||
return get_error_data_result(message="`content` is required")
|
||||
if "important_keywords" in req:
|
||||
if type(req["important_keywords"]) != list:
|
||||
if not isinstance(req["important_keywords"], list):
|
||||
return get_error_data_result(
|
||||
"`important_keywords` is required to be a list"
|
||||
)
|
||||
if "questions" in req:
|
||||
if type(req["questions"]) != list:
|
||||
if not isinstance(req["questions"], list):
|
||||
return get_error_data_result(
|
||||
"`questions` is required to be a list"
|
||||
)
|
||||
|
@ -143,8 +143,10 @@ def completion(tenant_id, chat_id):
|
||||
}
|
||||
conv.message.append(question)
|
||||
for m in conv.message:
|
||||
if m["role"] == "system": continue
|
||||
if m["role"] == "assistant" and not msg: continue
|
||||
if m["role"] == "system":
|
||||
continue
|
||||
if m["role"] == "assistant" and not msg:
|
||||
continue
|
||||
msg.append(m)
|
||||
message_id = msg[-1].get("id")
|
||||
e, dia = DialogService.get_by_id(conv.dialog_id)
|
||||
@ -267,7 +269,8 @@ def agent_completion(tenant_id, agent_id):
|
||||
if m["role"] == "assistant" and not msg:
|
||||
continue
|
||||
msg.append(m)
|
||||
if not msg[-1].get("id"): msg[-1]["id"] = get_uuid()
|
||||
if not msg[-1].get("id"):
|
||||
msg[-1]["id"] = get_uuid()
|
||||
message_id = msg[-1]["id"]
|
||||
|
||||
stream = req.get("stream", True)
|
||||
@ -361,7 +364,8 @@ def agent_completion(tenant_id, agent_id):
|
||||
return resp
|
||||
|
||||
for answer in canvas.run(stream=False):
|
||||
if answer.get("running_status"): continue
|
||||
if answer.get("running_status"):
|
||||
continue
|
||||
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
|
||||
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
|
||||
if final_ans.get("reference"):
|
||||
|
@ -330,7 +330,7 @@ def user_info_from_github(access_token):
|
||||
headers=headers,
|
||||
).json()
|
||||
user_info["email"] = next(
|
||||
(email for email in email_info if email["primary"] == True), None
|
||||
(email for email in email_info if email["primary"]), None
|
||||
)["email"]
|
||||
return user_info
|
||||
|
||||
|
@ -130,7 +130,7 @@ def is_continuous_field(cls: typing.Type) -> bool:
|
||||
for p in cls.__bases__:
|
||||
if p in CONTINUOUS_FIELD_TYPE:
|
||||
return True
|
||||
elif p != Field and p != object:
|
||||
elif p is not Field and p is not object:
|
||||
if is_continuous_field(p):
|
||||
return True
|
||||
else:
|
||||
|
@ -170,7 +170,7 @@ def add_graph_templates():
|
||||
cnvs = json.load(open(os.path.join(dir, fnm), "r"))
|
||||
try:
|
||||
CanvasTemplateService.save(**cnvs)
|
||||
except:
|
||||
except Exception:
|
||||
CanvasTemplateService.update_by_id(cnvs["id"], cnvs)
|
||||
except Exception:
|
||||
logging.exception("Add graph templates error: ")
|
||||
|
@ -15,13 +15,14 @@
|
||||
#
|
||||
import pathlib
|
||||
import re
|
||||
from .user_service import UserService
|
||||
from .user_service import UserService as UserService
|
||||
|
||||
|
||||
def duplicate_name(query_func, **kwargs):
|
||||
fnm = kwargs["name"]
|
||||
objs = query_func(**kwargs)
|
||||
if not objs: return fnm
|
||||
if not objs:
|
||||
return fnm
|
||||
ext = pathlib.Path(fnm).suffix #.jpg
|
||||
nm = re.sub(r"%s$"%ext, "", fnm)
|
||||
r = re.search(r"\(([0-9]+)\)$", nm)
|
||||
@ -31,8 +32,8 @@ def duplicate_name(query_func, **kwargs):
|
||||
nm = re.sub(r"\([0-9]+\)$", "", nm)
|
||||
c += 1
|
||||
nm = f"{nm}({c})"
|
||||
if ext: nm += f"{ext}"
|
||||
if ext:
|
||||
nm += f"{ext}"
|
||||
|
||||
kwargs["name"] = nm
|
||||
return duplicate_name(query_func, **kwargs)
|
||||
|
||||
|
@ -64,7 +64,8 @@ class API4ConversationService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def stats(cls, tenant_id, from_date, to_date, source=None):
|
||||
if len(to_date) == 10: to_date += " 23:59:59"
|
||||
if len(to_date) == 10:
|
||||
to_date += " 23:59:59"
|
||||
return cls.model.select(
|
||||
cls.model.create_date.truncate("day").alias("dt"),
|
||||
peewee.fn.COUNT(
|
||||
|
@ -13,9 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from datetime import datetime
|
||||
import peewee
|
||||
from api.db.db_models import DB, API4Conversation, APIToken, Dialog, CanvasTemplate, UserCanvas
|
||||
from api.db.db_models import DB, CanvasTemplate, UserCanvas
|
||||
from api.db.services.common_service import CommonService
|
||||
|
||||
|
||||
|
@ -115,7 +115,7 @@ class CommonService:
|
||||
try:
|
||||
obj = cls.model.query(id=pid)[0]
|
||||
return True, obj
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return False, None
|
||||
|
||||
@classmethod
|
||||
|
@ -106,15 +106,15 @@ def message_fit_in(msg, max_length=4000):
|
||||
return c, msg
|
||||
|
||||
ll = num_tokens_from_string(msg_[0]["content"])
|
||||
l = num_tokens_from_string(msg_[-1]["content"])
|
||||
if ll / (ll + l) > 0.8:
|
||||
ll2 = num_tokens_from_string(msg_[-1]["content"])
|
||||
if ll / (ll + ll2) > 0.8:
|
||||
m = msg_[0]["content"]
|
||||
m = encoder.decode(encoder.encode(m)[:max_length - l])
|
||||
m = encoder.decode(encoder.encode(m)[:max_length - ll2])
|
||||
msg[0]["content"] = m
|
||||
return max_length, msg
|
||||
|
||||
m = msg_[1]["content"]
|
||||
m = encoder.decode(encoder.encode(m)[:max_length - l])
|
||||
m = encoder.decode(encoder.encode(m)[:max_length - ll2])
|
||||
msg[1]["content"] = m
|
||||
return max_length, msg
|
||||
|
||||
@ -257,7 +257,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||||
recall_docs = [
|
||||
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||
if not recall_docs: recall_docs = kbinfos["doc_aggs"]
|
||||
if not recall_docs:
|
||||
recall_docs = kbinfos["doc_aggs"]
|
||||
kbinfos["doc_aggs"] = recall_docs
|
||||
|
||||
refs = deepcopy(kbinfos)
|
||||
@ -433,13 +434,15 @@ def relevant(tenant_id, llm_id, question, contents: list):
|
||||
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.
|
||||
No other words needed except 'yes' or 'no'.
|
||||
"""
|
||||
if not contents:return False
|
||||
if not contents:
|
||||
return False
|
||||
contents = "Documents: \n" + " - ".join(contents)
|
||||
contents = f"Question: {question}\n" + contents
|
||||
if num_tokens_from_string(contents) >= chat_mdl.max_length - 4:
|
||||
contents = encoder.decode(encoder.encode(contents)[:chat_mdl.max_length - 4])
|
||||
ans = chat_mdl.chat(prompt, [{"role": "user", "content": contents}], {"temperature": 0.01})
|
||||
if ans.lower().find("yes") >= 0: return True
|
||||
if ans.lower().find("yes") >= 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@ -481,8 +484,10 @@ Requirements:
|
||||
]
|
||||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||||
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
|
||||
if isinstance(kwd, tuple): kwd = kwd[0]
|
||||
if kwd.find("**ERROR**") >=0: return ""
|
||||
if isinstance(kwd, tuple):
|
||||
kwd = kwd[0]
|
||||
if kwd.find("**ERROR**") >=0:
|
||||
return ""
|
||||
return kwd
|
||||
|
||||
|
||||
@ -508,8 +513,10 @@ Requirements:
|
||||
]
|
||||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||||
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
|
||||
if isinstance(kwd, tuple): kwd = kwd[0]
|
||||
if kwd.find("**ERROR**") >= 0: return ""
|
||||
if isinstance(kwd, tuple):
|
||||
kwd = kwd[0]
|
||||
if kwd.find("**ERROR**") >= 0:
|
||||
return ""
|
||||
return kwd
|
||||
|
||||
|
||||
@ -520,7 +527,8 @@ def full_question(tenant_id, llm_id, messages):
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
||||
conv = []
|
||||
for m in messages:
|
||||
if m["role"] not in ["user", "assistant"]: continue
|
||||
if m["role"] not in ["user", "assistant"]:
|
||||
continue
|
||||
conv.append("{}: {}".format(m["role"].upper(), m["content"]))
|
||||
conv = "\n".join(conv)
|
||||
today = datetime.date.today().isoformat()
|
||||
@ -581,7 +589,8 @@ Output: What's the weather in Rochester on {tomorrow}?
|
||||
|
||||
|
||||
def tts(tts_mdl, text):
|
||||
if not tts_mdl or not text: return
|
||||
if not tts_mdl or not text:
|
||||
return
|
||||
bin = b""
|
||||
for chunk in tts_mdl.tts(text):
|
||||
bin += chunk
|
||||
@ -641,7 +650,8 @@ def ask(question, kb_ids, tenant_id):
|
||||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||||
recall_docs = [
|
||||
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||
if not recall_docs: recall_docs = kbinfos["doc_aggs"]
|
||||
if not recall_docs:
|
||||
recall_docs = kbinfos["doc_aggs"]
|
||||
kbinfos["doc_aggs"] = recall_docs
|
||||
refs = deepcopy(kbinfos)
|
||||
for c in refs["chunks"]:
|
||||
|
@ -532,7 +532,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
try:
|
||||
mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output,
|
||||
ensure_ascii=False, indent=2)
|
||||
if len(mind_map) < 32: raise Exception("Few content: " + mind_map)
|
||||
if len(mind_map) < 32:
|
||||
raise Exception("Few content: " + mind_map)
|
||||
cks.append({
|
||||
"id": get_uuid(),
|
||||
"doc_id": doc_id,
|
||||
|
@ -20,7 +20,7 @@ from api.db.db_models import DB
|
||||
from api.db.db_models import File, File2Document
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.utils import current_timestamp, datetime_format, get_uuid
|
||||
from api.utils import current_timestamp, datetime_format
|
||||
|
||||
|
||||
class File2DocumentService(CommonService):
|
||||
@ -63,7 +63,7 @@ class File2DocumentService(CommonService):
|
||||
def update_by_file_id(cls, file_id, obj):
|
||||
obj["update_time"] = current_timestamp()
|
||||
obj["update_date"] = datetime_format(datetime.now())
|
||||
num = cls.model.update(obj).where(cls.model.id == file_id).execute()
|
||||
# num = cls.model.update(obj).where(cls.model.id == file_id).execute()
|
||||
e, obj = cls.get_by_id(cls.model.id)
|
||||
return obj
|
||||
|
||||
|
@ -85,7 +85,8 @@ class FileService(CommonService):
|
||||
.join(Document, on=(File2Document.document_id == Document.id))
|
||||
.join(Knowledgebase, on=(Knowledgebase.id == Document.kb_id))
|
||||
.where(cls.model.id == file_id))
|
||||
if not kbs: return []
|
||||
if not kbs:
|
||||
return []
|
||||
kbs_info_list = []
|
||||
for kb in list(kbs.dicts()):
|
||||
kbs_info_list.append({"kb_id": kb['id'], "kb_name": kb['name']})
|
||||
@ -304,7 +305,8 @@ class FileService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def add_file_from_kb(cls, doc, kb_folder_id, tenant_id):
|
||||
for _ in File2DocumentService.get_by_document_id(doc["id"]): return
|
||||
for _ in File2DocumentService.get_by_document_id(doc["id"]):
|
||||
return
|
||||
file = {
|
||||
"id": get_uuid(),
|
||||
"parent_id": kb_folder_id,
|
||||
|
@ -107,7 +107,8 @@ class TenantLLMService(CommonService):
|
||||
|
||||
model_config = cls.get_api_key(tenant_id, mdlnm)
|
||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
|
||||
if model_config: model_config = model_config.to_dict()
|
||||
if model_config:
|
||||
model_config = model_config.to_dict()
|
||||
if not model_config:
|
||||
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
||||
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
||||
|
@ -57,28 +57,33 @@ class TaskService(CommonService):
|
||||
Tenant.img2txt_id,
|
||||
Tenant.asr_id,
|
||||
Tenant.llm_id,
|
||||
cls.model.update_time]
|
||||
docs = cls.model.select(*fields) \
|
||||
.join(Document, on=(cls.model.doc_id == Document.id)) \
|
||||
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id)) \
|
||||
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \
|
||||
cls.model.update_time,
|
||||
]
|
||||
docs = (
|
||||
cls.model.select(*fields)
|
||||
.join(Document, on=(cls.model.doc_id == Document.id))
|
||||
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id))
|
||||
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
|
||||
.where(cls.model.id == task_id)
|
||||
)
|
||||
docs = list(docs.dicts())
|
||||
if not docs: return None
|
||||
if not docs:
|
||||
return None
|
||||
|
||||
msg = "\nTask has been received."
|
||||
prog = random.random() / 10.
|
||||
prog = random.random() / 10.0
|
||||
if docs[0]["retry_count"] >= 3:
|
||||
msg = "\nERROR: Task is abandoned after 3 times attempts."
|
||||
prog = -1
|
||||
|
||||
cls.model.update(progress_msg=cls.model.progress_msg + msg,
|
||||
progress=prog,
|
||||
retry_count=docs[0]["retry_count"]+1
|
||||
).where(
|
||||
cls.model.id == docs[0]["id"]).execute()
|
||||
cls.model.update(
|
||||
progress_msg=cls.model.progress_msg + msg,
|
||||
progress=prog,
|
||||
retry_count=docs[0]["retry_count"] + 1,
|
||||
).where(cls.model.id == docs[0]["id"]).execute()
|
||||
|
||||
if docs[0]["retry_count"] >= 3: return None
|
||||
if docs[0]["retry_count"] >= 3:
|
||||
return None
|
||||
|
||||
return docs[0]
|
||||
|
||||
@ -86,21 +91,44 @@ class TaskService(CommonService):
|
||||
@DB.connection_context()
|
||||
def get_ongoing_doc_name(cls):
|
||||
with DB.lock("get_task", -1):
|
||||
docs = cls.model.select(*[Document.id, Document.kb_id, Document.location, File.parent_id]) \
|
||||
.join(Document, on=(cls.model.doc_id == Document.id)) \
|
||||
.join(File2Document, on=(File2Document.document_id == Document.id), join_type=JOIN.LEFT_OUTER) \
|
||||
.join(File, on=(File2Document.file_id == File.id), join_type=JOIN.LEFT_OUTER) \
|
||||
docs = (
|
||||
cls.model.select(
|
||||
*[Document.id, Document.kb_id, Document.location, File.parent_id]
|
||||
)
|
||||
.join(Document, on=(cls.model.doc_id == Document.id))
|
||||
.join(
|
||||
File2Document,
|
||||
on=(File2Document.document_id == Document.id),
|
||||
join_type=JOIN.LEFT_OUTER,
|
||||
)
|
||||
.join(
|
||||
File,
|
||||
on=(File2Document.file_id == File.id),
|
||||
join_type=JOIN.LEFT_OUTER,
|
||||
)
|
||||
.where(
|
||||
Document.status == StatusEnum.VALID.value,
|
||||
Document.run == TaskStatus.RUNNING.value,
|
||||
~(Document.type == FileType.VIRTUAL.value),
|
||||
cls.model.progress < 1,
|
||||
cls.model.create_time >= current_timestamp() - 1000 * 600
|
||||
cls.model.create_time >= current_timestamp() - 1000 * 600,
|
||||
)
|
||||
)
|
||||
docs = list(docs.dicts())
|
||||
if not docs: return []
|
||||
if not docs:
|
||||
return []
|
||||
|
||||
return list(set([(d["parent_id"] if d["parent_id"] else d["kb_id"], d["location"]) for d in docs]))
|
||||
return list(
|
||||
set(
|
||||
[
|
||||
(
|
||||
d["parent_id"] if d["parent_id"] else d["kb_id"],
|
||||
d["location"],
|
||||
)
|
||||
for d in docs
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
@ -118,28 +146,30 @@ class TaskService(CommonService):
|
||||
def update_progress(cls, id, info):
|
||||
if os.environ.get("MACOS"):
|
||||
if info["progress_msg"]:
|
||||
cls.model.update(progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]).where(
|
||||
cls.model.id == id).execute()
|
||||
cls.model.update(
|
||||
progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]
|
||||
).where(cls.model.id == id).execute()
|
||||
if "progress" in info:
|
||||
cls.model.update(progress=info["progress"]).where(
|
||||
cls.model.id == id).execute()
|
||||
cls.model.id == id
|
||||
).execute()
|
||||
return
|
||||
|
||||
with DB.lock("update_progress", -1):
|
||||
if info["progress_msg"]:
|
||||
cls.model.update(progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]).where(
|
||||
cls.model.id == id).execute()
|
||||
cls.model.update(
|
||||
progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]
|
||||
).where(cls.model.id == id).execute()
|
||||
if "progress" in info:
|
||||
cls.model.update(progress=info["progress"]).where(
|
||||
cls.model.id == id).execute()
|
||||
cls.model.id == id
|
||||
).execute()
|
||||
|
||||
|
||||
def queue_tasks(doc: dict, bucket: str, name: str):
|
||||
def new_task():
|
||||
return {
|
||||
"id": get_uuid(),
|
||||
"doc_id": doc["id"]
|
||||
}
|
||||
return {"id": get_uuid(), "doc_id": doc["id"]}
|
||||
|
||||
tsks = []
|
||||
|
||||
if doc["type"] == FileType.PDF.value:
|
||||
@ -150,8 +180,8 @@ def queue_tasks(doc: dict, bucket: str, name: str):
|
||||
if doc["parser_id"] == "paper":
|
||||
page_size = doc["parser_config"].get("task_page_size", 22)
|
||||
if doc["parser_id"] in ["one", "knowledge_graph"] or not do_layout:
|
||||
page_size = 10 ** 9
|
||||
page_ranges = doc["parser_config"].get("pages") or [(1, 10 ** 5)]
|
||||
page_size = 10**9
|
||||
page_ranges = doc["parser_config"].get("pages") or [(1, 10**5)]
|
||||
for s, e in page_ranges:
|
||||
s -= 1
|
||||
s = max(0, s)
|
||||
@ -177,4 +207,6 @@ def queue_tasks(doc: dict, bucket: str, name: str):
|
||||
DocumentService.begin2parse(doc["id"])
|
||||
|
||||
for t in tsks:
|
||||
assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=t), "Can't access Redis. Please check the Redis' status."
|
||||
assert REDIS_CONN.queue_product(
|
||||
SVR_QUEUE_NAME, message=t
|
||||
), "Can't access Redis. Please check the Redis' status."
|
||||
|
@ -22,7 +22,7 @@ from api.db import UserTenantRole
|
||||
from api.db.db_models import DB, UserTenant
|
||||
from api.db.db_models import User, Tenant
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.utils import get_uuid, get_format_time, current_timestamp, datetime_format
|
||||
from api.utils import get_uuid, current_timestamp, datetime_format
|
||||
from api.db import StatusEnum
|
||||
|
||||
|
||||
|
@ -21,10 +21,7 @@
|
||||
import logging
|
||||
import os
|
||||
from api.utils.log_utils import initRootLogger
|
||||
LOG_LEVELS = os.environ.get("LOG_LEVELS", "")
|
||||
initRootLogger("ragflow_server", LOG_LEVELS)
|
||||
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
@ -44,6 +41,9 @@ from api.versions import get_ragflow_version
|
||||
from api.utils import show_configs
|
||||
from rag.settings import print_rag_settings
|
||||
|
||||
LOG_LEVELS = os.environ.get("LOG_LEVELS", "")
|
||||
initRootLogger("ragflow_server", LOG_LEVELS)
|
||||
|
||||
|
||||
def update_progress():
|
||||
while True:
|
||||
|
@ -36,7 +36,6 @@ from werkzeug.http import HTTP_STATUS_CODES
|
||||
from api.db.db_models import APIToken
|
||||
from api import settings
|
||||
|
||||
from api import settings
|
||||
from api.utils import CustomJSONEncoder, get_uuid
|
||||
from api.utils import json_dumps
|
||||
from api.constants import REQUEST_WAIT_SEC, REQUEST_MAX_WAIT_SEC
|
||||
|
@ -45,5 +45,5 @@ try:
|
||||
pool = Pool(processes=1)
|
||||
thread = pool.apply_async(download_nltk_data)
|
||||
binary = thread.get(timeout=60)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
print('\x1b[6;37;41m WARNING \x1b[0m' + "Downloading NLTK data failure.", flush=True)
|
||||
|
@ -18,4 +18,16 @@ from .ppt_parser import RAGFlowPptParser as PptParser
|
||||
from .html_parser import RAGFlowHtmlParser as HtmlParser
|
||||
from .json_parser import RAGFlowJsonParser as JsonParser
|
||||
from .markdown_parser import RAGFlowMarkdownParser as MarkdownParser
|
||||
from .txt_parser import RAGFlowTxtParser as TxtParser
|
||||
from .txt_parser import RAGFlowTxtParser as TxtParser
|
||||
|
||||
__all__ = [
|
||||
"PdfParser",
|
||||
"PlainParser",
|
||||
"DocxParser",
|
||||
"ExcelParser",
|
||||
"PptParser",
|
||||
"HtmlParser",
|
||||
"JsonParser",
|
||||
"MarkdownParser",
|
||||
"TxtParser",
|
||||
]
|
@ -29,7 +29,8 @@ class RAGFlowExcelParser:
|
||||
for sheetname in wb.sheetnames:
|
||||
ws = wb[sheetname]
|
||||
rows = list(ws.rows)
|
||||
if not rows: continue
|
||||
if not rows:
|
||||
continue
|
||||
|
||||
tb_rows_0 = "<tr>"
|
||||
for t in list(rows[0]):
|
||||
@ -40,7 +41,9 @@ class RAGFlowExcelParser:
|
||||
tb = ""
|
||||
tb += f"<table><caption>{sheetname}</caption>"
|
||||
tb += tb_rows_0
|
||||
for r in list(rows[1 + chunk_i * chunk_rows:1 + (chunk_i + 1) * chunk_rows]):
|
||||
for r in list(
|
||||
rows[1 + chunk_i * chunk_rows : 1 + (chunk_i + 1) * chunk_rows]
|
||||
):
|
||||
tb += "<tr>"
|
||||
for i, c in enumerate(r):
|
||||
if c.value is None:
|
||||
@ -62,20 +65,21 @@ class RAGFlowExcelParser:
|
||||
for sheetname in wb.sheetnames:
|
||||
ws = wb[sheetname]
|
||||
rows = list(ws.rows)
|
||||
if not rows:continue
|
||||
if not rows:
|
||||
continue
|
||||
ti = list(rows[0])
|
||||
for r in list(rows[1:]):
|
||||
l = []
|
||||
fields = []
|
||||
for i, c in enumerate(r):
|
||||
if not c.value:
|
||||
continue
|
||||
t = str(ti[i].value) if i < len(ti) else ""
|
||||
t += (":" if t else "") + str(c.value)
|
||||
l.append(t)
|
||||
l = "; ".join(l)
|
||||
fields.append(t)
|
||||
line = "; ".join(fields)
|
||||
if sheetname.lower().find("sheet") < 0:
|
||||
l += " ——" + sheetname
|
||||
res.append(l)
|
||||
line += " ——" + sheetname
|
||||
res.append(line)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
|
@ -36,7 +36,7 @@ class RAGFlowHtmlParser:
|
||||
|
||||
@classmethod
|
||||
def parser_txt(cls, txt):
|
||||
if type(txt) != str:
|
||||
if not isinstance(txt, str):
|
||||
raise TypeError("txt type should be str!")
|
||||
html_doc = readability.Document(txt)
|
||||
title = html_doc.title()
|
||||
|
@ -22,7 +22,7 @@ class RAGFlowJsonParser:
|
||||
txt = binary.decode(encoding, errors="ignore")
|
||||
json_data = json.loads(txt)
|
||||
chunks = self.split_json(json_data, True)
|
||||
sections = [json.dumps(l, ensure_ascii=False) for l in chunks if l]
|
||||
sections = [json.dumps(line, ensure_ascii=False) for line in chunks if line]
|
||||
return sections
|
||||
|
||||
@staticmethod
|
||||
|
@ -752,7 +752,7 @@ class RAGFlowPdfParser:
|
||||
"x1": np.max([b["x1"] for b in bxs]),
|
||||
"bottom": np.max([b["bottom"] for b in bxs]) - ht
|
||||
}
|
||||
louts = [l for l in self.page_layout[pn] if l["type"] == ltype]
|
||||
louts = [layout for layout in self.page_layout[pn] if layout["type"] == ltype]
|
||||
ii = Recognizer.find_overlapped(b, louts, naive=True)
|
||||
if ii is not None:
|
||||
b = louts[ii]
|
||||
@ -763,7 +763,8 @@ class RAGFlowPdfParser:
|
||||
"layoutno", "")))
|
||||
|
||||
left, top, right, bott = b["x0"], b["top"], b["x1"], b["bottom"]
|
||||
if right < left: right = left + 1
|
||||
if right < left:
|
||||
right = left + 1
|
||||
poss.append((pn + self.page_from, left, right, top, bott))
|
||||
return self.page_images[pn] \
|
||||
.crop((left * ZM, top * ZM,
|
||||
@ -845,7 +846,8 @@ class RAGFlowPdfParser:
|
||||
top = bx["top"] - self.page_cum_height[pn[0] - 1]
|
||||
bott = bx["bottom"] - self.page_cum_height[pn[0] - 1]
|
||||
page_images_cnt = len(self.page_images)
|
||||
if pn[-1] - 1 >= page_images_cnt: return ""
|
||||
if pn[-1] - 1 >= page_images_cnt:
|
||||
return ""
|
||||
while bott * ZM > self.page_images[pn[-1] - 1].size[1]:
|
||||
bott -= self.page_images[pn[-1] - 1].size[1] / ZM
|
||||
pn.append(pn[-1] + 1)
|
||||
@ -889,7 +891,6 @@ class RAGFlowPdfParser:
|
||||
nonlocal mh, pw, lines, widths
|
||||
lines.append(line)
|
||||
widths.append(width(line))
|
||||
width_mean = np.mean(widths)
|
||||
mmj = self.proj_match(
|
||||
line["text"]) or line.get(
|
||||
"layout_type",
|
||||
@ -994,7 +995,7 @@ class RAGFlowPdfParser:
|
||||
else:
|
||||
self.is_english = False
|
||||
|
||||
st = timer()
|
||||
# st = timer()
|
||||
for i, img in enumerate(self.page_images_x2):
|
||||
chars = self.page_chars[i] if not self.is_english else []
|
||||
self.mean_height.append(
|
||||
@ -1028,8 +1029,8 @@ class RAGFlowPdfParser:
|
||||
|
||||
self.page_cum_height = np.cumsum(self.page_cum_height)
|
||||
assert len(self.page_cum_height) == len(self.page_images) + 1
|
||||
if len(self.boxes) == 0 and zoomin < 9: self.__images__(fnm, zoomin * 3, page_from,
|
||||
page_to, callback)
|
||||
if len(self.boxes) == 0 and zoomin < 9:
|
||||
self.__images__(fnm, zoomin * 3, page_from, page_to, callback)
|
||||
|
||||
def __call__(self, fnm, need_image=True, zoomin=3, return_html=False):
|
||||
self.__images__(fnm, zoomin)
|
||||
@ -1168,7 +1169,7 @@ class PlainParser(object):
|
||||
if not self.outlines:
|
||||
logging.warning("Miss outlines")
|
||||
|
||||
return [(l, "") for l in lines], []
|
||||
return [(line, "") for line in lines], []
|
||||
|
||||
def crop(self, ck, need_position):
|
||||
raise NotImplementedError
|
||||
|
@ -15,21 +15,42 @@ import datetime
|
||||
|
||||
|
||||
def refactor(cv):
|
||||
for n in ["raw_txt", "parser_name", "inference", "ori_text", "use_time", "time_stat"]:
|
||||
if n in cv and cv[n] is not None: del cv[n]
|
||||
for n in [
|
||||
"raw_txt",
|
||||
"parser_name",
|
||||
"inference",
|
||||
"ori_text",
|
||||
"use_time",
|
||||
"time_stat",
|
||||
]:
|
||||
if n in cv and cv[n] is not None:
|
||||
del cv[n]
|
||||
cv["is_deleted"] = 0
|
||||
if "basic" not in cv: cv["basic"] = {}
|
||||
if cv["basic"].get("photo2"): del cv["basic"]["photo2"]
|
||||
if "basic" not in cv:
|
||||
cv["basic"] = {}
|
||||
if cv["basic"].get("photo2"):
|
||||
del cv["basic"]["photo2"]
|
||||
|
||||
for n in ["education", "work", "certificate", "project", "language", "skill", "training"]:
|
||||
if n not in cv or cv[n] is None: continue
|
||||
if type(cv[n]) == type({}): cv[n] = [v for _, v in cv[n].items()]
|
||||
if type(cv[n]) != type([]):
|
||||
for n in [
|
||||
"education",
|
||||
"work",
|
||||
"certificate",
|
||||
"project",
|
||||
"language",
|
||||
"skill",
|
||||
"training",
|
||||
]:
|
||||
if n not in cv or cv[n] is None:
|
||||
continue
|
||||
if isinstance(cv[n], dict):
|
||||
cv[n] = [v for _, v in cv[n].items()]
|
||||
if not isinstance(cv[n], list):
|
||||
del cv[n]
|
||||
continue
|
||||
vv = []
|
||||
for v in cv[n]:
|
||||
if "external" in v and v["external"] is not None: del v["external"]
|
||||
if "external" in v and v["external"] is not None:
|
||||
del v["external"]
|
||||
vv.append(v)
|
||||
cv[n] = {str(i): vv[i] for i in range(len(vv))}
|
||||
|
||||
@ -42,24 +63,44 @@ def refactor(cv):
|
||||
cv["basic"][t] = cv["basic"][n]
|
||||
del cv["basic"][n]
|
||||
|
||||
work = sorted([v for _, v in cv.get("work", {}).items()], key=lambda x: x.get("start_time", ""))
|
||||
edu = sorted([v for _, v in cv.get("education", {}).items()], key=lambda x: x.get("start_time", ""))
|
||||
work = sorted(
|
||||
[v for _, v in cv.get("work", {}).items()],
|
||||
key=lambda x: x.get("start_time", ""),
|
||||
)
|
||||
edu = sorted(
|
||||
[v for _, v in cv.get("education", {}).items()],
|
||||
key=lambda x: x.get("start_time", ""),
|
||||
)
|
||||
|
||||
if work:
|
||||
cv["basic"]["work_start_time"] = work[0].get("start_time", "")
|
||||
cv["basic"]["management_experience"] = 'Y' if any(
|
||||
[w.get("management_experience", '') == 'Y' for w in work]) else 'N'
|
||||
cv["basic"]["management_experience"] = (
|
||||
"Y"
|
||||
if any([w.get("management_experience", "") == "Y" for w in work])
|
||||
else "N"
|
||||
)
|
||||
cv["basic"]["annual_salary"] = work[-1].get("annual_salary_from", "0")
|
||||
|
||||
for n in ["annual_salary_from", "annual_salary_to", "industry_name", "position_name", "responsibilities",
|
||||
"corporation_type", "scale", "corporation_name"]:
|
||||
for n in [
|
||||
"annual_salary_from",
|
||||
"annual_salary_to",
|
||||
"industry_name",
|
||||
"position_name",
|
||||
"responsibilities",
|
||||
"corporation_type",
|
||||
"scale",
|
||||
"corporation_name",
|
||||
]:
|
||||
cv["basic"][n] = work[-1].get(n, "")
|
||||
|
||||
if edu:
|
||||
for n in ["school_name", "discipline_name"]:
|
||||
if n in edu[-1]: cv["basic"][n] = edu[-1][n]
|
||||
if n in edu[-1]:
|
||||
cv["basic"][n] = edu[-1][n]
|
||||
|
||||
cv["basic"]["updated_at"] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
if "contact" not in cv: cv["contact"] = {}
|
||||
if not cv["contact"].get("name"): cv["contact"]["name"] = cv["basic"].get("name", "")
|
||||
return cv
|
||||
if "contact" not in cv:
|
||||
cv["contact"] = {}
|
||||
if not cv["contact"].get("name"):
|
||||
cv["contact"]["name"] = cv["basic"].get("name", "")
|
||||
return cv
|
||||
|
@ -21,13 +21,18 @@ from . import regions
|
||||
|
||||
|
||||
current_file_path = os.path.dirname(os.path.abspath(__file__))
|
||||
GOODS = pd.read_csv(os.path.join(current_file_path, "res/corp_baike_len.csv"), sep="\t", header=0).fillna(0)
|
||||
GOODS = pd.read_csv(
|
||||
os.path.join(current_file_path, "res/corp_baike_len.csv"), sep="\t", header=0
|
||||
).fillna(0)
|
||||
GOODS["cid"] = GOODS["cid"].astype(str)
|
||||
GOODS = GOODS.set_index(["cid"])
|
||||
CORP_TKS = json.load(open(os.path.join(current_file_path, "res/corp.tks.freq.json"), "r"))
|
||||
CORP_TKS = json.load(
|
||||
open(os.path.join(current_file_path, "res/corp.tks.freq.json"), "r")
|
||||
)
|
||||
GOOD_CORP = json.load(open(os.path.join(current_file_path, "res/good_corp.json"), "r"))
|
||||
CORP_TAG = json.load(open(os.path.join(current_file_path, "res/corp_tag.json"), "r"))
|
||||
|
||||
|
||||
def baike(cid, default_v=0):
|
||||
global GOODS
|
||||
try:
|
||||
@ -39,27 +44,41 @@ def baike(cid, default_v=0):
|
||||
|
||||
def corpNorm(nm, add_region=True):
|
||||
global CORP_TKS
|
||||
if not nm or type(nm)!=type(""):return ""
|
||||
if not nm or isinstance(nm, str):
|
||||
return ""
|
||||
nm = rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(nm)).lower()
|
||||
nm = re.sub(r"&", "&", nm)
|
||||
nm = re.sub(r"[\(\)()\+'\"\t \*\\【】-]+", " ", nm)
|
||||
nm = re.sub(r"([—-]+.*| +co\..*|corp\..*| +inc\..*| +ltd.*)", "", nm, 10000, re.IGNORECASE)
|
||||
nm = re.sub(r"(计算机|技术|(技术|科技|网络)*有限公司|公司|有限|研发中心|中国|总部)$", "", nm, 10000, re.IGNORECASE)
|
||||
if not nm or (len(nm)<5 and not regions.isName(nm[0:2])):return nm
|
||||
nm = re.sub(
|
||||
r"([—-]+.*| +co\..*|corp\..*| +inc\..*| +ltd.*)", "", nm, 10000, re.IGNORECASE
|
||||
)
|
||||
nm = re.sub(
|
||||
r"(计算机|技术|(技术|科技|网络)*有限公司|公司|有限|研发中心|中国|总部)$",
|
||||
"",
|
||||
nm,
|
||||
10000,
|
||||
re.IGNORECASE,
|
||||
)
|
||||
if not nm or (len(nm) < 5 and not regions.isName(nm[0:2])):
|
||||
return nm
|
||||
|
||||
tks = rag_tokenizer.tokenize(nm).split()
|
||||
reg = [t for i,t in enumerate(tks) if regions.isName(t) and (t != "中国" or i > 0)]
|
||||
reg = [t for i, t in enumerate(tks) if regions.isName(t) and (t != "中国" or i > 0)]
|
||||
nm = ""
|
||||
for t in tks:
|
||||
if regions.isName(t) or t in CORP_TKS:continue
|
||||
if re.match(r"[0-9a-zA-Z\\,.]+", t) and re.match(r".*[0-9a-zA-Z\,.]+$", nm):nm += " "
|
||||
if regions.isName(t) or t in CORP_TKS:
|
||||
continue
|
||||
if re.match(r"[0-9a-zA-Z\\,.]+", t) and re.match(r".*[0-9a-zA-Z\,.]+$", nm):
|
||||
nm += " "
|
||||
nm += t
|
||||
|
||||
r = re.search(r"^([^a-z0-9 \(\)&]{2,})[a-z ]{4,}$", nm.strip())
|
||||
if r:nm = r.group(1)
|
||||
if r:
|
||||
nm = r.group(1)
|
||||
r = re.search(r"^([a-z ]{3,})[^a-z0-9 \(\)&]{2,}$", nm.strip())
|
||||
if r:nm = r.group(1)
|
||||
return nm.strip() + (("" if not reg else "(%s)"%reg[0]) if add_region else "")
|
||||
if r:
|
||||
nm = r.group(1)
|
||||
return nm.strip() + (("" if not reg else "(%s)" % reg[0]) if add_region else "")
|
||||
|
||||
|
||||
def rmNoise(n):
|
||||
@ -67,33 +86,40 @@ def rmNoise(n):
|
||||
n = re.sub(r"[,. &()()]+", "", n)
|
||||
return n
|
||||
|
||||
|
||||
GOOD_CORP = set([corpNorm(rmNoise(c), False) for c in GOOD_CORP])
|
||||
for c,v in CORP_TAG.items():
|
||||
for c, v in CORP_TAG.items():
|
||||
cc = corpNorm(rmNoise(c), False)
|
||||
if not cc:
|
||||
logging.debug(c)
|
||||
CORP_TAG = {corpNorm(rmNoise(c), False):v for c,v in CORP_TAG.items()}
|
||||
CORP_TAG = {corpNorm(rmNoise(c), False): v for c, v in CORP_TAG.items()}
|
||||
|
||||
|
||||
def is_good(nm):
|
||||
global GOOD_CORP
|
||||
if nm.find("外派")>=0:return False
|
||||
if nm.find("外派") >= 0:
|
||||
return False
|
||||
nm = rmNoise(nm)
|
||||
nm = corpNorm(nm, False)
|
||||
for n in GOOD_CORP:
|
||||
if re.match(r"[0-9a-zA-Z]+$", n):
|
||||
if n == nm: return True
|
||||
elif nm.find(n)>=0:return True
|
||||
if n == nm:
|
||||
return True
|
||||
elif nm.find(n) >= 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def corp_tag(nm):
|
||||
global CORP_TAG
|
||||
nm = rmNoise(nm)
|
||||
nm = corpNorm(nm, False)
|
||||
for n in CORP_TAG.keys():
|
||||
if re.match(r"[0-9a-zA-Z., ]+$", n):
|
||||
if n == nm: return CORP_TAG[n]
|
||||
elif nm.find(n)>=0:
|
||||
if len(n)<3 and len(nm)/len(n)>=2:continue
|
||||
if n == nm:
|
||||
return CORP_TAG[n]
|
||||
elif nm.find(n) >= 0:
|
||||
if len(n) < 3 and len(nm) / len(n) >= 2:
|
||||
continue
|
||||
return CORP_TAG[n]
|
||||
return []
|
||||
|
||||
|
@ -11,27 +11,31 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
TBL = {"94":"EMBA",
|
||||
"6":"MBA",
|
||||
"95":"MPA",
|
||||
"92":"专升本",
|
||||
"4":"专科",
|
||||
"90":"中专",
|
||||
"91":"中技",
|
||||
"86":"初中",
|
||||
"3":"博士",
|
||||
"10":"博士后",
|
||||
"1":"本科",
|
||||
"2":"硕士",
|
||||
"87":"职高",
|
||||
"89":"高中"
|
||||
TBL = {
|
||||
"94": "EMBA",
|
||||
"6": "MBA",
|
||||
"95": "MPA",
|
||||
"92": "专升本",
|
||||
"4": "专科",
|
||||
"90": "中专",
|
||||
"91": "中技",
|
||||
"86": "初中",
|
||||
"3": "博士",
|
||||
"10": "博士后",
|
||||
"1": "本科",
|
||||
"2": "硕士",
|
||||
"87": "职高",
|
||||
"89": "高中",
|
||||
}
|
||||
|
||||
TBL_ = {v:k for k,v in TBL.items()}
|
||||
TBL_ = {v: k for k, v in TBL.items()}
|
||||
|
||||
|
||||
def get_name(id):
|
||||
return TBL.get(str(id), "")
|
||||
|
||||
|
||||
def get_id(nm):
|
||||
if not nm:return ""
|
||||
if not nm:
|
||||
return ""
|
||||
return TBL_.get(nm.upper().strip(), "")
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -16,8 +16,11 @@ import json
|
||||
import re
|
||||
import copy
|
||||
import pandas as pd
|
||||
|
||||
current_file_path = os.path.dirname(os.path.abspath(__file__))
|
||||
TBL = pd.read_csv(os.path.join(current_file_path, "res/schools.csv"), sep="\t", header=0).fillna("")
|
||||
TBL = pd.read_csv(
|
||||
os.path.join(current_file_path, "res/schools.csv"), sep="\t", header=0
|
||||
).fillna("")
|
||||
TBL["name_en"] = TBL["name_en"].map(lambda x: x.lower().strip())
|
||||
GOOD_SCH = json.load(open(os.path.join(current_file_path, "res/good_sch.json"), "r"))
|
||||
GOOD_SCH = set([re.sub(r"[,. &()()]+", "", c) for c in GOOD_SCH])
|
||||
@ -26,14 +29,15 @@ GOOD_SCH = set([re.sub(r"[,. &()()]+", "", c) for c in GOOD_SCH])
|
||||
def loadRank(fnm):
|
||||
global TBL
|
||||
TBL["rank"] = 1000000
|
||||
with open(fnm, "r", encoding='utf-8') as f:
|
||||
with open(fnm, "r", encoding="utf-8") as f:
|
||||
while True:
|
||||
l = f.readline()
|
||||
if not l:break
|
||||
l = l.strip("\n").split(",")
|
||||
line = f.readline()
|
||||
if not line:
|
||||
break
|
||||
line = line.strip("\n").split(",")
|
||||
try:
|
||||
nm,rk = l[0].strip(),int(l[1])
|
||||
#assert len(TBL[((TBL.name_cn == nm) | (TBL.name_en == nm))]),f"<{nm}>"
|
||||
nm, rk = line[0].strip(), int(line[1])
|
||||
# assert len(TBL[((TBL.name_cn == nm) | (TBL.name_en == nm))]),f"<{nm}>"
|
||||
TBL.loc[((TBL.name_cn == nm) | (TBL.name_en == nm)), "rank"] = rk
|
||||
except Exception:
|
||||
pass
|
||||
@ -44,27 +48,35 @@ loadRank(os.path.join(current_file_path, "res/school.rank.csv"))
|
||||
|
||||
def split(txt):
|
||||
tks = []
|
||||
for t in re.sub(r"[ \t]+", " ",txt).split():
|
||||
if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and \
|
||||
re.match(r"[a-zA-Z]", t) and tks:
|
||||
for t in re.sub(r"[ \t]+", " ", txt).split():
|
||||
if (
|
||||
tks
|
||||
and re.match(r".*[a-zA-Z]$", tks[-1])
|
||||
and re.match(r"[a-zA-Z]", t)
|
||||
and tks
|
||||
):
|
||||
tks[-1] = tks[-1] + " " + t
|
||||
else:tks.append(t)
|
||||
else:
|
||||
tks.append(t)
|
||||
return tks
|
||||
|
||||
|
||||
def select(nm):
|
||||
global TBL
|
||||
if not nm:return
|
||||
if isinstance(nm, list):nm = str(nm[0])
|
||||
if not nm:
|
||||
return
|
||||
if isinstance(nm, list):
|
||||
nm = str(nm[0])
|
||||
nm = split(nm)[0]
|
||||
nm = str(nm).lower().strip()
|
||||
nm = re.sub(r"[((][^()()]+[))]", "", nm.lower())
|
||||
nm = re.sub(r"(^the |[,.&()();;·]+|^(英国|美国|瑞士))", "", nm)
|
||||
nm = re.sub(r"大学.*学院", "大学", nm)
|
||||
tbl = copy.deepcopy(TBL)
|
||||
tbl["hit_alias"] = tbl["alias"].map(lambda x:nm in set(x.split("+")))
|
||||
res = tbl[((tbl.name_cn == nm) | (tbl.name_en == nm) | (tbl.hit_alias == True))]
|
||||
if res.empty:return
|
||||
tbl["hit_alias"] = tbl["alias"].map(lambda x: nm in set(x.split("+")))
|
||||
res = tbl[((tbl.name_cn == nm) | (tbl.name_en == nm) | tbl.hit_alias)]
|
||||
if res.empty:
|
||||
return
|
||||
|
||||
return json.loads(res.to_json(orient="records"))[0]
|
||||
|
||||
@ -74,4 +86,3 @@ def is_good(nm):
|
||||
nm = re.sub(r"[((][^()()]+[))]", "", nm.lower())
|
||||
nm = re.sub(r"[''`‘’“”,. &()();;]+", "", nm)
|
||||
return nm in GOOD_SCH
|
||||
|
||||
|
@ -25,7 +25,8 @@ from xpinyin import Pinyin
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
class TimeoutException(Exception): pass
|
||||
class TimeoutException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
@ -50,8 +51,10 @@ def rmHtmlTag(line):
|
||||
|
||||
|
||||
def highest_degree(dg):
|
||||
if not dg: return ""
|
||||
if type(dg) == type(""): dg = [dg]
|
||||
if not dg:
|
||||
return ""
|
||||
if isinstance(dg, str):
|
||||
dg = [dg]
|
||||
m = {"初中": 0, "高中": 1, "中专": 2, "大专": 3, "专升本": 4, "本科": 5, "硕士": 6, "博士": 7, "博士后": 8}
|
||||
return sorted([(d, m.get(d, -1)) for d in dg], key=lambda x: x[1] * -1)[0][0]
|
||||
|
||||
@ -68,10 +71,12 @@ def forEdu(cv):
|
||||
for ii, n in enumerate(sorted(cv["education_obj"], key=lambda x: x.get("start_time", "3"))):
|
||||
e = {}
|
||||
if n.get("end_time"):
|
||||
if n["end_time"] > edu_end_dt: edu_end_dt = n["end_time"]
|
||||
if n["end_time"] > edu_end_dt:
|
||||
edu_end_dt = n["end_time"]
|
||||
try:
|
||||
dt = n["end_time"]
|
||||
if re.match(r"[0-9]{9,}", dt): dt = turnTm2Dt(dt)
|
||||
if re.match(r"[0-9]{9,}", dt):
|
||||
dt = turnTm2Dt(dt)
|
||||
y, m, d = getYMD(dt)
|
||||
ed_dt.append(str(y))
|
||||
e["end_dt_kwd"] = str(y)
|
||||
@ -80,7 +85,8 @@ def forEdu(cv):
|
||||
if n.get("start_time"):
|
||||
try:
|
||||
dt = n["start_time"]
|
||||
if re.match(r"[0-9]{9,}", dt): dt = turnTm2Dt(dt)
|
||||
if re.match(r"[0-9]{9,}", dt):
|
||||
dt = turnTm2Dt(dt)
|
||||
y, m, d = getYMD(dt)
|
||||
st_dt.append(str(y))
|
||||
e["start_dt_kwd"] = str(y)
|
||||
@ -89,13 +95,20 @@ def forEdu(cv):
|
||||
|
||||
r = schools.select(n.get("school_name", ""))
|
||||
if r:
|
||||
if str(r.get("type", "")) == "1": fea.append("211")
|
||||
if str(r.get("type", "")) == "2": fea.append("211")
|
||||
if str(r.get("is_abroad", "")) == "1": fea.append("留学")
|
||||
if str(r.get("is_double_first", "")) == "1": fea.append("双一流")
|
||||
if str(r.get("is_985", "")) == "1": fea.append("985")
|
||||
if str(r.get("is_world_known", "")) == "1": fea.append("海外知名")
|
||||
if r.get("rank") and cv["school_rank_int"] > r["rank"]: cv["school_rank_int"] = r["rank"]
|
||||
if str(r.get("type", "")) == "1":
|
||||
fea.append("211")
|
||||
if str(r.get("type", "")) == "2":
|
||||
fea.append("211")
|
||||
if str(r.get("is_abroad", "")) == "1":
|
||||
fea.append("留学")
|
||||
if str(r.get("is_double_first", "")) == "1":
|
||||
fea.append("双一流")
|
||||
if str(r.get("is_985", "")) == "1":
|
||||
fea.append("985")
|
||||
if str(r.get("is_world_known", "")) == "1":
|
||||
fea.append("海外知名")
|
||||
if r.get("rank") and cv["school_rank_int"] > r["rank"]:
|
||||
cv["school_rank_int"] = r["rank"]
|
||||
|
||||
if n.get("school_name") and isinstance(n["school_name"], str):
|
||||
sch.append(re.sub(r"(211|985|重点大学|[,&;;-])", "", n["school_name"]))
|
||||
@ -106,22 +119,25 @@ def forEdu(cv):
|
||||
maj.append(n["discipline_name"])
|
||||
e["major_kwd"] = n["discipline_name"]
|
||||
|
||||
if not n.get("degree") and "985" in fea and not first_fea: n["degree"] = "1"
|
||||
if not n.get("degree") and "985" in fea and not first_fea:
|
||||
n["degree"] = "1"
|
||||
|
||||
if n.get("degree"):
|
||||
d = degrees.get_name(n["degree"])
|
||||
if d: e["degree_kwd"] = d
|
||||
if d == "本科" and ("专科" in deg or "专升本" in deg or "中专" in deg or "大专" in deg or re.search(r"(成人|自考|自学考试)",
|
||||
n.get(
|
||||
"school_name",
|
||||
""))): d = "专升本"
|
||||
if d: deg.append(d)
|
||||
if d:
|
||||
e["degree_kwd"] = d
|
||||
if d == "本科" and ("专科" in deg or "专升本" in deg or "中专" in deg or "大专" in deg or re.search(r"(成人|自考|自学考试)", n.get("school_name",""))):
|
||||
d = "专升本"
|
||||
if d:
|
||||
deg.append(d)
|
||||
|
||||
# for first degree
|
||||
if not fdeg and d in ["中专", "专升本", "专科", "本科", "大专"]:
|
||||
fdeg = [d]
|
||||
if n.get("school_name"): fsch = [n["school_name"]]
|
||||
if n.get("discipline_name"): fmaj = [n["discipline_name"]]
|
||||
if n.get("school_name"):
|
||||
fsch = [n["school_name"]]
|
||||
if n.get("discipline_name"):
|
||||
fmaj = [n["discipline_name"]]
|
||||
first_fea = copy.deepcopy(fea)
|
||||
|
||||
edu_nst.append(e)
|
||||
@ -140,16 +156,26 @@ def forEdu(cv):
|
||||
else:
|
||||
cv["sch_rank_kwd"].append("一般学校")
|
||||
|
||||
if edu_nst: cv["edu_nst"] = edu_nst
|
||||
if fea: cv["edu_fea_kwd"] = list(set(fea))
|
||||
if first_fea: cv["edu_first_fea_kwd"] = list(set(first_fea))
|
||||
if maj: cv["major_kwd"] = maj
|
||||
if fsch: cv["first_school_name_kwd"] = fsch
|
||||
if fdeg: cv["first_degree_kwd"] = fdeg
|
||||
if fmaj: cv["first_major_kwd"] = fmaj
|
||||
if st_dt: cv["edu_start_kwd"] = st_dt
|
||||
if ed_dt: cv["edu_end_kwd"] = ed_dt
|
||||
if ed_dt: cv["edu_end_int"] = max([int(t) for t in ed_dt])
|
||||
if edu_nst:
|
||||
cv["edu_nst"] = edu_nst
|
||||
if fea:
|
||||
cv["edu_fea_kwd"] = list(set(fea))
|
||||
if first_fea:
|
||||
cv["edu_first_fea_kwd"] = list(set(first_fea))
|
||||
if maj:
|
||||
cv["major_kwd"] = maj
|
||||
if fsch:
|
||||
cv["first_school_name_kwd"] = fsch
|
||||
if fdeg:
|
||||
cv["first_degree_kwd"] = fdeg
|
||||
if fmaj:
|
||||
cv["first_major_kwd"] = fmaj
|
||||
if st_dt:
|
||||
cv["edu_start_kwd"] = st_dt
|
||||
if ed_dt:
|
||||
cv["edu_end_kwd"] = ed_dt
|
||||
if ed_dt:
|
||||
cv["edu_end_int"] = max([int(t) for t in ed_dt])
|
||||
if deg:
|
||||
if "本科" in deg and "专科" in deg:
|
||||
deg.append("专升本")
|
||||
@ -158,8 +184,10 @@ def forEdu(cv):
|
||||
cv["highest_degree_kwd"] = highest_degree(deg)
|
||||
if edu_end_dt:
|
||||
try:
|
||||
if re.match(r"[0-9]{9,}", edu_end_dt): edu_end_dt = turnTm2Dt(edu_end_dt)
|
||||
if edu_end_dt.strip("\n") == "至今": edu_end_dt = cv.get("updated_at_dt", str(datetime.date.today()))
|
||||
if re.match(r"[0-9]{9,}", edu_end_dt):
|
||||
edu_end_dt = turnTm2Dt(edu_end_dt)
|
||||
if edu_end_dt.strip("\n") == "至今":
|
||||
edu_end_dt = cv.get("updated_at_dt", str(datetime.date.today()))
|
||||
y, m, d = getYMD(edu_end_dt)
|
||||
cv["work_exp_flt"] = min(int(str(datetime.date.today())[0:4]) - int(y), cv.get("work_exp_flt", 1000))
|
||||
except Exception as e:
|
||||
@ -171,7 +199,8 @@ def forEdu(cv):
|
||||
or not cv.get("degree_kwd"):
|
||||
for c in sch:
|
||||
if schools.is_good(c):
|
||||
if "tag_kwd" not in cv: cv["tag_kwd"] = []
|
||||
if "tag_kwd" not in cv:
|
||||
cv["tag_kwd"] = []
|
||||
cv["tag_kwd"].append("好学校")
|
||||
cv["tag_kwd"].append("好学历")
|
||||
break
|
||||
@ -180,28 +209,39 @@ def forEdu(cv):
|
||||
any([d.lower() in ["硕士", "博士", "mba", "博士"] for d in cv.get("degree_kwd", [])])) \
|
||||
or all([d.lower() in ["硕士", "博士", "mba", "博士后"] for d in cv.get("degree_kwd", [])]) \
|
||||
or any([d in ["mba", "emba", "博士后"] for d in cv.get("degree_kwd", [])]):
|
||||
if "tag_kwd" not in cv: cv["tag_kwd"] = []
|
||||
if "好学历" not in cv["tag_kwd"]: cv["tag_kwd"].append("好学历")
|
||||
if "tag_kwd" not in cv:
|
||||
cv["tag_kwd"] = []
|
||||
if "好学历" not in cv["tag_kwd"]:
|
||||
cv["tag_kwd"].append("好学历")
|
||||
|
||||
if cv.get("major_kwd"): cv["major_tks"] = rag_tokenizer.tokenize(" ".join(maj))
|
||||
if cv.get("school_name_kwd"): cv["school_name_tks"] = rag_tokenizer.tokenize(" ".join(sch))
|
||||
if cv.get("first_school_name_kwd"): cv["first_school_name_tks"] = rag_tokenizer.tokenize(" ".join(fsch))
|
||||
if cv.get("first_major_kwd"): cv["first_major_tks"] = rag_tokenizer.tokenize(" ".join(fmaj))
|
||||
if cv.get("major_kwd"):
|
||||
cv["major_tks"] = rag_tokenizer.tokenize(" ".join(maj))
|
||||
if cv.get("school_name_kwd"):
|
||||
cv["school_name_tks"] = rag_tokenizer.tokenize(" ".join(sch))
|
||||
if cv.get("first_school_name_kwd"):
|
||||
cv["first_school_name_tks"] = rag_tokenizer.tokenize(" ".join(fsch))
|
||||
if cv.get("first_major_kwd"):
|
||||
cv["first_major_tks"] = rag_tokenizer.tokenize(" ".join(fmaj))
|
||||
|
||||
return cv
|
||||
|
||||
|
||||
def forProj(cv):
|
||||
if not cv.get("project_obj"): return cv
|
||||
if not cv.get("project_obj"):
|
||||
return cv
|
||||
|
||||
pro_nms, desc = [], []
|
||||
for i, n in enumerate(
|
||||
sorted(cv.get("project_obj", []), key=lambda x: str(x.get("updated_at", "")) if type(x) == type({}) else "",
|
||||
sorted(cv.get("project_obj", []), key=lambda x: str(x.get("updated_at", "")) if isinstance(x, dict) else "",
|
||||
reverse=True)):
|
||||
if n.get("name"): pro_nms.append(n["name"])
|
||||
if n.get("describe"): desc.append(str(n["describe"]))
|
||||
if n.get("responsibilities"): desc.append(str(n["responsibilities"]))
|
||||
if n.get("achivement"): desc.append(str(n["achivement"]))
|
||||
if n.get("name"):
|
||||
pro_nms.append(n["name"])
|
||||
if n.get("describe"):
|
||||
desc.append(str(n["describe"]))
|
||||
if n.get("responsibilities"):
|
||||
desc.append(str(n["responsibilities"]))
|
||||
if n.get("achivement"):
|
||||
desc.append(str(n["achivement"]))
|
||||
|
||||
if pro_nms:
|
||||
# cv["pro_nms_tks"] = rag_tokenizer.tokenize(" ".join(pro_nms))
|
||||
@ -233,15 +273,16 @@ def forWork(cv):
|
||||
work_st_tm = ""
|
||||
corp_tags = []
|
||||
for i, n in enumerate(
|
||||
sorted(cv.get("work_obj", []), key=lambda x: str(x.get("start_time", "")) if type(x) == type({}) else "",
|
||||
sorted(cv.get("work_obj", []), key=lambda x: str(x.get("start_time", "")) if isinstance(x, dict) else "",
|
||||
reverse=True)):
|
||||
if type(n) == type(""):
|
||||
if isinstance(n, str):
|
||||
try:
|
||||
n = json_loads(n)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if n.get("start_time") and (not work_st_tm or n["start_time"] < work_st_tm): work_st_tm = n["start_time"]
|
||||
if n.get("start_time") and (not work_st_tm or n["start_time"] < work_st_tm):
|
||||
work_st_tm = n["start_time"]
|
||||
for c in flds:
|
||||
if not n.get(c) or str(n[c]) == '0':
|
||||
fea[c].append("")
|
||||
@ -262,14 +303,18 @@ def forWork(cv):
|
||||
fea[c].append(rmHtmlTag(str(n[c]).lower()))
|
||||
|
||||
y, m, d = getYMD(n.get("start_time"))
|
||||
if not y or not m: continue
|
||||
if not y or not m:
|
||||
continue
|
||||
st = "%s-%02d-%02d" % (y, int(m), int(d))
|
||||
latest_job_tm = st
|
||||
|
||||
y, m, d = getYMD(n.get("end_time"))
|
||||
if (not y or not m) and i > 0: continue
|
||||
if not y or not m or int(y) > 2022: y, m, d = getYMD(str(n.get("updated_at", "")))
|
||||
if not y or not m: continue
|
||||
if (not y or not m) and i > 0:
|
||||
continue
|
||||
if not y or not m or int(y) > 2022:
|
||||
y, m, d = getYMD(str(n.get("updated_at", "")))
|
||||
if not y or not m:
|
||||
continue
|
||||
ed = "%s-%02d-%02d" % (y, int(m), int(d))
|
||||
|
||||
try:
|
||||
@ -279,22 +324,28 @@ def forWork(cv):
|
||||
|
||||
if n.get("scale"):
|
||||
r = re.search(r"^([0-9]+)", str(n["scale"]))
|
||||
if r: scales.append(int(r.group(1)))
|
||||
if r:
|
||||
scales.append(int(r.group(1)))
|
||||
|
||||
if goodcorp:
|
||||
if "tag_kwd" not in cv: cv["tag_kwd"] = []
|
||||
if "tag_kwd" not in cv:
|
||||
cv["tag_kwd"] = []
|
||||
cv["tag_kwd"].append("好公司")
|
||||
if goodcorp_:
|
||||
if "tag_kwd" not in cv: cv["tag_kwd"] = []
|
||||
if "tag_kwd" not in cv:
|
||||
cv["tag_kwd"] = []
|
||||
cv["tag_kwd"].append("好公司(曾)")
|
||||
|
||||
if corp_tags:
|
||||
if "tag_kwd" not in cv: cv["tag_kwd"] = []
|
||||
if "tag_kwd" not in cv:
|
||||
cv["tag_kwd"] = []
|
||||
cv["tag_kwd"].extend(corp_tags)
|
||||
cv["corp_tag_kwd"] = [c for c in corp_tags if re.match(r"(综合|行业)", c)]
|
||||
|
||||
if latest_job_tm: cv["latest_job_dt"] = latest_job_tm
|
||||
if fea["corporation_id"]: cv["corporation_id"] = fea["corporation_id"]
|
||||
if latest_job_tm:
|
||||
cv["latest_job_dt"] = latest_job_tm
|
||||
if fea["corporation_id"]:
|
||||
cv["corporation_id"] = fea["corporation_id"]
|
||||
|
||||
if fea["position_name"]:
|
||||
cv["position_name_tks"] = rag_tokenizer.tokenize(fea["position_name"][0])
|
||||
@ -317,18 +368,23 @@ def forWork(cv):
|
||||
cv["responsibilities_ltks"] = rag_tokenizer.tokenize(fea["responsibilities"][0])
|
||||
cv["resp_ltks"] = rag_tokenizer.tokenize(" ".join(fea["responsibilities"][1:]))
|
||||
|
||||
if fea["subordinates_count"]: fea["subordinates_count"] = [int(i) for i in fea["subordinates_count"] if
|
||||
if fea["subordinates_count"]:
|
||||
fea["subordinates_count"] = [int(i) for i in fea["subordinates_count"] if
|
||||
re.match(r"[^0-9]+$", str(i))]
|
||||
if fea["subordinates_count"]: cv["max_sub_cnt_int"] = np.max(fea["subordinates_count"])
|
||||
if fea["subordinates_count"]:
|
||||
cv["max_sub_cnt_int"] = np.max(fea["subordinates_count"])
|
||||
|
||||
if type(cv.get("corporation_id")) == type(1): cv["corporation_id"] = [str(cv["corporation_id"])]
|
||||
if not cv.get("corporation_id"): cv["corporation_id"] = []
|
||||
if isinstance(cv.get("corporation_id"), int):
|
||||
cv["corporation_id"] = [str(cv["corporation_id"])]
|
||||
if not cv.get("corporation_id"):
|
||||
cv["corporation_id"] = []
|
||||
for i in cv.get("corporation_id", []):
|
||||
cv["baike_flt"] = max(corporations.baike(i), cv["baike_flt"] if "baike_flt" in cv else 0)
|
||||
|
||||
if work_st_tm:
|
||||
try:
|
||||
if re.match(r"[0-9]{9,}", work_st_tm): work_st_tm = turnTm2Dt(work_st_tm)
|
||||
if re.match(r"[0-9]{9,}", work_st_tm):
|
||||
work_st_tm = turnTm2Dt(work_st_tm)
|
||||
y, m, d = getYMD(work_st_tm)
|
||||
cv["work_exp_flt"] = min(int(str(datetime.date.today())[0:4]) - int(y), cv.get("work_exp_flt", 1000))
|
||||
except Exception as e:
|
||||
@ -339,28 +395,37 @@ def forWork(cv):
|
||||
cv["dua_flt"] = np.mean(duas)
|
||||
cv["cur_dua_int"] = duas[0]
|
||||
cv["job_num_int"] = len(duas)
|
||||
if scales: cv["scale_flt"] = np.max(scales)
|
||||
if scales:
|
||||
cv["scale_flt"] = np.max(scales)
|
||||
return cv
|
||||
|
||||
|
||||
def turnTm2Dt(b):
|
||||
if not b: return
|
||||
if not b:
|
||||
return
|
||||
b = str(b).strip()
|
||||
if re.match(r"[0-9]{10,}", b): b = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(b[:10])))
|
||||
if re.match(r"[0-9]{10,}", b):
|
||||
b = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(b[:10])))
|
||||
return b
|
||||
|
||||
|
||||
def getYMD(b):
|
||||
y, m, d = "", "", "01"
|
||||
if not b: return (y, m, d)
|
||||
if not b:
|
||||
return (y, m, d)
|
||||
b = turnTm2Dt(b)
|
||||
if re.match(r"[0-9]{4}", b): y = int(b[:4])
|
||||
if re.match(r"[0-9]{4}", b):
|
||||
y = int(b[:4])
|
||||
r = re.search(r"[0-9]{4}.?([0-9]{1,2})", b)
|
||||
if r: m = r.group(1)
|
||||
if r:
|
||||
m = r.group(1)
|
||||
r = re.search(r"[0-9]{4}.?[0-9]{,2}.?([0-9]{1,2})", b)
|
||||
if r: d = r.group(1)
|
||||
if not d or int(d) == 0 or int(d) > 31: d = "1"
|
||||
if not m or int(m) > 12 or int(m) < 1: m = "1"
|
||||
if r:
|
||||
d = r.group(1)
|
||||
if not d or int(d) == 0 or int(d) > 31:
|
||||
d = "1"
|
||||
if not m or int(m) > 12 or int(m) < 1:
|
||||
m = "1"
|
||||
return (y, m, d)
|
||||
|
||||
|
||||
@ -369,7 +434,8 @@ def birth(cv):
|
||||
cv["integerity_flt"] *= 0.9
|
||||
return cv
|
||||
y, m, d = getYMD(cv["birth"])
|
||||
if not m or not y: return cv
|
||||
if not m or not y:
|
||||
return cv
|
||||
b = "%s-%02d-%02d" % (y, int(m), int(d))
|
||||
cv["birth_dt"] = b
|
||||
cv["birthday_kwd"] = "%02d%02d" % (int(m), int(d))
|
||||
@ -380,7 +446,8 @@ def birth(cv):
|
||||
|
||||
def parse(cv):
|
||||
for k in cv.keys():
|
||||
if cv[k] == '\\N': cv[k] = ''
|
||||
if cv[k] == '\\N':
|
||||
cv[k] = ''
|
||||
# cv = cv.asDict()
|
||||
tks_fld = ["address", "corporation_name", "discipline_name", "email", "expect_city_names",
|
||||
"expect_industry_name", "expect_position_name", "industry_name", "industry_names", "name",
|
||||
@ -402,9 +469,12 @@ def parse(cv):
|
||||
|
||||
rmkeys = []
|
||||
for k in cv.keys():
|
||||
if cv[k] is None: rmkeys.append(k)
|
||||
if (type(cv[k]) == type([]) or type(cv[k]) == type("")) and len(cv[k]) == 0: rmkeys.append(k)
|
||||
for k in rmkeys: del cv[k]
|
||||
if cv[k] is None:
|
||||
rmkeys.append(k)
|
||||
if (isinstance(cv[k], list) or isinstance(cv[k], str)) and len(cv[k]) == 0:
|
||||
rmkeys.append(k)
|
||||
for k in rmkeys:
|
||||
del cv[k]
|
||||
|
||||
integerity = 0.
|
||||
flds_num = 0.
|
||||
@ -414,7 +484,8 @@ def parse(cv):
|
||||
flds_num += len(flds)
|
||||
for f in flds:
|
||||
v = str(cv.get(f, ""))
|
||||
if len(v) > 0 and v != '0' and v != '[]': integerity += 1
|
||||
if len(v) > 0 and v != '0' and v != '[]':
|
||||
integerity += 1
|
||||
|
||||
hasValues(tks_fld)
|
||||
hasValues(small_tks_fld)
|
||||
@ -433,7 +504,8 @@ def parse(cv):
|
||||
(r"[ ()\(\)人/·0-9-]+", ""),
|
||||
(r".*(元|规模|于|=|北京|上海|至今|中国|工资|州|shanghai|强|餐饮|融资|职).*", "")]:
|
||||
cv["corporation_type"] = re.sub(p, r, cv["corporation_type"], 1000, re.IGNORECASE)
|
||||
if len(cv["corporation_type"]) < 2: del cv["corporation_type"]
|
||||
if len(cv["corporation_type"]) < 2:
|
||||
del cv["corporation_type"]
|
||||
|
||||
if cv.get("political_status"):
|
||||
for p, r in [
|
||||
@ -441,9 +513,11 @@ def parse(cv):
|
||||
(r".*(无党派|公民).*", "群众"),
|
||||
(r".*团员.*", "团员")]:
|
||||
cv["political_status"] = re.sub(p, r, cv["political_status"])
|
||||
if not re.search(r"[党团群]", cv["political_status"]): del cv["political_status"]
|
||||
if not re.search(r"[党团群]", cv["political_status"]):
|
||||
del cv["political_status"]
|
||||
|
||||
if cv.get("phone"): cv["phone"] = re.sub(r"^0*86([0-9]{11})", r"\1", re.sub(r"[^0-9]+", "", cv["phone"]))
|
||||
if cv.get("phone"):
|
||||
cv["phone"] = re.sub(r"^0*86([0-9]{11})", r"\1", re.sub(r"[^0-9]+", "", cv["phone"]))
|
||||
|
||||
keys = list(cv.keys())
|
||||
for k in keys:
|
||||
@ -454,9 +528,11 @@ def parse(cv):
|
||||
cv[k] = [a for _, a in cv[k].items()]
|
||||
nms = []
|
||||
for n in cv[k]:
|
||||
if type(n) != type({}) or "name" not in n or not n.get("name"): continue
|
||||
if not isinstance(n, dict) or "name" not in n or not n.get("name"):
|
||||
continue
|
||||
n["name"] = re.sub(r"((442)|\t )", "", n["name"]).strip().lower()
|
||||
if not n["name"]: continue
|
||||
if not n["name"]:
|
||||
continue
|
||||
nms.append(n["name"])
|
||||
if nms:
|
||||
t = k[:-4]
|
||||
@ -469,15 +545,18 @@ def parse(cv):
|
||||
# tokenize fields
|
||||
if k in tks_fld:
|
||||
cv[f"{k}_tks"] = rag_tokenizer.tokenize(cv[k])
|
||||
if k in small_tks_fld: cv[f"{k}_sm_tks"] = rag_tokenizer.tokenize(cv[f"{k}_tks"])
|
||||
if k in small_tks_fld:
|
||||
cv[f"{k}_sm_tks"] = rag_tokenizer.tokenize(cv[f"{k}_tks"])
|
||||
|
||||
# keyword fields
|
||||
if k in kwd_fld: cv[f"{k}_kwd"] = [n.lower()
|
||||
if k in kwd_fld:
|
||||
cv[f"{k}_kwd"] = [n.lower()
|
||||
for n in re.split(r"[\t,,;;. ]",
|
||||
re.sub(r"([^a-zA-Z])[ ]+([^a-zA-Z ])", r"\1,\2", cv[k])
|
||||
) if n]
|
||||
|
||||
if k in num_fld and cv.get(k): cv[f"{k}_int"] = cv[k]
|
||||
if k in num_fld and cv.get(k):
|
||||
cv[f"{k}_int"] = cv[k]
|
||||
|
||||
cv["email_kwd"] = cv.get("email_tks", "").replace(" ", "")
|
||||
# for name field
|
||||
@ -501,10 +580,12 @@ def parse(cv):
|
||||
cv["name_py_pref0_tks"] = ""
|
||||
cv["name_py_pref_tks"] = ""
|
||||
for py in PY.get_pinyins(nm[:20], ''):
|
||||
for i in range(2, len(py) + 1): cv["name_py_pref_tks"] += " " + py[:i]
|
||||
for i in range(2, len(py) + 1):
|
||||
cv["name_py_pref_tks"] += " " + py[:i]
|
||||
for py in PY.get_pinyins(nm[:20], ' '):
|
||||
py = py.split()
|
||||
for i in range(1, len(py) + 1): cv["name_py_pref0_tks"] += " " + "".join(py[:i])
|
||||
for i in range(1, len(py) + 1):
|
||||
cv["name_py_pref0_tks"] += " " + "".join(py[:i])
|
||||
|
||||
cv["name_kwd"] = name
|
||||
cv["name_pinyin_kwd"] = PY.get_pinyins(nm[:20], ' ')[:3]
|
||||
@ -526,22 +607,30 @@ def parse(cv):
|
||||
cv["updated_at_dt"] = cv["updated_at"].strftime('%Y-%m-%d %H:%M:%S')
|
||||
else:
|
||||
y, m, d = getYMD(str(cv.get("updated_at", "")))
|
||||
if not y: y = "2012"
|
||||
if not m: m = "01"
|
||||
if not d: d = "01"
|
||||
if not y:
|
||||
y = "2012"
|
||||
if not m:
|
||||
m = "01"
|
||||
if not d:
|
||||
d = "01"
|
||||
cv["updated_at_dt"] = "%s-%02d-%02d 00:00:00" % (y, int(m), int(d))
|
||||
# long text tokenize
|
||||
|
||||
if cv.get("responsibilities"): cv["responsibilities_ltks"] = rag_tokenizer.tokenize(rmHtmlTag(cv["responsibilities"]))
|
||||
if cv.get("responsibilities"):
|
||||
cv["responsibilities_ltks"] = rag_tokenizer.tokenize(rmHtmlTag(cv["responsibilities"]))
|
||||
|
||||
# for yes or no field
|
||||
fea = []
|
||||
for f, y, n in is_fld:
|
||||
if f not in cv: continue
|
||||
if cv[f] == '是': fea.append(y)
|
||||
if cv[f] == '否': fea.append(n)
|
||||
if f not in cv:
|
||||
continue
|
||||
if cv[f] == '是':
|
||||
fea.append(y)
|
||||
if cv[f] == '否':
|
||||
fea.append(n)
|
||||
|
||||
if fea: cv["tag_kwd"] = fea
|
||||
if fea:
|
||||
cv["tag_kwd"] = fea
|
||||
|
||||
cv = forEdu(cv)
|
||||
cv = forProj(cv)
|
||||
@ -550,9 +639,11 @@ def parse(cv):
|
||||
|
||||
cv["corp_proj_sch_deg_kwd"] = [c for c in cv.get("corp_tag_kwd", [])]
|
||||
for i in range(len(cv["corp_proj_sch_deg_kwd"])):
|
||||
for j in cv.get("sch_rank_kwd", []): cv["corp_proj_sch_deg_kwd"][i] += "+" + j
|
||||
for j in cv.get("sch_rank_kwd", []):
|
||||
cv["corp_proj_sch_deg_kwd"][i] += "+" + j
|
||||
for i in range(len(cv["corp_proj_sch_deg_kwd"])):
|
||||
if cv.get("highest_degree_kwd"): cv["corp_proj_sch_deg_kwd"][i] += "+" + cv["highest_degree_kwd"]
|
||||
if cv.get("highest_degree_kwd"):
|
||||
cv["corp_proj_sch_deg_kwd"][i] += "+" + cv["highest_degree_kwd"]
|
||||
|
||||
try:
|
||||
if not cv.get("work_exp_flt") and cv.get("work_start_time"):
|
||||
@ -565,17 +656,21 @@ def parse(cv):
|
||||
cv["work_exp_flt"] = int(str(datetime.date.today())[0:4]) - int(y)
|
||||
except Exception as e:
|
||||
logging.exception("parse {} ==> {}".format(e, cv.get("work_start_time")))
|
||||
if "work_exp_flt" not in cv and cv.get("work_experience", 0): cv["work_exp_flt"] = int(cv["work_experience"]) / 12.
|
||||
if "work_exp_flt" not in cv and cv.get("work_experience", 0):
|
||||
cv["work_exp_flt"] = int(cv["work_experience"]) / 12.
|
||||
|
||||
keys = list(cv.keys())
|
||||
for k in keys:
|
||||
if not re.search(r"_(fea|tks|nst|dt|int|flt|ltks|kwd|id)$", k): del cv[k]
|
||||
if not re.search(r"_(fea|tks|nst|dt|int|flt|ltks|kwd|id)$", k):
|
||||
del cv[k]
|
||||
for k in cv.keys():
|
||||
if not re.search("_(kwd|id)$", k) or type(cv[k]) != type([]): continue
|
||||
if not re.search("_(kwd|id)$", k) or not isinstance(cv[k], list):
|
||||
continue
|
||||
cv[k] = list(set([re.sub("(市)$", "", str(n)) for n in cv[k] if n not in ['中国', '0']]))
|
||||
keys = [k for k in cv.keys() if re.search(r"_feas*$", k)]
|
||||
for k in keys:
|
||||
if cv[k] <= 0: del cv[k]
|
||||
if cv[k] <= 0:
|
||||
del cv[k]
|
||||
|
||||
cv["tob_resume_id"] = str(cv["tob_resume_id"])
|
||||
cv["id"] = cv["tob_resume_id"]
|
||||
@ -592,5 +687,6 @@ def dealWithInt64(d):
|
||||
if isinstance(d, list):
|
||||
d = [dealWithInt64(t) for t in d]
|
||||
|
||||
if isinstance(d, np.integer): d = int(d)
|
||||
if isinstance(d, np.integer):
|
||||
d = int(d)
|
||||
return d
|
||||
|
@ -51,6 +51,7 @@ class RAGFlowTxtParser:
|
||||
dels = [d for d in dels if d]
|
||||
dels = "|".join(dels)
|
||||
secs = re.split(r"(%s)" % dels, txt)
|
||||
for sec in secs: add_chunk(sec)
|
||||
for sec in secs:
|
||||
add_chunk(sec)
|
||||
|
||||
return [[c, ""] for c in cks]
|
||||
|
@ -18,7 +18,6 @@ from .recognizer import Recognizer
|
||||
from .layout_recognizer import LayoutRecognizer
|
||||
from .table_structure_recognizer import TableStructureRecognizer
|
||||
|
||||
|
||||
def init_in_out(args):
|
||||
from PIL import Image
|
||||
import os
|
||||
@ -47,7 +46,7 @@ def init_in_out(args):
|
||||
try:
|
||||
images.append(Image.open(fnm))
|
||||
outputs.append(os.path.split(fnm)[-1])
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
if os.path.isdir(args.inputs):
|
||||
@ -56,6 +55,16 @@ def init_in_out(args):
|
||||
else:
|
||||
images_and_outputs(args.inputs)
|
||||
|
||||
for i in range(len(outputs)): outputs[i] = os.path.join(args.output_dir, outputs[i])
|
||||
for i in range(len(outputs)):
|
||||
outputs[i] = os.path.join(args.output_dir, outputs[i])
|
||||
|
||||
return images, outputs
|
||||
return images, outputs
|
||||
|
||||
|
||||
__all__ = [
|
||||
"OCR",
|
||||
"Recognizer",
|
||||
"LayoutRecognizer",
|
||||
"TableStructureRecognizer",
|
||||
"init_in_out",
|
||||
]
|
||||
|
@ -42,7 +42,7 @@ class LayoutRecognizer(Recognizer):
|
||||
get_project_base_directory(),
|
||||
"rag/res/deepdoc")
|
||||
super().__init__(self.labels, domain, model_dir)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
|
||||
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
||||
local_dir_use_symlinks=False)
|
||||
@ -77,7 +77,7 @@ class LayoutRecognizer(Recognizer):
|
||||
"page_number": pn,
|
||||
} for b in lts if float(b["score"]) >= 0.8 or b["type"] not in self.garbage_layouts]
|
||||
lts = self.sort_Y_firstly(lts, np.mean(
|
||||
[l["bottom"] - l["top"] for l in lts]) / 2)
|
||||
[lt["bottom"] - lt["top"] for lt in lts]) / 2)
|
||||
lts = self.layouts_cleanup(bxs, lts)
|
||||
page_layout.append(lts)
|
||||
|
||||
|
@ -19,7 +19,9 @@ from huggingface_hub import snapshot_download
|
||||
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from .operators import *
|
||||
import math
|
||||
import numpy as np
|
||||
import cv2
|
||||
import onnxruntime as ort
|
||||
|
||||
from .postprocess import build_post_process
|
||||
@ -484,7 +486,7 @@ class OCR(object):
|
||||
"rag/res/deepdoc")
|
||||
self.text_detector = TextDetector(model_dir)
|
||||
self.text_recognizer = TextRecognizer(model_dir)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
|
||||
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
||||
local_dir_use_symlinks=False)
|
||||
|
@ -232,7 +232,7 @@ class LinearResize(object):
|
||||
"""
|
||||
assert len(self.target_size) == 2
|
||||
assert self.target_size[0] > 0 and self.target_size[1] > 0
|
||||
im_channel = im.shape[2]
|
||||
_im_channel = im.shape[2]
|
||||
im_scale_y, im_scale_x = self.generate_scale(im)
|
||||
im = cv2.resize(
|
||||
im,
|
||||
@ -255,7 +255,7 @@ class LinearResize(object):
|
||||
im_scale_y: the resize ratio of Y
|
||||
"""
|
||||
origin_shape = im.shape[:2]
|
||||
im_c = im.shape[2]
|
||||
_im_c = im.shape[2]
|
||||
if self.keep_ratio:
|
||||
im_size_min = np.min(origin_shape)
|
||||
im_size_max = np.max(origin_shape)
|
||||
@ -581,7 +581,7 @@ class SRResize(object):
|
||||
return data
|
||||
|
||||
images_HR = data["image_hr"]
|
||||
label_strs = data["label"]
|
||||
_label_strs = data["label"]
|
||||
transform = ResizeNormalize((imgW, imgH))
|
||||
images_HR = transform(images_HR)
|
||||
data["img_hr"] = images_HR
|
||||
|
@ -121,7 +121,7 @@ class DBPostProcess(object):
|
||||
outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
|
||||
cv2.CHAIN_APPROX_SIMPLE)
|
||||
if len(outs) == 3:
|
||||
img, contours, _ = outs[0], outs[1], outs[2]
|
||||
_img, contours, _ = outs[0], outs[1], outs[2]
|
||||
elif len(outs) == 2:
|
||||
contours, _ = outs[0], outs[1]
|
||||
|
||||
|
@ -13,15 +13,18 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
import cv2
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
import onnxruntime as ort
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from .operators import *
|
||||
|
||||
|
||||
class Recognizer(object):
|
||||
def __init__(self, label_list, task_name, model_dir=None):
|
||||
"""
|
||||
@ -277,7 +280,8 @@ class Recognizer(object):
|
||||
return
|
||||
min_dis, min_i = 1000000, None
|
||||
for i,b in enumerate(boxes):
|
||||
if box.get("layoutno", "0") != b.get("layoutno", "0"): continue
|
||||
if box.get("layoutno", "0") != b.get("layoutno", "0"):
|
||||
continue
|
||||
dis = min(abs(box["x0"] - b["x0"]), abs(box["x1"] - b["x1"]), abs(box["x0"]+box["x1"] - b["x1"] - b["x0"])/2)
|
||||
if dis < min_dis:
|
||||
min_i = i
|
||||
@ -402,7 +406,8 @@ class Recognizer(object):
|
||||
scores = np.max(boxes[:, 4:], axis=1)
|
||||
boxes = boxes[scores > thr, :]
|
||||
scores = scores[scores > thr]
|
||||
if len(boxes) == 0: return []
|
||||
if len(boxes) == 0:
|
||||
return []
|
||||
|
||||
# Get the class with the highest confidence
|
||||
class_ids = np.argmax(boxes[:, 4:], axis=1)
|
||||
@ -432,7 +437,8 @@ class Recognizer(object):
|
||||
for i in range(len(image_list)):
|
||||
if not isinstance(image_list[i], np.ndarray):
|
||||
imgs.append(np.array(image_list[i]))
|
||||
else: imgs.append(image_list[i])
|
||||
else:
|
||||
imgs.append(image_list[i])
|
||||
|
||||
batch_loop_cnt = math.ceil(float(len(imgs)) / batch_size)
|
||||
for i in range(batch_loop_cnt):
|
||||
|
@ -88,7 +88,8 @@ class CommunityReportsExtractor:
|
||||
("findings", list),
|
||||
("rating", float),
|
||||
("rating_explanation", str),
|
||||
]): continue
|
||||
]):
|
||||
continue
|
||||
response["weight"] = weight
|
||||
response["entities"] = ents
|
||||
except Exception as e:
|
||||
@ -100,7 +101,8 @@ class CommunityReportsExtractor:
|
||||
res_str.append(self._get_text_output(response))
|
||||
res_dict.append(response)
|
||||
over += 1
|
||||
if callback: callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}")
|
||||
if callback:
|
||||
callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}")
|
||||
|
||||
return CommunityReportsResult(
|
||||
structured_output=res_dict,
|
||||
|
@ -8,6 +8,7 @@ Reference:
|
||||
from typing import Any
|
||||
import numpy as np
|
||||
import networkx as nx
|
||||
from dataclasses import dataclass
|
||||
from graphrag.leiden import stable_largest_connected_component
|
||||
|
||||
|
||||
|
@ -129,9 +129,11 @@ class GraphExtractor:
|
||||
source_doc_map[doc_index] = text
|
||||
all_records[doc_index] = result
|
||||
total_token_count += token_count
|
||||
if callback: callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}")
|
||||
if callback:
|
||||
callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}")
|
||||
except Exception as e:
|
||||
if callback: callback(msg="Knowledge graph extraction error:{}".format(str(e)))
|
||||
if callback:
|
||||
callback(msg="Knowledge graph extraction error:{}".format(str(e)))
|
||||
logging.exception("error extracting graph")
|
||||
self._on_error(
|
||||
e,
|
||||
@ -164,7 +166,8 @@ class GraphExtractor:
|
||||
text = perform_variable_replacements(self._extraction_prompt, variables=variables)
|
||||
gen_conf = {"temperature": 0.3}
|
||||
response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
|
||||
if response.find("**ERROR**") >= 0: raise Exception(response)
|
||||
if response.find("**ERROR**") >= 0:
|
||||
raise Exception(response)
|
||||
token_count = num_tokens_from_string(text + response)
|
||||
|
||||
results = response or ""
|
||||
@ -175,7 +178,8 @@ class GraphExtractor:
|
||||
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
|
||||
history.append({"role": "user", "content": text})
|
||||
response = self._llm.chat("", history, gen_conf)
|
||||
if response.find("**ERROR**") >=0: raise Exception(response)
|
||||
if response.find("**ERROR**") >=0:
|
||||
raise Exception(response)
|
||||
results += response or ""
|
||||
|
||||
# if this is the final glean, don't bother updating the continuation flag
|
||||
|
@ -134,7 +134,8 @@ def build_knowledge_graph_chunks(tenant_id: str, chunks: list[str], callback, en
|
||||
callback(0.75, "Extracting mind graph.")
|
||||
mindmap = MindMapExtractor(llm_bdl)
|
||||
mg = mindmap(_chunks).output
|
||||
if not len(mg.keys()): return chunks
|
||||
if not len(mg.keys()):
|
||||
return chunks
|
||||
|
||||
logging.debug(json.dumps(mg, ensure_ascii=False, indent=2))
|
||||
chunks.append(
|
||||
|
@ -78,7 +78,8 @@ def _compute_leiden_communities(
|
||||
) -> dict[int, dict[str, int]]:
|
||||
"""Return Leiden root communities."""
|
||||
results: dict[int, dict[str, int]] = {}
|
||||
if is_empty(graph): return results
|
||||
if is_empty(graph):
|
||||
return results
|
||||
if use_lcc:
|
||||
graph = stable_largest_connected_component(graph)
|
||||
|
||||
@ -100,7 +101,8 @@ def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]:
|
||||
logging.debug(
|
||||
"Running leiden with max_cluster_size=%s, lcc=%s", max_cluster_size, use_lcc
|
||||
)
|
||||
if not graph.nodes(): return {}
|
||||
if not graph.nodes():
|
||||
return {}
|
||||
|
||||
node_id_to_community_map = _compute_leiden_communities(
|
||||
graph=graph,
|
||||
@ -125,9 +127,11 @@ def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]:
|
||||
result[community_id]["nodes"].append(node_id)
|
||||
result[community_id]["weight"] += graph.nodes[node_id].get("rank", 0) * graph.nodes[node_id].get("weight", 1)
|
||||
weights = [comm["weight"] for _, comm in result.items()]
|
||||
if not weights:continue
|
||||
if not weights:
|
||||
continue
|
||||
max_weight = max(weights)
|
||||
for _, comm in result.items(): comm["weight"] /= max_weight
|
||||
for _, comm in result.items():
|
||||
comm["weight"] /= max_weight
|
||||
|
||||
return results_by_level
|
||||
|
||||
|
@ -1 +1,5 @@
|
||||
from .ragflow_chat import *
|
||||
from .ragflow_chat import RAGFlowChat
|
||||
|
||||
__all__ = [
|
||||
"RAGFlowChat"
|
||||
]
|
||||
|
@ -2,7 +2,6 @@ import logging
|
||||
import requests
|
||||
from bridge.context import ContextType # Import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType # Import Reply, ReplyType
|
||||
from bridge import *
|
||||
from plugins import Plugin, register # Import Plugin and register
|
||||
from plugins.event import Event, EventContext, EventAction # Import event-related classes
|
||||
|
||||
|
@ -94,7 +94,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
callback(0.1, "Start to parse.")
|
||||
txt = get_text(filename, binary)
|
||||
sections = txt.split("\n")
|
||||
sections = [(l, "") for l in sections if l]
|
||||
sections = [(line, "") for line in sections if line]
|
||||
remove_contents_table(sections, eng=is_english(
|
||||
random_choices([t for t, _ in sections], k=200)))
|
||||
callback(0.8, "Finish parsing.")
|
||||
@ -102,7 +102,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
sections = HtmlParser()(filename, binary)
|
||||
sections = [(l, "") for l in sections if l]
|
||||
sections = [(line, "") for line in sections if line]
|
||||
remove_contents_table(sections, eng=is_english(
|
||||
random_choices([t for t, _ in sections], k=200)))
|
||||
callback(0.8, "Finish parsing.")
|
||||
@ -112,7 +112,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
binary = BytesIO(binary)
|
||||
doc_parsed = parser.from_buffer(binary)
|
||||
sections = doc_parsed['content'].split('\n')
|
||||
sections = [(l, "") for l in sections if l]
|
||||
sections = [(line, "") for line in sections if line]
|
||||
remove_contents_table(sections, eng=is_english(
|
||||
random_choices([t for t, _ in sections], k=200)))
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
@ -75,7 +75,7 @@ def chunk(
|
||||
_add_content(msg, msg.get_content_type())
|
||||
|
||||
sections = TxtParser.parser_txt("\n".join(text_txt)) + [
|
||||
(l, "") for l in HtmlParser.parser_txt("\n".join(html_txt)) if l
|
||||
(line, "") for line in HtmlParser.parser_txt("\n".join(html_txt)) if line
|
||||
]
|
||||
|
||||
st = timer()
|
||||
|
@ -18,7 +18,8 @@ def chunk(filename, binary, tenant_id, from_page=0, to_page=100000,
|
||||
chunks = build_knowledge_graph_chunks(tenant_id, sections, callback,
|
||||
parser_config.get("entity_types", ["organization", "person", "location", "event", "time"])
|
||||
)
|
||||
for c in chunks: c["docnm_kwd"] = filename
|
||||
for c in chunks:
|
||||
c["docnm_kwd"] = filename
|
||||
|
||||
doc = {
|
||||
"docnm_kwd": filename,
|
||||
|
@ -48,7 +48,7 @@ class Docx(DocxParser):
|
||||
continue
|
||||
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
|
||||
pn += 1
|
||||
return [l for l in lines if l]
|
||||
return [line for line in lines if line]
|
||||
|
||||
def __call__(self, filename, binary=None, from_page=0, to_page=100000):
|
||||
self.doc = Document(
|
||||
@ -60,7 +60,8 @@ class Docx(DocxParser):
|
||||
if pn > to_page:
|
||||
break
|
||||
question_level, p_text = docx_question_level(p, bull)
|
||||
if not p_text.strip("\n"):continue
|
||||
if not p_text.strip("\n"):
|
||||
continue
|
||||
lines.append((question_level, p_text))
|
||||
|
||||
for run in p.runs:
|
||||
@ -78,19 +79,21 @@ class Docx(DocxParser):
|
||||
if lines[e][0] <= lines[s][0]:
|
||||
break
|
||||
e += 1
|
||||
if e - s == 1 and visit[s]: continue
|
||||
if e - s == 1 and visit[s]:
|
||||
continue
|
||||
sec = []
|
||||
next_level = lines[s][0] + 1
|
||||
while not sec and next_level < 22:
|
||||
for i in range(s+1, e):
|
||||
if lines[i][0] != next_level: continue
|
||||
if lines[i][0] != next_level:
|
||||
continue
|
||||
sec.append(lines[i][1])
|
||||
visit[i] = True
|
||||
next_level += 1
|
||||
sec.insert(0, lines[s][1])
|
||||
|
||||
sections.append("\n".join(sec))
|
||||
return [l for l in sections if l]
|
||||
return [s for s in sections if s]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'''
|
||||
@ -168,13 +171,13 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
callback(0.1, "Start to parse.")
|
||||
txt = get_text(filename, binary)
|
||||
sections = txt.split("\n")
|
||||
sections = [l for l in sections if l]
|
||||
sections = [s for s in sections if s]
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
sections = HtmlParser()(filename, binary)
|
||||
sections = [l for l in sections if l]
|
||||
sections = [s for s in sections if s]
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.doc$", filename, re.IGNORECASE):
|
||||
@ -182,7 +185,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
binary = BytesIO(binary)
|
||||
doc_parsed = parser.from_buffer(binary)
|
||||
sections = doc_parsed['content'].split('\n')
|
||||
sections = [l for l in sections if l]
|
||||
sections = [s for s in sections if s]
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
else:
|
||||
|
@ -190,7 +190,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
sections, tbls = pdf_parser(filename if not binary else binary,
|
||||
from_page=from_page, to_page=to_page, callback=callback)
|
||||
if sections and len(sections[0]) < 3:
|
||||
sections = [(t, l, [[0] * 5]) for t, l in sections]
|
||||
sections = [(t, lvl, [[0] * 5]) for t, lvl in sections]
|
||||
# set pivot using the most frequent type of title,
|
||||
# then merge between 2 pivot
|
||||
if len(sections) > 0 and len(pdf_parser.outlines) / len(sections) > 0.1:
|
||||
@ -211,7 +211,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
else:
|
||||
bull = bullets_category([txt for txt, _, _ in sections])
|
||||
most_level, levels = title_frequency(
|
||||
bull, [(txt, l) for txt, l, poss in sections])
|
||||
bull, [(txt, lvl) for txt, lvl, _ in sections])
|
||||
|
||||
assert len(sections) == len(levels)
|
||||
sec_ids = []
|
||||
@ -225,7 +225,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
sections = [(txt, sec_ids[i], poss)
|
||||
for i, (txt, _, poss) in enumerate(sections)]
|
||||
for (img, rows), poss in tbls:
|
||||
if not rows: continue
|
||||
if not rows:
|
||||
continue
|
||||
sections.append((rows if isinstance(rows, str) else rows[0], -1,
|
||||
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
|
||||
|
||||
|
@ -54,7 +54,8 @@ class Pdf(PdfParser):
|
||||
sections = [(b["text"], self.get_position(b, zoomin))
|
||||
for i, b in enumerate(self.boxes)]
|
||||
for (img, rows), poss in tbls:
|
||||
if not rows:continue
|
||||
if not rows:
|
||||
continue
|
||||
sections.append((rows if isinstance(rows, str) else rows[0],
|
||||
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
|
||||
return [(txt, "") for txt, _ in sorted(sections, key=lambda x: (
|
||||
@ -109,7 +110,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
binary = BytesIO(binary)
|
||||
doc_parsed = parser.from_buffer(binary)
|
||||
sections = doc_parsed['content'].split('\n')
|
||||
sections = [l for l in sections if l]
|
||||
sections = [s for s in sections if s]
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
else:
|
||||
|
@ -171,7 +171,7 @@ class Pdf(PdfParser):
|
||||
tbl_bottom = tbls[tbl_index][1][0][4]
|
||||
tbl_tag = "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##" \
|
||||
.format(tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom)
|
||||
tbl_text = ''.join(tbls[tbl_index][0][1])
|
||||
_tbl_text = ''.join(tbls[tbl_index][0][1])
|
||||
return tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag,
|
||||
|
||||
|
||||
@ -325,9 +325,11 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
||||
txt = get_text(filename, binary)
|
||||
lines = txt.split("\n")
|
||||
comma, tab = 0, 0
|
||||
for l in lines:
|
||||
if len(l.split(",")) == 2: comma += 1
|
||||
if len(l.split("\t")) == 2: tab += 1
|
||||
for line in lines:
|
||||
if len(line.split(",")) == 2:
|
||||
comma += 1
|
||||
if len(line.split("\t")) == 2:
|
||||
tab += 1
|
||||
delimiter = "\t" if tab >= comma else ","
|
||||
|
||||
fails = []
|
||||
@ -336,18 +338,21 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
||||
while i < len(lines):
|
||||
arr = lines[i].split(delimiter)
|
||||
if len(arr) != 2:
|
||||
if question: answer += "\n" + lines[i]
|
||||
if question:
|
||||
answer += "\n" + lines[i]
|
||||
else:
|
||||
fails.append(str(i+1))
|
||||
elif len(arr) == 2:
|
||||
if question and answer: res.append(beAdoc(deepcopy(doc), question, answer, eng))
|
||||
if question and answer:
|
||||
res.append(beAdoc(deepcopy(doc), question, answer, eng))
|
||||
question, answer = arr
|
||||
i += 1
|
||||
if len(res) % 999 == 0:
|
||||
callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + (
|
||||
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||
|
||||
if question: res.append(beAdoc(deepcopy(doc), question, answer, eng))
|
||||
if question:
|
||||
res.append(beAdoc(deepcopy(doc), question, answer, eng))
|
||||
|
||||
callback(0.6, ("Extract Q&A: {}".format(len(res)) + (
|
||||
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||
@ -367,19 +372,18 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
||||
callback(0.1, "Start to parse.")
|
||||
txt = get_text(filename, binary)
|
||||
lines = txt.split("\n")
|
||||
last_question, last_answer = "", ""
|
||||
_last_question, last_answer = "", ""
|
||||
question_stack, level_stack = [], []
|
||||
code_block = False
|
||||
level_index = [-1] * 7
|
||||
for index, l in enumerate(lines):
|
||||
if l.strip().startswith('```'):
|
||||
for index, line in enumerate(lines):
|
||||
if line.strip().startswith('```'):
|
||||
code_block = not code_block
|
||||
question_level, question = 0, ''
|
||||
if not code_block:
|
||||
question_level, question = mdQuestionLevel(l)
|
||||
question_level, question = mdQuestionLevel(line)
|
||||
|
||||
if not question_level or question_level > 6: # not a question
|
||||
last_answer = f'{last_answer}\n{l}'
|
||||
last_answer = f'{last_answer}\n{line}'
|
||||
else: # is a question
|
||||
if last_answer.strip():
|
||||
sum_question = '\n'.join(question_stack)
|
||||
|
@ -41,14 +41,16 @@ class Excel(ExcelParser):
|
||||
for sheetname in wb.sheetnames:
|
||||
ws = wb[sheetname]
|
||||
rows = list(ws.rows)
|
||||
if not rows:continue
|
||||
if not rows:
|
||||
continue
|
||||
headers = [cell.value for cell in rows[0]]
|
||||
missed = set([i for i, h in enumerate(headers) if h is None])
|
||||
headers = [
|
||||
cell.value for i,
|
||||
cell in enumerate(
|
||||
rows[0]) if i not in missed]
|
||||
if not headers:continue
|
||||
if not headers:
|
||||
continue
|
||||
data = []
|
||||
for i, r in enumerate(rows[1:]):
|
||||
rn += 1
|
||||
@ -88,7 +90,6 @@ def trans_bool(s):
|
||||
|
||||
def column_data_type(arr):
|
||||
arr = list(arr)
|
||||
uni = len(set([a for a in arr if a is not None]))
|
||||
counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0}
|
||||
trans = {t: f for f, t in
|
||||
[(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]}
|
||||
@ -157,7 +158,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000,
|
||||
continue
|
||||
if i >= to_page:
|
||||
break
|
||||
row = [l for l in line.split(kwargs.get("delimiter", "\t"))]
|
||||
row = [field for field in line.split(kwargs.get("delimiter", "\t"))]
|
||||
if len(row) != len(headers):
|
||||
fails.append(str(i))
|
||||
continue
|
||||
|
@ -13,12 +13,124 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from .embedding_model import *
|
||||
from .chat_model import *
|
||||
from .cv_model import *
|
||||
from .rerank_model import *
|
||||
from .sequence2txt_model import *
|
||||
from .tts_model import *
|
||||
from .embedding_model import (
|
||||
OllamaEmbed,
|
||||
LocalAIEmbed,
|
||||
OpenAIEmbed,
|
||||
AzureEmbed,
|
||||
XinferenceEmbed,
|
||||
QWenEmbed,
|
||||
ZhipuEmbed,
|
||||
FastEmbed,
|
||||
YoudaoEmbed,
|
||||
BaiChuanEmbed,
|
||||
JinaEmbed,
|
||||
DefaultEmbedding,
|
||||
MistralEmbed,
|
||||
BedrockEmbed,
|
||||
GeminiEmbed,
|
||||
NvidiaEmbed,
|
||||
LmStudioEmbed,
|
||||
OpenAI_APIEmbed,
|
||||
CoHereEmbed,
|
||||
TogetherAIEmbed,
|
||||
PerfXCloudEmbed,
|
||||
UpstageEmbed,
|
||||
SILICONFLOWEmbed,
|
||||
ReplicateEmbed,
|
||||
BaiduYiyanEmbed,
|
||||
VoyageEmbed,
|
||||
HuggingFaceEmbed,
|
||||
VolcEngineEmbed,
|
||||
)
|
||||
from .chat_model import (
|
||||
GptTurbo,
|
||||
AzureChat,
|
||||
ZhipuChat,
|
||||
QWenChat,
|
||||
OllamaChat,
|
||||
LocalAIChat,
|
||||
XinferenceChat,
|
||||
MoonshotChat,
|
||||
DeepSeekChat,
|
||||
VolcEngineChat,
|
||||
BaiChuanChat,
|
||||
MiniMaxChat,
|
||||
MistralChat,
|
||||
GeminiChat,
|
||||
BedrockChat,
|
||||
GroqChat,
|
||||
OpenRouterChat,
|
||||
StepFunChat,
|
||||
NvidiaChat,
|
||||
LmStudioChat,
|
||||
OpenAI_APIChat,
|
||||
CoHereChat,
|
||||
LeptonAIChat,
|
||||
TogetherAIChat,
|
||||
PerfXCloudChat,
|
||||
UpstageChat,
|
||||
NovitaAIChat,
|
||||
SILICONFLOWChat,
|
||||
YiChat,
|
||||
ReplicateChat,
|
||||
HunyuanChat,
|
||||
SparkChat,
|
||||
BaiduYiyanChat,
|
||||
AnthropicChat,
|
||||
GoogleChat,
|
||||
HuggingFaceChat,
|
||||
)
|
||||
|
||||
from .cv_model import (
|
||||
GptV4,
|
||||
AzureGptV4,
|
||||
OllamaCV,
|
||||
XinferenceCV,
|
||||
QWenCV,
|
||||
Zhipu4V,
|
||||
LocalCV,
|
||||
GeminiCV,
|
||||
OpenRouterCV,
|
||||
LocalAICV,
|
||||
NvidiaCV,
|
||||
LmStudioCV,
|
||||
StepFunCV,
|
||||
OpenAI_APICV,
|
||||
TogetherAICV,
|
||||
YiCV,
|
||||
HunyuanCV,
|
||||
)
|
||||
from .rerank_model import (
|
||||
LocalAIRerank,
|
||||
DefaultRerank,
|
||||
JinaRerank,
|
||||
YoudaoRerank,
|
||||
XInferenceRerank,
|
||||
NvidiaRerank,
|
||||
LmStudioRerank,
|
||||
OpenAI_APIRerank,
|
||||
CoHereRerank,
|
||||
TogetherAIRerank,
|
||||
SILICONFLOWRerank,
|
||||
BaiduYiyanRerank,
|
||||
VoyageRerank,
|
||||
QWenRerank,
|
||||
)
|
||||
from .sequence2txt_model import (
|
||||
GPTSeq2txt,
|
||||
QWenSeq2txt,
|
||||
AzureSeq2txt,
|
||||
XinferenceSeq2txt,
|
||||
TencentCloudSeq2txt,
|
||||
)
|
||||
from .tts_model import (
|
||||
FishAudioTTS,
|
||||
QwenTTS,
|
||||
OpenAITTS,
|
||||
SparkTTS,
|
||||
XinferenceTTS,
|
||||
)
|
||||
|
||||
EmbeddingModel = {
|
||||
"Ollama": OllamaEmbed,
|
||||
@ -48,7 +160,7 @@ EmbeddingModel = {
|
||||
"BaiduYiyan": BaiduYiyanEmbed,
|
||||
"Voyage AI": VoyageEmbed,
|
||||
"HuggingFace": HuggingFaceEmbed,
|
||||
"VolcEngine":VolcEngineEmbed,
|
||||
"VolcEngine": VolcEngineEmbed,
|
||||
}
|
||||
|
||||
CvModel = {
|
||||
@ -68,7 +180,7 @@ CvModel = {
|
||||
"OpenAI-API-Compatible": OpenAI_APICV,
|
||||
"TogetherAI": TogetherAICV,
|
||||
"01.AI": YiCV,
|
||||
"Tencent Hunyuan": HunyuanCV
|
||||
"Tencent Hunyuan": HunyuanCV,
|
||||
}
|
||||
|
||||
ChatModel = {
|
||||
@ -111,7 +223,7 @@ ChatModel = {
|
||||
}
|
||||
|
||||
RerankModel = {
|
||||
"LocalAI":LocalAIRerank,
|
||||
"LocalAI": LocalAIRerank,
|
||||
"BAAI": DefaultRerank,
|
||||
"Jina": JinaRerank,
|
||||
"Youdao": YoudaoRerank,
|
||||
@ -132,7 +244,7 @@ Seq2txtModel = {
|
||||
"Tongyi-Qianwen": QWenSeq2txt,
|
||||
"Azure-OpenAI": AzureSeq2txt,
|
||||
"Xinference": XinferenceSeq2txt,
|
||||
"Tencent Cloud": TencentCloudSeq2txt
|
||||
"Tencent Cloud": TencentCloudSeq2txt,
|
||||
}
|
||||
|
||||
TTSModel = {
|
||||
|
@ -69,7 +69,8 @@ class Base(ABC):
|
||||
stream=True,
|
||||
**gen_conf)
|
||||
for resp in response:
|
||||
if not resp.choices: continue
|
||||
if not resp.choices:
|
||||
continue
|
||||
if not resp.choices[0].delta.content:
|
||||
resp.choices[0].delta.content = ""
|
||||
ans += resp.choices[0].delta.content
|
||||
@ -81,7 +82,8 @@ class Base(ABC):
|
||||
)
|
||||
elif isinstance(resp.usage, dict):
|
||||
total_tokens = resp.usage.get("total_tokens", total_tokens)
|
||||
else: total_tokens = resp.usage.total_tokens
|
||||
else:
|
||||
total_tokens = resp.usage.total_tokens
|
||||
|
||||
if resp.choices[0].finish_reason == "length":
|
||||
if is_chinese(ans):
|
||||
@ -98,13 +100,15 @@ class Base(ABC):
|
||||
|
||||
class GptTurbo(Base):
|
||||
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
|
||||
if not base_url: base_url = "https://api.openai.com/v1"
|
||||
if not base_url:
|
||||
base_url = "https://api.openai.com/v1"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class MoonshotChat(Base):
|
||||
def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
|
||||
if not base_url: base_url = "https://api.moonshot.cn/v1"
|
||||
if not base_url:
|
||||
base_url = "https://api.moonshot.cn/v1"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
@ -128,7 +132,8 @@ class HuggingFaceChat(Base):
|
||||
|
||||
class DeepSeekChat(Base):
|
||||
def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"):
|
||||
if not base_url: base_url = "https://api.deepseek.com/v1"
|
||||
if not base_url:
|
||||
base_url = "https://api.deepseek.com/v1"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
@ -202,7 +207,8 @@ class BaiChuanChat(Base):
|
||||
stream=True,
|
||||
**self._format_params(gen_conf))
|
||||
for resp in response:
|
||||
if not resp.choices: continue
|
||||
if not resp.choices:
|
||||
continue
|
||||
if not resp.choices[0].delta.content:
|
||||
resp.choices[0].delta.content = ""
|
||||
ans += resp.choices[0].delta.content
|
||||
@ -313,8 +319,10 @@ class ZhipuChat(Base):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
try:
|
||||
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
|
||||
if "presence_penalty" in gen_conf:
|
||||
del gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf:
|
||||
del gen_conf["frequency_penalty"]
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
@ -333,8 +341,10 @@ class ZhipuChat(Base):
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
|
||||
if "presence_penalty" in gen_conf:
|
||||
del gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf:
|
||||
del gen_conf["frequency_penalty"]
|
||||
ans = ""
|
||||
tk_count = 0
|
||||
try:
|
||||
@ -345,7 +355,8 @@ class ZhipuChat(Base):
|
||||
**gen_conf
|
||||
)
|
||||
for resp in response:
|
||||
if not resp.choices[0].delta.content: continue
|
||||
if not resp.choices[0].delta.content:
|
||||
continue
|
||||
delta = resp.choices[0].delta.content
|
||||
ans += delta
|
||||
if resp.choices[0].finish_reason == "length":
|
||||
@ -354,7 +365,8 @@ class ZhipuChat(Base):
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
tk_count = resp.usage.total_tokens
|
||||
if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
|
||||
if resp.choices[0].finish_reason == "stop":
|
||||
tk_count = resp.usage.total_tokens
|
||||
yield ans
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
@ -372,11 +384,16 @@ class OllamaChat(Base):
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
try:
|
||||
options = {}
|
||||
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
|
||||
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
|
||||
if "top_p" in gen_conf: options["top_p"] = gen_conf["top_p"]
|
||||
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
||||
if "temperature" in gen_conf:
|
||||
options["temperature"] = gen_conf["temperature"]
|
||||
if "max_tokens" in gen_conf:
|
||||
options["num_predict"] = gen_conf["max_tokens"]
|
||||
if "top_p" in gen_conf:
|
||||
options["top_p"] = gen_conf["top_p"]
|
||||
if "presence_penalty" in gen_conf:
|
||||
options["presence_penalty"] = gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf:
|
||||
options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
||||
response = self.client.chat(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
@ -392,11 +409,16 @@ class OllamaChat(Base):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
options = {}
|
||||
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
|
||||
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
|
||||
if "top_p" in gen_conf: options["top_p"] = gen_conf["top_p"]
|
||||
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
||||
if "temperature" in gen_conf:
|
||||
options["temperature"] = gen_conf["temperature"]
|
||||
if "max_tokens" in gen_conf:
|
||||
options["num_predict"] = gen_conf["max_tokens"]
|
||||
if "top_p" in gen_conf:
|
||||
options["top_p"] = gen_conf["top_p"]
|
||||
if "presence_penalty" in gen_conf:
|
||||
options["presence_penalty"] = gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf:
|
||||
options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
||||
ans = ""
|
||||
try:
|
||||
response = self.client.chat(
|
||||
@ -636,7 +658,8 @@ class MistralChat(Base):
|
||||
messages=history,
|
||||
**gen_conf)
|
||||
for resp in response:
|
||||
if not resp.choices or not resp.choices[0].delta.content: continue
|
||||
if not resp.choices or not resp.choices[0].delta.content:
|
||||
continue
|
||||
ans += resp.choices[0].delta.content
|
||||
total_tokens += 1
|
||||
if resp.choices[0].finish_reason == "length":
|
||||
@ -1196,7 +1219,8 @@ class SparkChat(Base):
|
||||
assert model_name in model2version or model_name in version2model, f"The given model name is not supported yet. Support: {list(model2version.keys())}"
|
||||
if model_name in model2version:
|
||||
model_version = model2version[model_name]
|
||||
else: model_version = model_name
|
||||
else:
|
||||
model_version = model_name
|
||||
super().__init__(key, model_version, base_url)
|
||||
|
||||
|
||||
@ -1281,8 +1305,10 @@ class AnthropicChat(Base):
|
||||
self.system = system
|
||||
if "max_tokens" not in gen_conf:
|
||||
gen_conf["max_tokens"] = 4096
|
||||
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
|
||||
if "presence_penalty" in gen_conf:
|
||||
del gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf:
|
||||
del gen_conf["frequency_penalty"]
|
||||
|
||||
ans = ""
|
||||
try:
|
||||
@ -1312,8 +1338,10 @@ class AnthropicChat(Base):
|
||||
self.system = system
|
||||
if "max_tokens" not in gen_conf:
|
||||
gen_conf["max_tokens"] = 4096
|
||||
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
|
||||
if "presence_penalty" in gen_conf:
|
||||
del gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf:
|
||||
del gen_conf["frequency_penalty"]
|
||||
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
|
@ -25,6 +25,7 @@ import base64
|
||||
from io import BytesIO
|
||||
import json
|
||||
import requests
|
||||
from transformers import GenerationConfig
|
||||
|
||||
from rag.nlp import is_english
|
||||
from api.utils import get_uuid
|
||||
@ -77,14 +78,16 @@ class Base(ABC):
|
||||
stream=True
|
||||
)
|
||||
for resp in response:
|
||||
if not resp.choices[0].delta.content: continue
|
||||
if not resp.choices[0].delta.content:
|
||||
continue
|
||||
delta = resp.choices[0].delta.content
|
||||
ans += delta
|
||||
if resp.choices[0].finish_reason == "length":
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
tk_count = resp.usage.total_tokens
|
||||
if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
|
||||
if resp.choices[0].finish_reason == "stop":
|
||||
tk_count = resp.usage.total_tokens
|
||||
yield ans
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
@ -99,7 +102,7 @@ class Base(ABC):
|
||||
buffered = BytesIO()
|
||||
try:
|
||||
image.save(buffered, format="JPEG")
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
image.save(buffered, format="PNG")
|
||||
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
|
||||
@ -139,7 +142,8 @@ class Base(ABC):
|
||||
|
||||
class GptV4(Base):
|
||||
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
|
||||
if not base_url: base_url="https://api.openai.com/v1"
|
||||
if not base_url:
|
||||
base_url="https://api.openai.com/v1"
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
@ -149,7 +153,8 @@ class GptV4(Base):
|
||||
prompt = self.prompt(b64)
|
||||
for i in range(len(prompt)):
|
||||
for c in prompt[i]["content"]:
|
||||
if "text" in c: c["type"] = "text"
|
||||
if "text" in c:
|
||||
c["type"] = "text"
|
||||
|
||||
res = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
@ -171,7 +176,8 @@ class AzureGptV4(Base):
|
||||
prompt = self.prompt(b64)
|
||||
for i in range(len(prompt)):
|
||||
for c in prompt[i]["content"]:
|
||||
if "text" in c: c["type"] = "text"
|
||||
if "text" in c:
|
||||
c["type"] = "text"
|
||||
|
||||
res = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
@ -344,14 +350,16 @@ class Zhipu4V(Base):
|
||||
stream=True
|
||||
)
|
||||
for resp in response:
|
||||
if not resp.choices[0].delta.content: continue
|
||||
if not resp.choices[0].delta.content:
|
||||
continue
|
||||
delta = resp.choices[0].delta.content
|
||||
ans += delta
|
||||
if resp.choices[0].finish_reason == "length":
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
tk_count = resp.usage.total_tokens
|
||||
if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
|
||||
if resp.choices[0].finish_reason == "stop":
|
||||
tk_count = resp.usage.total_tokens
|
||||
yield ans
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
@ -389,11 +397,16 @@ class OllamaCV(Base):
|
||||
if his["role"] == "user":
|
||||
his["images"] = [image]
|
||||
options = {}
|
||||
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
|
||||
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
|
||||
if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
|
||||
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
||||
if "temperature" in gen_conf:
|
||||
options["temperature"] = gen_conf["temperature"]
|
||||
if "max_tokens" in gen_conf:
|
||||
options["num_predict"] = gen_conf["max_tokens"]
|
||||
if "top_p" in gen_conf:
|
||||
options["top_k"] = gen_conf["top_p"]
|
||||
if "presence_penalty" in gen_conf:
|
||||
options["presence_penalty"] = gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf:
|
||||
options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
||||
response = self.client.chat(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
@ -414,11 +427,16 @@ class OllamaCV(Base):
|
||||
if his["role"] == "user":
|
||||
his["images"] = [image]
|
||||
options = {}
|
||||
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
|
||||
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
|
||||
if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
|
||||
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
||||
if "temperature" in gen_conf:
|
||||
options["temperature"] = gen_conf["temperature"]
|
||||
if "max_tokens" in gen_conf:
|
||||
options["num_predict"] = gen_conf["max_tokens"]
|
||||
if "top_p" in gen_conf:
|
||||
options["top_k"] = gen_conf["top_p"]
|
||||
if "presence_penalty" in gen_conf:
|
||||
options["presence_penalty"] = gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf:
|
||||
options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
||||
ans = ""
|
||||
try:
|
||||
response = self.client.chat(
|
||||
@ -469,7 +487,7 @@ class XinferenceCV(Base):
|
||||
|
||||
class GeminiCV(Base):
|
||||
def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
|
||||
from google.generativeai import client, GenerativeModel, GenerationConfig
|
||||
from google.generativeai import client, GenerativeModel
|
||||
client.configure(api_key=key)
|
||||
_client = client.get_default_generative_client()
|
||||
self.model_name = model_name
|
||||
@ -503,7 +521,7 @@ class GeminiCV(Base):
|
||||
if his["role"] == "user":
|
||||
his["parts"] = [his["content"]]
|
||||
his.pop("content")
|
||||
history[-1]["parts"].append(f"data:image/jpeg;base64," + image)
|
||||
history[-1]["parts"].append("data:image/jpeg;base64," + image)
|
||||
|
||||
response = self.model.generate_content(history, generation_config=GenerationConfig(
|
||||
max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3),
|
||||
@ -519,7 +537,6 @@ class GeminiCV(Base):
|
||||
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||
|
||||
ans = ""
|
||||
tk_count = 0
|
||||
try:
|
||||
for his in history:
|
||||
if his["role"] == "assistant":
|
||||
@ -529,14 +546,15 @@ class GeminiCV(Base):
|
||||
if his["role"] == "user":
|
||||
his["parts"] = [his["content"]]
|
||||
his.pop("content")
|
||||
history[-1]["parts"].append(f"data:image/jpeg;base64," + image)
|
||||
history[-1]["parts"].append("data:image/jpeg;base64," + image)
|
||||
|
||||
response = self.model.generate_content(history, generation_config=GenerationConfig(
|
||||
max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3),
|
||||
top_p=gen_conf.get("top_p", 0.7)), stream=True)
|
||||
|
||||
for resp in response:
|
||||
if not resp.text: continue
|
||||
if not resp.text:
|
||||
continue
|
||||
ans += resp.text
|
||||
yield ans
|
||||
except Exception as e:
|
||||
@ -632,7 +650,8 @@ class NvidiaCV(Base):
|
||||
|
||||
class StepFunCV(GptV4):
|
||||
def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"):
|
||||
if not base_url: base_url="https://api.stepfun.com/v1"
|
||||
if not base_url:
|
||||
base_url="https://api.stepfun.com/v1"
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
|
@ -15,12 +15,9 @@
|
||||
#
|
||||
import requests
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
from zhipuai import ZhipuAI
|
||||
import io
|
||||
from abc import ABC
|
||||
from ollama import Client
|
||||
from openai import OpenAI
|
||||
import os
|
||||
import json
|
||||
from rag.utils import num_tokens_from_string
|
||||
import base64
|
||||
@ -49,7 +46,8 @@ class Base(ABC):
|
||||
|
||||
class GPTSeq2txt(Base):
|
||||
def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1"):
|
||||
if not base_url: base_url = "https://api.openai.com/v1"
|
||||
if not base_url:
|
||||
base_url = "https://api.openai.com/v1"
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
|
||||
|
@ -16,7 +16,6 @@
|
||||
|
||||
import _thread as thread
|
||||
import base64
|
||||
import datetime
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
@ -175,7 +174,8 @@ class QwenTTS(Base):
|
||||
|
||||
class OpenAITTS(Base):
|
||||
def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
|
||||
if not base_url: base_url = "https://api.openai.com/v1"
|
||||
if not base_url:
|
||||
base_url = "https://api.openai.com/v1"
|
||||
self.api_key = key
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
|
@ -222,7 +222,8 @@ def bullets_category(sections):
|
||||
|
||||
def is_english(texts):
|
||||
eng = 0
|
||||
if not texts: return False
|
||||
if not texts:
|
||||
return False
|
||||
for t in texts:
|
||||
if re.match(r"[ `a-zA-Z.,':;/\"?<>!\(\)-]", t.strip()):
|
||||
eng += 1
|
||||
@ -250,7 +251,8 @@ def tokenize_chunks(chunks, doc, eng, pdf_parser=None):
|
||||
res = []
|
||||
# wrap up as es documents
|
||||
for ck in chunks:
|
||||
if len(ck.strip()) == 0:continue
|
||||
if len(ck.strip()) == 0:
|
||||
continue
|
||||
logging.debug("-- {}".format(ck))
|
||||
d = copy.deepcopy(doc)
|
||||
if pdf_parser:
|
||||
@ -269,7 +271,8 @@ def tokenize_chunks_docx(chunks, doc, eng, images):
|
||||
res = []
|
||||
# wrap up as es documents
|
||||
for ck, image in zip(chunks, images):
|
||||
if len(ck.strip()) == 0:continue
|
||||
if len(ck.strip()) == 0:
|
||||
continue
|
||||
logging.debug("-- {}".format(ck))
|
||||
d = copy.deepcopy(doc)
|
||||
d["image"] = image
|
||||
@ -288,8 +291,10 @@ def tokenize_table(tbls, doc, eng, batch_size=10):
|
||||
d = copy.deepcopy(doc)
|
||||
tokenize(d, rows, eng)
|
||||
d["content_with_weight"] = rows
|
||||
if img: d["image"] = img
|
||||
if poss: add_positions(d, poss)
|
||||
if img:
|
||||
d["image"] = img
|
||||
if poss:
|
||||
add_positions(d, poss)
|
||||
res.append(d)
|
||||
continue
|
||||
de = "; " if eng else "; "
|
||||
@ -387,9 +392,9 @@ def title_frequency(bull, sections):
|
||||
if re.search(r"(title|head)", layout) and not not_title(txt.split("@")[0]):
|
||||
levels[i] = bullets_size
|
||||
most_level = bullets_size+1
|
||||
for l, c in sorted(Counter(levels).items(), key=lambda x:x[1]*-1):
|
||||
if l <= bullets_size:
|
||||
most_level = l
|
||||
for level, c in sorted(Counter(levels).items(), key=lambda x:x[1]*-1):
|
||||
if level <= bullets_size:
|
||||
most_level = level
|
||||
break
|
||||
return most_level, levels
|
||||
|
||||
@ -504,7 +509,8 @@ def naive_merge(sections, chunk_token_num=128, delimiter="\n。;!?"):
|
||||
def add_chunk(t, pos):
|
||||
nonlocal cks, tk_nums, delimiter
|
||||
tnum = num_tokens_from_string(t)
|
||||
if not pos: pos = ""
|
||||
if not pos:
|
||||
pos = ""
|
||||
if tnum < 8:
|
||||
pos = ""
|
||||
# Ensure that the length of the merged chunk does not exceed chunk_token_num
|
||||
|
@ -121,7 +121,8 @@ class FulltextQueryer:
|
||||
keywords.append(tt)
|
||||
twts = self.tw.weights([tt])
|
||||
syns = self.syn.lookup(tt)
|
||||
if syns and len(keywords) < 32: keywords.extend(syns)
|
||||
if syns and len(keywords) < 32:
|
||||
keywords.extend(syns)
|
||||
logging.debug(json.dumps(twts, ensure_ascii=False))
|
||||
tms = []
|
||||
for tk, w in sorted(twts, key=lambda x: x[1] * -1):
|
||||
@ -147,7 +148,8 @@ class FulltextQueryer:
|
||||
|
||||
tk_syns = self.syn.lookup(tk)
|
||||
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
|
||||
if len(keywords) < 32: keywords.extend([s for s in tk_syns if s])
|
||||
if len(keywords) < 32:
|
||||
keywords.extend([s for s in tk_syns if s])
|
||||
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
|
||||
tk_syns = [f"\"{s}\"" if s.find(" ")>0 else s for s in tk_syns]
|
||||
|
||||
|
@ -104,7 +104,6 @@ class RagTokenizer:
|
||||
return HanziConv.toSimplified(line)
|
||||
|
||||
def dfs_(self, chars, s, preTks, tkslist):
|
||||
MAX_L = 10
|
||||
res = s
|
||||
# if s > MAX_L or s>= len(chars):
|
||||
if s >= len(chars):
|
||||
@ -184,12 +183,6 @@ class RagTokenizer:
|
||||
return sorted(res, key=lambda x: x[1], reverse=True)
|
||||
|
||||
def merge_(self, tks):
|
||||
patts = [
|
||||
(r"[ ]+", " "),
|
||||
(r"([0-9\+\.,%\*=-]) ([0-9\+\.,%\*=-])", r"\1\2"),
|
||||
]
|
||||
# for p,s in patts: tks = re.sub(p, s, tks)
|
||||
|
||||
# if split chars is part of token
|
||||
res = []
|
||||
tks = re.sub(r"[ ]+", " ", tks).split()
|
||||
@ -284,7 +277,8 @@ class RagTokenizer:
|
||||
same = 0
|
||||
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
|
||||
same += 1
|
||||
if same > 0: res.append(" ".join(tks[j: j + same]))
|
||||
if same > 0:
|
||||
res.append(" ".join(tks[j: j + same]))
|
||||
_i = i + same
|
||||
_j = j + same
|
||||
j = _j + 1
|
||||
|
@ -62,10 +62,10 @@ class Dealer:
|
||||
res = {}
|
||||
f = open(fnm, "r")
|
||||
while True:
|
||||
l = f.readline()
|
||||
if not l:
|
||||
line = f.readline()
|
||||
if not line:
|
||||
break
|
||||
arr = l.replace("\n", "").split("\t")
|
||||
arr = line.replace("\n", "").split("\t")
|
||||
if len(arr) < 2:
|
||||
res[arr[0]] = 0
|
||||
else:
|
||||
|
@ -47,7 +47,8 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
def __call__(self, chunks, random_state, callback=None):
|
||||
layers = [(0, len(chunks))]
|
||||
start, end = 0, len(chunks)
|
||||
if len(chunks) <= 1: return
|
||||
if len(chunks) <= 1:
|
||||
return
|
||||
chunks = [(s, a) for s, a in chunks if len(a) > 0]
|
||||
|
||||
def summarize(ck_idx, lock):
|
||||
@ -66,7 +67,8 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
logging.debug(f"SUM: {cnt}")
|
||||
embds, _ = self._embd_model.encode([cnt])
|
||||
with lock:
|
||||
if not len(embds[0]): return
|
||||
if not len(embds[0]):
|
||||
return
|
||||
chunks.append((cnt, embds[0]))
|
||||
except Exception as e:
|
||||
logging.exception("summarize got exception")
|
||||
|
@ -33,14 +33,16 @@ def collect():
|
||||
|
||||
def main():
|
||||
locations = collect()
|
||||
if not locations:return
|
||||
if not locations:
|
||||
return
|
||||
logging.info(f"TASKS: {len(locations)}")
|
||||
for kb_id, loc in locations:
|
||||
try:
|
||||
if REDIS_CONN.is_alive():
|
||||
try:
|
||||
key = "{}/{}".format(kb_id, loc)
|
||||
if REDIS_CONN.exist(key):continue
|
||||
if REDIS_CONN.exist(key):
|
||||
continue
|
||||
file_bin = STORAGE_IMPL.get(kb_id, loc)
|
||||
REDIS_CONN.transaction(key, file_bin, 12 * 60)
|
||||
logging.info("CACHE: {}".format(loc))
|
||||
|
@ -23,18 +23,12 @@ import os
|
||||
|
||||
from api.utils.log_utils import initRootLogger
|
||||
|
||||
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
|
||||
CONSUMER_NAME = "task_executor_" + CONSUMER_NO
|
||||
LOG_LEVELS = os.environ.get("LOG_LEVELS", "")
|
||||
initRootLogger(CONSUMER_NAME, LOG_LEVELS)
|
||||
|
||||
from datetime import datetime
|
||||
import json
|
||||
import os
|
||||
import hashlib
|
||||
import copy
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import threading
|
||||
from functools import partial
|
||||
@ -63,6 +57,11 @@ from rag.utils import rmSpace, num_tokens_from_string
|
||||
from rag.utils.redis_conn import REDIS_CONN, Payload
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
|
||||
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
|
||||
CONSUMER_NAME = "task_executor_" + CONSUMER_NO
|
||||
LOG_LEVELS = os.environ.get("LOG_LEVELS", "")
|
||||
initRootLogger(CONSUMER_NAME, LOG_LEVELS)
|
||||
|
||||
BATCH_SIZE = 64
|
||||
|
||||
FACTORY = {
|
||||
@ -201,7 +200,8 @@ def build_chunks(task, progress_callback):
|
||||
"doc_id": task["doc_id"],
|
||||
"kb_id": str(task["kb_id"])
|
||||
}
|
||||
if task["pagerank"]: doc["pagerank_fea"] = int(task["pagerank"])
|
||||
if task["pagerank"]:
|
||||
doc["pagerank_fea"] = int(task["pagerank"])
|
||||
el = 0
|
||||
for ck in cks:
|
||||
d = copy.deepcopy(doc)
|
||||
@ -342,7 +342,8 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
|
||||
"docnm_kwd": row["name"],
|
||||
"title_tks": rag_tokenizer.tokenize(row["name"])
|
||||
}
|
||||
if row["pagerank"]: doc["pagerank_fea"] = int(row["pagerank"])
|
||||
if row["pagerank"]:
|
||||
doc["pagerank_fea"] = int(row["pagerank"])
|
||||
res = []
|
||||
tk_count = 0
|
||||
for content, vctr in chunks[original_length:]:
|
||||
|
@ -41,15 +41,15 @@ def findMaxDt(fnm):
|
||||
try:
|
||||
with open(fnm, "r") as f:
|
||||
while True:
|
||||
l = f.readline()
|
||||
if not l:
|
||||
line = f.readline()
|
||||
if not line:
|
||||
break
|
||||
l = l.strip("\n")
|
||||
if l == 'nan':
|
||||
line = line.strip("\n")
|
||||
if line == 'nan':
|
||||
continue
|
||||
if l > m:
|
||||
m = l
|
||||
except Exception as e:
|
||||
if line > m:
|
||||
m = line
|
||||
except Exception:
|
||||
pass
|
||||
return m
|
||||
|
||||
@ -59,15 +59,15 @@ def findMaxTm(fnm):
|
||||
try:
|
||||
with open(fnm, "r") as f:
|
||||
while True:
|
||||
l = f.readline()
|
||||
if not l:
|
||||
line = f.readline()
|
||||
if not line:
|
||||
break
|
||||
l = l.strip("\n")
|
||||
if l == 'nan':
|
||||
line = line.strip("\n")
|
||||
if line == 'nan':
|
||||
continue
|
||||
if int(l) > m:
|
||||
m = int(l)
|
||||
except Exception as e:
|
||||
if int(line) > m:
|
||||
m = int(line)
|
||||
except Exception:
|
||||
pass
|
||||
return m
|
||||
|
||||
|
@ -32,7 +32,7 @@ class RAGFlowAzureSasBlob(object):
|
||||
self.conn = None
|
||||
|
||||
def health(self):
|
||||
bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
|
||||
_bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
|
||||
return self.conn.upload_blob(name=fnm, data=BytesIO(binary), length=len(binary))
|
||||
|
||||
def put(self, bucket, fnm, binary):
|
||||
|
@ -36,7 +36,7 @@ class RAGFlowAzureSpnBlob(object):
|
||||
self.conn = None
|
||||
|
||||
def health(self):
|
||||
bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
|
||||
_bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
|
||||
f = self.conn.create_file(fnm)
|
||||
f.append_data(binary, offset=0, length=len(binary))
|
||||
return f.flush_data(len(binary))
|
||||
|
@ -132,7 +132,8 @@ class ESConnection(DocStoreConnection):
|
||||
bqry.filter.append(
|
||||
Q("bool", must_not=Q("range", available_int={"lt": 1})))
|
||||
continue
|
||||
if not v: continue
|
||||
if not v:
|
||||
continue
|
||||
if isinstance(v, list):
|
||||
bqry.filter.append(Q("terms", **{k: v}))
|
||||
elif isinstance(v, str) or isinstance(v, int):
|
||||
|
@ -1,14 +1,21 @@
|
||||
from beartype.claw import beartype_this_package
|
||||
beartype_this_package() # <-- raise exceptions in your code
|
||||
|
||||
import importlib.metadata
|
||||
|
||||
__version__ = importlib.metadata.version("ragflow_sdk")
|
||||
|
||||
from .ragflow import RAGFlow
|
||||
from .modules.dataset import DataSet
|
||||
from .modules.chat import Chat
|
||||
from .modules.session import Session
|
||||
from .modules.document import Document
|
||||
from .modules.chunk import Chunk
|
||||
from .modules.agent import Agent
|
||||
from .modules.agent import Agent
|
||||
|
||||
__version__ = importlib.metadata.version("ragflow_sdk")
|
||||
|
||||
__all__ = [
|
||||
"RAGFlow",
|
||||
"DataSet",
|
||||
"Chat",
|
||||
"Session",
|
||||
"Document",
|
||||
"Chunk",
|
||||
"Agent"
|
||||
]
|
@ -29,7 +29,7 @@ class Session(Base):
|
||||
raise Exception(json_data["message"])
|
||||
if line.startswith("data:"):
|
||||
json_data = json.loads(line[5:])
|
||||
if json_data["data"] != True:
|
||||
if not json_data["data"]:
|
||||
answer = json_data["data"]["answer"]
|
||||
reference = json_data["data"]["reference"]
|
||||
temp_dict = {
|
||||
|
@ -1,5 +1,3 @@
|
||||
import string
|
||||
import random
|
||||
import os
|
||||
import pytest
|
||||
import requests
|
||||
|
@ -39,7 +39,6 @@ def update_dataset(auth, json_req):
|
||||
def upload_file(auth, dataset_id, path):
|
||||
authorization = {"Authorization": auth}
|
||||
url = f"{HOST_ADDRESS}/v1/document/upload"
|
||||
base_name = os.path.basename(path)
|
||||
json_req = {
|
||||
"kb_id": dataset_id,
|
||||
}
|
||||
|
@ -1,3 +1,3 @@
|
||||
def test_get_email(get_email):
|
||||
print(f"\nEmail account:",flush=True)
|
||||
print("\nEmail account:",flush=True)
|
||||
print(f"{get_email}\n",flush=True)
|
@ -13,14 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from common import HOST_ADDRESS, create_dataset, list_dataset, rm_dataset, update_dataset, upload_file, DATASET_NAME_LIMIT
|
||||
from common import create_dataset, list_dataset, rm_dataset, upload_file
|
||||
from common import list_document, get_docs_info, parse_docs
|
||||
from time import sleep
|
||||
from timeit import default_timer as timer
|
||||
import re
|
||||
import pytest
|
||||
import random
|
||||
import string
|
||||
|
||||
|
||||
def test_parse_txt_document(get_auth):
|
||||
|
@ -1,6 +1,5 @@
|
||||
from common import HOST_ADDRESS, create_dataset, list_dataset, rm_dataset, update_dataset, DATASET_NAME_LIMIT
|
||||
from common import create_dataset, list_dataset, rm_dataset, update_dataset, DATASET_NAME_LIMIT
|
||||
import re
|
||||
import pytest
|
||||
import random
|
||||
import string
|
||||
|
||||
@ -33,8 +32,6 @@ def test_dataset(get_auth):
|
||||
|
||||
def test_dataset_1k_dataset(get_auth):
|
||||
# create dataset
|
||||
authorization = {"Authorization": get_auth}
|
||||
url = f"{HOST_ADDRESS}/v1/kb/create"
|
||||
for i in range(1000):
|
||||
res = create_dataset(get_auth, f"test_create_dataset_{i}")
|
||||
assert res.get("code") == 0, f"{res.get('message')}"
|
||||
@ -76,7 +73,7 @@ def test_duplicated_name_dataset(get_auth):
|
||||
dataset_id = item.get("id")
|
||||
dataset_list.append(dataset_id)
|
||||
match = re.match(pattern, dataset_name)
|
||||
assert match != None
|
||||
assert match is not None
|
||||
|
||||
for dataset_id in dataset_list:
|
||||
res = rm_dataset(get_auth, dataset_id)
|
||||
|
@ -1,3 +1,3 @@
|
||||
def test_get_email(get_email):
|
||||
print(f"\nEmail account:",flush=True)
|
||||
print("\nEmail account:",flush=True)
|
||||
print(f"{get_email}\n",flush=True)
|
@ -1,4 +1,4 @@
|
||||
from ragflow_sdk import RAGFlow,Agent
|
||||
from ragflow_sdk import RAGFlow
|
||||
from common import HOST_ADDRESS
|
||||
import pytest
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user