Fix errors detected by Ruff (#3918)

### What problem does this PR solve?

Fix errors detected by Ruff

### Type of change

- [x] Refactoring
This commit is contained in:
Zhichang Yu 2024-12-08 14:21:12 +08:00 committed by GitHub
parent e267a026f3
commit 0d68a6cd1b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
97 changed files with 2558 additions and 1976 deletions

View File

@ -133,7 +133,8 @@ class Canvas(ABC):
"components": {} "components": {}
} }
for k in self.dsl.keys(): for k in self.dsl.keys():
if k in ["components"]:continue if k in ["components"]:
continue
dsl[k] = deepcopy(self.dsl[k]) dsl[k] = deepcopy(self.dsl[k])
for k, cpn in self.components.items(): for k, cpn in self.components.items():
@ -158,7 +159,8 @@ class Canvas(ABC):
def get_compnent_name(self, cid): def get_compnent_name(self, cid):
for n in self.dsl["graph"]["nodes"]: for n in self.dsl["graph"]["nodes"]:
if cid == n["id"]: return n["data"]["name"] if cid == n["id"]:
return n["data"]["name"]
return "" return ""
def run(self, **kwargs): def run(self, **kwargs):
@ -173,7 +175,8 @@ class Canvas(ABC):
if kwargs.get("stream"): if kwargs.get("stream"):
for an in ans(): for an in ans():
yield an yield an
else: yield ans else:
yield ans
return return
if not self.path: if not self.path:
@ -188,7 +191,8 @@ class Canvas(ABC):
def prepare2run(cpns): def prepare2run(cpns):
nonlocal ran, ans nonlocal ran, ans
for c in cpns: for c in cpns:
if self.path[-1] and c == self.path[-1][-1]: continue if self.path[-1] and c == self.path[-1][-1]:
continue
cpn = self.components[c]["obj"] cpn = self.components[c]["obj"]
if cpn.component_name == "Answer": if cpn.component_name == "Answer":
self.answer.append(c) self.answer.append(c)
@ -197,7 +201,8 @@ class Canvas(ABC):
if c not in without_dependent_checking: if c not in without_dependent_checking:
cpids = cpn.get_dependent_components() cpids = cpn.get_dependent_components()
if any([cc not in self.path[-1] for cc in cpids]): if any([cc not in self.path[-1] for cc in cpids]):
if c not in waiting: waiting.append(c) if c not in waiting:
waiting.append(c)
continue continue
yield "*'{}'* is running...🕞".format(self.get_compnent_name(c)) yield "*'{}'* is running...🕞".format(self.get_compnent_name(c))
ans = cpn.run(self.history, **kwargs) ans = cpn.run(self.history, **kwargs)
@ -211,10 +216,12 @@ class Canvas(ABC):
logging.debug(f"Canvas.run: {ran} {self.path}") logging.debug(f"Canvas.run: {ran} {self.path}")
cpn_id = self.path[-1][ran] cpn_id = self.path[-1][ran]
cpn = self.get_component(cpn_id) cpn = self.get_component(cpn_id)
if not cpn["downstream"]: break if not cpn["downstream"]:
break
loop = self._find_loop() loop = self._find_loop()
if loop: raise OverflowError(f"Too much loops: {loop}") if loop:
raise OverflowError(f"Too much loops: {loop}")
if cpn["obj"].component_name.lower() in ["switch", "categorize", "relevant"]: if cpn["obj"].component_name.lower() in ["switch", "categorize", "relevant"]:
switch_out = cpn["obj"].output()[1].iloc[0, 0] switch_out = cpn["obj"].output()[1].iloc[0, 0]
@ -283,19 +290,22 @@ class Canvas(ABC):
def _find_loop(self, max_loops=6): def _find_loop(self, max_loops=6):
path = self.path[-1][::-1] path = self.path[-1][::-1]
if len(path) < 2: return False if len(path) < 2:
return False
for i in range(len(path)): for i in range(len(path)):
if path[i].lower().find("answer") >= 0: if path[i].lower().find("answer") >= 0:
path = path[:i] path = path[:i]
break break
if len(path) < 2: return False if len(path) < 2:
return False
for l in range(2, len(path) // 2): for loc in range(2, len(path) // 2):
pat = ",".join(path[0:l]) pat = ",".join(path[0:loc])
path_str = ",".join(path) path_str = ",".join(path)
if len(pat) >= len(path_str): return False if len(pat) >= len(path_str):
return False
loop = max_loops loop = max_loops
while path_str.find(pat) == 0 and loop >= 0: while path_str.find(pat) == 0 and loop >= 0:
loop -= 1 loop -= 1
@ -303,7 +313,7 @@ class Canvas(ABC):
return False return False
path_str = path_str[len(pat)+1:] path_str = path_str[len(pat)+1:]
if loop < 0: if loop < 0:
pat = " => ".join([p.split(":")[0] for p in path[0:l]]) pat = " => ".join([p.split(":")[0] for p in path[0:loc]])
return pat + " => " + pat return pat + " => " + pat
return False return False

View File

@ -39,3 +39,73 @@ def component_class(class_name):
m = importlib.import_module("agent.component") m = importlib.import_module("agent.component")
c = getattr(m, class_name) c = getattr(m, class_name)
return c return c
__all__ = [
"Begin",
"BeginParam",
"Generate",
"GenerateParam",
"Retrieval",
"RetrievalParam",
"Answer",
"AnswerParam",
"Categorize",
"CategorizeParam",
"Switch",
"SwitchParam",
"Relevant",
"RelevantParam",
"Message",
"MessageParam",
"RewriteQuestion",
"RewriteQuestionParam",
"KeywordExtract",
"KeywordExtractParam",
"Concentrator",
"ConcentratorParam",
"Baidu",
"BaiduParam",
"DuckDuckGo",
"DuckDuckGoParam",
"Wikipedia",
"WikipediaParam",
"PubMed",
"PubMedParam",
"ArXiv",
"ArXivParam",
"Google",
"GoogleParam",
"Bing",
"BingParam",
"GoogleScholar",
"GoogleScholarParam",
"DeepL",
"DeepLParam",
"GitHub",
"GitHubParam",
"BaiduFanyi",
"BaiduFanyiParam",
"QWeather",
"QWeatherParam",
"ExeSQL",
"ExeSQLParam",
"YahooFinance",
"YahooFinanceParam",
"WenCai",
"WenCaiParam",
"Jin10",
"Jin10Param",
"TuShare",
"TuShareParam",
"AkShare",
"AkShareParam",
"Crawler",
"CrawlerParam",
"Invoke",
"InvokeParam",
"Template",
"TemplateParam",
"Email",
"EmailParam",
"component_class"
]

View File

@ -428,7 +428,8 @@ class ComponentBase(ABC):
def output(self, allow_partial=True) -> Tuple[str, Union[pd.DataFrame, partial]]: def output(self, allow_partial=True) -> Tuple[str, Union[pd.DataFrame, partial]]:
o = getattr(self._param, self._param.output_var_name) o = getattr(self._param, self._param.output_var_name)
if not isinstance(o, partial) and not isinstance(o, pd.DataFrame): if not isinstance(o, partial) and not isinstance(o, pd.DataFrame):
if not isinstance(o, list): o = [o] if not isinstance(o, list):
o = [o]
o = pd.DataFrame(o) o = pd.DataFrame(o)
if allow_partial or not isinstance(o, partial): if allow_partial or not isinstance(o, partial):
@ -440,7 +441,8 @@ class ComponentBase(ABC):
for oo in o(): for oo in o():
if not isinstance(oo, pd.DataFrame): if not isinstance(oo, pd.DataFrame):
outs = pd.DataFrame(oo if isinstance(oo, list) else [oo]) outs = pd.DataFrame(oo if isinstance(oo, list) else [oo])
else: outs = oo else:
outs = oo
return self._param.output_var_name, outs return self._param.output_var_name, outs
def reset(self): def reset(self):
@ -482,13 +484,15 @@ class ComponentBase(ABC):
outs.append(pd.DataFrame([{"content": q["value"]}])) outs.append(pd.DataFrame([{"content": q["value"]}]))
if outs: if outs:
df = pd.concat(outs, ignore_index=True) df = pd.concat(outs, ignore_index=True)
if "content" in df: df = df.drop_duplicates(subset=['content']).reset_index(drop=True) if "content" in df:
df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
return df return df
upstream_outs = [] upstream_outs = []
for u in reversed_cpnts[::-1]: for u in reversed_cpnts[::-1]:
if self.get_component_name(u) in ["switch", "concentrator"]: continue if self.get_component_name(u) in ["switch", "concentrator"]:
continue
if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval": if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1] o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
if o is not None: if o is not None:
@ -532,7 +536,8 @@ class ComponentBase(ABC):
reversed_cpnts.extend(self._canvas.path[-1]) reversed_cpnts.extend(self._canvas.path[-1])
for u in reversed_cpnts[::-1]: for u in reversed_cpnts[::-1]:
if self.get_component_name(u) in ["switch", "answer"]: continue if self.get_component_name(u) in ["switch", "answer"]:
continue
return self._canvas.get_component(u)["obj"].output()[1] return self._canvas.get_component(u)["obj"].output()[1]
@staticmethod @staticmethod

View File

@ -34,15 +34,18 @@ class CategorizeParam(GenerateParam):
super().check() super().check()
self.check_empty(self.category_description, "[Categorize] Category examples") self.check_empty(self.category_description, "[Categorize] Category examples")
for k, v in self.category_description.items(): for k, v in self.category_description.items():
if not k: raise ValueError("[Categorize] Category name can not be empty!") if not k:
if not v.get("to"): raise ValueError(f"[Categorize] 'To' of category {k} can not be empty!") raise ValueError("[Categorize] Category name can not be empty!")
if not v.get("to"):
raise ValueError(f"[Categorize] 'To' of category {k} can not be empty!")
def get_prompt(self): def get_prompt(self):
cate_lines = [] cate_lines = []
for c, desc in self.category_description.items(): for c, desc in self.category_description.items():
for l in desc.get("examples", "").split("\n"): for line in desc.get("examples", "").split("\n"):
if not l: continue if not line:
cate_lines.append("Question: {}\tCategory: {}".format(l, c)) continue
cate_lines.append("Question: {}\tCategory: {}".format(line, c))
descriptions = [] descriptions = []
for c, desc in self.category_description.items(): for c, desc in self.category_description.items():
if desc.get("description"): if desc.get("description"):

View File

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
# #
from abc import ABC from abc import ABC
import re
from agent.component.base import ComponentBase, ComponentParamBase from agent.component.base import ComponentBase, ComponentParamBase
import deepl import deepl

View File

@ -46,8 +46,10 @@ class ExeSQLParam(ComponentParamBase):
self.check_empty(self.password, "Database password") self.check_empty(self.password, "Database password")
self.check_positive_integer(self.top_n, "Number of records") self.check_positive_integer(self.top_n, "Number of records")
if self.database == "rag_flow": if self.database == "rag_flow":
if self.host == "ragflow-mysql": raise ValueError("The host is not accessible.") if self.host == "ragflow-mysql":
if self.password == "infini_rag_flow": raise ValueError("The host is not accessible.") raise ValueError("The host is not accessible.")
if self.password == "infini_rag_flow":
raise ValueError("The host is not accessible.")
class ExeSQL(ComponentBase, ABC): class ExeSQL(ComponentBase, ABC):

View File

@ -51,11 +51,16 @@ class GenerateParam(ComponentParamBase):
def gen_conf(self): def gen_conf(self):
conf = {} conf = {}
if self.max_tokens > 0: conf["max_tokens"] = self.max_tokens if self.max_tokens > 0:
if self.temperature > 0: conf["temperature"] = self.temperature conf["max_tokens"] = self.max_tokens
if self.top_p > 0: conf["top_p"] = self.top_p if self.temperature > 0:
if self.presence_penalty > 0: conf["presence_penalty"] = self.presence_penalty conf["temperature"] = self.temperature
if self.frequency_penalty > 0: conf["frequency_penalty"] = self.frequency_penalty if self.top_p > 0:
conf["top_p"] = self.top_p
if self.presence_penalty > 0:
conf["presence_penalty"] = self.presence_penalty
if self.frequency_penalty > 0:
conf["frequency_penalty"] = self.frequency_penalty
return conf return conf
@ -83,7 +88,8 @@ class Generate(ComponentBase):
recall_docs = [] recall_docs = []
for i in idx: for i in idx:
did = retrieval_res.loc[int(i), "doc_id"] did = retrieval_res.loc[int(i), "doc_id"]
if did in doc_ids: continue if did in doc_ids:
continue
doc_ids.add(did) doc_ids.add(did)
recall_docs.append({"doc_id": did, "doc_name": retrieval_res.loc[int(i), "docnm_kwd"]}) recall_docs.append({"doc_id": did, "doc_name": retrieval_res.loc[int(i), "docnm_kwd"]})
@ -108,7 +114,8 @@ class Generate(ComponentBase):
retrieval_res = [] retrieval_res = []
self._param.inputs = [] self._param.inputs = []
for para in self._param.parameters: for para in self._param.parameters:
if not para.get("component_id"): continue if not para.get("component_id"):
continue
component_id = para["component_id"].split("@")[0] component_id = para["component_id"].split("@")[0]
if para["component_id"].lower().find("@") >= 0: if para["component_id"].lower().find("@") >= 0:
cpn_id, key = para["component_id"].split("@") cpn_id, key = para["component_id"].split("@")
@ -142,7 +149,8 @@ class Generate(ComponentBase):
if retrieval_res: if retrieval_res:
retrieval_res = pd.concat(retrieval_res, ignore_index=True) retrieval_res = pd.concat(retrieval_res, ignore_index=True)
else: retrieval_res = pd.DataFrame([]) else:
retrieval_res = pd.DataFrame([])
for n, v in kwargs.items(): for n, v in kwargs.items():
prompt = re.sub(r"\{%s\}" % re.escape(n), str(v).replace("\\", " "), prompt) prompt = re.sub(r"\{%s\}" % re.escape(n), str(v).replace("\\", " "), prompt)
@ -164,9 +172,11 @@ class Generate(ComponentBase):
return pd.DataFrame([res]) return pd.DataFrame([res])
msg = self._canvas.get_history(self._param.message_history_window_size) msg = self._canvas.get_history(self._param.message_history_window_size)
if len(msg) < 1: msg.append({"role": "user", "content": ""}) if len(msg) < 1:
msg.append({"role": "user", "content": ""})
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97)) _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
if len(msg) < 2: msg.append({"role": "user", "content": ""}) if len(msg) < 2:
msg.append({"role": "user", "content": ""})
ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf()) ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf())
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns: if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
@ -185,9 +195,11 @@ class Generate(ComponentBase):
return return
msg = self._canvas.get_history(self._param.message_history_window_size) msg = self._canvas.get_history(self._param.message_history_window_size)
if len(msg) < 1: msg.append({"role": "user", "content": ""}) if len(msg) < 1:
msg.append({"role": "user", "content": ""})
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97)) _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
if len(msg) < 2: msg.append({"role": "user", "content": ""}) if len(msg) < 2:
msg.append({"role": "user", "content": ""})
answer = "" answer = ""
for ans in chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf()): for ans in chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf()):
res = {"content": ans, "reference": []} res = {"content": ans, "reference": []}

View File

@ -95,7 +95,8 @@ class RewriteQuestion(Generate, ABC):
hist = self._canvas.get_history(4) hist = self._canvas.get_history(4)
conv = [] conv = []
for m in hist: for m in hist:
if m["role"] not in ["user", "assistant"]: continue if m["role"] not in ["user", "assistant"]:
continue
conv.append("{}: {}".format(m["role"].upper(), m["content"])) conv.append("{}: {}".format(m["role"].upper(), m["content"]))
conv = "\n".join(conv) conv = "\n".join(conv)

View File

@ -41,7 +41,8 @@ class SwitchParam(ComponentParamBase):
def check(self): def check(self):
self.check_empty(self.conditions, "[Switch] conditions") self.check_empty(self.conditions, "[Switch] conditions")
for cond in self.conditions: for cond in self.conditions:
if not cond["to"]: raise ValueError(f"[Switch] 'To' can not be empty!") if not cond["to"]:
raise ValueError("[Switch] 'To' can not be empty!")
class Switch(ComponentBase, ABC): class Switch(ComponentBase, ABC):
@ -51,7 +52,8 @@ class Switch(ComponentBase, ABC):
res = [] res = []
for cond in self._param.conditions: for cond in self._param.conditions:
for item in cond["items"]: for item in cond["items"]:
if not item["cpn_id"]: continue if not item["cpn_id"]:
continue
if item["cpn_id"].find("begin") >= 0: if item["cpn_id"].find("begin") >= 0:
continue continue
cid = item["cpn_id"].split("@")[0] cid = item["cpn_id"].split("@")[0]
@ -63,7 +65,8 @@ class Switch(ComponentBase, ABC):
for cond in self._param.conditions: for cond in self._param.conditions:
res = [] res = []
for item in cond["items"]: for item in cond["items"]:
if not item["cpn_id"]:continue if not item["cpn_id"]:
continue
cid = item["cpn_id"].split("@")[0] cid = item["cpn_id"].split("@")[0]
if item["cpn_id"].find("@") > 0: if item["cpn_id"].find("@") > 0:
cpn_id, key = item["cpn_id"].split("@") cpn_id, key = item["cpn_id"].split("@")
@ -107,22 +110,22 @@ class Switch(ComponentBase, ABC):
elif operator == ">": elif operator == ">":
try: try:
return True if float(input) > float(value) else False return True if float(input) > float(value) else False
except Exception as e: except Exception:
return True if input > value else False return True if input > value else False
elif operator == "<": elif operator == "<":
try: try:
return True if float(input) < float(value) else False return True if float(input) < float(value) else False
except Exception as e: except Exception:
return True if input < value else False return True if input < value else False
elif operator == "": elif operator == "":
try: try:
return True if float(input) >= float(value) else False return True if float(input) >= float(value) else False
except Exception as e: except Exception:
return True if input >= value else False return True if input >= value else False
elif operator == "": elif operator == "":
try: try:
return True if float(input) <= float(value) else False return True if float(input) <= float(value) else False
except Exception as e: except Exception:
return True if input <= value else False return True if input <= value else False
raise ValueError('Not supported operator' + operator) raise ValueError('Not supported operator' + operator)

View File

@ -47,7 +47,8 @@ class Template(ComponentBase):
self._param.inputs = [] self._param.inputs = []
for para in self._param.parameters: for para in self._param.parameters:
if not para.get("component_id"): continue if not para.get("component_id"):
continue
component_id = para["component_id"].split("@")[0] component_id = para["component_id"].split("@")[0]
if para["component_id"].lower().find("@") >= 0: if para["component_id"].lower().find("@") >= 0:
cpn_id, key = para["component_id"].split("@") cpn_id, key = para["component_id"].split("@")

View File

@ -43,6 +43,7 @@ if __name__ == '__main__':
else: else:
print(ans["content"]) print(ans["content"])
if DEBUG: print(canvas.path) if DEBUG:
print(canvas.path)
question = input("\n==================== User =====================\n> ") question = input("\n==================== User =====================\n> ")
canvas.add_user_input(question) canvas.add_user_input(question)

View File

@ -142,7 +142,6 @@ def set_conversation():
if not objs: if not objs:
return get_json_result( return get_json_result(
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
req = request.json
try: try:
if objs[0].source == "agent": if objs[0].source == "agent":
e, cvs = UserCanvasService.get_by_id(objs[0].dialog_id) e, cvs = UserCanvasService.get_by_id(objs[0].dialog_id)
@ -188,7 +187,8 @@ def completion():
e, conv = API4ConversationService.get_by_id(req["conversation_id"]) e, conv = API4ConversationService.get_by_id(req["conversation_id"])
if not e: if not e:
return get_data_error_result(message="Conversation not found!") return get_data_error_result(message="Conversation not found!")
if "quote" not in req: req["quote"] = False if "quote" not in req:
req["quote"] = False
msg = [] msg = []
for m in req["messages"]: for m in req["messages"]:
@ -197,7 +197,8 @@ def completion():
if m["role"] == "assistant" and not msg: if m["role"] == "assistant" and not msg:
continue continue
msg.append(m) msg.append(m)
if not msg[-1].get("id"): msg[-1]["id"] = get_uuid() if not msg[-1].get("id"):
msg[-1]["id"] = get_uuid()
message_id = msg[-1]["id"] message_id = msg[-1]["id"]
def fillin_conv(ans): def fillin_conv(ans):
@ -674,11 +675,13 @@ def completion_faq():
e, conv = API4ConversationService.get_by_id(req["conversation_id"]) e, conv = API4ConversationService.get_by_id(req["conversation_id"])
if not e: if not e:
return get_data_error_result(message="Conversation not found!") return get_data_error_result(message="Conversation not found!")
if "quote" not in req: req["quote"] = True if "quote" not in req:
req["quote"] = True
msg = [] msg = []
msg.append({"role": "user", "content": req["word"]}) msg.append({"role": "user", "content": req["word"]})
if not msg[-1].get("id"): msg[-1]["id"] = get_uuid() if not msg[-1].get("id"):
msg[-1]["id"] = get_uuid()
message_id = msg[-1]["id"] message_id = msg[-1]["id"]
def fillin_conv(ans): def fillin_conv(ans):

View File

@ -13,10 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import logging
import json import json
import traceback import traceback
from functools import partial
from flask import request, Response from flask import request, Response
from flask_login import login_required, current_user from flask_login import login_required, current_user
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
@ -60,7 +58,8 @@ def rm():
def save(): def save():
req = request.json req = request.json
req["user_id"] = current_user.id req["user_id"] = current_user.id
if not isinstance(req["dsl"], str): req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False) if not isinstance(req["dsl"], str):
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
req["dsl"] = json.loads(req["dsl"]) req["dsl"] = json.loads(req["dsl"])
if "id" not in req: if "id" not in req:
@ -153,7 +152,8 @@ def run():
return resp return resp
for answer in canvas.run(stream=False): for answer in canvas.run(stream=False):
if answer.get("running_status"): continue if answer.get("running_status"):
continue
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else "" final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id}) canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
if final_ans.get("reference"): if final_ans.get("reference"):

View File

@ -237,7 +237,8 @@ def create():
e, kb = KnowledgebaseService.get_by_id(doc.kb_id) e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
if not e: if not e:
return get_data_error_result(message="Knowledgebase not found!") return get_data_error_result(message="Knowledgebase not found!")
if kb.pagerank: d["pagerank_fea"] = kb.pagerank if kb.pagerank:
d["pagerank_fea"] = kb.pagerank
embd_id = DocumentService.get_embd_id(req["doc_id"]) embd_id = DocumentService.get_embd_id(req["doc_id"])
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id) embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)

View File

@ -281,10 +281,12 @@ def thumbup():
if req["message_id"] == msg.get("id", "") and msg.get("role", "") == "assistant": if req["message_id"] == msg.get("id", "") and msg.get("role", "") == "assistant":
if up_down: if up_down:
msg["thumbup"] = True msg["thumbup"] = True
if "feedback" in msg: del msg["feedback"] if "feedback" in msg:
del msg["feedback"]
else: else:
msg["thumbup"] = False msg["thumbup"] = False
if feedback: msg["feedback"] = feedback if feedback:
msg["feedback"] = feedback
break break
ConversationService.update_by_id(conv["id"], conv) ConversationService.update_by_id(conv["id"], conv)

View File

@ -37,10 +37,12 @@ def set_dialog():
top_n = req.get("top_n", 6) top_n = req.get("top_n", 6)
top_k = req.get("top_k", 1024) top_k = req.get("top_k", 1024)
rerank_id = req.get("rerank_id", "") rerank_id = req.get("rerank_id", "")
if not rerank_id: req["rerank_id"] = "" if not rerank_id:
req["rerank_id"] = ""
similarity_threshold = req.get("similarity_threshold", 0.1) similarity_threshold = req.get("similarity_threshold", 0.1)
vector_similarity_weight = req.get("vector_similarity_weight", 0.3) vector_similarity_weight = req.get("vector_similarity_weight", 0.3)
if vector_similarity_weight is None: vector_similarity_weight = 0.3 if vector_similarity_weight is None:
vector_similarity_weight = 0.3
llm_setting = req.get("llm_setting", {}) llm_setting = req.get("llm_setting", {})
default_prompt = { default_prompt = {
"system": """你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答。当所有知识库内容都与问题无关时,你的回答必须包括“知识库中未找到您要的答案!”这句话。回答需要考虑聊天历史。 "system": """你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答。当所有知识库内容都与问题无关时,你的回答必须包括“知识库中未找到您要的答案!”这句话。回答需要考虑聊天历史。

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
# #
import json
import os.path import os.path
import pathlib import pathlib
import re import re
@ -90,7 +89,8 @@ def web_crawl():
raise LookupError("Can't find this knowledgebase!") raise LookupError("Can't find this knowledgebase!")
blob = html2pdf(url) blob = html2pdf(url)
if not blob: return server_error_response(ValueError("Download failure.")) if not blob:
return server_error_response(ValueError("Download failure."))
root_folder = FileService.get_root_folder(current_user.id) root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder["id"] pf_id = root_folder["id"]
@ -290,7 +290,8 @@ def change_status():
def rm(): def rm():
req = request.json req = request.json
doc_ids = req["doc_id"] doc_ids = req["doc_id"]
if isinstance(doc_ids, str): doc_ids = [doc_ids] if isinstance(doc_ids, str):
doc_ids = [doc_ids]
for doc_id in doc_ids: for doc_id in doc_ids:
if not DocumentService.accessible4deletion(doc_id, current_user.id): if not DocumentService.accessible4deletion(doc_id, current_user.id):

View File

@ -351,8 +351,10 @@ def list_app():
llm_set = set([m["llm_name"] + "@" + m["fid"] for m in llms]) llm_set = set([m["llm_name"] + "@" + m["fid"] for m in llms])
for o in objs: for o in objs:
if not o.api_key: continue if not o.api_key:
if o.llm_name + "@" + o.llm_factory in llm_set: continue continue
if o.llm_name + "@" + o.llm_factory in llm_set:
continue
llms.append({"llm_name": o.llm_name, "model_type": o.model_type, "fid": o.llm_factory, "available": True}) llms.append({"llm_name": o.llm_name, "model_type": o.model_type, "fid": o.llm_factory, "available": True})
res = {} res = {}

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
# #
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService from api.db.services.canvas_service import UserCanvasService
from api.utils.api_utils import get_error_data_result, token_required from api.utils.api_utils import get_error_data_result, token_required
from api.utils.api_utils import get_result from api.utils.api_utils import get_result
from flask import request from flask import request

View File

@ -41,7 +41,6 @@ from api.utils.api_utils import construct_json_result, get_parser_config
from rag.nlp import search from rag.nlp import search
from rag.utils import rmSpace from rag.utils import rmSpace
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
import os
MAXIMUM_OF_UPLOADING_FILES = 256 MAXIMUM_OF_UPLOADING_FILES = 256
@ -976,12 +975,12 @@ def add_chunk(tenant_id, dataset_id, document_id):
if not req.get("content"): if not req.get("content"):
return get_error_data_result(message="`content` is required") return get_error_data_result(message="`content` is required")
if "important_keywords" in req: if "important_keywords" in req:
if type(req["important_keywords"]) != list: if not isinstance(req["important_keywords"], list):
return get_error_data_result( return get_error_data_result(
"`important_keywords` is required to be a list" "`important_keywords` is required to be a list"
) )
if "questions" in req: if "questions" in req:
if type(req["questions"]) != list: if not isinstance(req["questions"], list):
return get_error_data_result( return get_error_data_result(
"`questions` is required to be a list" "`questions` is required to be a list"
) )

View File

@ -143,8 +143,10 @@ def completion(tenant_id, chat_id):
} }
conv.message.append(question) conv.message.append(question)
for m in conv.message: for m in conv.message:
if m["role"] == "system": continue if m["role"] == "system":
if m["role"] == "assistant" and not msg: continue continue
if m["role"] == "assistant" and not msg:
continue
msg.append(m) msg.append(m)
message_id = msg[-1].get("id") message_id = msg[-1].get("id")
e, dia = DialogService.get_by_id(conv.dialog_id) e, dia = DialogService.get_by_id(conv.dialog_id)
@ -267,7 +269,8 @@ def agent_completion(tenant_id, agent_id):
if m["role"] == "assistant" and not msg: if m["role"] == "assistant" and not msg:
continue continue
msg.append(m) msg.append(m)
if not msg[-1].get("id"): msg[-1]["id"] = get_uuid() if not msg[-1].get("id"):
msg[-1]["id"] = get_uuid()
message_id = msg[-1]["id"] message_id = msg[-1]["id"]
stream = req.get("stream", True) stream = req.get("stream", True)
@ -361,7 +364,8 @@ def agent_completion(tenant_id, agent_id):
return resp return resp
for answer in canvas.run(stream=False): for answer in canvas.run(stream=False):
if answer.get("running_status"): continue if answer.get("running_status"):
continue
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else "" final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id}) canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
if final_ans.get("reference"): if final_ans.get("reference"):

View File

@ -330,7 +330,7 @@ def user_info_from_github(access_token):
headers=headers, headers=headers,
).json() ).json()
user_info["email"] = next( user_info["email"] = next(
(email for email in email_info if email["primary"] == True), None (email for email in email_info if email["primary"]), None
)["email"] )["email"]
return user_info return user_info

View File

@ -130,7 +130,7 @@ def is_continuous_field(cls: typing.Type) -> bool:
for p in cls.__bases__: for p in cls.__bases__:
if p in CONTINUOUS_FIELD_TYPE: if p in CONTINUOUS_FIELD_TYPE:
return True return True
elif p != Field and p != object: elif p is not Field and p is not object:
if is_continuous_field(p): if is_continuous_field(p):
return True return True
else: else:

View File

@ -170,7 +170,7 @@ def add_graph_templates():
cnvs = json.load(open(os.path.join(dir, fnm), "r")) cnvs = json.load(open(os.path.join(dir, fnm), "r"))
try: try:
CanvasTemplateService.save(**cnvs) CanvasTemplateService.save(**cnvs)
except: except Exception:
CanvasTemplateService.update_by_id(cnvs["id"], cnvs) CanvasTemplateService.update_by_id(cnvs["id"], cnvs)
except Exception: except Exception:
logging.exception("Add graph templates error: ") logging.exception("Add graph templates error: ")

View File

@ -15,13 +15,14 @@
# #
import pathlib import pathlib
import re import re
from .user_service import UserService from .user_service import UserService as UserService
def duplicate_name(query_func, **kwargs): def duplicate_name(query_func, **kwargs):
fnm = kwargs["name"] fnm = kwargs["name"]
objs = query_func(**kwargs) objs = query_func(**kwargs)
if not objs: return fnm if not objs:
return fnm
ext = pathlib.Path(fnm).suffix #.jpg ext = pathlib.Path(fnm).suffix #.jpg
nm = re.sub(r"%s$"%ext, "", fnm) nm = re.sub(r"%s$"%ext, "", fnm)
r = re.search(r"\(([0-9]+)\)$", nm) r = re.search(r"\(([0-9]+)\)$", nm)
@ -31,8 +32,8 @@ def duplicate_name(query_func, **kwargs):
nm = re.sub(r"\([0-9]+\)$", "", nm) nm = re.sub(r"\([0-9]+\)$", "", nm)
c += 1 c += 1
nm = f"{nm}({c})" nm = f"{nm}({c})"
if ext: nm += f"{ext}" if ext:
nm += f"{ext}"
kwargs["name"] = nm kwargs["name"] = nm
return duplicate_name(query_func, **kwargs) return duplicate_name(query_func, **kwargs)

View File

@ -64,7 +64,8 @@ class API4ConversationService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def stats(cls, tenant_id, from_date, to_date, source=None): def stats(cls, tenant_id, from_date, to_date, source=None):
if len(to_date) == 10: to_date += " 23:59:59" if len(to_date) == 10:
to_date += " 23:59:59"
return cls.model.select( return cls.model.select(
cls.model.create_date.truncate("day").alias("dt"), cls.model.create_date.truncate("day").alias("dt"),
peewee.fn.COUNT( peewee.fn.COUNT(

View File

@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
from datetime import datetime from api.db.db_models import DB, CanvasTemplate, UserCanvas
import peewee
from api.db.db_models import DB, API4Conversation, APIToken, Dialog, CanvasTemplate, UserCanvas
from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService

View File

@ -115,7 +115,7 @@ class CommonService:
try: try:
obj = cls.model.query(id=pid)[0] obj = cls.model.query(id=pid)[0]
return True, obj return True, obj
except Exception as e: except Exception:
return False, None return False, None
@classmethod @classmethod

View File

@ -106,15 +106,15 @@ def message_fit_in(msg, max_length=4000):
return c, msg return c, msg
ll = num_tokens_from_string(msg_[0]["content"]) ll = num_tokens_from_string(msg_[0]["content"])
l = num_tokens_from_string(msg_[-1]["content"]) ll2 = num_tokens_from_string(msg_[-1]["content"])
if ll / (ll + l) > 0.8: if ll / (ll + ll2) > 0.8:
m = msg_[0]["content"] m = msg_[0]["content"]
m = encoder.decode(encoder.encode(m)[:max_length - l]) m = encoder.decode(encoder.encode(m)[:max_length - ll2])
msg[0]["content"] = m msg[0]["content"] = m
return max_length, msg return max_length, msg
m = msg_[1]["content"] m = msg_[1]["content"]
m = encoder.decode(encoder.encode(m)[:max_length - l]) m = encoder.decode(encoder.encode(m)[:max_length - ll2])
msg[1]["content"] = m msg[1]["content"] = m
return max_length, msg return max_length, msg
@ -257,7 +257,8 @@ def chat(dialog, messages, stream=True, **kwargs):
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [ recall_docs = [
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
if not recall_docs: recall_docs = kbinfos["doc_aggs"] if not recall_docs:
recall_docs = kbinfos["doc_aggs"]
kbinfos["doc_aggs"] = recall_docs kbinfos["doc_aggs"] = recall_docs
refs = deepcopy(kbinfos) refs = deepcopy(kbinfos)
@ -433,13 +434,15 @@ def relevant(tenant_id, llm_id, question, contents: list):
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question. Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.
No other words needed except 'yes' or 'no'. No other words needed except 'yes' or 'no'.
""" """
if not contents:return False if not contents:
return False
contents = "Documents: \n" + " - ".join(contents) contents = "Documents: \n" + " - ".join(contents)
contents = f"Question: {question}\n" + contents contents = f"Question: {question}\n" + contents
if num_tokens_from_string(contents) >= chat_mdl.max_length - 4: if num_tokens_from_string(contents) >= chat_mdl.max_length - 4:
contents = encoder.decode(encoder.encode(contents)[:chat_mdl.max_length - 4]) contents = encoder.decode(encoder.encode(contents)[:chat_mdl.max_length - 4])
ans = chat_mdl.chat(prompt, [{"role": "user", "content": contents}], {"temperature": 0.01}) ans = chat_mdl.chat(prompt, [{"role": "user", "content": contents}], {"temperature": 0.01})
if ans.lower().find("yes") >= 0: return True if ans.lower().find("yes") >= 0:
return True
return False return False
@ -481,8 +484,10 @@ Requirements:
] ]
_, msg = message_fit_in(msg, chat_mdl.max_length) _, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2}) kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple): kwd = kwd[0] if isinstance(kwd, tuple):
if kwd.find("**ERROR**") >=0: return "" kwd = kwd[0]
if kwd.find("**ERROR**") >=0:
return ""
return kwd return kwd
@ -508,8 +513,10 @@ Requirements:
] ]
_, msg = message_fit_in(msg, chat_mdl.max_length) _, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2}) kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple): kwd = kwd[0] if isinstance(kwd, tuple):
if kwd.find("**ERROR**") >= 0: return "" kwd = kwd[0]
if kwd.find("**ERROR**") >= 0:
return ""
return kwd return kwd
@ -520,7 +527,8 @@ def full_question(tenant_id, llm_id, messages):
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
conv = [] conv = []
for m in messages: for m in messages:
if m["role"] not in ["user", "assistant"]: continue if m["role"] not in ["user", "assistant"]:
continue
conv.append("{}: {}".format(m["role"].upper(), m["content"])) conv.append("{}: {}".format(m["role"].upper(), m["content"]))
conv = "\n".join(conv) conv = "\n".join(conv)
today = datetime.date.today().isoformat() today = datetime.date.today().isoformat()
@ -581,7 +589,8 @@ Output: What's the weather in Rochester on {tomorrow}?
def tts(tts_mdl, text): def tts(tts_mdl, text):
if not tts_mdl or not text: return if not tts_mdl or not text:
return
bin = b"" bin = b""
for chunk in tts_mdl.tts(text): for chunk in tts_mdl.tts(text):
bin += chunk bin += chunk
@ -641,7 +650,8 @@ def ask(question, kb_ids, tenant_id):
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [ recall_docs = [
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
if not recall_docs: recall_docs = kbinfos["doc_aggs"] if not recall_docs:
recall_docs = kbinfos["doc_aggs"]
kbinfos["doc_aggs"] = recall_docs kbinfos["doc_aggs"] = recall_docs
refs = deepcopy(kbinfos) refs = deepcopy(kbinfos)
for c in refs["chunks"]: for c in refs["chunks"]:

View File

@ -532,7 +532,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
try: try:
mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output, mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output,
ensure_ascii=False, indent=2) ensure_ascii=False, indent=2)
if len(mind_map) < 32: raise Exception("Few content: " + mind_map) if len(mind_map) < 32:
raise Exception("Few content: " + mind_map)
cks.append({ cks.append({
"id": get_uuid(), "id": get_uuid(),
"doc_id": doc_id, "doc_id": doc_id,

View File

@ -20,7 +20,7 @@ from api.db.db_models import DB
from api.db.db_models import File, File2Document from api.db.db_models import File, File2Document
from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.utils import current_timestamp, datetime_format, get_uuid from api.utils import current_timestamp, datetime_format
class File2DocumentService(CommonService): class File2DocumentService(CommonService):
@ -63,7 +63,7 @@ class File2DocumentService(CommonService):
def update_by_file_id(cls, file_id, obj): def update_by_file_id(cls, file_id, obj):
obj["update_time"] = current_timestamp() obj["update_time"] = current_timestamp()
obj["update_date"] = datetime_format(datetime.now()) obj["update_date"] = datetime_format(datetime.now())
num = cls.model.update(obj).where(cls.model.id == file_id).execute() # num = cls.model.update(obj).where(cls.model.id == file_id).execute()
e, obj = cls.get_by_id(cls.model.id) e, obj = cls.get_by_id(cls.model.id)
return obj return obj

View File

@ -85,7 +85,8 @@ class FileService(CommonService):
.join(Document, on=(File2Document.document_id == Document.id)) .join(Document, on=(File2Document.document_id == Document.id))
.join(Knowledgebase, on=(Knowledgebase.id == Document.kb_id)) .join(Knowledgebase, on=(Knowledgebase.id == Document.kb_id))
.where(cls.model.id == file_id)) .where(cls.model.id == file_id))
if not kbs: return [] if not kbs:
return []
kbs_info_list = [] kbs_info_list = []
for kb in list(kbs.dicts()): for kb in list(kbs.dicts()):
kbs_info_list.append({"kb_id": kb['id'], "kb_name": kb['name']}) kbs_info_list.append({"kb_id": kb['id'], "kb_name": kb['name']})
@ -304,7 +305,8 @@ class FileService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def add_file_from_kb(cls, doc, kb_folder_id, tenant_id): def add_file_from_kb(cls, doc, kb_folder_id, tenant_id):
for _ in File2DocumentService.get_by_document_id(doc["id"]): return for _ in File2DocumentService.get_by_document_id(doc["id"]):
return
file = { file = {
"id": get_uuid(), "id": get_uuid(),
"parent_id": kb_folder_id, "parent_id": kb_folder_id,

View File

@ -107,7 +107,8 @@ class TenantLLMService(CommonService):
model_config = cls.get_api_key(tenant_id, mdlnm) model_config = cls.get_api_key(tenant_id, mdlnm)
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm) mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
if model_config: model_config = model_config.to_dict() if model_config:
model_config = model_config.to_dict()
if not model_config: if not model_config:
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]: if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid) llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)

View File

@ -57,28 +57,33 @@ class TaskService(CommonService):
Tenant.img2txt_id, Tenant.img2txt_id,
Tenant.asr_id, Tenant.asr_id,
Tenant.llm_id, Tenant.llm_id,
cls.model.update_time] cls.model.update_time,
docs = cls.model.select(*fields) \ ]
.join(Document, on=(cls.model.doc_id == Document.id)) \ docs = (
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id)) \ cls.model.select(*fields)
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \ .join(Document, on=(cls.model.doc_id == Document.id))
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where(cls.model.id == task_id) .where(cls.model.id == task_id)
)
docs = list(docs.dicts()) docs = list(docs.dicts())
if not docs: return None if not docs:
return None
msg = "\nTask has been received." msg = "\nTask has been received."
prog = random.random() / 10. prog = random.random() / 10.0
if docs[0]["retry_count"] >= 3: if docs[0]["retry_count"] >= 3:
msg = "\nERROR: Task is abandoned after 3 times attempts." msg = "\nERROR: Task is abandoned after 3 times attempts."
prog = -1 prog = -1
cls.model.update(progress_msg=cls.model.progress_msg + msg, cls.model.update(
progress=prog, progress_msg=cls.model.progress_msg + msg,
retry_count=docs[0]["retry_count"]+1 progress=prog,
).where( retry_count=docs[0]["retry_count"] + 1,
cls.model.id == docs[0]["id"]).execute() ).where(cls.model.id == docs[0]["id"]).execute()
if docs[0]["retry_count"] >= 3: return None if docs[0]["retry_count"] >= 3:
return None
return docs[0] return docs[0]
@ -86,21 +91,44 @@ class TaskService(CommonService):
@DB.connection_context() @DB.connection_context()
def get_ongoing_doc_name(cls): def get_ongoing_doc_name(cls):
with DB.lock("get_task", -1): with DB.lock("get_task", -1):
docs = cls.model.select(*[Document.id, Document.kb_id, Document.location, File.parent_id]) \ docs = (
.join(Document, on=(cls.model.doc_id == Document.id)) \ cls.model.select(
.join(File2Document, on=(File2Document.document_id == Document.id), join_type=JOIN.LEFT_OUTER) \ *[Document.id, Document.kb_id, Document.location, File.parent_id]
.join(File, on=(File2Document.file_id == File.id), join_type=JOIN.LEFT_OUTER) \ )
.join(Document, on=(cls.model.doc_id == Document.id))
.join(
File2Document,
on=(File2Document.document_id == Document.id),
join_type=JOIN.LEFT_OUTER,
)
.join(
File,
on=(File2Document.file_id == File.id),
join_type=JOIN.LEFT_OUTER,
)
.where( .where(
Document.status == StatusEnum.VALID.value, Document.status == StatusEnum.VALID.value,
Document.run == TaskStatus.RUNNING.value, Document.run == TaskStatus.RUNNING.value,
~(Document.type == FileType.VIRTUAL.value), ~(Document.type == FileType.VIRTUAL.value),
cls.model.progress < 1, cls.model.progress < 1,
cls.model.create_time >= current_timestamp() - 1000 * 600 cls.model.create_time >= current_timestamp() - 1000 * 600,
) )
)
docs = list(docs.dicts()) docs = list(docs.dicts())
if not docs: return [] if not docs:
return []
return list(set([(d["parent_id"] if d["parent_id"] else d["kb_id"], d["location"]) for d in docs])) return list(
set(
[
(
d["parent_id"] if d["parent_id"] else d["kb_id"],
d["location"],
)
for d in docs
]
)
)
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
@ -118,28 +146,30 @@ class TaskService(CommonService):
def update_progress(cls, id, info): def update_progress(cls, id, info):
if os.environ.get("MACOS"): if os.environ.get("MACOS"):
if info["progress_msg"]: if info["progress_msg"]:
cls.model.update(progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]).where( cls.model.update(
cls.model.id == id).execute() progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]
).where(cls.model.id == id).execute()
if "progress" in info: if "progress" in info:
cls.model.update(progress=info["progress"]).where( cls.model.update(progress=info["progress"]).where(
cls.model.id == id).execute() cls.model.id == id
).execute()
return return
with DB.lock("update_progress", -1): with DB.lock("update_progress", -1):
if info["progress_msg"]: if info["progress_msg"]:
cls.model.update(progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]).where( cls.model.update(
cls.model.id == id).execute() progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]
).where(cls.model.id == id).execute()
if "progress" in info: if "progress" in info:
cls.model.update(progress=info["progress"]).where( cls.model.update(progress=info["progress"]).where(
cls.model.id == id).execute() cls.model.id == id
).execute()
def queue_tasks(doc: dict, bucket: str, name: str): def queue_tasks(doc: dict, bucket: str, name: str):
def new_task(): def new_task():
return { return {"id": get_uuid(), "doc_id": doc["id"]}
"id": get_uuid(),
"doc_id": doc["id"]
}
tsks = [] tsks = []
if doc["type"] == FileType.PDF.value: if doc["type"] == FileType.PDF.value:
@ -150,8 +180,8 @@ def queue_tasks(doc: dict, bucket: str, name: str):
if doc["parser_id"] == "paper": if doc["parser_id"] == "paper":
page_size = doc["parser_config"].get("task_page_size", 22) page_size = doc["parser_config"].get("task_page_size", 22)
if doc["parser_id"] in ["one", "knowledge_graph"] or not do_layout: if doc["parser_id"] in ["one", "knowledge_graph"] or not do_layout:
page_size = 10 ** 9 page_size = 10**9
page_ranges = doc["parser_config"].get("pages") or [(1, 10 ** 5)] page_ranges = doc["parser_config"].get("pages") or [(1, 10**5)]
for s, e in page_ranges: for s, e in page_ranges:
s -= 1 s -= 1
s = max(0, s) s = max(0, s)
@ -177,4 +207,6 @@ def queue_tasks(doc: dict, bucket: str, name: str):
DocumentService.begin2parse(doc["id"]) DocumentService.begin2parse(doc["id"])
for t in tsks: for t in tsks:
assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=t), "Can't access Redis. Please check the Redis' status." assert REDIS_CONN.queue_product(
SVR_QUEUE_NAME, message=t
), "Can't access Redis. Please check the Redis' status."

View File

@ -22,7 +22,7 @@ from api.db import UserTenantRole
from api.db.db_models import DB, UserTenant from api.db.db_models import DB, UserTenant
from api.db.db_models import User, Tenant from api.db.db_models import User, Tenant
from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService
from api.utils import get_uuid, get_format_time, current_timestamp, datetime_format from api.utils import get_uuid, current_timestamp, datetime_format
from api.db import StatusEnum from api.db import StatusEnum

View File

@ -21,10 +21,7 @@
import logging import logging
import os import os
from api.utils.log_utils import initRootLogger from api.utils.log_utils import initRootLogger
LOG_LEVELS = os.environ.get("LOG_LEVELS", "")
initRootLogger("ragflow_server", LOG_LEVELS)
import os
import signal import signal
import sys import sys
import time import time
@ -44,6 +41,9 @@ from api.versions import get_ragflow_version
from api.utils import show_configs from api.utils import show_configs
from rag.settings import print_rag_settings from rag.settings import print_rag_settings
LOG_LEVELS = os.environ.get("LOG_LEVELS", "")
initRootLogger("ragflow_server", LOG_LEVELS)
def update_progress(): def update_progress():
while True: while True:

View File

@ -36,7 +36,6 @@ from werkzeug.http import HTTP_STATUS_CODES
from api.db.db_models import APIToken from api.db.db_models import APIToken
from api import settings from api import settings
from api import settings
from api.utils import CustomJSONEncoder, get_uuid from api.utils import CustomJSONEncoder, get_uuid
from api.utils import json_dumps from api.utils import json_dumps
from api.constants import REQUEST_WAIT_SEC, REQUEST_MAX_WAIT_SEC from api.constants import REQUEST_WAIT_SEC, REQUEST_MAX_WAIT_SEC

View File

@ -45,5 +45,5 @@ try:
pool = Pool(processes=1) pool = Pool(processes=1)
thread = pool.apply_async(download_nltk_data) thread = pool.apply_async(download_nltk_data)
binary = thread.get(timeout=60) binary = thread.get(timeout=60)
except Exception as e: except Exception:
print('\x1b[6;37;41m WARNING \x1b[0m' + "Downloading NLTK data failure.", flush=True) print('\x1b[6;37;41m WARNING \x1b[0m' + "Downloading NLTK data failure.", flush=True)

View File

@ -19,3 +19,15 @@ from .html_parser import RAGFlowHtmlParser as HtmlParser
from .json_parser import RAGFlowJsonParser as JsonParser from .json_parser import RAGFlowJsonParser as JsonParser
from .markdown_parser import RAGFlowMarkdownParser as MarkdownParser from .markdown_parser import RAGFlowMarkdownParser as MarkdownParser
from .txt_parser import RAGFlowTxtParser as TxtParser from .txt_parser import RAGFlowTxtParser as TxtParser
__all__ = [
"PdfParser",
"PlainParser",
"DocxParser",
"ExcelParser",
"PptParser",
"HtmlParser",
"JsonParser",
"MarkdownParser",
"TxtParser",
]

View File

@ -29,7 +29,8 @@ class RAGFlowExcelParser:
for sheetname in wb.sheetnames: for sheetname in wb.sheetnames:
ws = wb[sheetname] ws = wb[sheetname]
rows = list(ws.rows) rows = list(ws.rows)
if not rows: continue if not rows:
continue
tb_rows_0 = "<tr>" tb_rows_0 = "<tr>"
for t in list(rows[0]): for t in list(rows[0]):
@ -40,7 +41,9 @@ class RAGFlowExcelParser:
tb = "" tb = ""
tb += f"<table><caption>{sheetname}</caption>" tb += f"<table><caption>{sheetname}</caption>"
tb += tb_rows_0 tb += tb_rows_0
for r in list(rows[1 + chunk_i * chunk_rows:1 + (chunk_i + 1) * chunk_rows]): for r in list(
rows[1 + chunk_i * chunk_rows : 1 + (chunk_i + 1) * chunk_rows]
):
tb += "<tr>" tb += "<tr>"
for i, c in enumerate(r): for i, c in enumerate(r):
if c.value is None: if c.value is None:
@ -62,20 +65,21 @@ class RAGFlowExcelParser:
for sheetname in wb.sheetnames: for sheetname in wb.sheetnames:
ws = wb[sheetname] ws = wb[sheetname]
rows = list(ws.rows) rows = list(ws.rows)
if not rows:continue if not rows:
continue
ti = list(rows[0]) ti = list(rows[0])
for r in list(rows[1:]): for r in list(rows[1:]):
l = [] fields = []
for i, c in enumerate(r): for i, c in enumerate(r):
if not c.value: if not c.value:
continue continue
t = str(ti[i].value) if i < len(ti) else "" t = str(ti[i].value) if i < len(ti) else ""
t += ("" if t else "") + str(c.value) t += ("" if t else "") + str(c.value)
l.append(t) fields.append(t)
l = "; ".join(l) line = "; ".join(fields)
if sheetname.lower().find("sheet") < 0: if sheetname.lower().find("sheet") < 0:
l += " ——" + sheetname line += " ——" + sheetname
res.append(l) res.append(line)
return res return res
@staticmethod @staticmethod

View File

@ -36,7 +36,7 @@ class RAGFlowHtmlParser:
@classmethod @classmethod
def parser_txt(cls, txt): def parser_txt(cls, txt):
if type(txt) != str: if not isinstance(txt, str):
raise TypeError("txt type should be str!") raise TypeError("txt type should be str!")
html_doc = readability.Document(txt) html_doc = readability.Document(txt)
title = html_doc.title() title = html_doc.title()

View File

@ -22,7 +22,7 @@ class RAGFlowJsonParser:
txt = binary.decode(encoding, errors="ignore") txt = binary.decode(encoding, errors="ignore")
json_data = json.loads(txt) json_data = json.loads(txt)
chunks = self.split_json(json_data, True) chunks = self.split_json(json_data, True)
sections = [json.dumps(l, ensure_ascii=False) for l in chunks if l] sections = [json.dumps(line, ensure_ascii=False) for line in chunks if line]
return sections return sections
@staticmethod @staticmethod

View File

@ -752,7 +752,7 @@ class RAGFlowPdfParser:
"x1": np.max([b["x1"] for b in bxs]), "x1": np.max([b["x1"] for b in bxs]),
"bottom": np.max([b["bottom"] for b in bxs]) - ht "bottom": np.max([b["bottom"] for b in bxs]) - ht
} }
louts = [l for l in self.page_layout[pn] if l["type"] == ltype] louts = [layout for layout in self.page_layout[pn] if layout["type"] == ltype]
ii = Recognizer.find_overlapped(b, louts, naive=True) ii = Recognizer.find_overlapped(b, louts, naive=True)
if ii is not None: if ii is not None:
b = louts[ii] b = louts[ii]
@ -763,7 +763,8 @@ class RAGFlowPdfParser:
"layoutno", ""))) "layoutno", "")))
left, top, right, bott = b["x0"], b["top"], b["x1"], b["bottom"] left, top, right, bott = b["x0"], b["top"], b["x1"], b["bottom"]
if right < left: right = left + 1 if right < left:
right = left + 1
poss.append((pn + self.page_from, left, right, top, bott)) poss.append((pn + self.page_from, left, right, top, bott))
return self.page_images[pn] \ return self.page_images[pn] \
.crop((left * ZM, top * ZM, .crop((left * ZM, top * ZM,
@ -845,7 +846,8 @@ class RAGFlowPdfParser:
top = bx["top"] - self.page_cum_height[pn[0] - 1] top = bx["top"] - self.page_cum_height[pn[0] - 1]
bott = bx["bottom"] - self.page_cum_height[pn[0] - 1] bott = bx["bottom"] - self.page_cum_height[pn[0] - 1]
page_images_cnt = len(self.page_images) page_images_cnt = len(self.page_images)
if pn[-1] - 1 >= page_images_cnt: return "" if pn[-1] - 1 >= page_images_cnt:
return ""
while bott * ZM > self.page_images[pn[-1] - 1].size[1]: while bott * ZM > self.page_images[pn[-1] - 1].size[1]:
bott -= self.page_images[pn[-1] - 1].size[1] / ZM bott -= self.page_images[pn[-1] - 1].size[1] / ZM
pn.append(pn[-1] + 1) pn.append(pn[-1] + 1)
@ -889,7 +891,6 @@ class RAGFlowPdfParser:
nonlocal mh, pw, lines, widths nonlocal mh, pw, lines, widths
lines.append(line) lines.append(line)
widths.append(width(line)) widths.append(width(line))
width_mean = np.mean(widths)
mmj = self.proj_match( mmj = self.proj_match(
line["text"]) or line.get( line["text"]) or line.get(
"layout_type", "layout_type",
@ -994,7 +995,7 @@ class RAGFlowPdfParser:
else: else:
self.is_english = False self.is_english = False
st = timer() # st = timer()
for i, img in enumerate(self.page_images_x2): for i, img in enumerate(self.page_images_x2):
chars = self.page_chars[i] if not self.is_english else [] chars = self.page_chars[i] if not self.is_english else []
self.mean_height.append( self.mean_height.append(
@ -1028,8 +1029,8 @@ class RAGFlowPdfParser:
self.page_cum_height = np.cumsum(self.page_cum_height) self.page_cum_height = np.cumsum(self.page_cum_height)
assert len(self.page_cum_height) == len(self.page_images) + 1 assert len(self.page_cum_height) == len(self.page_images) + 1
if len(self.boxes) == 0 and zoomin < 9: self.__images__(fnm, zoomin * 3, page_from, if len(self.boxes) == 0 and zoomin < 9:
page_to, callback) self.__images__(fnm, zoomin * 3, page_from, page_to, callback)
def __call__(self, fnm, need_image=True, zoomin=3, return_html=False): def __call__(self, fnm, need_image=True, zoomin=3, return_html=False):
self.__images__(fnm, zoomin) self.__images__(fnm, zoomin)
@ -1168,7 +1169,7 @@ class PlainParser(object):
if not self.outlines: if not self.outlines:
logging.warning("Miss outlines") logging.warning("Miss outlines")
return [(l, "") for l in lines], [] return [(line, "") for line in lines], []
def crop(self, ck, need_position): def crop(self, ck, need_position):
raise NotImplementedError raise NotImplementedError

View File

@ -15,21 +15,42 @@ import datetime
def refactor(cv): def refactor(cv):
for n in ["raw_txt", "parser_name", "inference", "ori_text", "use_time", "time_stat"]: for n in [
if n in cv and cv[n] is not None: del cv[n] "raw_txt",
"parser_name",
"inference",
"ori_text",
"use_time",
"time_stat",
]:
if n in cv and cv[n] is not None:
del cv[n]
cv["is_deleted"] = 0 cv["is_deleted"] = 0
if "basic" not in cv: cv["basic"] = {} if "basic" not in cv:
if cv["basic"].get("photo2"): del cv["basic"]["photo2"] cv["basic"] = {}
if cv["basic"].get("photo2"):
del cv["basic"]["photo2"]
for n in ["education", "work", "certificate", "project", "language", "skill", "training"]: for n in [
if n not in cv or cv[n] is None: continue "education",
if type(cv[n]) == type({}): cv[n] = [v for _, v in cv[n].items()] "work",
if type(cv[n]) != type([]): "certificate",
"project",
"language",
"skill",
"training",
]:
if n not in cv or cv[n] is None:
continue
if isinstance(cv[n], dict):
cv[n] = [v for _, v in cv[n].items()]
if not isinstance(cv[n], list):
del cv[n] del cv[n]
continue continue
vv = [] vv = []
for v in cv[n]: for v in cv[n]:
if "external" in v and v["external"] is not None: del v["external"] if "external" in v and v["external"] is not None:
del v["external"]
vv.append(v) vv.append(v)
cv[n] = {str(i): vv[i] for i in range(len(vv))} cv[n] = {str(i): vv[i] for i in range(len(vv))}
@ -42,24 +63,44 @@ def refactor(cv):
cv["basic"][t] = cv["basic"][n] cv["basic"][t] = cv["basic"][n]
del cv["basic"][n] del cv["basic"][n]
work = sorted([v for _, v in cv.get("work", {}).items()], key=lambda x: x.get("start_time", "")) work = sorted(
edu = sorted([v for _, v in cv.get("education", {}).items()], key=lambda x: x.get("start_time", "")) [v for _, v in cv.get("work", {}).items()],
key=lambda x: x.get("start_time", ""),
)
edu = sorted(
[v for _, v in cv.get("education", {}).items()],
key=lambda x: x.get("start_time", ""),
)
if work: if work:
cv["basic"]["work_start_time"] = work[0].get("start_time", "") cv["basic"]["work_start_time"] = work[0].get("start_time", "")
cv["basic"]["management_experience"] = 'Y' if any( cv["basic"]["management_experience"] = (
[w.get("management_experience", '') == 'Y' for w in work]) else 'N' "Y"
if any([w.get("management_experience", "") == "Y" for w in work])
else "N"
)
cv["basic"]["annual_salary"] = work[-1].get("annual_salary_from", "0") cv["basic"]["annual_salary"] = work[-1].get("annual_salary_from", "0")
for n in ["annual_salary_from", "annual_salary_to", "industry_name", "position_name", "responsibilities", for n in [
"corporation_type", "scale", "corporation_name"]: "annual_salary_from",
"annual_salary_to",
"industry_name",
"position_name",
"responsibilities",
"corporation_type",
"scale",
"corporation_name",
]:
cv["basic"][n] = work[-1].get(n, "") cv["basic"][n] = work[-1].get(n, "")
if edu: if edu:
for n in ["school_name", "discipline_name"]: for n in ["school_name", "discipline_name"]:
if n in edu[-1]: cv["basic"][n] = edu[-1][n] if n in edu[-1]:
cv["basic"][n] = edu[-1][n]
cv["basic"]["updated_at"] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") cv["basic"]["updated_at"] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
if "contact" not in cv: cv["contact"] = {} if "contact" not in cv:
if not cv["contact"].get("name"): cv["contact"]["name"] = cv["basic"].get("name", "") cv["contact"] = {}
if not cv["contact"].get("name"):
cv["contact"]["name"] = cv["basic"].get("name", "")
return cv return cv

View File

@ -21,13 +21,18 @@ from . import regions
current_file_path = os.path.dirname(os.path.abspath(__file__)) current_file_path = os.path.dirname(os.path.abspath(__file__))
GOODS = pd.read_csv(os.path.join(current_file_path, "res/corp_baike_len.csv"), sep="\t", header=0).fillna(0) GOODS = pd.read_csv(
os.path.join(current_file_path, "res/corp_baike_len.csv"), sep="\t", header=0
).fillna(0)
GOODS["cid"] = GOODS["cid"].astype(str) GOODS["cid"] = GOODS["cid"].astype(str)
GOODS = GOODS.set_index(["cid"]) GOODS = GOODS.set_index(["cid"])
CORP_TKS = json.load(open(os.path.join(current_file_path, "res/corp.tks.freq.json"), "r")) CORP_TKS = json.load(
open(os.path.join(current_file_path, "res/corp.tks.freq.json"), "r")
)
GOOD_CORP = json.load(open(os.path.join(current_file_path, "res/good_corp.json"), "r")) GOOD_CORP = json.load(open(os.path.join(current_file_path, "res/good_corp.json"), "r"))
CORP_TAG = json.load(open(os.path.join(current_file_path, "res/corp_tag.json"), "r")) CORP_TAG = json.load(open(os.path.join(current_file_path, "res/corp_tag.json"), "r"))
def baike(cid, default_v=0): def baike(cid, default_v=0):
global GOODS global GOODS
try: try:
@ -39,27 +44,41 @@ def baike(cid, default_v=0):
def corpNorm(nm, add_region=True): def corpNorm(nm, add_region=True):
global CORP_TKS global CORP_TKS
if not nm or type(nm)!=type(""):return "" if not nm or isinstance(nm, str):
return ""
nm = rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(nm)).lower() nm = rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(nm)).lower()
nm = re.sub(r"&amp;", "&", nm) nm = re.sub(r"&amp;", "&", nm)
nm = re.sub(r"[\(\)\+'\"\t \*\\【】-]+", " ", nm) nm = re.sub(r"[\(\)\+'\"\t \*\\【】-]+", " ", nm)
nm = re.sub(r"([—-]+.*| +co\..*|corp\..*| +inc\..*| +ltd.*)", "", nm, 10000, re.IGNORECASE) nm = re.sub(
nm = re.sub(r"(计算机|技术|(技术|科技|网络)*有限公司|公司|有限|研发中心|中国|总部)$", "", nm, 10000, re.IGNORECASE) r"([—-]+.*| +co\..*|corp\..*| +inc\..*| +ltd.*)", "", nm, 10000, re.IGNORECASE
if not nm or (len(nm)<5 and not regions.isName(nm[0:2])):return nm )
nm = re.sub(
r"(计算机|技术|(技术|科技|网络)*有限公司|公司|有限|研发中心|中国|总部)$",
"",
nm,
10000,
re.IGNORECASE,
)
if not nm or (len(nm) < 5 and not regions.isName(nm[0:2])):
return nm
tks = rag_tokenizer.tokenize(nm).split() tks = rag_tokenizer.tokenize(nm).split()
reg = [t for i,t in enumerate(tks) if regions.isName(t) and (t != "中国" or i > 0)] reg = [t for i, t in enumerate(tks) if regions.isName(t) and (t != "中国" or i > 0)]
nm = "" nm = ""
for t in tks: for t in tks:
if regions.isName(t) or t in CORP_TKS:continue if regions.isName(t) or t in CORP_TKS:
if re.match(r"[0-9a-zA-Z\\,.]+", t) and re.match(r".*[0-9a-zA-Z\,.]+$", nm):nm += " " continue
if re.match(r"[0-9a-zA-Z\\,.]+", t) and re.match(r".*[0-9a-zA-Z\,.]+$", nm):
nm += " "
nm += t nm += t
r = re.search(r"^([^a-z0-9 \(\)&]{2,})[a-z ]{4,}$", nm.strip()) r = re.search(r"^([^a-z0-9 \(\)&]{2,})[a-z ]{4,}$", nm.strip())
if r:nm = r.group(1) if r:
nm = r.group(1)
r = re.search(r"^([a-z ]{3,})[^a-z0-9 \(\)&]{2,}$", nm.strip()) r = re.search(r"^([a-z ]{3,})[^a-z0-9 \(\)&]{2,}$", nm.strip())
if r:nm = r.group(1) if r:
return nm.strip() + (("" if not reg else "(%s)"%reg[0]) if add_region else "") nm = r.group(1)
return nm.strip() + (("" if not reg else "(%s)" % reg[0]) if add_region else "")
def rmNoise(n): def rmNoise(n):
@ -67,33 +86,40 @@ def rmNoise(n):
n = re.sub(r"[,. &()]+", "", n) n = re.sub(r"[,. &()]+", "", n)
return n return n
GOOD_CORP = set([corpNorm(rmNoise(c), False) for c in GOOD_CORP]) GOOD_CORP = set([corpNorm(rmNoise(c), False) for c in GOOD_CORP])
for c,v in CORP_TAG.items(): for c, v in CORP_TAG.items():
cc = corpNorm(rmNoise(c), False) cc = corpNorm(rmNoise(c), False)
if not cc: if not cc:
logging.debug(c) logging.debug(c)
CORP_TAG = {corpNorm(rmNoise(c), False):v for c,v in CORP_TAG.items()} CORP_TAG = {corpNorm(rmNoise(c), False): v for c, v in CORP_TAG.items()}
def is_good(nm): def is_good(nm):
global GOOD_CORP global GOOD_CORP
if nm.find("外派")>=0:return False if nm.find("外派") >= 0:
return False
nm = rmNoise(nm) nm = rmNoise(nm)
nm = corpNorm(nm, False) nm = corpNorm(nm, False)
for n in GOOD_CORP: for n in GOOD_CORP:
if re.match(r"[0-9a-zA-Z]+$", n): if re.match(r"[0-9a-zA-Z]+$", n):
if n == nm: return True if n == nm:
elif nm.find(n)>=0:return True return True
elif nm.find(n) >= 0:
return True
return False return False
def corp_tag(nm): def corp_tag(nm):
global CORP_TAG global CORP_TAG
nm = rmNoise(nm) nm = rmNoise(nm)
nm = corpNorm(nm, False) nm = corpNorm(nm, False)
for n in CORP_TAG.keys(): for n in CORP_TAG.keys():
if re.match(r"[0-9a-zA-Z., ]+$", n): if re.match(r"[0-9a-zA-Z., ]+$", n):
if n == nm: return CORP_TAG[n] if n == nm:
elif nm.find(n)>=0: return CORP_TAG[n]
if len(n)<3 and len(nm)/len(n)>=2:continue elif nm.find(n) >= 0:
if len(n) < 3 and len(nm) / len(n) >= 2:
continue
return CORP_TAG[n] return CORP_TAG[n]
return [] return []

View File

@ -11,27 +11,31 @@
# limitations under the License. # limitations under the License.
# #
TBL = {"94":"EMBA", TBL = {
"6":"MBA", "94": "EMBA",
"95":"MPA", "6": "MBA",
"92":"专升本", "95": "MPA",
"4":"专科", "92": "专升本",
"90":"中专", "4": "专科",
"91":"中技", "90": "中专",
"86":"初中", "91": "中技",
"3":"博士", "86": "初中",
"10":"博士后", "3": "博士",
"1":"本科", "10": "博士后",
"2":"硕士", "1": "本科",
"87":"职高", "2": "硕士",
"89":"高中" "87": "职高",
"89": "高中",
} }
TBL_ = {v:k for k,v in TBL.items()} TBL_ = {v: k for k, v in TBL.items()}
def get_name(id): def get_name(id):
return TBL.get(str(id), "") return TBL.get(str(id), "")
def get_id(nm): def get_id(nm):
if not nm:return "" if not nm:
return ""
return TBL_.get(nm.upper().strip(), "") return TBL_.get(nm.upper().strip(), "")

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -16,8 +16,11 @@ import json
import re import re
import copy import copy
import pandas as pd import pandas as pd
current_file_path = os.path.dirname(os.path.abspath(__file__)) current_file_path = os.path.dirname(os.path.abspath(__file__))
TBL = pd.read_csv(os.path.join(current_file_path, "res/schools.csv"), sep="\t", header=0).fillna("") TBL = pd.read_csv(
os.path.join(current_file_path, "res/schools.csv"), sep="\t", header=0
).fillna("")
TBL["name_en"] = TBL["name_en"].map(lambda x: x.lower().strip()) TBL["name_en"] = TBL["name_en"].map(lambda x: x.lower().strip())
GOOD_SCH = json.load(open(os.path.join(current_file_path, "res/good_sch.json"), "r")) GOOD_SCH = json.load(open(os.path.join(current_file_path, "res/good_sch.json"), "r"))
GOOD_SCH = set([re.sub(r"[,. &()]+", "", c) for c in GOOD_SCH]) GOOD_SCH = set([re.sub(r"[,. &()]+", "", c) for c in GOOD_SCH])
@ -26,14 +29,15 @@ GOOD_SCH = set([re.sub(r"[,. &()]+", "", c) for c in GOOD_SCH])
def loadRank(fnm): def loadRank(fnm):
global TBL global TBL
TBL["rank"] = 1000000 TBL["rank"] = 1000000
with open(fnm, "r", encoding='utf-8') as f: with open(fnm, "r", encoding="utf-8") as f:
while True: while True:
l = f.readline() line = f.readline()
if not l:break if not line:
l = l.strip("\n").split(",") break
line = line.strip("\n").split(",")
try: try:
nm,rk = l[0].strip(),int(l[1]) nm, rk = line[0].strip(), int(line[1])
#assert len(TBL[((TBL.name_cn == nm) | (TBL.name_en == nm))]),f"<{nm}>" # assert len(TBL[((TBL.name_cn == nm) | (TBL.name_en == nm))]),f"<{nm}>"
TBL.loc[((TBL.name_cn == nm) | (TBL.name_en == nm)), "rank"] = rk TBL.loc[((TBL.name_cn == nm) | (TBL.name_en == nm)), "rank"] = rk
except Exception: except Exception:
pass pass
@ -44,27 +48,35 @@ loadRank(os.path.join(current_file_path, "res/school.rank.csv"))
def split(txt): def split(txt):
tks = [] tks = []
for t in re.sub(r"[ \t]+", " ",txt).split(): for t in re.sub(r"[ \t]+", " ", txt).split():
if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and \ if (
re.match(r"[a-zA-Z]", t) and tks: tks
and re.match(r".*[a-zA-Z]$", tks[-1])
and re.match(r"[a-zA-Z]", t)
and tks
):
tks[-1] = tks[-1] + " " + t tks[-1] = tks[-1] + " " + t
else:tks.append(t) else:
tks.append(t)
return tks return tks
def select(nm): def select(nm):
global TBL global TBL
if not nm:return if not nm:
if isinstance(nm, list):nm = str(nm[0]) return
if isinstance(nm, list):
nm = str(nm[0])
nm = split(nm)[0] nm = split(nm)[0]
nm = str(nm).lower().strip() nm = str(nm).lower().strip()
nm = re.sub(r"[(][^()]+[)]", "", nm.lower()) nm = re.sub(r"[(][^()]+[)]", "", nm.lower())
nm = re.sub(r"(^the |[,.&();;·]+|^(英国|美国|瑞士))", "", nm) nm = re.sub(r"(^the |[,.&();;·]+|^(英国|美国|瑞士))", "", nm)
nm = re.sub(r"大学.*学院", "大学", nm) nm = re.sub(r"大学.*学院", "大学", nm)
tbl = copy.deepcopy(TBL) tbl = copy.deepcopy(TBL)
tbl["hit_alias"] = tbl["alias"].map(lambda x:nm in set(x.split("+"))) tbl["hit_alias"] = tbl["alias"].map(lambda x: nm in set(x.split("+")))
res = tbl[((tbl.name_cn == nm) | (tbl.name_en == nm) | (tbl.hit_alias == True))] res = tbl[((tbl.name_cn == nm) | (tbl.name_en == nm) | tbl.hit_alias)]
if res.empty:return if res.empty:
return
return json.loads(res.to_json(orient="records"))[0] return json.loads(res.to_json(orient="records"))[0]
@ -74,4 +86,3 @@ def is_good(nm):
nm = re.sub(r"[(][^()]+[)]", "", nm.lower()) nm = re.sub(r"[(][^()]+[)]", "", nm.lower())
nm = re.sub(r"[''`‘’“”,. &();]+", "", nm) nm = re.sub(r"[''`‘’“”,. &();]+", "", nm)
return nm in GOOD_SCH return nm in GOOD_SCH

View File

@ -25,7 +25,8 @@ from xpinyin import Pinyin
from contextlib import contextmanager from contextlib import contextmanager
class TimeoutException(Exception): pass class TimeoutException(Exception):
pass
@contextmanager @contextmanager
@ -50,8 +51,10 @@ def rmHtmlTag(line):
def highest_degree(dg): def highest_degree(dg):
if not dg: return "" if not dg:
if type(dg) == type(""): dg = [dg] return ""
if isinstance(dg, str):
dg = [dg]
m = {"初中": 0, "高中": 1, "中专": 2, "大专": 3, "专升本": 4, "本科": 5, "硕士": 6, "博士": 7, "博士后": 8} m = {"初中": 0, "高中": 1, "中专": 2, "大专": 3, "专升本": 4, "本科": 5, "硕士": 6, "博士": 7, "博士后": 8}
return sorted([(d, m.get(d, -1)) for d in dg], key=lambda x: x[1] * -1)[0][0] return sorted([(d, m.get(d, -1)) for d in dg], key=lambda x: x[1] * -1)[0][0]
@ -68,10 +71,12 @@ def forEdu(cv):
for ii, n in enumerate(sorted(cv["education_obj"], key=lambda x: x.get("start_time", "3"))): for ii, n in enumerate(sorted(cv["education_obj"], key=lambda x: x.get("start_time", "3"))):
e = {} e = {}
if n.get("end_time"): if n.get("end_time"):
if n["end_time"] > edu_end_dt: edu_end_dt = n["end_time"] if n["end_time"] > edu_end_dt:
edu_end_dt = n["end_time"]
try: try:
dt = n["end_time"] dt = n["end_time"]
if re.match(r"[0-9]{9,}", dt): dt = turnTm2Dt(dt) if re.match(r"[0-9]{9,}", dt):
dt = turnTm2Dt(dt)
y, m, d = getYMD(dt) y, m, d = getYMD(dt)
ed_dt.append(str(y)) ed_dt.append(str(y))
e["end_dt_kwd"] = str(y) e["end_dt_kwd"] = str(y)
@ -80,7 +85,8 @@ def forEdu(cv):
if n.get("start_time"): if n.get("start_time"):
try: try:
dt = n["start_time"] dt = n["start_time"]
if re.match(r"[0-9]{9,}", dt): dt = turnTm2Dt(dt) if re.match(r"[0-9]{9,}", dt):
dt = turnTm2Dt(dt)
y, m, d = getYMD(dt) y, m, d = getYMD(dt)
st_dt.append(str(y)) st_dt.append(str(y))
e["start_dt_kwd"] = str(y) e["start_dt_kwd"] = str(y)
@ -89,13 +95,20 @@ def forEdu(cv):
r = schools.select(n.get("school_name", "")) r = schools.select(n.get("school_name", ""))
if r: if r:
if str(r.get("type", "")) == "1": fea.append("211") if str(r.get("type", "")) == "1":
if str(r.get("type", "")) == "2": fea.append("211") fea.append("211")
if str(r.get("is_abroad", "")) == "1": fea.append("留学") if str(r.get("type", "")) == "2":
if str(r.get("is_double_first", "")) == "1": fea.append("双一流") fea.append("211")
if str(r.get("is_985", "")) == "1": fea.append("985") if str(r.get("is_abroad", "")) == "1":
if str(r.get("is_world_known", "")) == "1": fea.append("海外知名") fea.append("留学")
if r.get("rank") and cv["school_rank_int"] > r["rank"]: cv["school_rank_int"] = r["rank"] if str(r.get("is_double_first", "")) == "1":
fea.append("双一流")
if str(r.get("is_985", "")) == "1":
fea.append("985")
if str(r.get("is_world_known", "")) == "1":
fea.append("海外知名")
if r.get("rank") and cv["school_rank_int"] > r["rank"]:
cv["school_rank_int"] = r["rank"]
if n.get("school_name") and isinstance(n["school_name"], str): if n.get("school_name") and isinstance(n["school_name"], str):
sch.append(re.sub(r"(211|985|重点大学|[,&;-])", "", n["school_name"])) sch.append(re.sub(r"(211|985|重点大学|[,&;-])", "", n["school_name"]))
@ -106,22 +119,25 @@ def forEdu(cv):
maj.append(n["discipline_name"]) maj.append(n["discipline_name"])
e["major_kwd"] = n["discipline_name"] e["major_kwd"] = n["discipline_name"]
if not n.get("degree") and "985" in fea and not first_fea: n["degree"] = "1" if not n.get("degree") and "985" in fea and not first_fea:
n["degree"] = "1"
if n.get("degree"): if n.get("degree"):
d = degrees.get_name(n["degree"]) d = degrees.get_name(n["degree"])
if d: e["degree_kwd"] = d if d:
if d == "本科" and ("专科" in deg or "专升本" in deg or "中专" in deg or "大专" in deg or re.search(r"(成人|自考|自学考试)", e["degree_kwd"] = d
n.get( if d == "本科" and ("专科" in deg or "专升本" in deg or "中专" in deg or "大专" in deg or re.search(r"(成人|自考|自学考试)", n.get("school_name",""))):
"school_name", d = "专升本"
""))): d = "专升本" if d:
if d: deg.append(d) deg.append(d)
# for first degree # for first degree
if not fdeg and d in ["中专", "专升本", "专科", "本科", "大专"]: if not fdeg and d in ["中专", "专升本", "专科", "本科", "大专"]:
fdeg = [d] fdeg = [d]
if n.get("school_name"): fsch = [n["school_name"]] if n.get("school_name"):
if n.get("discipline_name"): fmaj = [n["discipline_name"]] fsch = [n["school_name"]]
if n.get("discipline_name"):
fmaj = [n["discipline_name"]]
first_fea = copy.deepcopy(fea) first_fea = copy.deepcopy(fea)
edu_nst.append(e) edu_nst.append(e)
@ -140,16 +156,26 @@ def forEdu(cv):
else: else:
cv["sch_rank_kwd"].append("一般学校") cv["sch_rank_kwd"].append("一般学校")
if edu_nst: cv["edu_nst"] = edu_nst if edu_nst:
if fea: cv["edu_fea_kwd"] = list(set(fea)) cv["edu_nst"] = edu_nst
if first_fea: cv["edu_first_fea_kwd"] = list(set(first_fea)) if fea:
if maj: cv["major_kwd"] = maj cv["edu_fea_kwd"] = list(set(fea))
if fsch: cv["first_school_name_kwd"] = fsch if first_fea:
if fdeg: cv["first_degree_kwd"] = fdeg cv["edu_first_fea_kwd"] = list(set(first_fea))
if fmaj: cv["first_major_kwd"] = fmaj if maj:
if st_dt: cv["edu_start_kwd"] = st_dt cv["major_kwd"] = maj
if ed_dt: cv["edu_end_kwd"] = ed_dt if fsch:
if ed_dt: cv["edu_end_int"] = max([int(t) for t in ed_dt]) cv["first_school_name_kwd"] = fsch
if fdeg:
cv["first_degree_kwd"] = fdeg
if fmaj:
cv["first_major_kwd"] = fmaj
if st_dt:
cv["edu_start_kwd"] = st_dt
if ed_dt:
cv["edu_end_kwd"] = ed_dt
if ed_dt:
cv["edu_end_int"] = max([int(t) for t in ed_dt])
if deg: if deg:
if "本科" in deg and "专科" in deg: if "本科" in deg and "专科" in deg:
deg.append("专升本") deg.append("专升本")
@ -158,8 +184,10 @@ def forEdu(cv):
cv["highest_degree_kwd"] = highest_degree(deg) cv["highest_degree_kwd"] = highest_degree(deg)
if edu_end_dt: if edu_end_dt:
try: try:
if re.match(r"[0-9]{9,}", edu_end_dt): edu_end_dt = turnTm2Dt(edu_end_dt) if re.match(r"[0-9]{9,}", edu_end_dt):
if edu_end_dt.strip("\n") == "至今": edu_end_dt = cv.get("updated_at_dt", str(datetime.date.today())) edu_end_dt = turnTm2Dt(edu_end_dt)
if edu_end_dt.strip("\n") == "至今":
edu_end_dt = cv.get("updated_at_dt", str(datetime.date.today()))
y, m, d = getYMD(edu_end_dt) y, m, d = getYMD(edu_end_dt)
cv["work_exp_flt"] = min(int(str(datetime.date.today())[0:4]) - int(y), cv.get("work_exp_flt", 1000)) cv["work_exp_flt"] = min(int(str(datetime.date.today())[0:4]) - int(y), cv.get("work_exp_flt", 1000))
except Exception as e: except Exception as e:
@ -171,7 +199,8 @@ def forEdu(cv):
or not cv.get("degree_kwd"): or not cv.get("degree_kwd"):
for c in sch: for c in sch:
if schools.is_good(c): if schools.is_good(c):
if "tag_kwd" not in cv: cv["tag_kwd"] = [] if "tag_kwd" not in cv:
cv["tag_kwd"] = []
cv["tag_kwd"].append("好学校") cv["tag_kwd"].append("好学校")
cv["tag_kwd"].append("好学历") cv["tag_kwd"].append("好学历")
break break
@ -180,28 +209,39 @@ def forEdu(cv):
any([d.lower() in ["硕士", "博士", "mba", "博士"] for d in cv.get("degree_kwd", [])])) \ any([d.lower() in ["硕士", "博士", "mba", "博士"] for d in cv.get("degree_kwd", [])])) \
or all([d.lower() in ["硕士", "博士", "mba", "博士后"] for d in cv.get("degree_kwd", [])]) \ or all([d.lower() in ["硕士", "博士", "mba", "博士后"] for d in cv.get("degree_kwd", [])]) \
or any([d in ["mba", "emba", "博士后"] for d in cv.get("degree_kwd", [])]): or any([d in ["mba", "emba", "博士后"] for d in cv.get("degree_kwd", [])]):
if "tag_kwd" not in cv: cv["tag_kwd"] = [] if "tag_kwd" not in cv:
if "好学历" not in cv["tag_kwd"]: cv["tag_kwd"].append("好学历") cv["tag_kwd"] = []
if "好学历" not in cv["tag_kwd"]:
cv["tag_kwd"].append("好学历")
if cv.get("major_kwd"): cv["major_tks"] = rag_tokenizer.tokenize(" ".join(maj)) if cv.get("major_kwd"):
if cv.get("school_name_kwd"): cv["school_name_tks"] = rag_tokenizer.tokenize(" ".join(sch)) cv["major_tks"] = rag_tokenizer.tokenize(" ".join(maj))
if cv.get("first_school_name_kwd"): cv["first_school_name_tks"] = rag_tokenizer.tokenize(" ".join(fsch)) if cv.get("school_name_kwd"):
if cv.get("first_major_kwd"): cv["first_major_tks"] = rag_tokenizer.tokenize(" ".join(fmaj)) cv["school_name_tks"] = rag_tokenizer.tokenize(" ".join(sch))
if cv.get("first_school_name_kwd"):
cv["first_school_name_tks"] = rag_tokenizer.tokenize(" ".join(fsch))
if cv.get("first_major_kwd"):
cv["first_major_tks"] = rag_tokenizer.tokenize(" ".join(fmaj))
return cv return cv
def forProj(cv): def forProj(cv):
if not cv.get("project_obj"): return cv if not cv.get("project_obj"):
return cv
pro_nms, desc = [], [] pro_nms, desc = [], []
for i, n in enumerate( for i, n in enumerate(
sorted(cv.get("project_obj", []), key=lambda x: str(x.get("updated_at", "")) if type(x) == type({}) else "", sorted(cv.get("project_obj", []), key=lambda x: str(x.get("updated_at", "")) if isinstance(x, dict) else "",
reverse=True)): reverse=True)):
if n.get("name"): pro_nms.append(n["name"]) if n.get("name"):
if n.get("describe"): desc.append(str(n["describe"])) pro_nms.append(n["name"])
if n.get("responsibilities"): desc.append(str(n["responsibilities"])) if n.get("describe"):
if n.get("achivement"): desc.append(str(n["achivement"])) desc.append(str(n["describe"]))
if n.get("responsibilities"):
desc.append(str(n["responsibilities"]))
if n.get("achivement"):
desc.append(str(n["achivement"]))
if pro_nms: if pro_nms:
# cv["pro_nms_tks"] = rag_tokenizer.tokenize(" ".join(pro_nms)) # cv["pro_nms_tks"] = rag_tokenizer.tokenize(" ".join(pro_nms))
@ -233,15 +273,16 @@ def forWork(cv):
work_st_tm = "" work_st_tm = ""
corp_tags = [] corp_tags = []
for i, n in enumerate( for i, n in enumerate(
sorted(cv.get("work_obj", []), key=lambda x: str(x.get("start_time", "")) if type(x) == type({}) else "", sorted(cv.get("work_obj", []), key=lambda x: str(x.get("start_time", "")) if isinstance(x, dict) else "",
reverse=True)): reverse=True)):
if type(n) == type(""): if isinstance(n, str):
try: try:
n = json_loads(n) n = json_loads(n)
except Exception: except Exception:
continue continue
if n.get("start_time") and (not work_st_tm or n["start_time"] < work_st_tm): work_st_tm = n["start_time"] if n.get("start_time") and (not work_st_tm or n["start_time"] < work_st_tm):
work_st_tm = n["start_time"]
for c in flds: for c in flds:
if not n.get(c) or str(n[c]) == '0': if not n.get(c) or str(n[c]) == '0':
fea[c].append("") fea[c].append("")
@ -262,14 +303,18 @@ def forWork(cv):
fea[c].append(rmHtmlTag(str(n[c]).lower())) fea[c].append(rmHtmlTag(str(n[c]).lower()))
y, m, d = getYMD(n.get("start_time")) y, m, d = getYMD(n.get("start_time"))
if not y or not m: continue if not y or not m:
continue
st = "%s-%02d-%02d" % (y, int(m), int(d)) st = "%s-%02d-%02d" % (y, int(m), int(d))
latest_job_tm = st latest_job_tm = st
y, m, d = getYMD(n.get("end_time")) y, m, d = getYMD(n.get("end_time"))
if (not y or not m) and i > 0: continue if (not y or not m) and i > 0:
if not y or not m or int(y) > 2022: y, m, d = getYMD(str(n.get("updated_at", ""))) continue
if not y or not m: continue if not y or not m or int(y) > 2022:
y, m, d = getYMD(str(n.get("updated_at", "")))
if not y or not m:
continue
ed = "%s-%02d-%02d" % (y, int(m), int(d)) ed = "%s-%02d-%02d" % (y, int(m), int(d))
try: try:
@ -279,22 +324,28 @@ def forWork(cv):
if n.get("scale"): if n.get("scale"):
r = re.search(r"^([0-9]+)", str(n["scale"])) r = re.search(r"^([0-9]+)", str(n["scale"]))
if r: scales.append(int(r.group(1))) if r:
scales.append(int(r.group(1)))
if goodcorp: if goodcorp:
if "tag_kwd" not in cv: cv["tag_kwd"] = [] if "tag_kwd" not in cv:
cv["tag_kwd"] = []
cv["tag_kwd"].append("好公司") cv["tag_kwd"].append("好公司")
if goodcorp_: if goodcorp_:
if "tag_kwd" not in cv: cv["tag_kwd"] = [] if "tag_kwd" not in cv:
cv["tag_kwd"] = []
cv["tag_kwd"].append("好公司(曾)") cv["tag_kwd"].append("好公司(曾)")
if corp_tags: if corp_tags:
if "tag_kwd" not in cv: cv["tag_kwd"] = [] if "tag_kwd" not in cv:
cv["tag_kwd"] = []
cv["tag_kwd"].extend(corp_tags) cv["tag_kwd"].extend(corp_tags)
cv["corp_tag_kwd"] = [c for c in corp_tags if re.match(r"(综合|行业)", c)] cv["corp_tag_kwd"] = [c for c in corp_tags if re.match(r"(综合|行业)", c)]
if latest_job_tm: cv["latest_job_dt"] = latest_job_tm if latest_job_tm:
if fea["corporation_id"]: cv["corporation_id"] = fea["corporation_id"] cv["latest_job_dt"] = latest_job_tm
if fea["corporation_id"]:
cv["corporation_id"] = fea["corporation_id"]
if fea["position_name"]: if fea["position_name"]:
cv["position_name_tks"] = rag_tokenizer.tokenize(fea["position_name"][0]) cv["position_name_tks"] = rag_tokenizer.tokenize(fea["position_name"][0])
@ -317,18 +368,23 @@ def forWork(cv):
cv["responsibilities_ltks"] = rag_tokenizer.tokenize(fea["responsibilities"][0]) cv["responsibilities_ltks"] = rag_tokenizer.tokenize(fea["responsibilities"][0])
cv["resp_ltks"] = rag_tokenizer.tokenize(" ".join(fea["responsibilities"][1:])) cv["resp_ltks"] = rag_tokenizer.tokenize(" ".join(fea["responsibilities"][1:]))
if fea["subordinates_count"]: fea["subordinates_count"] = [int(i) for i in fea["subordinates_count"] if if fea["subordinates_count"]:
fea["subordinates_count"] = [int(i) for i in fea["subordinates_count"] if
re.match(r"[^0-9]+$", str(i))] re.match(r"[^0-9]+$", str(i))]
if fea["subordinates_count"]: cv["max_sub_cnt_int"] = np.max(fea["subordinates_count"]) if fea["subordinates_count"]:
cv["max_sub_cnt_int"] = np.max(fea["subordinates_count"])
if type(cv.get("corporation_id")) == type(1): cv["corporation_id"] = [str(cv["corporation_id"])] if isinstance(cv.get("corporation_id"), int):
if not cv.get("corporation_id"): cv["corporation_id"] = [] cv["corporation_id"] = [str(cv["corporation_id"])]
if not cv.get("corporation_id"):
cv["corporation_id"] = []
for i in cv.get("corporation_id", []): for i in cv.get("corporation_id", []):
cv["baike_flt"] = max(corporations.baike(i), cv["baike_flt"] if "baike_flt" in cv else 0) cv["baike_flt"] = max(corporations.baike(i), cv["baike_flt"] if "baike_flt" in cv else 0)
if work_st_tm: if work_st_tm:
try: try:
if re.match(r"[0-9]{9,}", work_st_tm): work_st_tm = turnTm2Dt(work_st_tm) if re.match(r"[0-9]{9,}", work_st_tm):
work_st_tm = turnTm2Dt(work_st_tm)
y, m, d = getYMD(work_st_tm) y, m, d = getYMD(work_st_tm)
cv["work_exp_flt"] = min(int(str(datetime.date.today())[0:4]) - int(y), cv.get("work_exp_flt", 1000)) cv["work_exp_flt"] = min(int(str(datetime.date.today())[0:4]) - int(y), cv.get("work_exp_flt", 1000))
except Exception as e: except Exception as e:
@ -339,28 +395,37 @@ def forWork(cv):
cv["dua_flt"] = np.mean(duas) cv["dua_flt"] = np.mean(duas)
cv["cur_dua_int"] = duas[0] cv["cur_dua_int"] = duas[0]
cv["job_num_int"] = len(duas) cv["job_num_int"] = len(duas)
if scales: cv["scale_flt"] = np.max(scales) if scales:
cv["scale_flt"] = np.max(scales)
return cv return cv
def turnTm2Dt(b): def turnTm2Dt(b):
if not b: return if not b:
return
b = str(b).strip() b = str(b).strip()
if re.match(r"[0-9]{10,}", b): b = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(b[:10]))) if re.match(r"[0-9]{10,}", b):
b = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(b[:10])))
return b return b
def getYMD(b): def getYMD(b):
y, m, d = "", "", "01" y, m, d = "", "", "01"
if not b: return (y, m, d) if not b:
return (y, m, d)
b = turnTm2Dt(b) b = turnTm2Dt(b)
if re.match(r"[0-9]{4}", b): y = int(b[:4]) if re.match(r"[0-9]{4}", b):
y = int(b[:4])
r = re.search(r"[0-9]{4}.?([0-9]{1,2})", b) r = re.search(r"[0-9]{4}.?([0-9]{1,2})", b)
if r: m = r.group(1) if r:
m = r.group(1)
r = re.search(r"[0-9]{4}.?[0-9]{,2}.?([0-9]{1,2})", b) r = re.search(r"[0-9]{4}.?[0-9]{,2}.?([0-9]{1,2})", b)
if r: d = r.group(1) if r:
if not d or int(d) == 0 or int(d) > 31: d = "1" d = r.group(1)
if not m or int(m) > 12 or int(m) < 1: m = "1" if not d or int(d) == 0 or int(d) > 31:
d = "1"
if not m or int(m) > 12 or int(m) < 1:
m = "1"
return (y, m, d) return (y, m, d)
@ -369,7 +434,8 @@ def birth(cv):
cv["integerity_flt"] *= 0.9 cv["integerity_flt"] *= 0.9
return cv return cv
y, m, d = getYMD(cv["birth"]) y, m, d = getYMD(cv["birth"])
if not m or not y: return cv if not m or not y:
return cv
b = "%s-%02d-%02d" % (y, int(m), int(d)) b = "%s-%02d-%02d" % (y, int(m), int(d))
cv["birth_dt"] = b cv["birth_dt"] = b
cv["birthday_kwd"] = "%02d%02d" % (int(m), int(d)) cv["birthday_kwd"] = "%02d%02d" % (int(m), int(d))
@ -380,7 +446,8 @@ def birth(cv):
def parse(cv): def parse(cv):
for k in cv.keys(): for k in cv.keys():
if cv[k] == '\\N': cv[k] = '' if cv[k] == '\\N':
cv[k] = ''
# cv = cv.asDict() # cv = cv.asDict()
tks_fld = ["address", "corporation_name", "discipline_name", "email", "expect_city_names", tks_fld = ["address", "corporation_name", "discipline_name", "email", "expect_city_names",
"expect_industry_name", "expect_position_name", "industry_name", "industry_names", "name", "expect_industry_name", "expect_position_name", "industry_name", "industry_names", "name",
@ -402,9 +469,12 @@ def parse(cv):
rmkeys = [] rmkeys = []
for k in cv.keys(): for k in cv.keys():
if cv[k] is None: rmkeys.append(k) if cv[k] is None:
if (type(cv[k]) == type([]) or type(cv[k]) == type("")) and len(cv[k]) == 0: rmkeys.append(k) rmkeys.append(k)
for k in rmkeys: del cv[k] if (isinstance(cv[k], list) or isinstance(cv[k], str)) and len(cv[k]) == 0:
rmkeys.append(k)
for k in rmkeys:
del cv[k]
integerity = 0. integerity = 0.
flds_num = 0. flds_num = 0.
@ -414,7 +484,8 @@ def parse(cv):
flds_num += len(flds) flds_num += len(flds)
for f in flds: for f in flds:
v = str(cv.get(f, "")) v = str(cv.get(f, ""))
if len(v) > 0 and v != '0' and v != '[]': integerity += 1 if len(v) > 0 and v != '0' and v != '[]':
integerity += 1
hasValues(tks_fld) hasValues(tks_fld)
hasValues(small_tks_fld) hasValues(small_tks_fld)
@ -433,7 +504,8 @@ def parse(cv):
(r"[ \(\)人/·0-9-]+", ""), (r"[ \(\)人/·0-9-]+", ""),
(r".*(元|规模|于|=|北京|上海|至今|中国|工资|州|shanghai|强|餐饮|融资|职).*", "")]: (r".*(元|规模|于|=|北京|上海|至今|中国|工资|州|shanghai|强|餐饮|融资|职).*", "")]:
cv["corporation_type"] = re.sub(p, r, cv["corporation_type"], 1000, re.IGNORECASE) cv["corporation_type"] = re.sub(p, r, cv["corporation_type"], 1000, re.IGNORECASE)
if len(cv["corporation_type"]) < 2: del cv["corporation_type"] if len(cv["corporation_type"]) < 2:
del cv["corporation_type"]
if cv.get("political_status"): if cv.get("political_status"):
for p, r in [ for p, r in [
@ -441,9 +513,11 @@ def parse(cv):
(r".*(无党派|公民).*", "群众"), (r".*(无党派|公民).*", "群众"),
(r".*团员.*", "团员")]: (r".*团员.*", "团员")]:
cv["political_status"] = re.sub(p, r, cv["political_status"]) cv["political_status"] = re.sub(p, r, cv["political_status"])
if not re.search(r"[党团群]", cv["political_status"]): del cv["political_status"] if not re.search(r"[党团群]", cv["political_status"]):
del cv["political_status"]
if cv.get("phone"): cv["phone"] = re.sub(r"^0*86([0-9]{11})", r"\1", re.sub(r"[^0-9]+", "", cv["phone"])) if cv.get("phone"):
cv["phone"] = re.sub(r"^0*86([0-9]{11})", r"\1", re.sub(r"[^0-9]+", "", cv["phone"]))
keys = list(cv.keys()) keys = list(cv.keys())
for k in keys: for k in keys:
@ -454,9 +528,11 @@ def parse(cv):
cv[k] = [a for _, a in cv[k].items()] cv[k] = [a for _, a in cv[k].items()]
nms = [] nms = []
for n in cv[k]: for n in cv[k]:
if type(n) != type({}) or "name" not in n or not n.get("name"): continue if not isinstance(n, dict) or "name" not in n or not n.get("name"):
continue
n["name"] = re.sub(r"(442|\t )", "", n["name"]).strip().lower() n["name"] = re.sub(r"(442|\t )", "", n["name"]).strip().lower()
if not n["name"]: continue if not n["name"]:
continue
nms.append(n["name"]) nms.append(n["name"])
if nms: if nms:
t = k[:-4] t = k[:-4]
@ -469,15 +545,18 @@ def parse(cv):
# tokenize fields # tokenize fields
if k in tks_fld: if k in tks_fld:
cv[f"{k}_tks"] = rag_tokenizer.tokenize(cv[k]) cv[f"{k}_tks"] = rag_tokenizer.tokenize(cv[k])
if k in small_tks_fld: cv[f"{k}_sm_tks"] = rag_tokenizer.tokenize(cv[f"{k}_tks"]) if k in small_tks_fld:
cv[f"{k}_sm_tks"] = rag_tokenizer.tokenize(cv[f"{k}_tks"])
# keyword fields # keyword fields
if k in kwd_fld: cv[f"{k}_kwd"] = [n.lower() if k in kwd_fld:
cv[f"{k}_kwd"] = [n.lower()
for n in re.split(r"[\t,;. ]", for n in re.split(r"[\t,;. ]",
re.sub(r"([^a-zA-Z])[ ]+([^a-zA-Z ])", r"\1\2", cv[k]) re.sub(r"([^a-zA-Z])[ ]+([^a-zA-Z ])", r"\1\2", cv[k])
) if n] ) if n]
if k in num_fld and cv.get(k): cv[f"{k}_int"] = cv[k] if k in num_fld and cv.get(k):
cv[f"{k}_int"] = cv[k]
cv["email_kwd"] = cv.get("email_tks", "").replace(" ", "") cv["email_kwd"] = cv.get("email_tks", "").replace(" ", "")
# for name field # for name field
@ -501,10 +580,12 @@ def parse(cv):
cv["name_py_pref0_tks"] = "" cv["name_py_pref0_tks"] = ""
cv["name_py_pref_tks"] = "" cv["name_py_pref_tks"] = ""
for py in PY.get_pinyins(nm[:20], ''): for py in PY.get_pinyins(nm[:20], ''):
for i in range(2, len(py) + 1): cv["name_py_pref_tks"] += " " + py[:i] for i in range(2, len(py) + 1):
cv["name_py_pref_tks"] += " " + py[:i]
for py in PY.get_pinyins(nm[:20], ' '): for py in PY.get_pinyins(nm[:20], ' '):
py = py.split() py = py.split()
for i in range(1, len(py) + 1): cv["name_py_pref0_tks"] += " " + "".join(py[:i]) for i in range(1, len(py) + 1):
cv["name_py_pref0_tks"] += " " + "".join(py[:i])
cv["name_kwd"] = name cv["name_kwd"] = name
cv["name_pinyin_kwd"] = PY.get_pinyins(nm[:20], ' ')[:3] cv["name_pinyin_kwd"] = PY.get_pinyins(nm[:20], ' ')[:3]
@ -526,22 +607,30 @@ def parse(cv):
cv["updated_at_dt"] = cv["updated_at"].strftime('%Y-%m-%d %H:%M:%S') cv["updated_at_dt"] = cv["updated_at"].strftime('%Y-%m-%d %H:%M:%S')
else: else:
y, m, d = getYMD(str(cv.get("updated_at", ""))) y, m, d = getYMD(str(cv.get("updated_at", "")))
if not y: y = "2012" if not y:
if not m: m = "01" y = "2012"
if not d: d = "01" if not m:
m = "01"
if not d:
d = "01"
cv["updated_at_dt"] = "%s-%02d-%02d 00:00:00" % (y, int(m), int(d)) cv["updated_at_dt"] = "%s-%02d-%02d 00:00:00" % (y, int(m), int(d))
# long text tokenize # long text tokenize
if cv.get("responsibilities"): cv["responsibilities_ltks"] = rag_tokenizer.tokenize(rmHtmlTag(cv["responsibilities"])) if cv.get("responsibilities"):
cv["responsibilities_ltks"] = rag_tokenizer.tokenize(rmHtmlTag(cv["responsibilities"]))
# for yes or no field # for yes or no field
fea = [] fea = []
for f, y, n in is_fld: for f, y, n in is_fld:
if f not in cv: continue if f not in cv:
if cv[f] == '': fea.append(y) continue
if cv[f] == '': fea.append(n) if cv[f] == '':
fea.append(y)
if cv[f] == '':
fea.append(n)
if fea: cv["tag_kwd"] = fea if fea:
cv["tag_kwd"] = fea
cv = forEdu(cv) cv = forEdu(cv)
cv = forProj(cv) cv = forProj(cv)
@ -550,9 +639,11 @@ def parse(cv):
cv["corp_proj_sch_deg_kwd"] = [c for c in cv.get("corp_tag_kwd", [])] cv["corp_proj_sch_deg_kwd"] = [c for c in cv.get("corp_tag_kwd", [])]
for i in range(len(cv["corp_proj_sch_deg_kwd"])): for i in range(len(cv["corp_proj_sch_deg_kwd"])):
for j in cv.get("sch_rank_kwd", []): cv["corp_proj_sch_deg_kwd"][i] += "+" + j for j in cv.get("sch_rank_kwd", []):
cv["corp_proj_sch_deg_kwd"][i] += "+" + j
for i in range(len(cv["corp_proj_sch_deg_kwd"])): for i in range(len(cv["corp_proj_sch_deg_kwd"])):
if cv.get("highest_degree_kwd"): cv["corp_proj_sch_deg_kwd"][i] += "+" + cv["highest_degree_kwd"] if cv.get("highest_degree_kwd"):
cv["corp_proj_sch_deg_kwd"][i] += "+" + cv["highest_degree_kwd"]
try: try:
if not cv.get("work_exp_flt") and cv.get("work_start_time"): if not cv.get("work_exp_flt") and cv.get("work_start_time"):
@ -565,17 +656,21 @@ def parse(cv):
cv["work_exp_flt"] = int(str(datetime.date.today())[0:4]) - int(y) cv["work_exp_flt"] = int(str(datetime.date.today())[0:4]) - int(y)
except Exception as e: except Exception as e:
logging.exception("parse {} ==> {}".format(e, cv.get("work_start_time"))) logging.exception("parse {} ==> {}".format(e, cv.get("work_start_time")))
if "work_exp_flt" not in cv and cv.get("work_experience", 0): cv["work_exp_flt"] = int(cv["work_experience"]) / 12. if "work_exp_flt" not in cv and cv.get("work_experience", 0):
cv["work_exp_flt"] = int(cv["work_experience"]) / 12.
keys = list(cv.keys()) keys = list(cv.keys())
for k in keys: for k in keys:
if not re.search(r"_(fea|tks|nst|dt|int|flt|ltks|kwd|id)$", k): del cv[k] if not re.search(r"_(fea|tks|nst|dt|int|flt|ltks|kwd|id)$", k):
del cv[k]
for k in cv.keys(): for k in cv.keys():
if not re.search("_(kwd|id)$", k) or type(cv[k]) != type([]): continue if not re.search("_(kwd|id)$", k) or not isinstance(cv[k], list):
continue
cv[k] = list(set([re.sub("(市)$", "", str(n)) for n in cv[k] if n not in ['中国', '0']])) cv[k] = list(set([re.sub("(市)$", "", str(n)) for n in cv[k] if n not in ['中国', '0']]))
keys = [k for k in cv.keys() if re.search(r"_feas*$", k)] keys = [k for k in cv.keys() if re.search(r"_feas*$", k)]
for k in keys: for k in keys:
if cv[k] <= 0: del cv[k] if cv[k] <= 0:
del cv[k]
cv["tob_resume_id"] = str(cv["tob_resume_id"]) cv["tob_resume_id"] = str(cv["tob_resume_id"])
cv["id"] = cv["tob_resume_id"] cv["id"] = cv["tob_resume_id"]
@ -592,5 +687,6 @@ def dealWithInt64(d):
if isinstance(d, list): if isinstance(d, list):
d = [dealWithInt64(t) for t in d] d = [dealWithInt64(t) for t in d]
if isinstance(d, np.integer): d = int(d) if isinstance(d, np.integer):
d = int(d)
return d return d

View File

@ -51,6 +51,7 @@ class RAGFlowTxtParser:
dels = [d for d in dels if d] dels = [d for d in dels if d]
dels = "|".join(dels) dels = "|".join(dels)
secs = re.split(r"(%s)" % dels, txt) secs = re.split(r"(%s)" % dels, txt)
for sec in secs: add_chunk(sec) for sec in secs:
add_chunk(sec)
return [[c, ""] for c in cks] return [[c, ""] for c in cks]

View File

@ -18,7 +18,6 @@ from .recognizer import Recognizer
from .layout_recognizer import LayoutRecognizer from .layout_recognizer import LayoutRecognizer
from .table_structure_recognizer import TableStructureRecognizer from .table_structure_recognizer import TableStructureRecognizer
def init_in_out(args): def init_in_out(args):
from PIL import Image from PIL import Image
import os import os
@ -47,7 +46,7 @@ def init_in_out(args):
try: try:
images.append(Image.open(fnm)) images.append(Image.open(fnm))
outputs.append(os.path.split(fnm)[-1]) outputs.append(os.path.split(fnm)[-1])
except Exception as e: except Exception:
traceback.print_exc() traceback.print_exc()
if os.path.isdir(args.inputs): if os.path.isdir(args.inputs):
@ -56,6 +55,16 @@ def init_in_out(args):
else: else:
images_and_outputs(args.inputs) images_and_outputs(args.inputs)
for i in range(len(outputs)): outputs[i] = os.path.join(args.output_dir, outputs[i]) for i in range(len(outputs)):
outputs[i] = os.path.join(args.output_dir, outputs[i])
return images, outputs return images, outputs
__all__ = [
"OCR",
"Recognizer",
"LayoutRecognizer",
"TableStructureRecognizer",
"init_in_out",
]

View File

@ -42,7 +42,7 @@ class LayoutRecognizer(Recognizer):
get_project_base_directory(), get_project_base_directory(),
"rag/res/deepdoc") "rag/res/deepdoc")
super().__init__(self.labels, domain, model_dir) super().__init__(self.labels, domain, model_dir)
except Exception as e: except Exception:
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc", model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
local_dir_use_symlinks=False) local_dir_use_symlinks=False)
@ -77,7 +77,7 @@ class LayoutRecognizer(Recognizer):
"page_number": pn, "page_number": pn,
} for b in lts if float(b["score"]) >= 0.8 or b["type"] not in self.garbage_layouts] } for b in lts if float(b["score"]) >= 0.8 or b["type"] not in self.garbage_layouts]
lts = self.sort_Y_firstly(lts, np.mean( lts = self.sort_Y_firstly(lts, np.mean(
[l["bottom"] - l["top"] for l in lts]) / 2) [lt["bottom"] - lt["top"] for lt in lts]) / 2)
lts = self.layouts_cleanup(bxs, lts) lts = self.layouts_cleanup(bxs, lts)
page_layout.append(lts) page_layout.append(lts)

View File

@ -19,7 +19,9 @@ from huggingface_hub import snapshot_download
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from .operators import * from .operators import *
import math
import numpy as np import numpy as np
import cv2
import onnxruntime as ort import onnxruntime as ort
from .postprocess import build_post_process from .postprocess import build_post_process
@ -484,7 +486,7 @@ class OCR(object):
"rag/res/deepdoc") "rag/res/deepdoc")
self.text_detector = TextDetector(model_dir) self.text_detector = TextDetector(model_dir)
self.text_recognizer = TextRecognizer(model_dir) self.text_recognizer = TextRecognizer(model_dir)
except Exception as e: except Exception:
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc", model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
local_dir_use_symlinks=False) local_dir_use_symlinks=False)

View File

@ -232,7 +232,7 @@ class LinearResize(object):
""" """
assert len(self.target_size) == 2 assert len(self.target_size) == 2
assert self.target_size[0] > 0 and self.target_size[1] > 0 assert self.target_size[0] > 0 and self.target_size[1] > 0
im_channel = im.shape[2] _im_channel = im.shape[2]
im_scale_y, im_scale_x = self.generate_scale(im) im_scale_y, im_scale_x = self.generate_scale(im)
im = cv2.resize( im = cv2.resize(
im, im,
@ -255,7 +255,7 @@ class LinearResize(object):
im_scale_y: the resize ratio of Y im_scale_y: the resize ratio of Y
""" """
origin_shape = im.shape[:2] origin_shape = im.shape[:2]
im_c = im.shape[2] _im_c = im.shape[2]
if self.keep_ratio: if self.keep_ratio:
im_size_min = np.min(origin_shape) im_size_min = np.min(origin_shape)
im_size_max = np.max(origin_shape) im_size_max = np.max(origin_shape)
@ -581,7 +581,7 @@ class SRResize(object):
return data return data
images_HR = data["image_hr"] images_HR = data["image_hr"]
label_strs = data["label"] _label_strs = data["label"]
transform = ResizeNormalize((imgW, imgH)) transform = ResizeNormalize((imgW, imgH))
images_HR = transform(images_HR) images_HR = transform(images_HR)
data["img_hr"] = images_HR data["img_hr"] = images_HR

View File

@ -121,7 +121,7 @@ class DBPostProcess(object):
outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
cv2.CHAIN_APPROX_SIMPLE) cv2.CHAIN_APPROX_SIMPLE)
if len(outs) == 3: if len(outs) == 3:
img, contours, _ = outs[0], outs[1], outs[2] _img, contours, _ = outs[0], outs[1], outs[2]
elif len(outs) == 2: elif len(outs) == 2:
contours, _ = outs[0], outs[1] contours, _ = outs[0], outs[1]

View File

@ -13,15 +13,18 @@
import logging import logging
import os import os
import math
import numpy as np
import cv2
from copy import deepcopy from copy import deepcopy
import onnxruntime as ort import onnxruntime as ort
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from .operators import * from .operators import *
class Recognizer(object): class Recognizer(object):
def __init__(self, label_list, task_name, model_dir=None): def __init__(self, label_list, task_name, model_dir=None):
""" """
@ -277,7 +280,8 @@ class Recognizer(object):
return return
min_dis, min_i = 1000000, None min_dis, min_i = 1000000, None
for i,b in enumerate(boxes): for i,b in enumerate(boxes):
if box.get("layoutno", "0") != b.get("layoutno", "0"): continue if box.get("layoutno", "0") != b.get("layoutno", "0"):
continue
dis = min(abs(box["x0"] - b["x0"]), abs(box["x1"] - b["x1"]), abs(box["x0"]+box["x1"] - b["x1"] - b["x0"])/2) dis = min(abs(box["x0"] - b["x0"]), abs(box["x1"] - b["x1"]), abs(box["x0"]+box["x1"] - b["x1"] - b["x0"])/2)
if dis < min_dis: if dis < min_dis:
min_i = i min_i = i
@ -402,7 +406,8 @@ class Recognizer(object):
scores = np.max(boxes[:, 4:], axis=1) scores = np.max(boxes[:, 4:], axis=1)
boxes = boxes[scores > thr, :] boxes = boxes[scores > thr, :]
scores = scores[scores > thr] scores = scores[scores > thr]
if len(boxes) == 0: return [] if len(boxes) == 0:
return []
# Get the class with the highest confidence # Get the class with the highest confidence
class_ids = np.argmax(boxes[:, 4:], axis=1) class_ids = np.argmax(boxes[:, 4:], axis=1)
@ -432,7 +437,8 @@ class Recognizer(object):
for i in range(len(image_list)): for i in range(len(image_list)):
if not isinstance(image_list[i], np.ndarray): if not isinstance(image_list[i], np.ndarray):
imgs.append(np.array(image_list[i])) imgs.append(np.array(image_list[i]))
else: imgs.append(image_list[i]) else:
imgs.append(image_list[i])
batch_loop_cnt = math.ceil(float(len(imgs)) / batch_size) batch_loop_cnt = math.ceil(float(len(imgs)) / batch_size)
for i in range(batch_loop_cnt): for i in range(batch_loop_cnt):

View File

@ -88,7 +88,8 @@ class CommunityReportsExtractor:
("findings", list), ("findings", list),
("rating", float), ("rating", float),
("rating_explanation", str), ("rating_explanation", str),
]): continue ]):
continue
response["weight"] = weight response["weight"] = weight
response["entities"] = ents response["entities"] = ents
except Exception as e: except Exception as e:
@ -100,7 +101,8 @@ class CommunityReportsExtractor:
res_str.append(self._get_text_output(response)) res_str.append(self._get_text_output(response))
res_dict.append(response) res_dict.append(response)
over += 1 over += 1
if callback: callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}") if callback:
callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}")
return CommunityReportsResult( return CommunityReportsResult(
structured_output=res_dict, structured_output=res_dict,

View File

@ -8,6 +8,7 @@ Reference:
from typing import Any from typing import Any
import numpy as np import numpy as np
import networkx as nx import networkx as nx
from dataclasses import dataclass
from graphrag.leiden import stable_largest_connected_component from graphrag.leiden import stable_largest_connected_component

View File

@ -129,9 +129,11 @@ class GraphExtractor:
source_doc_map[doc_index] = text source_doc_map[doc_index] = text
all_records[doc_index] = result all_records[doc_index] = result
total_token_count += token_count total_token_count += token_count
if callback: callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}") if callback:
callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}")
except Exception as e: except Exception as e:
if callback: callback(msg="Knowledge graph extraction error:{}".format(str(e))) if callback:
callback(msg="Knowledge graph extraction error:{}".format(str(e)))
logging.exception("error extracting graph") logging.exception("error extracting graph")
self._on_error( self._on_error(
e, e,
@ -164,7 +166,8 @@ class GraphExtractor:
text = perform_variable_replacements(self._extraction_prompt, variables=variables) text = perform_variable_replacements(self._extraction_prompt, variables=variables)
gen_conf = {"temperature": 0.3} gen_conf = {"temperature": 0.3}
response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf) response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
if response.find("**ERROR**") >= 0: raise Exception(response) if response.find("**ERROR**") >= 0:
raise Exception(response)
token_count = num_tokens_from_string(text + response) token_count = num_tokens_from_string(text + response)
results = response or "" results = response or ""
@ -175,7 +178,8 @@ class GraphExtractor:
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables) text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
history.append({"role": "user", "content": text}) history.append({"role": "user", "content": text})
response = self._llm.chat("", history, gen_conf) response = self._llm.chat("", history, gen_conf)
if response.find("**ERROR**") >=0: raise Exception(response) if response.find("**ERROR**") >=0:
raise Exception(response)
results += response or "" results += response or ""
# if this is the final glean, don't bother updating the continuation flag # if this is the final glean, don't bother updating the continuation flag

View File

@ -134,7 +134,8 @@ def build_knowledge_graph_chunks(tenant_id: str, chunks: list[str], callback, en
callback(0.75, "Extracting mind graph.") callback(0.75, "Extracting mind graph.")
mindmap = MindMapExtractor(llm_bdl) mindmap = MindMapExtractor(llm_bdl)
mg = mindmap(_chunks).output mg = mindmap(_chunks).output
if not len(mg.keys()): return chunks if not len(mg.keys()):
return chunks
logging.debug(json.dumps(mg, ensure_ascii=False, indent=2)) logging.debug(json.dumps(mg, ensure_ascii=False, indent=2))
chunks.append( chunks.append(

View File

@ -78,7 +78,8 @@ def _compute_leiden_communities(
) -> dict[int, dict[str, int]]: ) -> dict[int, dict[str, int]]:
"""Return Leiden root communities.""" """Return Leiden root communities."""
results: dict[int, dict[str, int]] = {} results: dict[int, dict[str, int]] = {}
if is_empty(graph): return results if is_empty(graph):
return results
if use_lcc: if use_lcc:
graph = stable_largest_connected_component(graph) graph = stable_largest_connected_component(graph)
@ -100,7 +101,8 @@ def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]:
logging.debug( logging.debug(
"Running leiden with max_cluster_size=%s, lcc=%s", max_cluster_size, use_lcc "Running leiden with max_cluster_size=%s, lcc=%s", max_cluster_size, use_lcc
) )
if not graph.nodes(): return {} if not graph.nodes():
return {}
node_id_to_community_map = _compute_leiden_communities( node_id_to_community_map = _compute_leiden_communities(
graph=graph, graph=graph,
@ -125,9 +127,11 @@ def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]:
result[community_id]["nodes"].append(node_id) result[community_id]["nodes"].append(node_id)
result[community_id]["weight"] += graph.nodes[node_id].get("rank", 0) * graph.nodes[node_id].get("weight", 1) result[community_id]["weight"] += graph.nodes[node_id].get("rank", 0) * graph.nodes[node_id].get("weight", 1)
weights = [comm["weight"] for _, comm in result.items()] weights = [comm["weight"] for _, comm in result.items()]
if not weights:continue if not weights:
continue
max_weight = max(weights) max_weight = max(weights)
for _, comm in result.items(): comm["weight"] /= max_weight for _, comm in result.items():
comm["weight"] /= max_weight
return results_by_level return results_by_level

View File

@ -1 +1,5 @@
from .ragflow_chat import * from .ragflow_chat import RAGFlowChat
__all__ = [
"RAGFlowChat"
]

View File

@ -2,7 +2,6 @@ import logging
import requests import requests
from bridge.context import ContextType # Import Context, ContextType from bridge.context import ContextType # Import Context, ContextType
from bridge.reply import Reply, ReplyType # Import Reply, ReplyType from bridge.reply import Reply, ReplyType # Import Reply, ReplyType
from bridge import *
from plugins import Plugin, register # Import Plugin and register from plugins import Plugin, register # Import Plugin and register
from plugins.event import Event, EventContext, EventAction # Import event-related classes from plugins.event import Event, EventContext, EventAction # Import event-related classes

View File

@ -94,7 +94,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
txt = get_text(filename, binary) txt = get_text(filename, binary)
sections = txt.split("\n") sections = txt.split("\n")
sections = [(l, "") for l in sections if l] sections = [(line, "") for line in sections if line]
remove_contents_table(sections, eng=is_english( remove_contents_table(sections, eng=is_english(
random_choices([t for t, _ in sections], k=200))) random_choices([t for t, _ in sections], k=200)))
callback(0.8, "Finish parsing.") callback(0.8, "Finish parsing.")
@ -102,7 +102,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE): elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
sections = HtmlParser()(filename, binary) sections = HtmlParser()(filename, binary)
sections = [(l, "") for l in sections if l] sections = [(line, "") for line in sections if line]
remove_contents_table(sections, eng=is_english( remove_contents_table(sections, eng=is_english(
random_choices([t for t, _ in sections], k=200))) random_choices([t for t, _ in sections], k=200)))
callback(0.8, "Finish parsing.") callback(0.8, "Finish parsing.")
@ -112,7 +112,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
binary = BytesIO(binary) binary = BytesIO(binary)
doc_parsed = parser.from_buffer(binary) doc_parsed = parser.from_buffer(binary)
sections = doc_parsed['content'].split('\n') sections = doc_parsed['content'].split('\n')
sections = [(l, "") for l in sections if l] sections = [(line, "") for line in sections if line]
remove_contents_table(sections, eng=is_english( remove_contents_table(sections, eng=is_english(
random_choices([t for t, _ in sections], k=200))) random_choices([t for t, _ in sections], k=200)))
callback(0.8, "Finish parsing.") callback(0.8, "Finish parsing.")

View File

@ -75,7 +75,7 @@ def chunk(
_add_content(msg, msg.get_content_type()) _add_content(msg, msg.get_content_type())
sections = TxtParser.parser_txt("\n".join(text_txt)) + [ sections = TxtParser.parser_txt("\n".join(text_txt)) + [
(l, "") for l in HtmlParser.parser_txt("\n".join(html_txt)) if l (line, "") for line in HtmlParser.parser_txt("\n".join(html_txt)) if line
] ]
st = timer() st = timer()

View File

@ -18,7 +18,8 @@ def chunk(filename, binary, tenant_id, from_page=0, to_page=100000,
chunks = build_knowledge_graph_chunks(tenant_id, sections, callback, chunks = build_knowledge_graph_chunks(tenant_id, sections, callback,
parser_config.get("entity_types", ["organization", "person", "location", "event", "time"]) parser_config.get("entity_types", ["organization", "person", "location", "event", "time"])
) )
for c in chunks: c["docnm_kwd"] = filename for c in chunks:
c["docnm_kwd"] = filename
doc = { doc = {
"docnm_kwd": filename, "docnm_kwd": filename,

View File

@ -48,7 +48,7 @@ class Docx(DocxParser):
continue continue
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml: if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
pn += 1 pn += 1
return [l for l in lines if l] return [line for line in lines if line]
def __call__(self, filename, binary=None, from_page=0, to_page=100000): def __call__(self, filename, binary=None, from_page=0, to_page=100000):
self.doc = Document( self.doc = Document(
@ -60,7 +60,8 @@ class Docx(DocxParser):
if pn > to_page: if pn > to_page:
break break
question_level, p_text = docx_question_level(p, bull) question_level, p_text = docx_question_level(p, bull)
if not p_text.strip("\n"):continue if not p_text.strip("\n"):
continue
lines.append((question_level, p_text)) lines.append((question_level, p_text))
for run in p.runs: for run in p.runs:
@ -78,19 +79,21 @@ class Docx(DocxParser):
if lines[e][0] <= lines[s][0]: if lines[e][0] <= lines[s][0]:
break break
e += 1 e += 1
if e - s == 1 and visit[s]: continue if e - s == 1 and visit[s]:
continue
sec = [] sec = []
next_level = lines[s][0] + 1 next_level = lines[s][0] + 1
while not sec and next_level < 22: while not sec and next_level < 22:
for i in range(s+1, e): for i in range(s+1, e):
if lines[i][0] != next_level: continue if lines[i][0] != next_level:
continue
sec.append(lines[i][1]) sec.append(lines[i][1])
visit[i] = True visit[i] = True
next_level += 1 next_level += 1
sec.insert(0, lines[s][1]) sec.insert(0, lines[s][1])
sections.append("\n".join(sec)) sections.append("\n".join(sec))
return [l for l in sections if l] return [s for s in sections if s]
def __str__(self) -> str: def __str__(self) -> str:
return f''' return f'''
@ -168,13 +171,13 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
txt = get_text(filename, binary) txt = get_text(filename, binary)
sections = txt.split("\n") sections = txt.split("\n")
sections = [l for l in sections if l] sections = [s for s in sections if s]
callback(0.8, "Finish parsing.") callback(0.8, "Finish parsing.")
elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE): elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
sections = HtmlParser()(filename, binary) sections = HtmlParser()(filename, binary)
sections = [l for l in sections if l] sections = [s for s in sections if s]
callback(0.8, "Finish parsing.") callback(0.8, "Finish parsing.")
elif re.search(r"\.doc$", filename, re.IGNORECASE): elif re.search(r"\.doc$", filename, re.IGNORECASE):
@ -182,7 +185,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
binary = BytesIO(binary) binary = BytesIO(binary)
doc_parsed = parser.from_buffer(binary) doc_parsed = parser.from_buffer(binary)
sections = doc_parsed['content'].split('\n') sections = doc_parsed['content'].split('\n')
sections = [l for l in sections if l] sections = [s for s in sections if s]
callback(0.8, "Finish parsing.") callback(0.8, "Finish parsing.")
else: else:

View File

@ -190,7 +190,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
sections, tbls = pdf_parser(filename if not binary else binary, sections, tbls = pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback) from_page=from_page, to_page=to_page, callback=callback)
if sections and len(sections[0]) < 3: if sections and len(sections[0]) < 3:
sections = [(t, l, [[0] * 5]) for t, l in sections] sections = [(t, lvl, [[0] * 5]) for t, lvl in sections]
# set pivot using the most frequent type of title, # set pivot using the most frequent type of title,
# then merge between 2 pivot # then merge between 2 pivot
if len(sections) > 0 and len(pdf_parser.outlines) / len(sections) > 0.1: if len(sections) > 0 and len(pdf_parser.outlines) / len(sections) > 0.1:
@ -211,7 +211,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
else: else:
bull = bullets_category([txt for txt, _, _ in sections]) bull = bullets_category([txt for txt, _, _ in sections])
most_level, levels = title_frequency( most_level, levels = title_frequency(
bull, [(txt, l) for txt, l, poss in sections]) bull, [(txt, lvl) for txt, lvl, _ in sections])
assert len(sections) == len(levels) assert len(sections) == len(levels)
sec_ids = [] sec_ids = []
@ -225,7 +225,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
sections = [(txt, sec_ids[i], poss) sections = [(txt, sec_ids[i], poss)
for i, (txt, _, poss) in enumerate(sections)] for i, (txt, _, poss) in enumerate(sections)]
for (img, rows), poss in tbls: for (img, rows), poss in tbls:
if not rows: continue if not rows:
continue
sections.append((rows if isinstance(rows, str) else rows[0], -1, sections.append((rows if isinstance(rows, str) else rows[0], -1,
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss])) [(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))

View File

@ -54,7 +54,8 @@ class Pdf(PdfParser):
sections = [(b["text"], self.get_position(b, zoomin)) sections = [(b["text"], self.get_position(b, zoomin))
for i, b in enumerate(self.boxes)] for i, b in enumerate(self.boxes)]
for (img, rows), poss in tbls: for (img, rows), poss in tbls:
if not rows:continue if not rows:
continue
sections.append((rows if isinstance(rows, str) else rows[0], sections.append((rows if isinstance(rows, str) else rows[0],
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss])) [(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
return [(txt, "") for txt, _ in sorted(sections, key=lambda x: ( return [(txt, "") for txt, _ in sorted(sections, key=lambda x: (
@ -109,7 +110,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
binary = BytesIO(binary) binary = BytesIO(binary)
doc_parsed = parser.from_buffer(binary) doc_parsed = parser.from_buffer(binary)
sections = doc_parsed['content'].split('\n') sections = doc_parsed['content'].split('\n')
sections = [l for l in sections if l] sections = [s for s in sections if s]
callback(0.8, "Finish parsing.") callback(0.8, "Finish parsing.")
else: else:

View File

@ -171,7 +171,7 @@ class Pdf(PdfParser):
tbl_bottom = tbls[tbl_index][1][0][4] tbl_bottom = tbls[tbl_index][1][0][4]
tbl_tag = "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##" \ tbl_tag = "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##" \
.format(tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom) .format(tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom)
tbl_text = ''.join(tbls[tbl_index][0][1]) _tbl_text = ''.join(tbls[tbl_index][0][1])
return tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, return tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag,
@ -325,9 +325,11 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
txt = get_text(filename, binary) txt = get_text(filename, binary)
lines = txt.split("\n") lines = txt.split("\n")
comma, tab = 0, 0 comma, tab = 0, 0
for l in lines: for line in lines:
if len(l.split(",")) == 2: comma += 1 if len(line.split(",")) == 2:
if len(l.split("\t")) == 2: tab += 1 comma += 1
if len(line.split("\t")) == 2:
tab += 1
delimiter = "\t" if tab >= comma else "," delimiter = "\t" if tab >= comma else ","
fails = [] fails = []
@ -336,18 +338,21 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
while i < len(lines): while i < len(lines):
arr = lines[i].split(delimiter) arr = lines[i].split(delimiter)
if len(arr) != 2: if len(arr) != 2:
if question: answer += "\n" + lines[i] if question:
answer += "\n" + lines[i]
else: else:
fails.append(str(i+1)) fails.append(str(i+1))
elif len(arr) == 2: elif len(arr) == 2:
if question and answer: res.append(beAdoc(deepcopy(doc), question, answer, eng)) if question and answer:
res.append(beAdoc(deepcopy(doc), question, answer, eng))
question, answer = arr question, answer = arr
i += 1 i += 1
if len(res) % 999 == 0: if len(res) % 999 == 0:
callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + ( callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
if question: res.append(beAdoc(deepcopy(doc), question, answer, eng)) if question:
res.append(beAdoc(deepcopy(doc), question, answer, eng))
callback(0.6, ("Extract Q&A: {}".format(len(res)) + ( callback(0.6, ("Extract Q&A: {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
@ -367,19 +372,18 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
txt = get_text(filename, binary) txt = get_text(filename, binary)
lines = txt.split("\n") lines = txt.split("\n")
last_question, last_answer = "", "" _last_question, last_answer = "", ""
question_stack, level_stack = [], [] question_stack, level_stack = [], []
code_block = False code_block = False
level_index = [-1] * 7 for index, line in enumerate(lines):
for index, l in enumerate(lines): if line.strip().startswith('```'):
if l.strip().startswith('```'):
code_block = not code_block code_block = not code_block
question_level, question = 0, '' question_level, question = 0, ''
if not code_block: if not code_block:
question_level, question = mdQuestionLevel(l) question_level, question = mdQuestionLevel(line)
if not question_level or question_level > 6: # not a question if not question_level or question_level > 6: # not a question
last_answer = f'{last_answer}\n{l}' last_answer = f'{last_answer}\n{line}'
else: # is a question else: # is a question
if last_answer.strip(): if last_answer.strip():
sum_question = '\n'.join(question_stack) sum_question = '\n'.join(question_stack)

View File

@ -41,14 +41,16 @@ class Excel(ExcelParser):
for sheetname in wb.sheetnames: for sheetname in wb.sheetnames:
ws = wb[sheetname] ws = wb[sheetname]
rows = list(ws.rows) rows = list(ws.rows)
if not rows:continue if not rows:
continue
headers = [cell.value for cell in rows[0]] headers = [cell.value for cell in rows[0]]
missed = set([i for i, h in enumerate(headers) if h is None]) missed = set([i for i, h in enumerate(headers) if h is None])
headers = [ headers = [
cell.value for i, cell.value for i,
cell in enumerate( cell in enumerate(
rows[0]) if i not in missed] rows[0]) if i not in missed]
if not headers:continue if not headers:
continue
data = [] data = []
for i, r in enumerate(rows[1:]): for i, r in enumerate(rows[1:]):
rn += 1 rn += 1
@ -88,7 +90,6 @@ def trans_bool(s):
def column_data_type(arr): def column_data_type(arr):
arr = list(arr) arr = list(arr)
uni = len(set([a for a in arr if a is not None]))
counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0} counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0}
trans = {t: f for f, t in trans = {t: f for f, t in
[(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]} [(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]}
@ -157,7 +158,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000,
continue continue
if i >= to_page: if i >= to_page:
break break
row = [l for l in line.split(kwargs.get("delimiter", "\t"))] row = [field for field in line.split(kwargs.get("delimiter", "\t"))]
if len(row) != len(headers): if len(row) != len(headers):
fails.append(str(i)) fails.append(str(i))
continue continue

View File

@ -13,12 +13,124 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
from .embedding_model import * from .embedding_model import (
from .chat_model import * OllamaEmbed,
from .cv_model import * LocalAIEmbed,
from .rerank_model import * OpenAIEmbed,
from .sequence2txt_model import * AzureEmbed,
from .tts_model import * XinferenceEmbed,
QWenEmbed,
ZhipuEmbed,
FastEmbed,
YoudaoEmbed,
BaiChuanEmbed,
JinaEmbed,
DefaultEmbedding,
MistralEmbed,
BedrockEmbed,
GeminiEmbed,
NvidiaEmbed,
LmStudioEmbed,
OpenAI_APIEmbed,
CoHereEmbed,
TogetherAIEmbed,
PerfXCloudEmbed,
UpstageEmbed,
SILICONFLOWEmbed,
ReplicateEmbed,
BaiduYiyanEmbed,
VoyageEmbed,
HuggingFaceEmbed,
VolcEngineEmbed,
)
from .chat_model import (
GptTurbo,
AzureChat,
ZhipuChat,
QWenChat,
OllamaChat,
LocalAIChat,
XinferenceChat,
MoonshotChat,
DeepSeekChat,
VolcEngineChat,
BaiChuanChat,
MiniMaxChat,
MistralChat,
GeminiChat,
BedrockChat,
GroqChat,
OpenRouterChat,
StepFunChat,
NvidiaChat,
LmStudioChat,
OpenAI_APIChat,
CoHereChat,
LeptonAIChat,
TogetherAIChat,
PerfXCloudChat,
UpstageChat,
NovitaAIChat,
SILICONFLOWChat,
YiChat,
ReplicateChat,
HunyuanChat,
SparkChat,
BaiduYiyanChat,
AnthropicChat,
GoogleChat,
HuggingFaceChat,
)
from .cv_model import (
GptV4,
AzureGptV4,
OllamaCV,
XinferenceCV,
QWenCV,
Zhipu4V,
LocalCV,
GeminiCV,
OpenRouterCV,
LocalAICV,
NvidiaCV,
LmStudioCV,
StepFunCV,
OpenAI_APICV,
TogetherAICV,
YiCV,
HunyuanCV,
)
from .rerank_model import (
LocalAIRerank,
DefaultRerank,
JinaRerank,
YoudaoRerank,
XInferenceRerank,
NvidiaRerank,
LmStudioRerank,
OpenAI_APIRerank,
CoHereRerank,
TogetherAIRerank,
SILICONFLOWRerank,
BaiduYiyanRerank,
VoyageRerank,
QWenRerank,
)
from .sequence2txt_model import (
GPTSeq2txt,
QWenSeq2txt,
AzureSeq2txt,
XinferenceSeq2txt,
TencentCloudSeq2txt,
)
from .tts_model import (
FishAudioTTS,
QwenTTS,
OpenAITTS,
SparkTTS,
XinferenceTTS,
)
EmbeddingModel = { EmbeddingModel = {
"Ollama": OllamaEmbed, "Ollama": OllamaEmbed,
@ -48,7 +160,7 @@ EmbeddingModel = {
"BaiduYiyan": BaiduYiyanEmbed, "BaiduYiyan": BaiduYiyanEmbed,
"Voyage AI": VoyageEmbed, "Voyage AI": VoyageEmbed,
"HuggingFace": HuggingFaceEmbed, "HuggingFace": HuggingFaceEmbed,
"VolcEngine":VolcEngineEmbed, "VolcEngine": VolcEngineEmbed,
} }
CvModel = { CvModel = {
@ -68,7 +180,7 @@ CvModel = {
"OpenAI-API-Compatible": OpenAI_APICV, "OpenAI-API-Compatible": OpenAI_APICV,
"TogetherAI": TogetherAICV, "TogetherAI": TogetherAICV,
"01.AI": YiCV, "01.AI": YiCV,
"Tencent Hunyuan": HunyuanCV "Tencent Hunyuan": HunyuanCV,
} }
ChatModel = { ChatModel = {
@ -111,7 +223,7 @@ ChatModel = {
} }
RerankModel = { RerankModel = {
"LocalAI":LocalAIRerank, "LocalAI": LocalAIRerank,
"BAAI": DefaultRerank, "BAAI": DefaultRerank,
"Jina": JinaRerank, "Jina": JinaRerank,
"Youdao": YoudaoRerank, "Youdao": YoudaoRerank,
@ -132,7 +244,7 @@ Seq2txtModel = {
"Tongyi-Qianwen": QWenSeq2txt, "Tongyi-Qianwen": QWenSeq2txt,
"Azure-OpenAI": AzureSeq2txt, "Azure-OpenAI": AzureSeq2txt,
"Xinference": XinferenceSeq2txt, "Xinference": XinferenceSeq2txt,
"Tencent Cloud": TencentCloudSeq2txt "Tencent Cloud": TencentCloudSeq2txt,
} }
TTSModel = { TTSModel = {

View File

@ -69,7 +69,8 @@ class Base(ABC):
stream=True, stream=True,
**gen_conf) **gen_conf)
for resp in response: for resp in response:
if not resp.choices: continue if not resp.choices:
continue
if not resp.choices[0].delta.content: if not resp.choices[0].delta.content:
resp.choices[0].delta.content = "" resp.choices[0].delta.content = ""
ans += resp.choices[0].delta.content ans += resp.choices[0].delta.content
@ -81,7 +82,8 @@ class Base(ABC):
) )
elif isinstance(resp.usage, dict): elif isinstance(resp.usage, dict):
total_tokens = resp.usage.get("total_tokens", total_tokens) total_tokens = resp.usage.get("total_tokens", total_tokens)
else: total_tokens = resp.usage.total_tokens else:
total_tokens = resp.usage.total_tokens
if resp.choices[0].finish_reason == "length": if resp.choices[0].finish_reason == "length":
if is_chinese(ans): if is_chinese(ans):
@ -98,13 +100,15 @@ class Base(ABC):
class GptTurbo(Base): class GptTurbo(Base):
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"): def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
if not base_url: base_url = "https://api.openai.com/v1" if not base_url:
base_url = "https://api.openai.com/v1"
super().__init__(key, model_name, base_url) super().__init__(key, model_name, base_url)
class MoonshotChat(Base): class MoonshotChat(Base):
def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"): def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
if not base_url: base_url = "https://api.moonshot.cn/v1" if not base_url:
base_url = "https://api.moonshot.cn/v1"
super().__init__(key, model_name, base_url) super().__init__(key, model_name, base_url)
@ -128,7 +132,8 @@ class HuggingFaceChat(Base):
class DeepSeekChat(Base): class DeepSeekChat(Base):
def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"): def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"):
if not base_url: base_url = "https://api.deepseek.com/v1" if not base_url:
base_url = "https://api.deepseek.com/v1"
super().__init__(key, model_name, base_url) super().__init__(key, model_name, base_url)
@ -202,7 +207,8 @@ class BaiChuanChat(Base):
stream=True, stream=True,
**self._format_params(gen_conf)) **self._format_params(gen_conf))
for resp in response: for resp in response:
if not resp.choices: continue if not resp.choices:
continue
if not resp.choices[0].delta.content: if not resp.choices[0].delta.content:
resp.choices[0].delta.content = "" resp.choices[0].delta.content = ""
ans += resp.choices[0].delta.content ans += resp.choices[0].delta.content
@ -313,8 +319,10 @@ class ZhipuChat(Base):
if system: if system:
history.insert(0, {"role": "system", "content": system}) history.insert(0, {"role": "system", "content": system})
try: try:
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"] if "presence_penalty" in gen_conf:
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"] del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
del gen_conf["frequency_penalty"]
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=self.model_name, model=self.model_name,
messages=history, messages=history,
@ -333,8 +341,10 @@ class ZhipuChat(Base):
def chat_streamly(self, system, history, gen_conf): def chat_streamly(self, system, history, gen_conf):
if system: if system:
history.insert(0, {"role": "system", "content": system}) history.insert(0, {"role": "system", "content": system})
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"] if "presence_penalty" in gen_conf:
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"] del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
del gen_conf["frequency_penalty"]
ans = "" ans = ""
tk_count = 0 tk_count = 0
try: try:
@ -345,7 +355,8 @@ class ZhipuChat(Base):
**gen_conf **gen_conf
) )
for resp in response: for resp in response:
if not resp.choices[0].delta.content: continue if not resp.choices[0].delta.content:
continue
delta = resp.choices[0].delta.content delta = resp.choices[0].delta.content
ans += delta ans += delta
if resp.choices[0].finish_reason == "length": if resp.choices[0].finish_reason == "length":
@ -354,7 +365,8 @@ class ZhipuChat(Base):
else: else:
ans += LENGTH_NOTIFICATION_EN ans += LENGTH_NOTIFICATION_EN
tk_count = resp.usage.total_tokens tk_count = resp.usage.total_tokens
if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens if resp.choices[0].finish_reason == "stop":
tk_count = resp.usage.total_tokens
yield ans yield ans
except Exception as e: except Exception as e:
yield ans + "\n**ERROR**: " + str(e) yield ans + "\n**ERROR**: " + str(e)
@ -372,11 +384,16 @@ class OllamaChat(Base):
history.insert(0, {"role": "system", "content": system}) history.insert(0, {"role": "system", "content": system})
try: try:
options = {} options = {}
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"] if "temperature" in gen_conf:
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"] options["temperature"] = gen_conf["temperature"]
if "top_p" in gen_conf: options["top_p"] = gen_conf["top_p"] if "max_tokens" in gen_conf:
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"] options["num_predict"] = gen_conf["max_tokens"]
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"] if "top_p" in gen_conf:
options["top_p"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
response = self.client.chat( response = self.client.chat(
model=self.model_name, model=self.model_name,
messages=history, messages=history,
@ -392,11 +409,16 @@ class OllamaChat(Base):
if system: if system:
history.insert(0, {"role": "system", "content": system}) history.insert(0, {"role": "system", "content": system})
options = {} options = {}
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"] if "temperature" in gen_conf:
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"] options["temperature"] = gen_conf["temperature"]
if "top_p" in gen_conf: options["top_p"] = gen_conf["top_p"] if "max_tokens" in gen_conf:
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"] options["num_predict"] = gen_conf["max_tokens"]
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"] if "top_p" in gen_conf:
options["top_p"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
ans = "" ans = ""
try: try:
response = self.client.chat( response = self.client.chat(
@ -636,7 +658,8 @@ class MistralChat(Base):
messages=history, messages=history,
**gen_conf) **gen_conf)
for resp in response: for resp in response:
if not resp.choices or not resp.choices[0].delta.content: continue if not resp.choices or not resp.choices[0].delta.content:
continue
ans += resp.choices[0].delta.content ans += resp.choices[0].delta.content
total_tokens += 1 total_tokens += 1
if resp.choices[0].finish_reason == "length": if resp.choices[0].finish_reason == "length":
@ -1196,7 +1219,8 @@ class SparkChat(Base):
assert model_name in model2version or model_name in version2model, f"The given model name is not supported yet. Support: {list(model2version.keys())}" assert model_name in model2version or model_name in version2model, f"The given model name is not supported yet. Support: {list(model2version.keys())}"
if model_name in model2version: if model_name in model2version:
model_version = model2version[model_name] model_version = model2version[model_name]
else: model_version = model_name else:
model_version = model_name
super().__init__(key, model_version, base_url) super().__init__(key, model_version, base_url)
@ -1281,8 +1305,10 @@ class AnthropicChat(Base):
self.system = system self.system = system
if "max_tokens" not in gen_conf: if "max_tokens" not in gen_conf:
gen_conf["max_tokens"] = 4096 gen_conf["max_tokens"] = 4096
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"] if "presence_penalty" in gen_conf:
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"] del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
del gen_conf["frequency_penalty"]
ans = "" ans = ""
try: try:
@ -1312,8 +1338,10 @@ class AnthropicChat(Base):
self.system = system self.system = system
if "max_tokens" not in gen_conf: if "max_tokens" not in gen_conf:
gen_conf["max_tokens"] = 4096 gen_conf["max_tokens"] = 4096
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"] if "presence_penalty" in gen_conf:
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"] del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
del gen_conf["frequency_penalty"]
ans = "" ans = ""
total_tokens = 0 total_tokens = 0

View File

@ -25,6 +25,7 @@ import base64
from io import BytesIO from io import BytesIO
import json import json
import requests import requests
from transformers import GenerationConfig
from rag.nlp import is_english from rag.nlp import is_english
from api.utils import get_uuid from api.utils import get_uuid
@ -77,14 +78,16 @@ class Base(ABC):
stream=True stream=True
) )
for resp in response: for resp in response:
if not resp.choices[0].delta.content: continue if not resp.choices[0].delta.content:
continue
delta = resp.choices[0].delta.content delta = resp.choices[0].delta.content
ans += delta ans += delta
if resp.choices[0].finish_reason == "length": if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english( ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
tk_count = resp.usage.total_tokens tk_count = resp.usage.total_tokens
if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens if resp.choices[0].finish_reason == "stop":
tk_count = resp.usage.total_tokens
yield ans yield ans
except Exception as e: except Exception as e:
yield ans + "\n**ERROR**: " + str(e) yield ans + "\n**ERROR**: " + str(e)
@ -99,7 +102,7 @@ class Base(ABC):
buffered = BytesIO() buffered = BytesIO()
try: try:
image.save(buffered, format="JPEG") image.save(buffered, format="JPEG")
except Exception as e: except Exception:
image.save(buffered, format="PNG") image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8") return base64.b64encode(buffered.getvalue()).decode("utf-8")
@ -139,7 +142,8 @@ class Base(ABC):
class GptV4(Base): class GptV4(Base):
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"): def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
if not base_url: base_url="https://api.openai.com/v1" if not base_url:
base_url="https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url) self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name self.model_name = model_name
self.lang = lang self.lang = lang
@ -149,7 +153,8 @@ class GptV4(Base):
prompt = self.prompt(b64) prompt = self.prompt(b64)
for i in range(len(prompt)): for i in range(len(prompt)):
for c in prompt[i]["content"]: for c in prompt[i]["content"]:
if "text" in c: c["type"] = "text" if "text" in c:
c["type"] = "text"
res = self.client.chat.completions.create( res = self.client.chat.completions.create(
model=self.model_name, model=self.model_name,
@ -171,7 +176,8 @@ class AzureGptV4(Base):
prompt = self.prompt(b64) prompt = self.prompt(b64)
for i in range(len(prompt)): for i in range(len(prompt)):
for c in prompt[i]["content"]: for c in prompt[i]["content"]:
if "text" in c: c["type"] = "text" if "text" in c:
c["type"] = "text"
res = self.client.chat.completions.create( res = self.client.chat.completions.create(
model=self.model_name, model=self.model_name,
@ -344,14 +350,16 @@ class Zhipu4V(Base):
stream=True stream=True
) )
for resp in response: for resp in response:
if not resp.choices[0].delta.content: continue if not resp.choices[0].delta.content:
continue
delta = resp.choices[0].delta.content delta = resp.choices[0].delta.content
ans += delta ans += delta
if resp.choices[0].finish_reason == "length": if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english( ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
tk_count = resp.usage.total_tokens tk_count = resp.usage.total_tokens
if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens if resp.choices[0].finish_reason == "stop":
tk_count = resp.usage.total_tokens
yield ans yield ans
except Exception as e: except Exception as e:
yield ans + "\n**ERROR**: " + str(e) yield ans + "\n**ERROR**: " + str(e)
@ -389,11 +397,16 @@ class OllamaCV(Base):
if his["role"] == "user": if his["role"] == "user":
his["images"] = [image] his["images"] = [image]
options = {} options = {}
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"] if "temperature" in gen_conf:
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"] options["temperature"] = gen_conf["temperature"]
if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"] if "max_tokens" in gen_conf:
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"] options["num_predict"] = gen_conf["max_tokens"]
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"] if "top_p" in gen_conf:
options["top_k"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
response = self.client.chat( response = self.client.chat(
model=self.model_name, model=self.model_name,
messages=history, messages=history,
@ -414,11 +427,16 @@ class OllamaCV(Base):
if his["role"] == "user": if his["role"] == "user":
his["images"] = [image] his["images"] = [image]
options = {} options = {}
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"] if "temperature" in gen_conf:
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"] options["temperature"] = gen_conf["temperature"]
if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"] if "max_tokens" in gen_conf:
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"] options["num_predict"] = gen_conf["max_tokens"]
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"] if "top_p" in gen_conf:
options["top_k"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
ans = "" ans = ""
try: try:
response = self.client.chat( response = self.client.chat(
@ -469,7 +487,7 @@ class XinferenceCV(Base):
class GeminiCV(Base): class GeminiCV(Base):
def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs): def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
from google.generativeai import client, GenerativeModel, GenerationConfig from google.generativeai import client, GenerativeModel
client.configure(api_key=key) client.configure(api_key=key)
_client = client.get_default_generative_client() _client = client.get_default_generative_client()
self.model_name = model_name self.model_name = model_name
@ -503,7 +521,7 @@ class GeminiCV(Base):
if his["role"] == "user": if his["role"] == "user":
his["parts"] = [his["content"]] his["parts"] = [his["content"]]
his.pop("content") his.pop("content")
history[-1]["parts"].append(f"data:image/jpeg;base64," + image) history[-1]["parts"].append("data:image/jpeg;base64," + image)
response = self.model.generate_content(history, generation_config=GenerationConfig( response = self.model.generate_content(history, generation_config=GenerationConfig(
max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3), max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3),
@ -519,7 +537,6 @@ class GeminiCV(Base):
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
ans = "" ans = ""
tk_count = 0
try: try:
for his in history: for his in history:
if his["role"] == "assistant": if his["role"] == "assistant":
@ -529,14 +546,15 @@ class GeminiCV(Base):
if his["role"] == "user": if his["role"] == "user":
his["parts"] = [his["content"]] his["parts"] = [his["content"]]
his.pop("content") his.pop("content")
history[-1]["parts"].append(f"data:image/jpeg;base64," + image) history[-1]["parts"].append("data:image/jpeg;base64," + image)
response = self.model.generate_content(history, generation_config=GenerationConfig( response = self.model.generate_content(history, generation_config=GenerationConfig(
max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3), max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7)), stream=True) top_p=gen_conf.get("top_p", 0.7)), stream=True)
for resp in response: for resp in response:
if not resp.text: continue if not resp.text:
continue
ans += resp.text ans += resp.text
yield ans yield ans
except Exception as e: except Exception as e:
@ -632,7 +650,8 @@ class NvidiaCV(Base):
class StepFunCV(GptV4): class StepFunCV(GptV4):
def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"): def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"):
if not base_url: base_url="https://api.stepfun.com/v1" if not base_url:
base_url="https://api.stepfun.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url) self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name self.model_name = model_name
self.lang = lang self.lang = lang

View File

@ -15,12 +15,9 @@
# #
import requests import requests
from openai.lib.azure import AzureOpenAI from openai.lib.azure import AzureOpenAI
from zhipuai import ZhipuAI
import io import io
from abc import ABC from abc import ABC
from ollama import Client
from openai import OpenAI from openai import OpenAI
import os
import json import json
from rag.utils import num_tokens_from_string from rag.utils import num_tokens_from_string
import base64 import base64
@ -49,7 +46,8 @@ class Base(ABC):
class GPTSeq2txt(Base): class GPTSeq2txt(Base):
def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1"): def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1"):
if not base_url: base_url = "https://api.openai.com/v1" if not base_url:
base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url) self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name self.model_name = model_name

View File

@ -16,7 +16,6 @@
import _thread as thread import _thread as thread
import base64 import base64
import datetime
import hashlib import hashlib
import hmac import hmac
import json import json
@ -175,7 +174,8 @@ class QwenTTS(Base):
class OpenAITTS(Base): class OpenAITTS(Base):
def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"): def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
if not base_url: base_url = "https://api.openai.com/v1" if not base_url:
base_url = "https://api.openai.com/v1"
self.api_key = key self.api_key = key
self.model_name = model_name self.model_name = model_name
self.base_url = base_url self.base_url = base_url

View File

@ -222,7 +222,8 @@ def bullets_category(sections):
def is_english(texts): def is_english(texts):
eng = 0 eng = 0
if not texts: return False if not texts:
return False
for t in texts: for t in texts:
if re.match(r"[ `a-zA-Z.,':;/\"?<>!\(\)-]", t.strip()): if re.match(r"[ `a-zA-Z.,':;/\"?<>!\(\)-]", t.strip()):
eng += 1 eng += 1
@ -250,7 +251,8 @@ def tokenize_chunks(chunks, doc, eng, pdf_parser=None):
res = [] res = []
# wrap up as es documents # wrap up as es documents
for ck in chunks: for ck in chunks:
if len(ck.strip()) == 0:continue if len(ck.strip()) == 0:
continue
logging.debug("-- {}".format(ck)) logging.debug("-- {}".format(ck))
d = copy.deepcopy(doc) d = copy.deepcopy(doc)
if pdf_parser: if pdf_parser:
@ -269,7 +271,8 @@ def tokenize_chunks_docx(chunks, doc, eng, images):
res = [] res = []
# wrap up as es documents # wrap up as es documents
for ck, image in zip(chunks, images): for ck, image in zip(chunks, images):
if len(ck.strip()) == 0:continue if len(ck.strip()) == 0:
continue
logging.debug("-- {}".format(ck)) logging.debug("-- {}".format(ck))
d = copy.deepcopy(doc) d = copy.deepcopy(doc)
d["image"] = image d["image"] = image
@ -288,8 +291,10 @@ def tokenize_table(tbls, doc, eng, batch_size=10):
d = copy.deepcopy(doc) d = copy.deepcopy(doc)
tokenize(d, rows, eng) tokenize(d, rows, eng)
d["content_with_weight"] = rows d["content_with_weight"] = rows
if img: d["image"] = img if img:
if poss: add_positions(d, poss) d["image"] = img
if poss:
add_positions(d, poss)
res.append(d) res.append(d)
continue continue
de = "; " if eng else " " de = "; " if eng else " "
@ -387,9 +392,9 @@ def title_frequency(bull, sections):
if re.search(r"(title|head)", layout) and not not_title(txt.split("@")[0]): if re.search(r"(title|head)", layout) and not not_title(txt.split("@")[0]):
levels[i] = bullets_size levels[i] = bullets_size
most_level = bullets_size+1 most_level = bullets_size+1
for l, c in sorted(Counter(levels).items(), key=lambda x:x[1]*-1): for level, c in sorted(Counter(levels).items(), key=lambda x:x[1]*-1):
if l <= bullets_size: if level <= bullets_size:
most_level = l most_level = level
break break
return most_level, levels return most_level, levels
@ -504,7 +509,8 @@ def naive_merge(sections, chunk_token_num=128, delimiter="\n。"):
def add_chunk(t, pos): def add_chunk(t, pos):
nonlocal cks, tk_nums, delimiter nonlocal cks, tk_nums, delimiter
tnum = num_tokens_from_string(t) tnum = num_tokens_from_string(t)
if not pos: pos = "" if not pos:
pos = ""
if tnum < 8: if tnum < 8:
pos = "" pos = ""
# Ensure that the length of the merged chunk does not exceed chunk_token_num # Ensure that the length of the merged chunk does not exceed chunk_token_num

View File

@ -121,7 +121,8 @@ class FulltextQueryer:
keywords.append(tt) keywords.append(tt)
twts = self.tw.weights([tt]) twts = self.tw.weights([tt])
syns = self.syn.lookup(tt) syns = self.syn.lookup(tt)
if syns and len(keywords) < 32: keywords.extend(syns) if syns and len(keywords) < 32:
keywords.extend(syns)
logging.debug(json.dumps(twts, ensure_ascii=False)) logging.debug(json.dumps(twts, ensure_ascii=False))
tms = [] tms = []
for tk, w in sorted(twts, key=lambda x: x[1] * -1): for tk, w in sorted(twts, key=lambda x: x[1] * -1):
@ -147,7 +148,8 @@ class FulltextQueryer:
tk_syns = self.syn.lookup(tk) tk_syns = self.syn.lookup(tk)
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns] tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
if len(keywords) < 32: keywords.extend([s for s in tk_syns if s]) if len(keywords) < 32:
keywords.extend([s for s in tk_syns if s])
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s] tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
tk_syns = [f"\"{s}\"" if s.find(" ")>0 else s for s in tk_syns] tk_syns = [f"\"{s}\"" if s.find(" ")>0 else s for s in tk_syns]

View File

@ -104,7 +104,6 @@ class RagTokenizer:
return HanziConv.toSimplified(line) return HanziConv.toSimplified(line)
def dfs_(self, chars, s, preTks, tkslist): def dfs_(self, chars, s, preTks, tkslist):
MAX_L = 10
res = s res = s
# if s > MAX_L or s>= len(chars): # if s > MAX_L or s>= len(chars):
if s >= len(chars): if s >= len(chars):
@ -184,12 +183,6 @@ class RagTokenizer:
return sorted(res, key=lambda x: x[1], reverse=True) return sorted(res, key=lambda x: x[1], reverse=True)
def merge_(self, tks): def merge_(self, tks):
patts = [
(r"[ ]+", " "),
(r"([0-9\+\.,%\*=-]) ([0-9\+\.,%\*=-])", r"\1\2"),
]
# for p,s in patts: tks = re.sub(p, s, tks)
# if split chars is part of token # if split chars is part of token
res = [] res = []
tks = re.sub(r"[ ]+", " ", tks).split() tks = re.sub(r"[ ]+", " ", tks).split()
@ -284,7 +277,8 @@ class RagTokenizer:
same = 0 same = 0
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]: while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
same += 1 same += 1
if same > 0: res.append(" ".join(tks[j: j + same])) if same > 0:
res.append(" ".join(tks[j: j + same]))
_i = i + same _i = i + same
_j = j + same _j = j + same
j = _j + 1 j = _j + 1

View File

@ -62,10 +62,10 @@ class Dealer:
res = {} res = {}
f = open(fnm, "r") f = open(fnm, "r")
while True: while True:
l = f.readline() line = f.readline()
if not l: if not line:
break break
arr = l.replace("\n", "").split("\t") arr = line.replace("\n", "").split("\t")
if len(arr) < 2: if len(arr) < 2:
res[arr[0]] = 0 res[arr[0]] = 0
else: else:

View File

@ -47,7 +47,8 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
def __call__(self, chunks, random_state, callback=None): def __call__(self, chunks, random_state, callback=None):
layers = [(0, len(chunks))] layers = [(0, len(chunks))]
start, end = 0, len(chunks) start, end = 0, len(chunks)
if len(chunks) <= 1: return if len(chunks) <= 1:
return
chunks = [(s, a) for s, a in chunks if len(a) > 0] chunks = [(s, a) for s, a in chunks if len(a) > 0]
def summarize(ck_idx, lock): def summarize(ck_idx, lock):
@ -66,7 +67,8 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
logging.debug(f"SUM: {cnt}") logging.debug(f"SUM: {cnt}")
embds, _ = self._embd_model.encode([cnt]) embds, _ = self._embd_model.encode([cnt])
with lock: with lock:
if not len(embds[0]): return if not len(embds[0]):
return
chunks.append((cnt, embds[0])) chunks.append((cnt, embds[0]))
except Exception as e: except Exception as e:
logging.exception("summarize got exception") logging.exception("summarize got exception")

View File

@ -33,14 +33,16 @@ def collect():
def main(): def main():
locations = collect() locations = collect()
if not locations:return if not locations:
return
logging.info(f"TASKS: {len(locations)}") logging.info(f"TASKS: {len(locations)}")
for kb_id, loc in locations: for kb_id, loc in locations:
try: try:
if REDIS_CONN.is_alive(): if REDIS_CONN.is_alive():
try: try:
key = "{}/{}".format(kb_id, loc) key = "{}/{}".format(kb_id, loc)
if REDIS_CONN.exist(key):continue if REDIS_CONN.exist(key):
continue
file_bin = STORAGE_IMPL.get(kb_id, loc) file_bin = STORAGE_IMPL.get(kb_id, loc)
REDIS_CONN.transaction(key, file_bin, 12 * 60) REDIS_CONN.transaction(key, file_bin, 12 * 60)
logging.info("CACHE: {}".format(loc)) logging.info("CACHE: {}".format(loc))

View File

@ -23,18 +23,12 @@ import os
from api.utils.log_utils import initRootLogger from api.utils.log_utils import initRootLogger
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
CONSUMER_NAME = "task_executor_" + CONSUMER_NO
LOG_LEVELS = os.environ.get("LOG_LEVELS", "")
initRootLogger(CONSUMER_NAME, LOG_LEVELS)
from datetime import datetime from datetime import datetime
import json import json
import os
import hashlib import hashlib
import copy import copy
import re import re
import sys
import time import time
import threading import threading
from functools import partial from functools import partial
@ -63,6 +57,11 @@ from rag.utils import rmSpace, num_tokens_from_string
from rag.utils.redis_conn import REDIS_CONN, Payload from rag.utils.redis_conn import REDIS_CONN, Payload
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
CONSUMER_NAME = "task_executor_" + CONSUMER_NO
LOG_LEVELS = os.environ.get("LOG_LEVELS", "")
initRootLogger(CONSUMER_NAME, LOG_LEVELS)
BATCH_SIZE = 64 BATCH_SIZE = 64
FACTORY = { FACTORY = {
@ -201,7 +200,8 @@ def build_chunks(task, progress_callback):
"doc_id": task["doc_id"], "doc_id": task["doc_id"],
"kb_id": str(task["kb_id"]) "kb_id": str(task["kb_id"])
} }
if task["pagerank"]: doc["pagerank_fea"] = int(task["pagerank"]) if task["pagerank"]:
doc["pagerank_fea"] = int(task["pagerank"])
el = 0 el = 0
for ck in cks: for ck in cks:
d = copy.deepcopy(doc) d = copy.deepcopy(doc)
@ -342,7 +342,8 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
"docnm_kwd": row["name"], "docnm_kwd": row["name"],
"title_tks": rag_tokenizer.tokenize(row["name"]) "title_tks": rag_tokenizer.tokenize(row["name"])
} }
if row["pagerank"]: doc["pagerank_fea"] = int(row["pagerank"]) if row["pagerank"]:
doc["pagerank_fea"] = int(row["pagerank"])
res = [] res = []
tk_count = 0 tk_count = 0
for content, vctr in chunks[original_length:]: for content, vctr in chunks[original_length:]:

View File

@ -41,15 +41,15 @@ def findMaxDt(fnm):
try: try:
with open(fnm, "r") as f: with open(fnm, "r") as f:
while True: while True:
l = f.readline() line = f.readline()
if not l: if not line:
break break
l = l.strip("\n") line = line.strip("\n")
if l == 'nan': if line == 'nan':
continue continue
if l > m: if line > m:
m = l m = line
except Exception as e: except Exception:
pass pass
return m return m
@ -59,15 +59,15 @@ def findMaxTm(fnm):
try: try:
with open(fnm, "r") as f: with open(fnm, "r") as f:
while True: while True:
l = f.readline() line = f.readline()
if not l: if not line:
break break
l = l.strip("\n") line = line.strip("\n")
if l == 'nan': if line == 'nan':
continue continue
if int(l) > m: if int(line) > m:
m = int(l) m = int(line)
except Exception as e: except Exception:
pass pass
return m return m

View File

@ -32,7 +32,7 @@ class RAGFlowAzureSasBlob(object):
self.conn = None self.conn = None
def health(self): def health(self):
bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1" _bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
return self.conn.upload_blob(name=fnm, data=BytesIO(binary), length=len(binary)) return self.conn.upload_blob(name=fnm, data=BytesIO(binary), length=len(binary))
def put(self, bucket, fnm, binary): def put(self, bucket, fnm, binary):

View File

@ -36,7 +36,7 @@ class RAGFlowAzureSpnBlob(object):
self.conn = None self.conn = None
def health(self): def health(self):
bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1" _bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
f = self.conn.create_file(fnm) f = self.conn.create_file(fnm)
f.append_data(binary, offset=0, length=len(binary)) f.append_data(binary, offset=0, length=len(binary))
return f.flush_data(len(binary)) return f.flush_data(len(binary))

View File

@ -132,7 +132,8 @@ class ESConnection(DocStoreConnection):
bqry.filter.append( bqry.filter.append(
Q("bool", must_not=Q("range", available_int={"lt": 1}))) Q("bool", must_not=Q("range", available_int={"lt": 1})))
continue continue
if not v: continue if not v:
continue
if isinstance(v, list): if isinstance(v, list):
bqry.filter.append(Q("terms", **{k: v})) bqry.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int): elif isinstance(v, str) or isinstance(v, int):

View File

@ -1,10 +1,5 @@
from beartype.claw import beartype_this_package
beartype_this_package() # <-- raise exceptions in your code
import importlib.metadata import importlib.metadata
__version__ = importlib.metadata.version("ragflow_sdk")
from .ragflow import RAGFlow from .ragflow import RAGFlow
from .modules.dataset import DataSet from .modules.dataset import DataSet
from .modules.chat import Chat from .modules.chat import Chat
@ -12,3 +7,15 @@ from .modules.session import Session
from .modules.document import Document from .modules.document import Document
from .modules.chunk import Chunk from .modules.chunk import Chunk
from .modules.agent import Agent from .modules.agent import Agent
__version__ = importlib.metadata.version("ragflow_sdk")
__all__ = [
"RAGFlow",
"DataSet",
"Chat",
"Session",
"Document",
"Chunk",
"Agent"
]

View File

@ -29,7 +29,7 @@ class Session(Base):
raise Exception(json_data["message"]) raise Exception(json_data["message"])
if line.startswith("data:"): if line.startswith("data:"):
json_data = json.loads(line[5:]) json_data = json.loads(line[5:])
if json_data["data"] != True: if not json_data["data"]:
answer = json_data["data"]["answer"] answer = json_data["data"]["answer"]
reference = json_data["data"]["reference"] reference = json_data["data"]["reference"]
temp_dict = { temp_dict = {

View File

@ -1,5 +1,3 @@
import string
import random
import os import os
import pytest import pytest
import requests import requests

View File

@ -39,7 +39,6 @@ def update_dataset(auth, json_req):
def upload_file(auth, dataset_id, path): def upload_file(auth, dataset_id, path):
authorization = {"Authorization": auth} authorization = {"Authorization": auth}
url = f"{HOST_ADDRESS}/v1/document/upload" url = f"{HOST_ADDRESS}/v1/document/upload"
base_name = os.path.basename(path)
json_req = { json_req = {
"kb_id": dataset_id, "kb_id": dataset_id,
} }

View File

@ -1,3 +1,3 @@
def test_get_email(get_email): def test_get_email(get_email):
print(f"\nEmail account:",flush=True) print("\nEmail account:",flush=True)
print(f"{get_email}\n",flush=True) print(f"{get_email}\n",flush=True)

View File

@ -13,14 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from common import HOST_ADDRESS, create_dataset, list_dataset, rm_dataset, update_dataset, upload_file, DATASET_NAME_LIMIT from common import create_dataset, list_dataset, rm_dataset, upload_file
from common import list_document, get_docs_info, parse_docs from common import list_document, get_docs_info, parse_docs
from time import sleep from time import sleep
from timeit import default_timer as timer from timeit import default_timer as timer
import re
import pytest
import random
import string
def test_parse_txt_document(get_auth): def test_parse_txt_document(get_auth):

View File

@ -1,6 +1,5 @@
from common import HOST_ADDRESS, create_dataset, list_dataset, rm_dataset, update_dataset, DATASET_NAME_LIMIT from common import create_dataset, list_dataset, rm_dataset, update_dataset, DATASET_NAME_LIMIT
import re import re
import pytest
import random import random
import string import string
@ -33,8 +32,6 @@ def test_dataset(get_auth):
def test_dataset_1k_dataset(get_auth): def test_dataset_1k_dataset(get_auth):
# create dataset # create dataset
authorization = {"Authorization": get_auth}
url = f"{HOST_ADDRESS}/v1/kb/create"
for i in range(1000): for i in range(1000):
res = create_dataset(get_auth, f"test_create_dataset_{i}") res = create_dataset(get_auth, f"test_create_dataset_{i}")
assert res.get("code") == 0, f"{res.get('message')}" assert res.get("code") == 0, f"{res.get('message')}"
@ -76,7 +73,7 @@ def test_duplicated_name_dataset(get_auth):
dataset_id = item.get("id") dataset_id = item.get("id")
dataset_list.append(dataset_id) dataset_list.append(dataset_id)
match = re.match(pattern, dataset_name) match = re.match(pattern, dataset_name)
assert match != None assert match is not None
for dataset_id in dataset_list: for dataset_id in dataset_list:
res = rm_dataset(get_auth, dataset_id) res = rm_dataset(get_auth, dataset_id)

View File

@ -1,3 +1,3 @@
def test_get_email(get_email): def test_get_email(get_email):
print(f"\nEmail account:",flush=True) print("\nEmail account:",flush=True)
print(f"{get_email}\n",flush=True) print(f"{get_email}\n",flush=True)

View File

@ -1,4 +1,4 @@
from ragflow_sdk import RAGFlow,Agent from ragflow_sdk import RAGFlow
from common import HOST_ADDRESS from common import HOST_ADDRESS
import pytest import pytest