Fix errors detected by Ruff (#3918)

### What problem does this PR solve?

Fix errors detected by Ruff

### Type of change

- [x] Refactoring
This commit is contained in:
Zhichang Yu 2024-12-08 14:21:12 +08:00 committed by GitHub
parent e267a026f3
commit 0d68a6cd1b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
97 changed files with 2558 additions and 1976 deletions

View File

@ -133,7 +133,8 @@ class Canvas(ABC):
"components": {}
}
for k in self.dsl.keys():
if k in ["components"]:continue
if k in ["components"]:
continue
dsl[k] = deepcopy(self.dsl[k])
for k, cpn in self.components.items():
@ -158,7 +159,8 @@ class Canvas(ABC):
def get_compnent_name(self, cid):
for n in self.dsl["graph"]["nodes"]:
if cid == n["id"]: return n["data"]["name"]
if cid == n["id"]:
return n["data"]["name"]
return ""
def run(self, **kwargs):
@ -173,7 +175,8 @@ class Canvas(ABC):
if kwargs.get("stream"):
for an in ans():
yield an
else: yield ans
else:
yield ans
return
if not self.path:
@ -188,7 +191,8 @@ class Canvas(ABC):
def prepare2run(cpns):
nonlocal ran, ans
for c in cpns:
if self.path[-1] and c == self.path[-1][-1]: continue
if self.path[-1] and c == self.path[-1][-1]:
continue
cpn = self.components[c]["obj"]
if cpn.component_name == "Answer":
self.answer.append(c)
@ -197,7 +201,8 @@ class Canvas(ABC):
if c not in without_dependent_checking:
cpids = cpn.get_dependent_components()
if any([cc not in self.path[-1] for cc in cpids]):
if c not in waiting: waiting.append(c)
if c not in waiting:
waiting.append(c)
continue
yield "*'{}'* is running...🕞".format(self.get_compnent_name(c))
ans = cpn.run(self.history, **kwargs)
@ -211,10 +216,12 @@ class Canvas(ABC):
logging.debug(f"Canvas.run: {ran} {self.path}")
cpn_id = self.path[-1][ran]
cpn = self.get_component(cpn_id)
if not cpn["downstream"]: break
if not cpn["downstream"]:
break
loop = self._find_loop()
if loop: raise OverflowError(f"Too much loops: {loop}")
if loop:
raise OverflowError(f"Too much loops: {loop}")
if cpn["obj"].component_name.lower() in ["switch", "categorize", "relevant"]:
switch_out = cpn["obj"].output()[1].iloc[0, 0]
@ -283,19 +290,22 @@ class Canvas(ABC):
def _find_loop(self, max_loops=6):
path = self.path[-1][::-1]
if len(path) < 2: return False
if len(path) < 2:
return False
for i in range(len(path)):
if path[i].lower().find("answer") >= 0:
path = path[:i]
break
if len(path) < 2: return False
if len(path) < 2:
return False
for l in range(2, len(path) // 2):
pat = ",".join(path[0:l])
for loc in range(2, len(path) // 2):
pat = ",".join(path[0:loc])
path_str = ",".join(path)
if len(pat) >= len(path_str): return False
if len(pat) >= len(path_str):
return False
loop = max_loops
while path_str.find(pat) == 0 and loop >= 0:
loop -= 1
@ -303,7 +313,7 @@ class Canvas(ABC):
return False
path_str = path_str[len(pat)+1:]
if loop < 0:
pat = " => ".join([p.split(":")[0] for p in path[0:l]])
pat = " => ".join([p.split(":")[0] for p in path[0:loc]])
return pat + " => " + pat
return False

View File

@ -39,3 +39,73 @@ def component_class(class_name):
m = importlib.import_module("agent.component")
c = getattr(m, class_name)
return c
__all__ = [
"Begin",
"BeginParam",
"Generate",
"GenerateParam",
"Retrieval",
"RetrievalParam",
"Answer",
"AnswerParam",
"Categorize",
"CategorizeParam",
"Switch",
"SwitchParam",
"Relevant",
"RelevantParam",
"Message",
"MessageParam",
"RewriteQuestion",
"RewriteQuestionParam",
"KeywordExtract",
"KeywordExtractParam",
"Concentrator",
"ConcentratorParam",
"Baidu",
"BaiduParam",
"DuckDuckGo",
"DuckDuckGoParam",
"Wikipedia",
"WikipediaParam",
"PubMed",
"PubMedParam",
"ArXiv",
"ArXivParam",
"Google",
"GoogleParam",
"Bing",
"BingParam",
"GoogleScholar",
"GoogleScholarParam",
"DeepL",
"DeepLParam",
"GitHub",
"GitHubParam",
"BaiduFanyi",
"BaiduFanyiParam",
"QWeather",
"QWeatherParam",
"ExeSQL",
"ExeSQLParam",
"YahooFinance",
"YahooFinanceParam",
"WenCai",
"WenCaiParam",
"Jin10",
"Jin10Param",
"TuShare",
"TuShareParam",
"AkShare",
"AkShareParam",
"Crawler",
"CrawlerParam",
"Invoke",
"InvokeParam",
"Template",
"TemplateParam",
"Email",
"EmailParam",
"component_class"
]

View File

@ -428,7 +428,8 @@ class ComponentBase(ABC):
def output(self, allow_partial=True) -> Tuple[str, Union[pd.DataFrame, partial]]:
o = getattr(self._param, self._param.output_var_name)
if not isinstance(o, partial) and not isinstance(o, pd.DataFrame):
if not isinstance(o, list): o = [o]
if not isinstance(o, list):
o = [o]
o = pd.DataFrame(o)
if allow_partial or not isinstance(o, partial):
@ -440,7 +441,8 @@ class ComponentBase(ABC):
for oo in o():
if not isinstance(oo, pd.DataFrame):
outs = pd.DataFrame(oo if isinstance(oo, list) else [oo])
else: outs = oo
else:
outs = oo
return self._param.output_var_name, outs
def reset(self):
@ -482,13 +484,15 @@ class ComponentBase(ABC):
outs.append(pd.DataFrame([{"content": q["value"]}]))
if outs:
df = pd.concat(outs, ignore_index=True)
if "content" in df: df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
if "content" in df:
df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
return df
upstream_outs = []
for u in reversed_cpnts[::-1]:
if self.get_component_name(u) in ["switch", "concentrator"]: continue
if self.get_component_name(u) in ["switch", "concentrator"]:
continue
if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
if o is not None:
@ -532,7 +536,8 @@ class ComponentBase(ABC):
reversed_cpnts.extend(self._canvas.path[-1])
for u in reversed_cpnts[::-1]:
if self.get_component_name(u) in ["switch", "answer"]: continue
if self.get_component_name(u) in ["switch", "answer"]:
continue
return self._canvas.get_component(u)["obj"].output()[1]
@staticmethod

View File

@ -34,15 +34,18 @@ class CategorizeParam(GenerateParam):
super().check()
self.check_empty(self.category_description, "[Categorize] Category examples")
for k, v in self.category_description.items():
if not k: raise ValueError("[Categorize] Category name can not be empty!")
if not v.get("to"): raise ValueError(f"[Categorize] 'To' of category {k} can not be empty!")
if not k:
raise ValueError("[Categorize] Category name can not be empty!")
if not v.get("to"):
raise ValueError(f"[Categorize] 'To' of category {k} can not be empty!")
def get_prompt(self):
cate_lines = []
for c, desc in self.category_description.items():
for l in desc.get("examples", "").split("\n"):
if not l: continue
cate_lines.append("Question: {}\tCategory: {}".format(l, c))
for line in desc.get("examples", "").split("\n"):
if not line:
continue
cate_lines.append("Question: {}\tCategory: {}".format(line, c))
descriptions = []
for c, desc in self.category_description.items():
if desc.get("description"):

View File

@ -14,7 +14,6 @@
# limitations under the License.
#
from abc import ABC
import re
from agent.component.base import ComponentBase, ComponentParamBase
import deepl

View File

@ -46,8 +46,10 @@ class ExeSQLParam(ComponentParamBase):
self.check_empty(self.password, "Database password")
self.check_positive_integer(self.top_n, "Number of records")
if self.database == "rag_flow":
if self.host == "ragflow-mysql": raise ValueError("The host is not accessible.")
if self.password == "infini_rag_flow": raise ValueError("The host is not accessible.")
if self.host == "ragflow-mysql":
raise ValueError("The host is not accessible.")
if self.password == "infini_rag_flow":
raise ValueError("The host is not accessible.")
class ExeSQL(ComponentBase, ABC):

View File

@ -51,11 +51,16 @@ class GenerateParam(ComponentParamBase):
def gen_conf(self):
conf = {}
if self.max_tokens > 0: conf["max_tokens"] = self.max_tokens
if self.temperature > 0: conf["temperature"] = self.temperature
if self.top_p > 0: conf["top_p"] = self.top_p
if self.presence_penalty > 0: conf["presence_penalty"] = self.presence_penalty
if self.frequency_penalty > 0: conf["frequency_penalty"] = self.frequency_penalty
if self.max_tokens > 0:
conf["max_tokens"] = self.max_tokens
if self.temperature > 0:
conf["temperature"] = self.temperature
if self.top_p > 0:
conf["top_p"] = self.top_p
if self.presence_penalty > 0:
conf["presence_penalty"] = self.presence_penalty
if self.frequency_penalty > 0:
conf["frequency_penalty"] = self.frequency_penalty
return conf
@ -83,7 +88,8 @@ class Generate(ComponentBase):
recall_docs = []
for i in idx:
did = retrieval_res.loc[int(i), "doc_id"]
if did in doc_ids: continue
if did in doc_ids:
continue
doc_ids.add(did)
recall_docs.append({"doc_id": did, "doc_name": retrieval_res.loc[int(i), "docnm_kwd"]})
@ -108,7 +114,8 @@ class Generate(ComponentBase):
retrieval_res = []
self._param.inputs = []
for para in self._param.parameters:
if not para.get("component_id"): continue
if not para.get("component_id"):
continue
component_id = para["component_id"].split("@")[0]
if para["component_id"].lower().find("@") >= 0:
cpn_id, key = para["component_id"].split("@")
@ -142,7 +149,8 @@ class Generate(ComponentBase):
if retrieval_res:
retrieval_res = pd.concat(retrieval_res, ignore_index=True)
else: retrieval_res = pd.DataFrame([])
else:
retrieval_res = pd.DataFrame([])
for n, v in kwargs.items():
prompt = re.sub(r"\{%s\}" % re.escape(n), str(v).replace("\\", " "), prompt)
@ -164,9 +172,11 @@ class Generate(ComponentBase):
return pd.DataFrame([res])
msg = self._canvas.get_history(self._param.message_history_window_size)
if len(msg) < 1: msg.append({"role": "user", "content": ""})
if len(msg) < 1:
msg.append({"role": "user", "content": ""})
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
if len(msg) < 2: msg.append({"role": "user", "content": ""})
if len(msg) < 2:
msg.append({"role": "user", "content": ""})
ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf())
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
@ -185,9 +195,11 @@ class Generate(ComponentBase):
return
msg = self._canvas.get_history(self._param.message_history_window_size)
if len(msg) < 1: msg.append({"role": "user", "content": ""})
if len(msg) < 1:
msg.append({"role": "user", "content": ""})
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
if len(msg) < 2: msg.append({"role": "user", "content": ""})
if len(msg) < 2:
msg.append({"role": "user", "content": ""})
answer = ""
for ans in chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf()):
res = {"content": ans, "reference": []}

View File

@ -95,7 +95,8 @@ class RewriteQuestion(Generate, ABC):
hist = self._canvas.get_history(4)
conv = []
for m in hist:
if m["role"] not in ["user", "assistant"]: continue
if m["role"] not in ["user", "assistant"]:
continue
conv.append("{}: {}".format(m["role"].upper(), m["content"]))
conv = "\n".join(conv)

View File

@ -41,7 +41,8 @@ class SwitchParam(ComponentParamBase):
def check(self):
self.check_empty(self.conditions, "[Switch] conditions")
for cond in self.conditions:
if not cond["to"]: raise ValueError(f"[Switch] 'To' can not be empty!")
if not cond["to"]:
raise ValueError("[Switch] 'To' can not be empty!")
class Switch(ComponentBase, ABC):
@ -51,7 +52,8 @@ class Switch(ComponentBase, ABC):
res = []
for cond in self._param.conditions:
for item in cond["items"]:
if not item["cpn_id"]: continue
if not item["cpn_id"]:
continue
if item["cpn_id"].find("begin") >= 0:
continue
cid = item["cpn_id"].split("@")[0]
@ -63,7 +65,8 @@ class Switch(ComponentBase, ABC):
for cond in self._param.conditions:
res = []
for item in cond["items"]:
if not item["cpn_id"]:continue
if not item["cpn_id"]:
continue
cid = item["cpn_id"].split("@")[0]
if item["cpn_id"].find("@") > 0:
cpn_id, key = item["cpn_id"].split("@")
@ -107,22 +110,22 @@ class Switch(ComponentBase, ABC):
elif operator == ">":
try:
return True if float(input) > float(value) else False
except Exception as e:
except Exception:
return True if input > value else False
elif operator == "<":
try:
return True if float(input) < float(value) else False
except Exception as e:
except Exception:
return True if input < value else False
elif operator == "":
try:
return True if float(input) >= float(value) else False
except Exception as e:
except Exception:
return True if input >= value else False
elif operator == "":
try:
return True if float(input) <= float(value) else False
except Exception as e:
except Exception:
return True if input <= value else False
raise ValueError('Not supported operator' + operator)

View File

@ -47,7 +47,8 @@ class Template(ComponentBase):
self._param.inputs = []
for para in self._param.parameters:
if not para.get("component_id"): continue
if not para.get("component_id"):
continue
component_id = para["component_id"].split("@")[0]
if para["component_id"].lower().find("@") >= 0:
cpn_id, key = para["component_id"].split("@")

View File

@ -43,6 +43,7 @@ if __name__ == '__main__':
else:
print(ans["content"])
if DEBUG: print(canvas.path)
if DEBUG:
print(canvas.path)
question = input("\n==================== User =====================\n> ")
canvas.add_user_input(question)

View File

@ -142,7 +142,6 @@ def set_conversation():
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
req = request.json
try:
if objs[0].source == "agent":
e, cvs = UserCanvasService.get_by_id(objs[0].dialog_id)
@ -188,7 +187,8 @@ def completion():
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
if not e:
return get_data_error_result(message="Conversation not found!")
if "quote" not in req: req["quote"] = False
if "quote" not in req:
req["quote"] = False
msg = []
for m in req["messages"]:
@ -197,7 +197,8 @@ def completion():
if m["role"] == "assistant" and not msg:
continue
msg.append(m)
if not msg[-1].get("id"): msg[-1]["id"] = get_uuid()
if not msg[-1].get("id"):
msg[-1]["id"] = get_uuid()
message_id = msg[-1]["id"]
def fillin_conv(ans):
@ -674,11 +675,13 @@ def completion_faq():
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
if not e:
return get_data_error_result(message="Conversation not found!")
if "quote" not in req: req["quote"] = True
if "quote" not in req:
req["quote"] = True
msg = []
msg.append({"role": "user", "content": req["word"]})
if not msg[-1].get("id"): msg[-1]["id"] = get_uuid()
if not msg[-1].get("id"):
msg[-1]["id"] = get_uuid()
message_id = msg[-1]["id"]
def fillin_conv(ans):

View File

@ -13,10 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import json
import traceback
from functools import partial
from flask import request, Response
from flask_login import login_required, current_user
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
@ -60,7 +58,8 @@ def rm():
def save():
req = request.json
req["user_id"] = current_user.id
if not isinstance(req["dsl"], str): req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
if not isinstance(req["dsl"], str):
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
req["dsl"] = json.loads(req["dsl"])
if "id" not in req:
@ -153,7 +152,8 @@ def run():
return resp
for answer in canvas.run(stream=False):
if answer.get("running_status"): continue
if answer.get("running_status"):
continue
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
if final_ans.get("reference"):

View File

@ -237,7 +237,8 @@ def create():
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
if not e:
return get_data_error_result(message="Knowledgebase not found!")
if kb.pagerank: d["pagerank_fea"] = kb.pagerank
if kb.pagerank:
d["pagerank_fea"] = kb.pagerank
embd_id = DocumentService.get_embd_id(req["doc_id"])
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)

View File

@ -281,10 +281,12 @@ def thumbup():
if req["message_id"] == msg.get("id", "") and msg.get("role", "") == "assistant":
if up_down:
msg["thumbup"] = True
if "feedback" in msg: del msg["feedback"]
if "feedback" in msg:
del msg["feedback"]
else:
msg["thumbup"] = False
if feedback: msg["feedback"] = feedback
if feedback:
msg["feedback"] = feedback
break
ConversationService.update_by_id(conv["id"], conv)

View File

@ -37,10 +37,12 @@ def set_dialog():
top_n = req.get("top_n", 6)
top_k = req.get("top_k", 1024)
rerank_id = req.get("rerank_id", "")
if not rerank_id: req["rerank_id"] = ""
if not rerank_id:
req["rerank_id"] = ""
similarity_threshold = req.get("similarity_threshold", 0.1)
vector_similarity_weight = req.get("vector_similarity_weight", 0.3)
if vector_similarity_weight is None: vector_similarity_weight = 0.3
if vector_similarity_weight is None:
vector_similarity_weight = 0.3
llm_setting = req.get("llm_setting", {})
default_prompt = {
"system": """你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答。当所有知识库内容都与问题无关时,你的回答必须包括“知识库中未找到您要的答案!”这句话。回答需要考虑聊天历史。

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License
#
import json
import os.path
import pathlib
import re
@ -90,7 +89,8 @@ def web_crawl():
raise LookupError("Can't find this knowledgebase!")
blob = html2pdf(url)
if not blob: return server_error_response(ValueError("Download failure."))
if not blob:
return server_error_response(ValueError("Download failure."))
root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder["id"]
@ -290,7 +290,8 @@ def change_status():
def rm():
req = request.json
doc_ids = req["doc_id"]
if isinstance(doc_ids, str): doc_ids = [doc_ids]
if isinstance(doc_ids, str):
doc_ids = [doc_ids]
for doc_id in doc_ids:
if not DocumentService.accessible4deletion(doc_id, current_user.id):

View File

@ -351,8 +351,10 @@ def list_app():
llm_set = set([m["llm_name"] + "@" + m["fid"] for m in llms])
for o in objs:
if not o.api_key: continue
if o.llm_name + "@" + o.llm_factory in llm_set: continue
if not o.api_key:
continue
if o.llm_name + "@" + o.llm_factory in llm_set:
continue
llms.append({"llm_name": o.llm_name, "model_type": o.model_type, "fid": o.llm_factory, "available": True})
res = {}

View File

@ -14,7 +14,7 @@
# limitations under the License.
#
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
from api.db.services.canvas_service import UserCanvasService
from api.utils.api_utils import get_error_data_result, token_required
from api.utils.api_utils import get_result
from flask import request

View File

@ -41,7 +41,6 @@ from api.utils.api_utils import construct_json_result, get_parser_config
from rag.nlp import search
from rag.utils import rmSpace
from rag.utils.storage_factory import STORAGE_IMPL
import os
MAXIMUM_OF_UPLOADING_FILES = 256
@ -976,12 +975,12 @@ def add_chunk(tenant_id, dataset_id, document_id):
if not req.get("content"):
return get_error_data_result(message="`content` is required")
if "important_keywords" in req:
if type(req["important_keywords"]) != list:
if not isinstance(req["important_keywords"], list):
return get_error_data_result(
"`important_keywords` is required to be a list"
)
if "questions" in req:
if type(req["questions"]) != list:
if not isinstance(req["questions"], list):
return get_error_data_result(
"`questions` is required to be a list"
)

View File

@ -143,8 +143,10 @@ def completion(tenant_id, chat_id):
}
conv.message.append(question)
for m in conv.message:
if m["role"] == "system": continue
if m["role"] == "assistant" and not msg: continue
if m["role"] == "system":
continue
if m["role"] == "assistant" and not msg:
continue
msg.append(m)
message_id = msg[-1].get("id")
e, dia = DialogService.get_by_id(conv.dialog_id)
@ -267,7 +269,8 @@ def agent_completion(tenant_id, agent_id):
if m["role"] == "assistant" and not msg:
continue
msg.append(m)
if not msg[-1].get("id"): msg[-1]["id"] = get_uuid()
if not msg[-1].get("id"):
msg[-1]["id"] = get_uuid()
message_id = msg[-1]["id"]
stream = req.get("stream", True)
@ -361,7 +364,8 @@ def agent_completion(tenant_id, agent_id):
return resp
for answer in canvas.run(stream=False):
if answer.get("running_status"): continue
if answer.get("running_status"):
continue
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
if final_ans.get("reference"):

View File

@ -330,7 +330,7 @@ def user_info_from_github(access_token):
headers=headers,
).json()
user_info["email"] = next(
(email for email in email_info if email["primary"] == True), None
(email for email in email_info if email["primary"]), None
)["email"]
return user_info

View File

@ -130,7 +130,7 @@ def is_continuous_field(cls: typing.Type) -> bool:
for p in cls.__bases__:
if p in CONTINUOUS_FIELD_TYPE:
return True
elif p != Field and p != object:
elif p is not Field and p is not object:
if is_continuous_field(p):
return True
else:

View File

@ -170,7 +170,7 @@ def add_graph_templates():
cnvs = json.load(open(os.path.join(dir, fnm), "r"))
try:
CanvasTemplateService.save(**cnvs)
except:
except Exception:
CanvasTemplateService.update_by_id(cnvs["id"], cnvs)
except Exception:
logging.exception("Add graph templates error: ")

View File

@ -15,13 +15,14 @@
#
import pathlib
import re
from .user_service import UserService
from .user_service import UserService as UserService
def duplicate_name(query_func, **kwargs):
fnm = kwargs["name"]
objs = query_func(**kwargs)
if not objs: return fnm
if not objs:
return fnm
ext = pathlib.Path(fnm).suffix #.jpg
nm = re.sub(r"%s$"%ext, "", fnm)
r = re.search(r"\(([0-9]+)\)$", nm)
@ -31,8 +32,8 @@ def duplicate_name(query_func, **kwargs):
nm = re.sub(r"\([0-9]+\)$", "", nm)
c += 1
nm = f"{nm}({c})"
if ext: nm += f"{ext}"
if ext:
nm += f"{ext}"
kwargs["name"] = nm
return duplicate_name(query_func, **kwargs)

View File

@ -64,7 +64,8 @@ class API4ConversationService(CommonService):
@classmethod
@DB.connection_context()
def stats(cls, tenant_id, from_date, to_date, source=None):
if len(to_date) == 10: to_date += " 23:59:59"
if len(to_date) == 10:
to_date += " 23:59:59"
return cls.model.select(
cls.model.create_date.truncate("day").alias("dt"),
peewee.fn.COUNT(

View File

@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from datetime import datetime
import peewee
from api.db.db_models import DB, API4Conversation, APIToken, Dialog, CanvasTemplate, UserCanvas
from api.db.db_models import DB, CanvasTemplate, UserCanvas
from api.db.services.common_service import CommonService

View File

@ -115,7 +115,7 @@ class CommonService:
try:
obj = cls.model.query(id=pid)[0]
return True, obj
except Exception as e:
except Exception:
return False, None
@classmethod

View File

@ -106,15 +106,15 @@ def message_fit_in(msg, max_length=4000):
return c, msg
ll = num_tokens_from_string(msg_[0]["content"])
l = num_tokens_from_string(msg_[-1]["content"])
if ll / (ll + l) > 0.8:
ll2 = num_tokens_from_string(msg_[-1]["content"])
if ll / (ll + ll2) > 0.8:
m = msg_[0]["content"]
m = encoder.decode(encoder.encode(m)[:max_length - l])
m = encoder.decode(encoder.encode(m)[:max_length - ll2])
msg[0]["content"] = m
return max_length, msg
m = msg_[1]["content"]
m = encoder.decode(encoder.encode(m)[:max_length - l])
m = encoder.decode(encoder.encode(m)[:max_length - ll2])
msg[1]["content"] = m
return max_length, msg
@ -257,7 +257,8 @@ def chat(dialog, messages, stream=True, **kwargs):
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
if not recall_docs: recall_docs = kbinfos["doc_aggs"]
if not recall_docs:
recall_docs = kbinfos["doc_aggs"]
kbinfos["doc_aggs"] = recall_docs
refs = deepcopy(kbinfos)
@ -433,13 +434,15 @@ def relevant(tenant_id, llm_id, question, contents: list):
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.
No other words needed except 'yes' or 'no'.
"""
if not contents:return False
if not contents:
return False
contents = "Documents: \n" + " - ".join(contents)
contents = f"Question: {question}\n" + contents
if num_tokens_from_string(contents) >= chat_mdl.max_length - 4:
contents = encoder.decode(encoder.encode(contents)[:chat_mdl.max_length - 4])
ans = chat_mdl.chat(prompt, [{"role": "user", "content": contents}], {"temperature": 0.01})
if ans.lower().find("yes") >= 0: return True
if ans.lower().find("yes") >= 0:
return True
return False
@ -481,8 +484,10 @@ Requirements:
]
_, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple): kwd = kwd[0]
if kwd.find("**ERROR**") >=0: return ""
if isinstance(kwd, tuple):
kwd = kwd[0]
if kwd.find("**ERROR**") >=0:
return ""
return kwd
@ -508,8 +513,10 @@ Requirements:
]
_, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple): kwd = kwd[0]
if kwd.find("**ERROR**") >= 0: return ""
if isinstance(kwd, tuple):
kwd = kwd[0]
if kwd.find("**ERROR**") >= 0:
return ""
return kwd
@ -520,7 +527,8 @@ def full_question(tenant_id, llm_id, messages):
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
conv = []
for m in messages:
if m["role"] not in ["user", "assistant"]: continue
if m["role"] not in ["user", "assistant"]:
continue
conv.append("{}: {}".format(m["role"].upper(), m["content"]))
conv = "\n".join(conv)
today = datetime.date.today().isoformat()
@ -581,7 +589,8 @@ Output: What's the weather in Rochester on {tomorrow}?
def tts(tts_mdl, text):
if not tts_mdl or not text: return
if not tts_mdl or not text:
return
bin = b""
for chunk in tts_mdl.tts(text):
bin += chunk
@ -641,7 +650,8 @@ def ask(question, kb_ids, tenant_id):
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
if not recall_docs: recall_docs = kbinfos["doc_aggs"]
if not recall_docs:
recall_docs = kbinfos["doc_aggs"]
kbinfos["doc_aggs"] = recall_docs
refs = deepcopy(kbinfos)
for c in refs["chunks"]:

View File

@ -532,7 +532,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
try:
mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output,
ensure_ascii=False, indent=2)
if len(mind_map) < 32: raise Exception("Few content: " + mind_map)
if len(mind_map) < 32:
raise Exception("Few content: " + mind_map)
cks.append({
"id": get_uuid(),
"doc_id": doc_id,

View File

@ -20,7 +20,7 @@ from api.db.db_models import DB
from api.db.db_models import File, File2Document
from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService
from api.utils import current_timestamp, datetime_format, get_uuid
from api.utils import current_timestamp, datetime_format
class File2DocumentService(CommonService):
@ -63,7 +63,7 @@ class File2DocumentService(CommonService):
def update_by_file_id(cls, file_id, obj):
obj["update_time"] = current_timestamp()
obj["update_date"] = datetime_format(datetime.now())
num = cls.model.update(obj).where(cls.model.id == file_id).execute()
# num = cls.model.update(obj).where(cls.model.id == file_id).execute()
e, obj = cls.get_by_id(cls.model.id)
return obj

View File

@ -85,7 +85,8 @@ class FileService(CommonService):
.join(Document, on=(File2Document.document_id == Document.id))
.join(Knowledgebase, on=(Knowledgebase.id == Document.kb_id))
.where(cls.model.id == file_id))
if not kbs: return []
if not kbs:
return []
kbs_info_list = []
for kb in list(kbs.dicts()):
kbs_info_list.append({"kb_id": kb['id'], "kb_name": kb['name']})
@ -304,7 +305,8 @@ class FileService(CommonService):
@classmethod
@DB.connection_context()
def add_file_from_kb(cls, doc, kb_folder_id, tenant_id):
for _ in File2DocumentService.get_by_document_id(doc["id"]): return
for _ in File2DocumentService.get_by_document_id(doc["id"]):
return
file = {
"id": get_uuid(),
"parent_id": kb_folder_id,

View File

@ -107,7 +107,8 @@ class TenantLLMService(CommonService):
model_config = cls.get_api_key(tenant_id, mdlnm)
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
if model_config: model_config = model_config.to_dict()
if model_config:
model_config = model_config.to_dict()
if not model_config:
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)

View File

@ -57,28 +57,33 @@ class TaskService(CommonService):
Tenant.img2txt_id,
Tenant.asr_id,
Tenant.llm_id,
cls.model.update_time]
docs = cls.model.select(*fields) \
.join(Document, on=(cls.model.doc_id == Document.id)) \
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id)) \
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \
cls.model.update_time,
]
docs = (
cls.model.select(*fields)
.join(Document, on=(cls.model.doc_id == Document.id))
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where(cls.model.id == task_id)
)
docs = list(docs.dicts())
if not docs: return None
if not docs:
return None
msg = "\nTask has been received."
prog = random.random() / 10.
prog = random.random() / 10.0
if docs[0]["retry_count"] >= 3:
msg = "\nERROR: Task is abandoned after 3 times attempts."
prog = -1
cls.model.update(progress_msg=cls.model.progress_msg + msg,
progress=prog,
retry_count=docs[0]["retry_count"]+1
).where(
cls.model.id == docs[0]["id"]).execute()
cls.model.update(
progress_msg=cls.model.progress_msg + msg,
progress=prog,
retry_count=docs[0]["retry_count"] + 1,
).where(cls.model.id == docs[0]["id"]).execute()
if docs[0]["retry_count"] >= 3: return None
if docs[0]["retry_count"] >= 3:
return None
return docs[0]
@ -86,21 +91,44 @@ class TaskService(CommonService):
@DB.connection_context()
def get_ongoing_doc_name(cls):
with DB.lock("get_task", -1):
docs = cls.model.select(*[Document.id, Document.kb_id, Document.location, File.parent_id]) \
.join(Document, on=(cls.model.doc_id == Document.id)) \
.join(File2Document, on=(File2Document.document_id == Document.id), join_type=JOIN.LEFT_OUTER) \
.join(File, on=(File2Document.file_id == File.id), join_type=JOIN.LEFT_OUTER) \
docs = (
cls.model.select(
*[Document.id, Document.kb_id, Document.location, File.parent_id]
)
.join(Document, on=(cls.model.doc_id == Document.id))
.join(
File2Document,
on=(File2Document.document_id == Document.id),
join_type=JOIN.LEFT_OUTER,
)
.join(
File,
on=(File2Document.file_id == File.id),
join_type=JOIN.LEFT_OUTER,
)
.where(
Document.status == StatusEnum.VALID.value,
Document.run == TaskStatus.RUNNING.value,
~(Document.type == FileType.VIRTUAL.value),
cls.model.progress < 1,
cls.model.create_time >= current_timestamp() - 1000 * 600
cls.model.create_time >= current_timestamp() - 1000 * 600,
)
)
docs = list(docs.dicts())
if not docs: return []
if not docs:
return []
return list(set([(d["parent_id"] if d["parent_id"] else d["kb_id"], d["location"]) for d in docs]))
return list(
set(
[
(
d["parent_id"] if d["parent_id"] else d["kb_id"],
d["location"],
)
for d in docs
]
)
)
@classmethod
@DB.connection_context()
@ -118,28 +146,30 @@ class TaskService(CommonService):
def update_progress(cls, id, info):
if os.environ.get("MACOS"):
if info["progress_msg"]:
cls.model.update(progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]).where(
cls.model.id == id).execute()
cls.model.update(
progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]
).where(cls.model.id == id).execute()
if "progress" in info:
cls.model.update(progress=info["progress"]).where(
cls.model.id == id).execute()
cls.model.id == id
).execute()
return
with DB.lock("update_progress", -1):
if info["progress_msg"]:
cls.model.update(progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]).where(
cls.model.id == id).execute()
cls.model.update(
progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]
).where(cls.model.id == id).execute()
if "progress" in info:
cls.model.update(progress=info["progress"]).where(
cls.model.id == id).execute()
cls.model.id == id
).execute()
def queue_tasks(doc: dict, bucket: str, name: str):
def new_task():
return {
"id": get_uuid(),
"doc_id": doc["id"]
}
return {"id": get_uuid(), "doc_id": doc["id"]}
tsks = []
if doc["type"] == FileType.PDF.value:
@ -150,8 +180,8 @@ def queue_tasks(doc: dict, bucket: str, name: str):
if doc["parser_id"] == "paper":
page_size = doc["parser_config"].get("task_page_size", 22)
if doc["parser_id"] in ["one", "knowledge_graph"] or not do_layout:
page_size = 10 ** 9
page_ranges = doc["parser_config"].get("pages") or [(1, 10 ** 5)]
page_size = 10**9
page_ranges = doc["parser_config"].get("pages") or [(1, 10**5)]
for s, e in page_ranges:
s -= 1
s = max(0, s)
@ -177,4 +207,6 @@ def queue_tasks(doc: dict, bucket: str, name: str):
DocumentService.begin2parse(doc["id"])
for t in tsks:
assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=t), "Can't access Redis. Please check the Redis' status."
assert REDIS_CONN.queue_product(
SVR_QUEUE_NAME, message=t
), "Can't access Redis. Please check the Redis' status."

View File

@ -22,7 +22,7 @@ from api.db import UserTenantRole
from api.db.db_models import DB, UserTenant
from api.db.db_models import User, Tenant
from api.db.services.common_service import CommonService
from api.utils import get_uuid, get_format_time, current_timestamp, datetime_format
from api.utils import get_uuid, current_timestamp, datetime_format
from api.db import StatusEnum

View File

@ -21,10 +21,7 @@
import logging
import os
from api.utils.log_utils import initRootLogger
LOG_LEVELS = os.environ.get("LOG_LEVELS", "")
initRootLogger("ragflow_server", LOG_LEVELS)
import os
import signal
import sys
import time
@ -44,6 +41,9 @@ from api.versions import get_ragflow_version
from api.utils import show_configs
from rag.settings import print_rag_settings
LOG_LEVELS = os.environ.get("LOG_LEVELS", "")
initRootLogger("ragflow_server", LOG_LEVELS)
def update_progress():
while True:

View File

@ -36,7 +36,6 @@ from werkzeug.http import HTTP_STATUS_CODES
from api.db.db_models import APIToken
from api import settings
from api import settings
from api.utils import CustomJSONEncoder, get_uuid
from api.utils import json_dumps
from api.constants import REQUEST_WAIT_SEC, REQUEST_MAX_WAIT_SEC

View File

@ -45,5 +45,5 @@ try:
pool = Pool(processes=1)
thread = pool.apply_async(download_nltk_data)
binary = thread.get(timeout=60)
except Exception as e:
except Exception:
print('\x1b[6;37;41m WARNING \x1b[0m' + "Downloading NLTK data failure.", flush=True)

View File

@ -18,4 +18,16 @@ from .ppt_parser import RAGFlowPptParser as PptParser
from .html_parser import RAGFlowHtmlParser as HtmlParser
from .json_parser import RAGFlowJsonParser as JsonParser
from .markdown_parser import RAGFlowMarkdownParser as MarkdownParser
from .txt_parser import RAGFlowTxtParser as TxtParser
from .txt_parser import RAGFlowTxtParser as TxtParser
__all__ = [
"PdfParser",
"PlainParser",
"DocxParser",
"ExcelParser",
"PptParser",
"HtmlParser",
"JsonParser",
"MarkdownParser",
"TxtParser",
]

View File

@ -29,7 +29,8 @@ class RAGFlowExcelParser:
for sheetname in wb.sheetnames:
ws = wb[sheetname]
rows = list(ws.rows)
if not rows: continue
if not rows:
continue
tb_rows_0 = "<tr>"
for t in list(rows[0]):
@ -40,7 +41,9 @@ class RAGFlowExcelParser:
tb = ""
tb += f"<table><caption>{sheetname}</caption>"
tb += tb_rows_0
for r in list(rows[1 + chunk_i * chunk_rows:1 + (chunk_i + 1) * chunk_rows]):
for r in list(
rows[1 + chunk_i * chunk_rows : 1 + (chunk_i + 1) * chunk_rows]
):
tb += "<tr>"
for i, c in enumerate(r):
if c.value is None:
@ -62,20 +65,21 @@ class RAGFlowExcelParser:
for sheetname in wb.sheetnames:
ws = wb[sheetname]
rows = list(ws.rows)
if not rows:continue
if not rows:
continue
ti = list(rows[0])
for r in list(rows[1:]):
l = []
fields = []
for i, c in enumerate(r):
if not c.value:
continue
t = str(ti[i].value) if i < len(ti) else ""
t += ("" if t else "") + str(c.value)
l.append(t)
l = "; ".join(l)
fields.append(t)
line = "; ".join(fields)
if sheetname.lower().find("sheet") < 0:
l += " ——" + sheetname
res.append(l)
line += " ——" + sheetname
res.append(line)
return res
@staticmethod

View File

@ -36,7 +36,7 @@ class RAGFlowHtmlParser:
@classmethod
def parser_txt(cls, txt):
if type(txt) != str:
if not isinstance(txt, str):
raise TypeError("txt type should be str!")
html_doc = readability.Document(txt)
title = html_doc.title()

View File

@ -22,7 +22,7 @@ class RAGFlowJsonParser:
txt = binary.decode(encoding, errors="ignore")
json_data = json.loads(txt)
chunks = self.split_json(json_data, True)
sections = [json.dumps(l, ensure_ascii=False) for l in chunks if l]
sections = [json.dumps(line, ensure_ascii=False) for line in chunks if line]
return sections
@staticmethod

View File

@ -752,7 +752,7 @@ class RAGFlowPdfParser:
"x1": np.max([b["x1"] for b in bxs]),
"bottom": np.max([b["bottom"] for b in bxs]) - ht
}
louts = [l for l in self.page_layout[pn] if l["type"] == ltype]
louts = [layout for layout in self.page_layout[pn] if layout["type"] == ltype]
ii = Recognizer.find_overlapped(b, louts, naive=True)
if ii is not None:
b = louts[ii]
@ -763,7 +763,8 @@ class RAGFlowPdfParser:
"layoutno", "")))
left, top, right, bott = b["x0"], b["top"], b["x1"], b["bottom"]
if right < left: right = left + 1
if right < left:
right = left + 1
poss.append((pn + self.page_from, left, right, top, bott))
return self.page_images[pn] \
.crop((left * ZM, top * ZM,
@ -845,7 +846,8 @@ class RAGFlowPdfParser:
top = bx["top"] - self.page_cum_height[pn[0] - 1]
bott = bx["bottom"] - self.page_cum_height[pn[0] - 1]
page_images_cnt = len(self.page_images)
if pn[-1] - 1 >= page_images_cnt: return ""
if pn[-1] - 1 >= page_images_cnt:
return ""
while bott * ZM > self.page_images[pn[-1] - 1].size[1]:
bott -= self.page_images[pn[-1] - 1].size[1] / ZM
pn.append(pn[-1] + 1)
@ -889,7 +891,6 @@ class RAGFlowPdfParser:
nonlocal mh, pw, lines, widths
lines.append(line)
widths.append(width(line))
width_mean = np.mean(widths)
mmj = self.proj_match(
line["text"]) or line.get(
"layout_type",
@ -994,7 +995,7 @@ class RAGFlowPdfParser:
else:
self.is_english = False
st = timer()
# st = timer()
for i, img in enumerate(self.page_images_x2):
chars = self.page_chars[i] if not self.is_english else []
self.mean_height.append(
@ -1028,8 +1029,8 @@ class RAGFlowPdfParser:
self.page_cum_height = np.cumsum(self.page_cum_height)
assert len(self.page_cum_height) == len(self.page_images) + 1
if len(self.boxes) == 0 and zoomin < 9: self.__images__(fnm, zoomin * 3, page_from,
page_to, callback)
if len(self.boxes) == 0 and zoomin < 9:
self.__images__(fnm, zoomin * 3, page_from, page_to, callback)
def __call__(self, fnm, need_image=True, zoomin=3, return_html=False):
self.__images__(fnm, zoomin)
@ -1168,7 +1169,7 @@ class PlainParser(object):
if not self.outlines:
logging.warning("Miss outlines")
return [(l, "") for l in lines], []
return [(line, "") for line in lines], []
def crop(self, ck, need_position):
raise NotImplementedError

View File

@ -15,21 +15,42 @@ import datetime
def refactor(cv):
for n in ["raw_txt", "parser_name", "inference", "ori_text", "use_time", "time_stat"]:
if n in cv and cv[n] is not None: del cv[n]
for n in [
"raw_txt",
"parser_name",
"inference",
"ori_text",
"use_time",
"time_stat",
]:
if n in cv and cv[n] is not None:
del cv[n]
cv["is_deleted"] = 0
if "basic" not in cv: cv["basic"] = {}
if cv["basic"].get("photo2"): del cv["basic"]["photo2"]
if "basic" not in cv:
cv["basic"] = {}
if cv["basic"].get("photo2"):
del cv["basic"]["photo2"]
for n in ["education", "work", "certificate", "project", "language", "skill", "training"]:
if n not in cv or cv[n] is None: continue
if type(cv[n]) == type({}): cv[n] = [v for _, v in cv[n].items()]
if type(cv[n]) != type([]):
for n in [
"education",
"work",
"certificate",
"project",
"language",
"skill",
"training",
]:
if n not in cv or cv[n] is None:
continue
if isinstance(cv[n], dict):
cv[n] = [v for _, v in cv[n].items()]
if not isinstance(cv[n], list):
del cv[n]
continue
vv = []
for v in cv[n]:
if "external" in v and v["external"] is not None: del v["external"]
if "external" in v and v["external"] is not None:
del v["external"]
vv.append(v)
cv[n] = {str(i): vv[i] for i in range(len(vv))}
@ -42,24 +63,44 @@ def refactor(cv):
cv["basic"][t] = cv["basic"][n]
del cv["basic"][n]
work = sorted([v for _, v in cv.get("work", {}).items()], key=lambda x: x.get("start_time", ""))
edu = sorted([v for _, v in cv.get("education", {}).items()], key=lambda x: x.get("start_time", ""))
work = sorted(
[v for _, v in cv.get("work", {}).items()],
key=lambda x: x.get("start_time", ""),
)
edu = sorted(
[v for _, v in cv.get("education", {}).items()],
key=lambda x: x.get("start_time", ""),
)
if work:
cv["basic"]["work_start_time"] = work[0].get("start_time", "")
cv["basic"]["management_experience"] = 'Y' if any(
[w.get("management_experience", '') == 'Y' for w in work]) else 'N'
cv["basic"]["management_experience"] = (
"Y"
if any([w.get("management_experience", "") == "Y" for w in work])
else "N"
)
cv["basic"]["annual_salary"] = work[-1].get("annual_salary_from", "0")
for n in ["annual_salary_from", "annual_salary_to", "industry_name", "position_name", "responsibilities",
"corporation_type", "scale", "corporation_name"]:
for n in [
"annual_salary_from",
"annual_salary_to",
"industry_name",
"position_name",
"responsibilities",
"corporation_type",
"scale",
"corporation_name",
]:
cv["basic"][n] = work[-1].get(n, "")
if edu:
for n in ["school_name", "discipline_name"]:
if n in edu[-1]: cv["basic"][n] = edu[-1][n]
if n in edu[-1]:
cv["basic"][n] = edu[-1][n]
cv["basic"]["updated_at"] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
if "contact" not in cv: cv["contact"] = {}
if not cv["contact"].get("name"): cv["contact"]["name"] = cv["basic"].get("name", "")
return cv
if "contact" not in cv:
cv["contact"] = {}
if not cv["contact"].get("name"):
cv["contact"]["name"] = cv["basic"].get("name", "")
return cv

View File

@ -21,13 +21,18 @@ from . import regions
current_file_path = os.path.dirname(os.path.abspath(__file__))
GOODS = pd.read_csv(os.path.join(current_file_path, "res/corp_baike_len.csv"), sep="\t", header=0).fillna(0)
GOODS = pd.read_csv(
os.path.join(current_file_path, "res/corp_baike_len.csv"), sep="\t", header=0
).fillna(0)
GOODS["cid"] = GOODS["cid"].astype(str)
GOODS = GOODS.set_index(["cid"])
CORP_TKS = json.load(open(os.path.join(current_file_path, "res/corp.tks.freq.json"), "r"))
CORP_TKS = json.load(
open(os.path.join(current_file_path, "res/corp.tks.freq.json"), "r")
)
GOOD_CORP = json.load(open(os.path.join(current_file_path, "res/good_corp.json"), "r"))
CORP_TAG = json.load(open(os.path.join(current_file_path, "res/corp_tag.json"), "r"))
def baike(cid, default_v=0):
global GOODS
try:
@ -39,27 +44,41 @@ def baike(cid, default_v=0):
def corpNorm(nm, add_region=True):
global CORP_TKS
if not nm or type(nm)!=type(""):return ""
if not nm or isinstance(nm, str):
return ""
nm = rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(nm)).lower()
nm = re.sub(r"&amp;", "&", nm)
nm = re.sub(r"[\(\)\+'\"\t \*\\【】-]+", " ", nm)
nm = re.sub(r"([—-]+.*| +co\..*|corp\..*| +inc\..*| +ltd.*)", "", nm, 10000, re.IGNORECASE)
nm = re.sub(r"(计算机|技术|(技术|科技|网络)*有限公司|公司|有限|研发中心|中国|总部)$", "", nm, 10000, re.IGNORECASE)
if not nm or (len(nm)<5 and not regions.isName(nm[0:2])):return nm
nm = re.sub(
r"([—-]+.*| +co\..*|corp\..*| +inc\..*| +ltd.*)", "", nm, 10000, re.IGNORECASE
)
nm = re.sub(
r"(计算机|技术|(技术|科技|网络)*有限公司|公司|有限|研发中心|中国|总部)$",
"",
nm,
10000,
re.IGNORECASE,
)
if not nm or (len(nm) < 5 and not regions.isName(nm[0:2])):
return nm
tks = rag_tokenizer.tokenize(nm).split()
reg = [t for i,t in enumerate(tks) if regions.isName(t) and (t != "中国" or i > 0)]
reg = [t for i, t in enumerate(tks) if regions.isName(t) and (t != "中国" or i > 0)]
nm = ""
for t in tks:
if regions.isName(t) or t in CORP_TKS:continue
if re.match(r"[0-9a-zA-Z\\,.]+", t) and re.match(r".*[0-9a-zA-Z\,.]+$", nm):nm += " "
if regions.isName(t) or t in CORP_TKS:
continue
if re.match(r"[0-9a-zA-Z\\,.]+", t) and re.match(r".*[0-9a-zA-Z\,.]+$", nm):
nm += " "
nm += t
r = re.search(r"^([^a-z0-9 \(\)&]{2,})[a-z ]{4,}$", nm.strip())
if r:nm = r.group(1)
if r:
nm = r.group(1)
r = re.search(r"^([a-z ]{3,})[^a-z0-9 \(\)&]{2,}$", nm.strip())
if r:nm = r.group(1)
return nm.strip() + (("" if not reg else "(%s)"%reg[0]) if add_region else "")
if r:
nm = r.group(1)
return nm.strip() + (("" if not reg else "(%s)" % reg[0]) if add_region else "")
def rmNoise(n):
@ -67,33 +86,40 @@ def rmNoise(n):
n = re.sub(r"[,. &()]+", "", n)
return n
GOOD_CORP = set([corpNorm(rmNoise(c), False) for c in GOOD_CORP])
for c,v in CORP_TAG.items():
for c, v in CORP_TAG.items():
cc = corpNorm(rmNoise(c), False)
if not cc:
logging.debug(c)
CORP_TAG = {corpNorm(rmNoise(c), False):v for c,v in CORP_TAG.items()}
CORP_TAG = {corpNorm(rmNoise(c), False): v for c, v in CORP_TAG.items()}
def is_good(nm):
global GOOD_CORP
if nm.find("外派")>=0:return False
if nm.find("外派") >= 0:
return False
nm = rmNoise(nm)
nm = corpNorm(nm, False)
for n in GOOD_CORP:
if re.match(r"[0-9a-zA-Z]+$", n):
if n == nm: return True
elif nm.find(n)>=0:return True
if n == nm:
return True
elif nm.find(n) >= 0:
return True
return False
def corp_tag(nm):
global CORP_TAG
nm = rmNoise(nm)
nm = corpNorm(nm, False)
for n in CORP_TAG.keys():
if re.match(r"[0-9a-zA-Z., ]+$", n):
if n == nm: return CORP_TAG[n]
elif nm.find(n)>=0:
if len(n)<3 and len(nm)/len(n)>=2:continue
if n == nm:
return CORP_TAG[n]
elif nm.find(n) >= 0:
if len(n) < 3 and len(nm) / len(n) >= 2:
continue
return CORP_TAG[n]
return []

View File

@ -11,27 +11,31 @@
# limitations under the License.
#
TBL = {"94":"EMBA",
"6":"MBA",
"95":"MPA",
"92":"专升本",
"4":"专科",
"90":"中专",
"91":"中技",
"86":"初中",
"3":"博士",
"10":"博士后",
"1":"本科",
"2":"硕士",
"87":"职高",
"89":"高中"
TBL = {
"94": "EMBA",
"6": "MBA",
"95": "MPA",
"92": "专升本",
"4": "专科",
"90": "中专",
"91": "中技",
"86": "初中",
"3": "博士",
"10": "博士后",
"1": "本科",
"2": "硕士",
"87": "职高",
"89": "高中",
}
TBL_ = {v:k for k,v in TBL.items()}
TBL_ = {v: k for k, v in TBL.items()}
def get_name(id):
return TBL.get(str(id), "")
def get_id(nm):
if not nm:return ""
if not nm:
return ""
return TBL_.get(nm.upper().strip(), "")

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -16,8 +16,11 @@ import json
import re
import copy
import pandas as pd
current_file_path = os.path.dirname(os.path.abspath(__file__))
TBL = pd.read_csv(os.path.join(current_file_path, "res/schools.csv"), sep="\t", header=0).fillna("")
TBL = pd.read_csv(
os.path.join(current_file_path, "res/schools.csv"), sep="\t", header=0
).fillna("")
TBL["name_en"] = TBL["name_en"].map(lambda x: x.lower().strip())
GOOD_SCH = json.load(open(os.path.join(current_file_path, "res/good_sch.json"), "r"))
GOOD_SCH = set([re.sub(r"[,. &()]+", "", c) for c in GOOD_SCH])
@ -26,14 +29,15 @@ GOOD_SCH = set([re.sub(r"[,. &()]+", "", c) for c in GOOD_SCH])
def loadRank(fnm):
global TBL
TBL["rank"] = 1000000
with open(fnm, "r", encoding='utf-8') as f:
with open(fnm, "r", encoding="utf-8") as f:
while True:
l = f.readline()
if not l:break
l = l.strip("\n").split(",")
line = f.readline()
if not line:
break
line = line.strip("\n").split(",")
try:
nm,rk = l[0].strip(),int(l[1])
#assert len(TBL[((TBL.name_cn == nm) | (TBL.name_en == nm))]),f"<{nm}>"
nm, rk = line[0].strip(), int(line[1])
# assert len(TBL[((TBL.name_cn == nm) | (TBL.name_en == nm))]),f"<{nm}>"
TBL.loc[((TBL.name_cn == nm) | (TBL.name_en == nm)), "rank"] = rk
except Exception:
pass
@ -44,27 +48,35 @@ loadRank(os.path.join(current_file_path, "res/school.rank.csv"))
def split(txt):
tks = []
for t in re.sub(r"[ \t]+", " ",txt).split():
if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and \
re.match(r"[a-zA-Z]", t) and tks:
for t in re.sub(r"[ \t]+", " ", txt).split():
if (
tks
and re.match(r".*[a-zA-Z]$", tks[-1])
and re.match(r"[a-zA-Z]", t)
and tks
):
tks[-1] = tks[-1] + " " + t
else:tks.append(t)
else:
tks.append(t)
return tks
def select(nm):
global TBL
if not nm:return
if isinstance(nm, list):nm = str(nm[0])
if not nm:
return
if isinstance(nm, list):
nm = str(nm[0])
nm = split(nm)[0]
nm = str(nm).lower().strip()
nm = re.sub(r"[(][^()]+[)]", "", nm.lower())
nm = re.sub(r"(^the |[,.&();;·]+|^(英国|美国|瑞士))", "", nm)
nm = re.sub(r"大学.*学院", "大学", nm)
tbl = copy.deepcopy(TBL)
tbl["hit_alias"] = tbl["alias"].map(lambda x:nm in set(x.split("+")))
res = tbl[((tbl.name_cn == nm) | (tbl.name_en == nm) | (tbl.hit_alias == True))]
if res.empty:return
tbl["hit_alias"] = tbl["alias"].map(lambda x: nm in set(x.split("+")))
res = tbl[((tbl.name_cn == nm) | (tbl.name_en == nm) | tbl.hit_alias)]
if res.empty:
return
return json.loads(res.to_json(orient="records"))[0]
@ -74,4 +86,3 @@ def is_good(nm):
nm = re.sub(r"[(][^()]+[)]", "", nm.lower())
nm = re.sub(r"[''`‘’“”,. &();]+", "", nm)
return nm in GOOD_SCH

View File

@ -25,7 +25,8 @@ from xpinyin import Pinyin
from contextlib import contextmanager
class TimeoutException(Exception): pass
class TimeoutException(Exception):
pass
@contextmanager
@ -50,8 +51,10 @@ def rmHtmlTag(line):
def highest_degree(dg):
if not dg: return ""
if type(dg) == type(""): dg = [dg]
if not dg:
return ""
if isinstance(dg, str):
dg = [dg]
m = {"初中": 0, "高中": 1, "中专": 2, "大专": 3, "专升本": 4, "本科": 5, "硕士": 6, "博士": 7, "博士后": 8}
return sorted([(d, m.get(d, -1)) for d in dg], key=lambda x: x[1] * -1)[0][0]
@ -68,10 +71,12 @@ def forEdu(cv):
for ii, n in enumerate(sorted(cv["education_obj"], key=lambda x: x.get("start_time", "3"))):
e = {}
if n.get("end_time"):
if n["end_time"] > edu_end_dt: edu_end_dt = n["end_time"]
if n["end_time"] > edu_end_dt:
edu_end_dt = n["end_time"]
try:
dt = n["end_time"]
if re.match(r"[0-9]{9,}", dt): dt = turnTm2Dt(dt)
if re.match(r"[0-9]{9,}", dt):
dt = turnTm2Dt(dt)
y, m, d = getYMD(dt)
ed_dt.append(str(y))
e["end_dt_kwd"] = str(y)
@ -80,7 +85,8 @@ def forEdu(cv):
if n.get("start_time"):
try:
dt = n["start_time"]
if re.match(r"[0-9]{9,}", dt): dt = turnTm2Dt(dt)
if re.match(r"[0-9]{9,}", dt):
dt = turnTm2Dt(dt)
y, m, d = getYMD(dt)
st_dt.append(str(y))
e["start_dt_kwd"] = str(y)
@ -89,13 +95,20 @@ def forEdu(cv):
r = schools.select(n.get("school_name", ""))
if r:
if str(r.get("type", "")) == "1": fea.append("211")
if str(r.get("type", "")) == "2": fea.append("211")
if str(r.get("is_abroad", "")) == "1": fea.append("留学")
if str(r.get("is_double_first", "")) == "1": fea.append("双一流")
if str(r.get("is_985", "")) == "1": fea.append("985")
if str(r.get("is_world_known", "")) == "1": fea.append("海外知名")
if r.get("rank") and cv["school_rank_int"] > r["rank"]: cv["school_rank_int"] = r["rank"]
if str(r.get("type", "")) == "1":
fea.append("211")
if str(r.get("type", "")) == "2":
fea.append("211")
if str(r.get("is_abroad", "")) == "1":
fea.append("留学")
if str(r.get("is_double_first", "")) == "1":
fea.append("双一流")
if str(r.get("is_985", "")) == "1":
fea.append("985")
if str(r.get("is_world_known", "")) == "1":
fea.append("海外知名")
if r.get("rank") and cv["school_rank_int"] > r["rank"]:
cv["school_rank_int"] = r["rank"]
if n.get("school_name") and isinstance(n["school_name"], str):
sch.append(re.sub(r"(211|985|重点大学|[,&;-])", "", n["school_name"]))
@ -106,22 +119,25 @@ def forEdu(cv):
maj.append(n["discipline_name"])
e["major_kwd"] = n["discipline_name"]
if not n.get("degree") and "985" in fea and not first_fea: n["degree"] = "1"
if not n.get("degree") and "985" in fea and not first_fea:
n["degree"] = "1"
if n.get("degree"):
d = degrees.get_name(n["degree"])
if d: e["degree_kwd"] = d
if d == "本科" and ("专科" in deg or "专升本" in deg or "中专" in deg or "大专" in deg or re.search(r"(成人|自考|自学考试)",
n.get(
"school_name",
""))): d = "专升本"
if d: deg.append(d)
if d:
e["degree_kwd"] = d
if d == "本科" and ("专科" in deg or "专升本" in deg or "中专" in deg or "大专" in deg or re.search(r"(成人|自考|自学考试)", n.get("school_name",""))):
d = "专升本"
if d:
deg.append(d)
# for first degree
if not fdeg and d in ["中专", "专升本", "专科", "本科", "大专"]:
fdeg = [d]
if n.get("school_name"): fsch = [n["school_name"]]
if n.get("discipline_name"): fmaj = [n["discipline_name"]]
if n.get("school_name"):
fsch = [n["school_name"]]
if n.get("discipline_name"):
fmaj = [n["discipline_name"]]
first_fea = copy.deepcopy(fea)
edu_nst.append(e)
@ -140,16 +156,26 @@ def forEdu(cv):
else:
cv["sch_rank_kwd"].append("一般学校")
if edu_nst: cv["edu_nst"] = edu_nst
if fea: cv["edu_fea_kwd"] = list(set(fea))
if first_fea: cv["edu_first_fea_kwd"] = list(set(first_fea))
if maj: cv["major_kwd"] = maj
if fsch: cv["first_school_name_kwd"] = fsch
if fdeg: cv["first_degree_kwd"] = fdeg
if fmaj: cv["first_major_kwd"] = fmaj
if st_dt: cv["edu_start_kwd"] = st_dt
if ed_dt: cv["edu_end_kwd"] = ed_dt
if ed_dt: cv["edu_end_int"] = max([int(t) for t in ed_dt])
if edu_nst:
cv["edu_nst"] = edu_nst
if fea:
cv["edu_fea_kwd"] = list(set(fea))
if first_fea:
cv["edu_first_fea_kwd"] = list(set(first_fea))
if maj:
cv["major_kwd"] = maj
if fsch:
cv["first_school_name_kwd"] = fsch
if fdeg:
cv["first_degree_kwd"] = fdeg
if fmaj:
cv["first_major_kwd"] = fmaj
if st_dt:
cv["edu_start_kwd"] = st_dt
if ed_dt:
cv["edu_end_kwd"] = ed_dt
if ed_dt:
cv["edu_end_int"] = max([int(t) for t in ed_dt])
if deg:
if "本科" in deg and "专科" in deg:
deg.append("专升本")
@ -158,8 +184,10 @@ def forEdu(cv):
cv["highest_degree_kwd"] = highest_degree(deg)
if edu_end_dt:
try:
if re.match(r"[0-9]{9,}", edu_end_dt): edu_end_dt = turnTm2Dt(edu_end_dt)
if edu_end_dt.strip("\n") == "至今": edu_end_dt = cv.get("updated_at_dt", str(datetime.date.today()))
if re.match(r"[0-9]{9,}", edu_end_dt):
edu_end_dt = turnTm2Dt(edu_end_dt)
if edu_end_dt.strip("\n") == "至今":
edu_end_dt = cv.get("updated_at_dt", str(datetime.date.today()))
y, m, d = getYMD(edu_end_dt)
cv["work_exp_flt"] = min(int(str(datetime.date.today())[0:4]) - int(y), cv.get("work_exp_flt", 1000))
except Exception as e:
@ -171,7 +199,8 @@ def forEdu(cv):
or not cv.get("degree_kwd"):
for c in sch:
if schools.is_good(c):
if "tag_kwd" not in cv: cv["tag_kwd"] = []
if "tag_kwd" not in cv:
cv["tag_kwd"] = []
cv["tag_kwd"].append("好学校")
cv["tag_kwd"].append("好学历")
break
@ -180,28 +209,39 @@ def forEdu(cv):
any([d.lower() in ["硕士", "博士", "mba", "博士"] for d in cv.get("degree_kwd", [])])) \
or all([d.lower() in ["硕士", "博士", "mba", "博士后"] for d in cv.get("degree_kwd", [])]) \
or any([d in ["mba", "emba", "博士后"] for d in cv.get("degree_kwd", [])]):
if "tag_kwd" not in cv: cv["tag_kwd"] = []
if "好学历" not in cv["tag_kwd"]: cv["tag_kwd"].append("好学历")
if "tag_kwd" not in cv:
cv["tag_kwd"] = []
if "好学历" not in cv["tag_kwd"]:
cv["tag_kwd"].append("好学历")
if cv.get("major_kwd"): cv["major_tks"] = rag_tokenizer.tokenize(" ".join(maj))
if cv.get("school_name_kwd"): cv["school_name_tks"] = rag_tokenizer.tokenize(" ".join(sch))
if cv.get("first_school_name_kwd"): cv["first_school_name_tks"] = rag_tokenizer.tokenize(" ".join(fsch))
if cv.get("first_major_kwd"): cv["first_major_tks"] = rag_tokenizer.tokenize(" ".join(fmaj))
if cv.get("major_kwd"):
cv["major_tks"] = rag_tokenizer.tokenize(" ".join(maj))
if cv.get("school_name_kwd"):
cv["school_name_tks"] = rag_tokenizer.tokenize(" ".join(sch))
if cv.get("first_school_name_kwd"):
cv["first_school_name_tks"] = rag_tokenizer.tokenize(" ".join(fsch))
if cv.get("first_major_kwd"):
cv["first_major_tks"] = rag_tokenizer.tokenize(" ".join(fmaj))
return cv
def forProj(cv):
if not cv.get("project_obj"): return cv
if not cv.get("project_obj"):
return cv
pro_nms, desc = [], []
for i, n in enumerate(
sorted(cv.get("project_obj", []), key=lambda x: str(x.get("updated_at", "")) if type(x) == type({}) else "",
sorted(cv.get("project_obj", []), key=lambda x: str(x.get("updated_at", "")) if isinstance(x, dict) else "",
reverse=True)):
if n.get("name"): pro_nms.append(n["name"])
if n.get("describe"): desc.append(str(n["describe"]))
if n.get("responsibilities"): desc.append(str(n["responsibilities"]))
if n.get("achivement"): desc.append(str(n["achivement"]))
if n.get("name"):
pro_nms.append(n["name"])
if n.get("describe"):
desc.append(str(n["describe"]))
if n.get("responsibilities"):
desc.append(str(n["responsibilities"]))
if n.get("achivement"):
desc.append(str(n["achivement"]))
if pro_nms:
# cv["pro_nms_tks"] = rag_tokenizer.tokenize(" ".join(pro_nms))
@ -233,15 +273,16 @@ def forWork(cv):
work_st_tm = ""
corp_tags = []
for i, n in enumerate(
sorted(cv.get("work_obj", []), key=lambda x: str(x.get("start_time", "")) if type(x) == type({}) else "",
sorted(cv.get("work_obj", []), key=lambda x: str(x.get("start_time", "")) if isinstance(x, dict) else "",
reverse=True)):
if type(n) == type(""):
if isinstance(n, str):
try:
n = json_loads(n)
except Exception:
continue
if n.get("start_time") and (not work_st_tm or n["start_time"] < work_st_tm): work_st_tm = n["start_time"]
if n.get("start_time") and (not work_st_tm or n["start_time"] < work_st_tm):
work_st_tm = n["start_time"]
for c in flds:
if not n.get(c) or str(n[c]) == '0':
fea[c].append("")
@ -262,14 +303,18 @@ def forWork(cv):
fea[c].append(rmHtmlTag(str(n[c]).lower()))
y, m, d = getYMD(n.get("start_time"))
if not y or not m: continue
if not y or not m:
continue
st = "%s-%02d-%02d" % (y, int(m), int(d))
latest_job_tm = st
y, m, d = getYMD(n.get("end_time"))
if (not y or not m) and i > 0: continue
if not y or not m or int(y) > 2022: y, m, d = getYMD(str(n.get("updated_at", "")))
if not y or not m: continue
if (not y or not m) and i > 0:
continue
if not y or not m or int(y) > 2022:
y, m, d = getYMD(str(n.get("updated_at", "")))
if not y or not m:
continue
ed = "%s-%02d-%02d" % (y, int(m), int(d))
try:
@ -279,22 +324,28 @@ def forWork(cv):
if n.get("scale"):
r = re.search(r"^([0-9]+)", str(n["scale"]))
if r: scales.append(int(r.group(1)))
if r:
scales.append(int(r.group(1)))
if goodcorp:
if "tag_kwd" not in cv: cv["tag_kwd"] = []
if "tag_kwd" not in cv:
cv["tag_kwd"] = []
cv["tag_kwd"].append("好公司")
if goodcorp_:
if "tag_kwd" not in cv: cv["tag_kwd"] = []
if "tag_kwd" not in cv:
cv["tag_kwd"] = []
cv["tag_kwd"].append("好公司(曾)")
if corp_tags:
if "tag_kwd" not in cv: cv["tag_kwd"] = []
if "tag_kwd" not in cv:
cv["tag_kwd"] = []
cv["tag_kwd"].extend(corp_tags)
cv["corp_tag_kwd"] = [c for c in corp_tags if re.match(r"(综合|行业)", c)]
if latest_job_tm: cv["latest_job_dt"] = latest_job_tm
if fea["corporation_id"]: cv["corporation_id"] = fea["corporation_id"]
if latest_job_tm:
cv["latest_job_dt"] = latest_job_tm
if fea["corporation_id"]:
cv["corporation_id"] = fea["corporation_id"]
if fea["position_name"]:
cv["position_name_tks"] = rag_tokenizer.tokenize(fea["position_name"][0])
@ -317,18 +368,23 @@ def forWork(cv):
cv["responsibilities_ltks"] = rag_tokenizer.tokenize(fea["responsibilities"][0])
cv["resp_ltks"] = rag_tokenizer.tokenize(" ".join(fea["responsibilities"][1:]))
if fea["subordinates_count"]: fea["subordinates_count"] = [int(i) for i in fea["subordinates_count"] if
if fea["subordinates_count"]:
fea["subordinates_count"] = [int(i) for i in fea["subordinates_count"] if
re.match(r"[^0-9]+$", str(i))]
if fea["subordinates_count"]: cv["max_sub_cnt_int"] = np.max(fea["subordinates_count"])
if fea["subordinates_count"]:
cv["max_sub_cnt_int"] = np.max(fea["subordinates_count"])
if type(cv.get("corporation_id")) == type(1): cv["corporation_id"] = [str(cv["corporation_id"])]
if not cv.get("corporation_id"): cv["corporation_id"] = []
if isinstance(cv.get("corporation_id"), int):
cv["corporation_id"] = [str(cv["corporation_id"])]
if not cv.get("corporation_id"):
cv["corporation_id"] = []
for i in cv.get("corporation_id", []):
cv["baike_flt"] = max(corporations.baike(i), cv["baike_flt"] if "baike_flt" in cv else 0)
if work_st_tm:
try:
if re.match(r"[0-9]{9,}", work_st_tm): work_st_tm = turnTm2Dt(work_st_tm)
if re.match(r"[0-9]{9,}", work_st_tm):
work_st_tm = turnTm2Dt(work_st_tm)
y, m, d = getYMD(work_st_tm)
cv["work_exp_flt"] = min(int(str(datetime.date.today())[0:4]) - int(y), cv.get("work_exp_flt", 1000))
except Exception as e:
@ -339,28 +395,37 @@ def forWork(cv):
cv["dua_flt"] = np.mean(duas)
cv["cur_dua_int"] = duas[0]
cv["job_num_int"] = len(duas)
if scales: cv["scale_flt"] = np.max(scales)
if scales:
cv["scale_flt"] = np.max(scales)
return cv
def turnTm2Dt(b):
if not b: return
if not b:
return
b = str(b).strip()
if re.match(r"[0-9]{10,}", b): b = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(b[:10])))
if re.match(r"[0-9]{10,}", b):
b = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(b[:10])))
return b
def getYMD(b):
y, m, d = "", "", "01"
if not b: return (y, m, d)
if not b:
return (y, m, d)
b = turnTm2Dt(b)
if re.match(r"[0-9]{4}", b): y = int(b[:4])
if re.match(r"[0-9]{4}", b):
y = int(b[:4])
r = re.search(r"[0-9]{4}.?([0-9]{1,2})", b)
if r: m = r.group(1)
if r:
m = r.group(1)
r = re.search(r"[0-9]{4}.?[0-9]{,2}.?([0-9]{1,2})", b)
if r: d = r.group(1)
if not d or int(d) == 0 or int(d) > 31: d = "1"
if not m or int(m) > 12 or int(m) < 1: m = "1"
if r:
d = r.group(1)
if not d or int(d) == 0 or int(d) > 31:
d = "1"
if not m or int(m) > 12 or int(m) < 1:
m = "1"
return (y, m, d)
@ -369,7 +434,8 @@ def birth(cv):
cv["integerity_flt"] *= 0.9
return cv
y, m, d = getYMD(cv["birth"])
if not m or not y: return cv
if not m or not y:
return cv
b = "%s-%02d-%02d" % (y, int(m), int(d))
cv["birth_dt"] = b
cv["birthday_kwd"] = "%02d%02d" % (int(m), int(d))
@ -380,7 +446,8 @@ def birth(cv):
def parse(cv):
for k in cv.keys():
if cv[k] == '\\N': cv[k] = ''
if cv[k] == '\\N':
cv[k] = ''
# cv = cv.asDict()
tks_fld = ["address", "corporation_name", "discipline_name", "email", "expect_city_names",
"expect_industry_name", "expect_position_name", "industry_name", "industry_names", "name",
@ -402,9 +469,12 @@ def parse(cv):
rmkeys = []
for k in cv.keys():
if cv[k] is None: rmkeys.append(k)
if (type(cv[k]) == type([]) or type(cv[k]) == type("")) and len(cv[k]) == 0: rmkeys.append(k)
for k in rmkeys: del cv[k]
if cv[k] is None:
rmkeys.append(k)
if (isinstance(cv[k], list) or isinstance(cv[k], str)) and len(cv[k]) == 0:
rmkeys.append(k)
for k in rmkeys:
del cv[k]
integerity = 0.
flds_num = 0.
@ -414,7 +484,8 @@ def parse(cv):
flds_num += len(flds)
for f in flds:
v = str(cv.get(f, ""))
if len(v) > 0 and v != '0' and v != '[]': integerity += 1
if len(v) > 0 and v != '0' and v != '[]':
integerity += 1
hasValues(tks_fld)
hasValues(small_tks_fld)
@ -433,7 +504,8 @@ def parse(cv):
(r"[ \(\)人/·0-9-]+", ""),
(r".*(元|规模|于|=|北京|上海|至今|中国|工资|州|shanghai|强|餐饮|融资|职).*", "")]:
cv["corporation_type"] = re.sub(p, r, cv["corporation_type"], 1000, re.IGNORECASE)
if len(cv["corporation_type"]) < 2: del cv["corporation_type"]
if len(cv["corporation_type"]) < 2:
del cv["corporation_type"]
if cv.get("political_status"):
for p, r in [
@ -441,9 +513,11 @@ def parse(cv):
(r".*(无党派|公民).*", "群众"),
(r".*团员.*", "团员")]:
cv["political_status"] = re.sub(p, r, cv["political_status"])
if not re.search(r"[党团群]", cv["political_status"]): del cv["political_status"]
if not re.search(r"[党团群]", cv["political_status"]):
del cv["political_status"]
if cv.get("phone"): cv["phone"] = re.sub(r"^0*86([0-9]{11})", r"\1", re.sub(r"[^0-9]+", "", cv["phone"]))
if cv.get("phone"):
cv["phone"] = re.sub(r"^0*86([0-9]{11})", r"\1", re.sub(r"[^0-9]+", "", cv["phone"]))
keys = list(cv.keys())
for k in keys:
@ -454,9 +528,11 @@ def parse(cv):
cv[k] = [a for _, a in cv[k].items()]
nms = []
for n in cv[k]:
if type(n) != type({}) or "name" not in n or not n.get("name"): continue
if not isinstance(n, dict) or "name" not in n or not n.get("name"):
continue
n["name"] = re.sub(r"(442|\t )", "", n["name"]).strip().lower()
if not n["name"]: continue
if not n["name"]:
continue
nms.append(n["name"])
if nms:
t = k[:-4]
@ -469,15 +545,18 @@ def parse(cv):
# tokenize fields
if k in tks_fld:
cv[f"{k}_tks"] = rag_tokenizer.tokenize(cv[k])
if k in small_tks_fld: cv[f"{k}_sm_tks"] = rag_tokenizer.tokenize(cv[f"{k}_tks"])
if k in small_tks_fld:
cv[f"{k}_sm_tks"] = rag_tokenizer.tokenize(cv[f"{k}_tks"])
# keyword fields
if k in kwd_fld: cv[f"{k}_kwd"] = [n.lower()
if k in kwd_fld:
cv[f"{k}_kwd"] = [n.lower()
for n in re.split(r"[\t,;. ]",
re.sub(r"([^a-zA-Z])[ ]+([^a-zA-Z ])", r"\1\2", cv[k])
) if n]
if k in num_fld and cv.get(k): cv[f"{k}_int"] = cv[k]
if k in num_fld and cv.get(k):
cv[f"{k}_int"] = cv[k]
cv["email_kwd"] = cv.get("email_tks", "").replace(" ", "")
# for name field
@ -501,10 +580,12 @@ def parse(cv):
cv["name_py_pref0_tks"] = ""
cv["name_py_pref_tks"] = ""
for py in PY.get_pinyins(nm[:20], ''):
for i in range(2, len(py) + 1): cv["name_py_pref_tks"] += " " + py[:i]
for i in range(2, len(py) + 1):
cv["name_py_pref_tks"] += " " + py[:i]
for py in PY.get_pinyins(nm[:20], ' '):
py = py.split()
for i in range(1, len(py) + 1): cv["name_py_pref0_tks"] += " " + "".join(py[:i])
for i in range(1, len(py) + 1):
cv["name_py_pref0_tks"] += " " + "".join(py[:i])
cv["name_kwd"] = name
cv["name_pinyin_kwd"] = PY.get_pinyins(nm[:20], ' ')[:3]
@ -526,22 +607,30 @@ def parse(cv):
cv["updated_at_dt"] = cv["updated_at"].strftime('%Y-%m-%d %H:%M:%S')
else:
y, m, d = getYMD(str(cv.get("updated_at", "")))
if not y: y = "2012"
if not m: m = "01"
if not d: d = "01"
if not y:
y = "2012"
if not m:
m = "01"
if not d:
d = "01"
cv["updated_at_dt"] = "%s-%02d-%02d 00:00:00" % (y, int(m), int(d))
# long text tokenize
if cv.get("responsibilities"): cv["responsibilities_ltks"] = rag_tokenizer.tokenize(rmHtmlTag(cv["responsibilities"]))
if cv.get("responsibilities"):
cv["responsibilities_ltks"] = rag_tokenizer.tokenize(rmHtmlTag(cv["responsibilities"]))
# for yes or no field
fea = []
for f, y, n in is_fld:
if f not in cv: continue
if cv[f] == '': fea.append(y)
if cv[f] == '': fea.append(n)
if f not in cv:
continue
if cv[f] == '':
fea.append(y)
if cv[f] == '':
fea.append(n)
if fea: cv["tag_kwd"] = fea
if fea:
cv["tag_kwd"] = fea
cv = forEdu(cv)
cv = forProj(cv)
@ -550,9 +639,11 @@ def parse(cv):
cv["corp_proj_sch_deg_kwd"] = [c for c in cv.get("corp_tag_kwd", [])]
for i in range(len(cv["corp_proj_sch_deg_kwd"])):
for j in cv.get("sch_rank_kwd", []): cv["corp_proj_sch_deg_kwd"][i] += "+" + j
for j in cv.get("sch_rank_kwd", []):
cv["corp_proj_sch_deg_kwd"][i] += "+" + j
for i in range(len(cv["corp_proj_sch_deg_kwd"])):
if cv.get("highest_degree_kwd"): cv["corp_proj_sch_deg_kwd"][i] += "+" + cv["highest_degree_kwd"]
if cv.get("highest_degree_kwd"):
cv["corp_proj_sch_deg_kwd"][i] += "+" + cv["highest_degree_kwd"]
try:
if not cv.get("work_exp_flt") and cv.get("work_start_time"):
@ -565,17 +656,21 @@ def parse(cv):
cv["work_exp_flt"] = int(str(datetime.date.today())[0:4]) - int(y)
except Exception as e:
logging.exception("parse {} ==> {}".format(e, cv.get("work_start_time")))
if "work_exp_flt" not in cv and cv.get("work_experience", 0): cv["work_exp_flt"] = int(cv["work_experience"]) / 12.
if "work_exp_flt" not in cv and cv.get("work_experience", 0):
cv["work_exp_flt"] = int(cv["work_experience"]) / 12.
keys = list(cv.keys())
for k in keys:
if not re.search(r"_(fea|tks|nst|dt|int|flt|ltks|kwd|id)$", k): del cv[k]
if not re.search(r"_(fea|tks|nst|dt|int|flt|ltks|kwd|id)$", k):
del cv[k]
for k in cv.keys():
if not re.search("_(kwd|id)$", k) or type(cv[k]) != type([]): continue
if not re.search("_(kwd|id)$", k) or not isinstance(cv[k], list):
continue
cv[k] = list(set([re.sub("(市)$", "", str(n)) for n in cv[k] if n not in ['中国', '0']]))
keys = [k for k in cv.keys() if re.search(r"_feas*$", k)]
for k in keys:
if cv[k] <= 0: del cv[k]
if cv[k] <= 0:
del cv[k]
cv["tob_resume_id"] = str(cv["tob_resume_id"])
cv["id"] = cv["tob_resume_id"]
@ -592,5 +687,6 @@ def dealWithInt64(d):
if isinstance(d, list):
d = [dealWithInt64(t) for t in d]
if isinstance(d, np.integer): d = int(d)
if isinstance(d, np.integer):
d = int(d)
return d

View File

@ -51,6 +51,7 @@ class RAGFlowTxtParser:
dels = [d for d in dels if d]
dels = "|".join(dels)
secs = re.split(r"(%s)" % dels, txt)
for sec in secs: add_chunk(sec)
for sec in secs:
add_chunk(sec)
return [[c, ""] for c in cks]

View File

@ -18,7 +18,6 @@ from .recognizer import Recognizer
from .layout_recognizer import LayoutRecognizer
from .table_structure_recognizer import TableStructureRecognizer
def init_in_out(args):
from PIL import Image
import os
@ -47,7 +46,7 @@ def init_in_out(args):
try:
images.append(Image.open(fnm))
outputs.append(os.path.split(fnm)[-1])
except Exception as e:
except Exception:
traceback.print_exc()
if os.path.isdir(args.inputs):
@ -56,6 +55,16 @@ def init_in_out(args):
else:
images_and_outputs(args.inputs)
for i in range(len(outputs)): outputs[i] = os.path.join(args.output_dir, outputs[i])
for i in range(len(outputs)):
outputs[i] = os.path.join(args.output_dir, outputs[i])
return images, outputs
return images, outputs
__all__ = [
"OCR",
"Recognizer",
"LayoutRecognizer",
"TableStructureRecognizer",
"init_in_out",
]

View File

@ -42,7 +42,7 @@ class LayoutRecognizer(Recognizer):
get_project_base_directory(),
"rag/res/deepdoc")
super().__init__(self.labels, domain, model_dir)
except Exception as e:
except Exception:
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
local_dir_use_symlinks=False)
@ -77,7 +77,7 @@ class LayoutRecognizer(Recognizer):
"page_number": pn,
} for b in lts if float(b["score"]) >= 0.8 or b["type"] not in self.garbage_layouts]
lts = self.sort_Y_firstly(lts, np.mean(
[l["bottom"] - l["top"] for l in lts]) / 2)
[lt["bottom"] - lt["top"] for lt in lts]) / 2)
lts = self.layouts_cleanup(bxs, lts)
page_layout.append(lts)

View File

@ -19,7 +19,9 @@ from huggingface_hub import snapshot_download
from api.utils.file_utils import get_project_base_directory
from .operators import *
import math
import numpy as np
import cv2
import onnxruntime as ort
from .postprocess import build_post_process
@ -484,7 +486,7 @@ class OCR(object):
"rag/res/deepdoc")
self.text_detector = TextDetector(model_dir)
self.text_recognizer = TextRecognizer(model_dir)
except Exception as e:
except Exception:
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
local_dir_use_symlinks=False)

View File

@ -232,7 +232,7 @@ class LinearResize(object):
"""
assert len(self.target_size) == 2
assert self.target_size[0] > 0 and self.target_size[1] > 0
im_channel = im.shape[2]
_im_channel = im.shape[2]
im_scale_y, im_scale_x = self.generate_scale(im)
im = cv2.resize(
im,
@ -255,7 +255,7 @@ class LinearResize(object):
im_scale_y: the resize ratio of Y
"""
origin_shape = im.shape[:2]
im_c = im.shape[2]
_im_c = im.shape[2]
if self.keep_ratio:
im_size_min = np.min(origin_shape)
im_size_max = np.max(origin_shape)
@ -581,7 +581,7 @@ class SRResize(object):
return data
images_HR = data["image_hr"]
label_strs = data["label"]
_label_strs = data["label"]
transform = ResizeNormalize((imgW, imgH))
images_HR = transform(images_HR)
data["img_hr"] = images_HR

View File

@ -121,7 +121,7 @@ class DBPostProcess(object):
outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
cv2.CHAIN_APPROX_SIMPLE)
if len(outs) == 3:
img, contours, _ = outs[0], outs[1], outs[2]
_img, contours, _ = outs[0], outs[1], outs[2]
elif len(outs) == 2:
contours, _ = outs[0], outs[1]

View File

@ -13,15 +13,18 @@
import logging
import os
import math
import numpy as np
import cv2
from copy import deepcopy
import onnxruntime as ort
from huggingface_hub import snapshot_download
from api.utils.file_utils import get_project_base_directory
from .operators import *
class Recognizer(object):
def __init__(self, label_list, task_name, model_dir=None):
"""
@ -277,7 +280,8 @@ class Recognizer(object):
return
min_dis, min_i = 1000000, None
for i,b in enumerate(boxes):
if box.get("layoutno", "0") != b.get("layoutno", "0"): continue
if box.get("layoutno", "0") != b.get("layoutno", "0"):
continue
dis = min(abs(box["x0"] - b["x0"]), abs(box["x1"] - b["x1"]), abs(box["x0"]+box["x1"] - b["x1"] - b["x0"])/2)
if dis < min_dis:
min_i = i
@ -402,7 +406,8 @@ class Recognizer(object):
scores = np.max(boxes[:, 4:], axis=1)
boxes = boxes[scores > thr, :]
scores = scores[scores > thr]
if len(boxes) == 0: return []
if len(boxes) == 0:
return []
# Get the class with the highest confidence
class_ids = np.argmax(boxes[:, 4:], axis=1)
@ -432,7 +437,8 @@ class Recognizer(object):
for i in range(len(image_list)):
if not isinstance(image_list[i], np.ndarray):
imgs.append(np.array(image_list[i]))
else: imgs.append(image_list[i])
else:
imgs.append(image_list[i])
batch_loop_cnt = math.ceil(float(len(imgs)) / batch_size)
for i in range(batch_loop_cnt):

View File

@ -88,7 +88,8 @@ class CommunityReportsExtractor:
("findings", list),
("rating", float),
("rating_explanation", str),
]): continue
]):
continue
response["weight"] = weight
response["entities"] = ents
except Exception as e:
@ -100,7 +101,8 @@ class CommunityReportsExtractor:
res_str.append(self._get_text_output(response))
res_dict.append(response)
over += 1
if callback: callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}")
if callback:
callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}")
return CommunityReportsResult(
structured_output=res_dict,

View File

@ -8,6 +8,7 @@ Reference:
from typing import Any
import numpy as np
import networkx as nx
from dataclasses import dataclass
from graphrag.leiden import stable_largest_connected_component

View File

@ -129,9 +129,11 @@ class GraphExtractor:
source_doc_map[doc_index] = text
all_records[doc_index] = result
total_token_count += token_count
if callback: callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}")
if callback:
callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}")
except Exception as e:
if callback: callback(msg="Knowledge graph extraction error:{}".format(str(e)))
if callback:
callback(msg="Knowledge graph extraction error:{}".format(str(e)))
logging.exception("error extracting graph")
self._on_error(
e,
@ -164,7 +166,8 @@ class GraphExtractor:
text = perform_variable_replacements(self._extraction_prompt, variables=variables)
gen_conf = {"temperature": 0.3}
response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
if response.find("**ERROR**") >= 0: raise Exception(response)
if response.find("**ERROR**") >= 0:
raise Exception(response)
token_count = num_tokens_from_string(text + response)
results = response or ""
@ -175,7 +178,8 @@ class GraphExtractor:
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
history.append({"role": "user", "content": text})
response = self._llm.chat("", history, gen_conf)
if response.find("**ERROR**") >=0: raise Exception(response)
if response.find("**ERROR**") >=0:
raise Exception(response)
results += response or ""
# if this is the final glean, don't bother updating the continuation flag

View File

@ -134,7 +134,8 @@ def build_knowledge_graph_chunks(tenant_id: str, chunks: list[str], callback, en
callback(0.75, "Extracting mind graph.")
mindmap = MindMapExtractor(llm_bdl)
mg = mindmap(_chunks).output
if not len(mg.keys()): return chunks
if not len(mg.keys()):
return chunks
logging.debug(json.dumps(mg, ensure_ascii=False, indent=2))
chunks.append(

View File

@ -78,7 +78,8 @@ def _compute_leiden_communities(
) -> dict[int, dict[str, int]]:
"""Return Leiden root communities."""
results: dict[int, dict[str, int]] = {}
if is_empty(graph): return results
if is_empty(graph):
return results
if use_lcc:
graph = stable_largest_connected_component(graph)
@ -100,7 +101,8 @@ def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]:
logging.debug(
"Running leiden with max_cluster_size=%s, lcc=%s", max_cluster_size, use_lcc
)
if not graph.nodes(): return {}
if not graph.nodes():
return {}
node_id_to_community_map = _compute_leiden_communities(
graph=graph,
@ -125,9 +127,11 @@ def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]:
result[community_id]["nodes"].append(node_id)
result[community_id]["weight"] += graph.nodes[node_id].get("rank", 0) * graph.nodes[node_id].get("weight", 1)
weights = [comm["weight"] for _, comm in result.items()]
if not weights:continue
if not weights:
continue
max_weight = max(weights)
for _, comm in result.items(): comm["weight"] /= max_weight
for _, comm in result.items():
comm["weight"] /= max_weight
return results_by_level

View File

@ -1 +1,5 @@
from .ragflow_chat import *
from .ragflow_chat import RAGFlowChat
__all__ = [
"RAGFlowChat"
]

View File

@ -2,7 +2,6 @@ import logging
import requests
from bridge.context import ContextType # Import Context, ContextType
from bridge.reply import Reply, ReplyType # Import Reply, ReplyType
from bridge import *
from plugins import Plugin, register # Import Plugin and register
from plugins.event import Event, EventContext, EventAction # Import event-related classes

View File

@ -94,7 +94,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
callback(0.1, "Start to parse.")
txt = get_text(filename, binary)
sections = txt.split("\n")
sections = [(l, "") for l in sections if l]
sections = [(line, "") for line in sections if line]
remove_contents_table(sections, eng=is_english(
random_choices([t for t, _ in sections], k=200)))
callback(0.8, "Finish parsing.")
@ -102,7 +102,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
sections = HtmlParser()(filename, binary)
sections = [(l, "") for l in sections if l]
sections = [(line, "") for line in sections if line]
remove_contents_table(sections, eng=is_english(
random_choices([t for t, _ in sections], k=200)))
callback(0.8, "Finish parsing.")
@ -112,7 +112,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
binary = BytesIO(binary)
doc_parsed = parser.from_buffer(binary)
sections = doc_parsed['content'].split('\n')
sections = [(l, "") for l in sections if l]
sections = [(line, "") for line in sections if line]
remove_contents_table(sections, eng=is_english(
random_choices([t for t, _ in sections], k=200)))
callback(0.8, "Finish parsing.")

View File

@ -75,7 +75,7 @@ def chunk(
_add_content(msg, msg.get_content_type())
sections = TxtParser.parser_txt("\n".join(text_txt)) + [
(l, "") for l in HtmlParser.parser_txt("\n".join(html_txt)) if l
(line, "") for line in HtmlParser.parser_txt("\n".join(html_txt)) if line
]
st = timer()

View File

@ -18,7 +18,8 @@ def chunk(filename, binary, tenant_id, from_page=0, to_page=100000,
chunks = build_knowledge_graph_chunks(tenant_id, sections, callback,
parser_config.get("entity_types", ["organization", "person", "location", "event", "time"])
)
for c in chunks: c["docnm_kwd"] = filename
for c in chunks:
c["docnm_kwd"] = filename
doc = {
"docnm_kwd": filename,

View File

@ -48,7 +48,7 @@ class Docx(DocxParser):
continue
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
pn += 1
return [l for l in lines if l]
return [line for line in lines if line]
def __call__(self, filename, binary=None, from_page=0, to_page=100000):
self.doc = Document(
@ -60,7 +60,8 @@ class Docx(DocxParser):
if pn > to_page:
break
question_level, p_text = docx_question_level(p, bull)
if not p_text.strip("\n"):continue
if not p_text.strip("\n"):
continue
lines.append((question_level, p_text))
for run in p.runs:
@ -78,19 +79,21 @@ class Docx(DocxParser):
if lines[e][0] <= lines[s][0]:
break
e += 1
if e - s == 1 and visit[s]: continue
if e - s == 1 and visit[s]:
continue
sec = []
next_level = lines[s][0] + 1
while not sec and next_level < 22:
for i in range(s+1, e):
if lines[i][0] != next_level: continue
if lines[i][0] != next_level:
continue
sec.append(lines[i][1])
visit[i] = True
next_level += 1
sec.insert(0, lines[s][1])
sections.append("\n".join(sec))
return [l for l in sections if l]
return [s for s in sections if s]
def __str__(self) -> str:
return f'''
@ -168,13 +171,13 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
callback(0.1, "Start to parse.")
txt = get_text(filename, binary)
sections = txt.split("\n")
sections = [l for l in sections if l]
sections = [s for s in sections if s]
callback(0.8, "Finish parsing.")
elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
sections = HtmlParser()(filename, binary)
sections = [l for l in sections if l]
sections = [s for s in sections if s]
callback(0.8, "Finish parsing.")
elif re.search(r"\.doc$", filename, re.IGNORECASE):
@ -182,7 +185,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
binary = BytesIO(binary)
doc_parsed = parser.from_buffer(binary)
sections = doc_parsed['content'].split('\n')
sections = [l for l in sections if l]
sections = [s for s in sections if s]
callback(0.8, "Finish parsing.")
else:

View File

@ -190,7 +190,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
sections, tbls = pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback)
if sections and len(sections[0]) < 3:
sections = [(t, l, [[0] * 5]) for t, l in sections]
sections = [(t, lvl, [[0] * 5]) for t, lvl in sections]
# set pivot using the most frequent type of title,
# then merge between 2 pivot
if len(sections) > 0 and len(pdf_parser.outlines) / len(sections) > 0.1:
@ -211,7 +211,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
else:
bull = bullets_category([txt for txt, _, _ in sections])
most_level, levels = title_frequency(
bull, [(txt, l) for txt, l, poss in sections])
bull, [(txt, lvl) for txt, lvl, _ in sections])
assert len(sections) == len(levels)
sec_ids = []
@ -225,7 +225,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
sections = [(txt, sec_ids[i], poss)
for i, (txt, _, poss) in enumerate(sections)]
for (img, rows), poss in tbls:
if not rows: continue
if not rows:
continue
sections.append((rows if isinstance(rows, str) else rows[0], -1,
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))

View File

@ -54,7 +54,8 @@ class Pdf(PdfParser):
sections = [(b["text"], self.get_position(b, zoomin))
for i, b in enumerate(self.boxes)]
for (img, rows), poss in tbls:
if not rows:continue
if not rows:
continue
sections.append((rows if isinstance(rows, str) else rows[0],
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
return [(txt, "") for txt, _ in sorted(sections, key=lambda x: (
@ -109,7 +110,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
binary = BytesIO(binary)
doc_parsed = parser.from_buffer(binary)
sections = doc_parsed['content'].split('\n')
sections = [l for l in sections if l]
sections = [s for s in sections if s]
callback(0.8, "Finish parsing.")
else:

View File

@ -171,7 +171,7 @@ class Pdf(PdfParser):
tbl_bottom = tbls[tbl_index][1][0][4]
tbl_tag = "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##" \
.format(tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom)
tbl_text = ''.join(tbls[tbl_index][0][1])
_tbl_text = ''.join(tbls[tbl_index][0][1])
return tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag,
@ -325,9 +325,11 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
txt = get_text(filename, binary)
lines = txt.split("\n")
comma, tab = 0, 0
for l in lines:
if len(l.split(",")) == 2: comma += 1
if len(l.split("\t")) == 2: tab += 1
for line in lines:
if len(line.split(",")) == 2:
comma += 1
if len(line.split("\t")) == 2:
tab += 1
delimiter = "\t" if tab >= comma else ","
fails = []
@ -336,18 +338,21 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
while i < len(lines):
arr = lines[i].split(delimiter)
if len(arr) != 2:
if question: answer += "\n" + lines[i]
if question:
answer += "\n" + lines[i]
else:
fails.append(str(i+1))
elif len(arr) == 2:
if question and answer: res.append(beAdoc(deepcopy(doc), question, answer, eng))
if question and answer:
res.append(beAdoc(deepcopy(doc), question, answer, eng))
question, answer = arr
i += 1
if len(res) % 999 == 0:
callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
if question: res.append(beAdoc(deepcopy(doc), question, answer, eng))
if question:
res.append(beAdoc(deepcopy(doc), question, answer, eng))
callback(0.6, ("Extract Q&A: {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
@ -367,19 +372,18 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
callback(0.1, "Start to parse.")
txt = get_text(filename, binary)
lines = txt.split("\n")
last_question, last_answer = "", ""
_last_question, last_answer = "", ""
question_stack, level_stack = [], []
code_block = False
level_index = [-1] * 7
for index, l in enumerate(lines):
if l.strip().startswith('```'):
for index, line in enumerate(lines):
if line.strip().startswith('```'):
code_block = not code_block
question_level, question = 0, ''
if not code_block:
question_level, question = mdQuestionLevel(l)
question_level, question = mdQuestionLevel(line)
if not question_level or question_level > 6: # not a question
last_answer = f'{last_answer}\n{l}'
last_answer = f'{last_answer}\n{line}'
else: # is a question
if last_answer.strip():
sum_question = '\n'.join(question_stack)

View File

@ -41,14 +41,16 @@ class Excel(ExcelParser):
for sheetname in wb.sheetnames:
ws = wb[sheetname]
rows = list(ws.rows)
if not rows:continue
if not rows:
continue
headers = [cell.value for cell in rows[0]]
missed = set([i for i, h in enumerate(headers) if h is None])
headers = [
cell.value for i,
cell in enumerate(
rows[0]) if i not in missed]
if not headers:continue
if not headers:
continue
data = []
for i, r in enumerate(rows[1:]):
rn += 1
@ -88,7 +90,6 @@ def trans_bool(s):
def column_data_type(arr):
arr = list(arr)
uni = len(set([a for a in arr if a is not None]))
counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0}
trans = {t: f for f, t in
[(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]}
@ -157,7 +158,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000,
continue
if i >= to_page:
break
row = [l for l in line.split(kwargs.get("delimiter", "\t"))]
row = [field for field in line.split(kwargs.get("delimiter", "\t"))]
if len(row) != len(headers):
fails.append(str(i))
continue

View File

@ -13,12 +13,124 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from .embedding_model import *
from .chat_model import *
from .cv_model import *
from .rerank_model import *
from .sequence2txt_model import *
from .tts_model import *
from .embedding_model import (
OllamaEmbed,
LocalAIEmbed,
OpenAIEmbed,
AzureEmbed,
XinferenceEmbed,
QWenEmbed,
ZhipuEmbed,
FastEmbed,
YoudaoEmbed,
BaiChuanEmbed,
JinaEmbed,
DefaultEmbedding,
MistralEmbed,
BedrockEmbed,
GeminiEmbed,
NvidiaEmbed,
LmStudioEmbed,
OpenAI_APIEmbed,
CoHereEmbed,
TogetherAIEmbed,
PerfXCloudEmbed,
UpstageEmbed,
SILICONFLOWEmbed,
ReplicateEmbed,
BaiduYiyanEmbed,
VoyageEmbed,
HuggingFaceEmbed,
VolcEngineEmbed,
)
from .chat_model import (
GptTurbo,
AzureChat,
ZhipuChat,
QWenChat,
OllamaChat,
LocalAIChat,
XinferenceChat,
MoonshotChat,
DeepSeekChat,
VolcEngineChat,
BaiChuanChat,
MiniMaxChat,
MistralChat,
GeminiChat,
BedrockChat,
GroqChat,
OpenRouterChat,
StepFunChat,
NvidiaChat,
LmStudioChat,
OpenAI_APIChat,
CoHereChat,
LeptonAIChat,
TogetherAIChat,
PerfXCloudChat,
UpstageChat,
NovitaAIChat,
SILICONFLOWChat,
YiChat,
ReplicateChat,
HunyuanChat,
SparkChat,
BaiduYiyanChat,
AnthropicChat,
GoogleChat,
HuggingFaceChat,
)
from .cv_model import (
GptV4,
AzureGptV4,
OllamaCV,
XinferenceCV,
QWenCV,
Zhipu4V,
LocalCV,
GeminiCV,
OpenRouterCV,
LocalAICV,
NvidiaCV,
LmStudioCV,
StepFunCV,
OpenAI_APICV,
TogetherAICV,
YiCV,
HunyuanCV,
)
from .rerank_model import (
LocalAIRerank,
DefaultRerank,
JinaRerank,
YoudaoRerank,
XInferenceRerank,
NvidiaRerank,
LmStudioRerank,
OpenAI_APIRerank,
CoHereRerank,
TogetherAIRerank,
SILICONFLOWRerank,
BaiduYiyanRerank,
VoyageRerank,
QWenRerank,
)
from .sequence2txt_model import (
GPTSeq2txt,
QWenSeq2txt,
AzureSeq2txt,
XinferenceSeq2txt,
TencentCloudSeq2txt,
)
from .tts_model import (
FishAudioTTS,
QwenTTS,
OpenAITTS,
SparkTTS,
XinferenceTTS,
)
EmbeddingModel = {
"Ollama": OllamaEmbed,
@ -48,7 +160,7 @@ EmbeddingModel = {
"BaiduYiyan": BaiduYiyanEmbed,
"Voyage AI": VoyageEmbed,
"HuggingFace": HuggingFaceEmbed,
"VolcEngine":VolcEngineEmbed,
"VolcEngine": VolcEngineEmbed,
}
CvModel = {
@ -68,7 +180,7 @@ CvModel = {
"OpenAI-API-Compatible": OpenAI_APICV,
"TogetherAI": TogetherAICV,
"01.AI": YiCV,
"Tencent Hunyuan": HunyuanCV
"Tencent Hunyuan": HunyuanCV,
}
ChatModel = {
@ -111,7 +223,7 @@ ChatModel = {
}
RerankModel = {
"LocalAI":LocalAIRerank,
"LocalAI": LocalAIRerank,
"BAAI": DefaultRerank,
"Jina": JinaRerank,
"Youdao": YoudaoRerank,
@ -132,7 +244,7 @@ Seq2txtModel = {
"Tongyi-Qianwen": QWenSeq2txt,
"Azure-OpenAI": AzureSeq2txt,
"Xinference": XinferenceSeq2txt,
"Tencent Cloud": TencentCloudSeq2txt
"Tencent Cloud": TencentCloudSeq2txt,
}
TTSModel = {

View File

@ -69,7 +69,8 @@ class Base(ABC):
stream=True,
**gen_conf)
for resp in response:
if not resp.choices: continue
if not resp.choices:
continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
ans += resp.choices[0].delta.content
@ -81,7 +82,8 @@ class Base(ABC):
)
elif isinstance(resp.usage, dict):
total_tokens = resp.usage.get("total_tokens", total_tokens)
else: total_tokens = resp.usage.total_tokens
else:
total_tokens = resp.usage.total_tokens
if resp.choices[0].finish_reason == "length":
if is_chinese(ans):
@ -98,13 +100,15 @@ class Base(ABC):
class GptTurbo(Base):
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
if not base_url: base_url = "https://api.openai.com/v1"
if not base_url:
base_url = "https://api.openai.com/v1"
super().__init__(key, model_name, base_url)
class MoonshotChat(Base):
def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
if not base_url: base_url = "https://api.moonshot.cn/v1"
if not base_url:
base_url = "https://api.moonshot.cn/v1"
super().__init__(key, model_name, base_url)
@ -128,7 +132,8 @@ class HuggingFaceChat(Base):
class DeepSeekChat(Base):
def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"):
if not base_url: base_url = "https://api.deepseek.com/v1"
if not base_url:
base_url = "https://api.deepseek.com/v1"
super().__init__(key, model_name, base_url)
@ -202,7 +207,8 @@ class BaiChuanChat(Base):
stream=True,
**self._format_params(gen_conf))
for resp in response:
if not resp.choices: continue
if not resp.choices:
continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
ans += resp.choices[0].delta.content
@ -313,8 +319,10 @@ class ZhipuChat(Base):
if system:
history.insert(0, {"role": "system", "content": system})
try:
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
if "presence_penalty" in gen_conf:
del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
del gen_conf["frequency_penalty"]
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
@ -333,8 +341,10 @@ class ZhipuChat(Base):
def chat_streamly(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
if "presence_penalty" in gen_conf:
del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
del gen_conf["frequency_penalty"]
ans = ""
tk_count = 0
try:
@ -345,7 +355,8 @@ class ZhipuChat(Base):
**gen_conf
)
for resp in response:
if not resp.choices[0].delta.content: continue
if not resp.choices[0].delta.content:
continue
delta = resp.choices[0].delta.content
ans += delta
if resp.choices[0].finish_reason == "length":
@ -354,7 +365,8 @@ class ZhipuChat(Base):
else:
ans += LENGTH_NOTIFICATION_EN
tk_count = resp.usage.total_tokens
if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
if resp.choices[0].finish_reason == "stop":
tk_count = resp.usage.total_tokens
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
@ -372,11 +384,16 @@ class OllamaChat(Base):
history.insert(0, {"role": "system", "content": system})
try:
options = {}
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf: options["top_p"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf:
options["top_p"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
response = self.client.chat(
model=self.model_name,
messages=history,
@ -392,11 +409,16 @@ class OllamaChat(Base):
if system:
history.insert(0, {"role": "system", "content": system})
options = {}
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf: options["top_p"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf:
options["top_p"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
ans = ""
try:
response = self.client.chat(
@ -636,7 +658,8 @@ class MistralChat(Base):
messages=history,
**gen_conf)
for resp in response:
if not resp.choices or not resp.choices[0].delta.content: continue
if not resp.choices or not resp.choices[0].delta.content:
continue
ans += resp.choices[0].delta.content
total_tokens += 1
if resp.choices[0].finish_reason == "length":
@ -1196,7 +1219,8 @@ class SparkChat(Base):
assert model_name in model2version or model_name in version2model, f"The given model name is not supported yet. Support: {list(model2version.keys())}"
if model_name in model2version:
model_version = model2version[model_name]
else: model_version = model_name
else:
model_version = model_name
super().__init__(key, model_version, base_url)
@ -1281,8 +1305,10 @@ class AnthropicChat(Base):
self.system = system
if "max_tokens" not in gen_conf:
gen_conf["max_tokens"] = 4096
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
if "presence_penalty" in gen_conf:
del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
del gen_conf["frequency_penalty"]
ans = ""
try:
@ -1312,8 +1338,10 @@ class AnthropicChat(Base):
self.system = system
if "max_tokens" not in gen_conf:
gen_conf["max_tokens"] = 4096
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
if "presence_penalty" in gen_conf:
del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
del gen_conf["frequency_penalty"]
ans = ""
total_tokens = 0

View File

@ -25,6 +25,7 @@ import base64
from io import BytesIO
import json
import requests
from transformers import GenerationConfig
from rag.nlp import is_english
from api.utils import get_uuid
@ -77,14 +78,16 @@ class Base(ABC):
stream=True
)
for resp in response:
if not resp.choices[0].delta.content: continue
if not resp.choices[0].delta.content:
continue
delta = resp.choices[0].delta.content
ans += delta
if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
tk_count = resp.usage.total_tokens
if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
if resp.choices[0].finish_reason == "stop":
tk_count = resp.usage.total_tokens
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
@ -99,7 +102,7 @@ class Base(ABC):
buffered = BytesIO()
try:
image.save(buffered, format="JPEG")
except Exception as e:
except Exception:
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
@ -139,7 +142,8 @@ class Base(ABC):
class GptV4(Base):
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
if not base_url: base_url="https://api.openai.com/v1"
if not base_url:
base_url="https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang
@ -149,7 +153,8 @@ class GptV4(Base):
prompt = self.prompt(b64)
for i in range(len(prompt)):
for c in prompt[i]["content"]:
if "text" in c: c["type"] = "text"
if "text" in c:
c["type"] = "text"
res = self.client.chat.completions.create(
model=self.model_name,
@ -171,7 +176,8 @@ class AzureGptV4(Base):
prompt = self.prompt(b64)
for i in range(len(prompt)):
for c in prompt[i]["content"]:
if "text" in c: c["type"] = "text"
if "text" in c:
c["type"] = "text"
res = self.client.chat.completions.create(
model=self.model_name,
@ -344,14 +350,16 @@ class Zhipu4V(Base):
stream=True
)
for resp in response:
if not resp.choices[0].delta.content: continue
if not resp.choices[0].delta.content:
continue
delta = resp.choices[0].delta.content
ans += delta
if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
tk_count = resp.usage.total_tokens
if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
if resp.choices[0].finish_reason == "stop":
tk_count = resp.usage.total_tokens
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
@ -389,11 +397,16 @@ class OllamaCV(Base):
if his["role"] == "user":
his["images"] = [image]
options = {}
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf:
options["top_k"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
response = self.client.chat(
model=self.model_name,
messages=history,
@ -414,11 +427,16 @@ class OllamaCV(Base):
if his["role"] == "user":
his["images"] = [image]
options = {}
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf:
options["top_k"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
ans = ""
try:
response = self.client.chat(
@ -469,7 +487,7 @@ class XinferenceCV(Base):
class GeminiCV(Base):
def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
from google.generativeai import client, GenerativeModel, GenerationConfig
from google.generativeai import client, GenerativeModel
client.configure(api_key=key)
_client = client.get_default_generative_client()
self.model_name = model_name
@ -503,7 +521,7 @@ class GeminiCV(Base):
if his["role"] == "user":
his["parts"] = [his["content"]]
his.pop("content")
history[-1]["parts"].append(f"data:image/jpeg;base64," + image)
history[-1]["parts"].append("data:image/jpeg;base64," + image)
response = self.model.generate_content(history, generation_config=GenerationConfig(
max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3),
@ -519,7 +537,6 @@ class GeminiCV(Base):
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
ans = ""
tk_count = 0
try:
for his in history:
if his["role"] == "assistant":
@ -529,14 +546,15 @@ class GeminiCV(Base):
if his["role"] == "user":
his["parts"] = [his["content"]]
his.pop("content")
history[-1]["parts"].append(f"data:image/jpeg;base64," + image)
history[-1]["parts"].append("data:image/jpeg;base64," + image)
response = self.model.generate_content(history, generation_config=GenerationConfig(
max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7)), stream=True)
for resp in response:
if not resp.text: continue
if not resp.text:
continue
ans += resp.text
yield ans
except Exception as e:
@ -632,7 +650,8 @@ class NvidiaCV(Base):
class StepFunCV(GptV4):
def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"):
if not base_url: base_url="https://api.stepfun.com/v1"
if not base_url:
base_url="https://api.stepfun.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang

View File

@ -15,12 +15,9 @@
#
import requests
from openai.lib.azure import AzureOpenAI
from zhipuai import ZhipuAI
import io
from abc import ABC
from ollama import Client
from openai import OpenAI
import os
import json
from rag.utils import num_tokens_from_string
import base64
@ -49,7 +46,8 @@ class Base(ABC):
class GPTSeq2txt(Base):
def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1"):
if not base_url: base_url = "https://api.openai.com/v1"
if not base_url:
base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name

View File

@ -16,7 +16,6 @@
import _thread as thread
import base64
import datetime
import hashlib
import hmac
import json
@ -175,7 +174,8 @@ class QwenTTS(Base):
class OpenAITTS(Base):
def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
if not base_url: base_url = "https://api.openai.com/v1"
if not base_url:
base_url = "https://api.openai.com/v1"
self.api_key = key
self.model_name = model_name
self.base_url = base_url

View File

@ -222,7 +222,8 @@ def bullets_category(sections):
def is_english(texts):
eng = 0
if not texts: return False
if not texts:
return False
for t in texts:
if re.match(r"[ `a-zA-Z.,':;/\"?<>!\(\)-]", t.strip()):
eng += 1
@ -250,7 +251,8 @@ def tokenize_chunks(chunks, doc, eng, pdf_parser=None):
res = []
# wrap up as es documents
for ck in chunks:
if len(ck.strip()) == 0:continue
if len(ck.strip()) == 0:
continue
logging.debug("-- {}".format(ck))
d = copy.deepcopy(doc)
if pdf_parser:
@ -269,7 +271,8 @@ def tokenize_chunks_docx(chunks, doc, eng, images):
res = []
# wrap up as es documents
for ck, image in zip(chunks, images):
if len(ck.strip()) == 0:continue
if len(ck.strip()) == 0:
continue
logging.debug("-- {}".format(ck))
d = copy.deepcopy(doc)
d["image"] = image
@ -288,8 +291,10 @@ def tokenize_table(tbls, doc, eng, batch_size=10):
d = copy.deepcopy(doc)
tokenize(d, rows, eng)
d["content_with_weight"] = rows
if img: d["image"] = img
if poss: add_positions(d, poss)
if img:
d["image"] = img
if poss:
add_positions(d, poss)
res.append(d)
continue
de = "; " if eng else " "
@ -387,9 +392,9 @@ def title_frequency(bull, sections):
if re.search(r"(title|head)", layout) and not not_title(txt.split("@")[0]):
levels[i] = bullets_size
most_level = bullets_size+1
for l, c in sorted(Counter(levels).items(), key=lambda x:x[1]*-1):
if l <= bullets_size:
most_level = l
for level, c in sorted(Counter(levels).items(), key=lambda x:x[1]*-1):
if level <= bullets_size:
most_level = level
break
return most_level, levels
@ -504,7 +509,8 @@ def naive_merge(sections, chunk_token_num=128, delimiter="\n。"):
def add_chunk(t, pos):
nonlocal cks, tk_nums, delimiter
tnum = num_tokens_from_string(t)
if not pos: pos = ""
if not pos:
pos = ""
if tnum < 8:
pos = ""
# Ensure that the length of the merged chunk does not exceed chunk_token_num

View File

@ -121,7 +121,8 @@ class FulltextQueryer:
keywords.append(tt)
twts = self.tw.weights([tt])
syns = self.syn.lookup(tt)
if syns and len(keywords) < 32: keywords.extend(syns)
if syns and len(keywords) < 32:
keywords.extend(syns)
logging.debug(json.dumps(twts, ensure_ascii=False))
tms = []
for tk, w in sorted(twts, key=lambda x: x[1] * -1):
@ -147,7 +148,8 @@ class FulltextQueryer:
tk_syns = self.syn.lookup(tk)
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
if len(keywords) < 32: keywords.extend([s for s in tk_syns if s])
if len(keywords) < 32:
keywords.extend([s for s in tk_syns if s])
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
tk_syns = [f"\"{s}\"" if s.find(" ")>0 else s for s in tk_syns]

View File

@ -104,7 +104,6 @@ class RagTokenizer:
return HanziConv.toSimplified(line)
def dfs_(self, chars, s, preTks, tkslist):
MAX_L = 10
res = s
# if s > MAX_L or s>= len(chars):
if s >= len(chars):
@ -184,12 +183,6 @@ class RagTokenizer:
return sorted(res, key=lambda x: x[1], reverse=True)
def merge_(self, tks):
patts = [
(r"[ ]+", " "),
(r"([0-9\+\.,%\*=-]) ([0-9\+\.,%\*=-])", r"\1\2"),
]
# for p,s in patts: tks = re.sub(p, s, tks)
# if split chars is part of token
res = []
tks = re.sub(r"[ ]+", " ", tks).split()
@ -284,7 +277,8 @@ class RagTokenizer:
same = 0
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
same += 1
if same > 0: res.append(" ".join(tks[j: j + same]))
if same > 0:
res.append(" ".join(tks[j: j + same]))
_i = i + same
_j = j + same
j = _j + 1

View File

@ -62,10 +62,10 @@ class Dealer:
res = {}
f = open(fnm, "r")
while True:
l = f.readline()
if not l:
line = f.readline()
if not line:
break
arr = l.replace("\n", "").split("\t")
arr = line.replace("\n", "").split("\t")
if len(arr) < 2:
res[arr[0]] = 0
else:

View File

@ -47,7 +47,8 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
def __call__(self, chunks, random_state, callback=None):
layers = [(0, len(chunks))]
start, end = 0, len(chunks)
if len(chunks) <= 1: return
if len(chunks) <= 1:
return
chunks = [(s, a) for s, a in chunks if len(a) > 0]
def summarize(ck_idx, lock):
@ -66,7 +67,8 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
logging.debug(f"SUM: {cnt}")
embds, _ = self._embd_model.encode([cnt])
with lock:
if not len(embds[0]): return
if not len(embds[0]):
return
chunks.append((cnt, embds[0]))
except Exception as e:
logging.exception("summarize got exception")

View File

@ -33,14 +33,16 @@ def collect():
def main():
locations = collect()
if not locations:return
if not locations:
return
logging.info(f"TASKS: {len(locations)}")
for kb_id, loc in locations:
try:
if REDIS_CONN.is_alive():
try:
key = "{}/{}".format(kb_id, loc)
if REDIS_CONN.exist(key):continue
if REDIS_CONN.exist(key):
continue
file_bin = STORAGE_IMPL.get(kb_id, loc)
REDIS_CONN.transaction(key, file_bin, 12 * 60)
logging.info("CACHE: {}".format(loc))

View File

@ -23,18 +23,12 @@ import os
from api.utils.log_utils import initRootLogger
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
CONSUMER_NAME = "task_executor_" + CONSUMER_NO
LOG_LEVELS = os.environ.get("LOG_LEVELS", "")
initRootLogger(CONSUMER_NAME, LOG_LEVELS)
from datetime import datetime
import json
import os
import hashlib
import copy
import re
import sys
import time
import threading
from functools import partial
@ -63,6 +57,11 @@ from rag.utils import rmSpace, num_tokens_from_string
from rag.utils.redis_conn import REDIS_CONN, Payload
from rag.utils.storage_factory import STORAGE_IMPL
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
CONSUMER_NAME = "task_executor_" + CONSUMER_NO
LOG_LEVELS = os.environ.get("LOG_LEVELS", "")
initRootLogger(CONSUMER_NAME, LOG_LEVELS)
BATCH_SIZE = 64
FACTORY = {
@ -201,7 +200,8 @@ def build_chunks(task, progress_callback):
"doc_id": task["doc_id"],
"kb_id": str(task["kb_id"])
}
if task["pagerank"]: doc["pagerank_fea"] = int(task["pagerank"])
if task["pagerank"]:
doc["pagerank_fea"] = int(task["pagerank"])
el = 0
for ck in cks:
d = copy.deepcopy(doc)
@ -342,7 +342,8 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
"docnm_kwd": row["name"],
"title_tks": rag_tokenizer.tokenize(row["name"])
}
if row["pagerank"]: doc["pagerank_fea"] = int(row["pagerank"])
if row["pagerank"]:
doc["pagerank_fea"] = int(row["pagerank"])
res = []
tk_count = 0
for content, vctr in chunks[original_length:]:

View File

@ -41,15 +41,15 @@ def findMaxDt(fnm):
try:
with open(fnm, "r") as f:
while True:
l = f.readline()
if not l:
line = f.readline()
if not line:
break
l = l.strip("\n")
if l == 'nan':
line = line.strip("\n")
if line == 'nan':
continue
if l > m:
m = l
except Exception as e:
if line > m:
m = line
except Exception:
pass
return m
@ -59,15 +59,15 @@ def findMaxTm(fnm):
try:
with open(fnm, "r") as f:
while True:
l = f.readline()
if not l:
line = f.readline()
if not line:
break
l = l.strip("\n")
if l == 'nan':
line = line.strip("\n")
if line == 'nan':
continue
if int(l) > m:
m = int(l)
except Exception as e:
if int(line) > m:
m = int(line)
except Exception:
pass
return m

View File

@ -32,7 +32,7 @@ class RAGFlowAzureSasBlob(object):
self.conn = None
def health(self):
bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
_bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
return self.conn.upload_blob(name=fnm, data=BytesIO(binary), length=len(binary))
def put(self, bucket, fnm, binary):

View File

@ -36,7 +36,7 @@ class RAGFlowAzureSpnBlob(object):
self.conn = None
def health(self):
bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
_bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
f = self.conn.create_file(fnm)
f.append_data(binary, offset=0, length=len(binary))
return f.flush_data(len(binary))

View File

@ -132,7 +132,8 @@ class ESConnection(DocStoreConnection):
bqry.filter.append(
Q("bool", must_not=Q("range", available_int={"lt": 1})))
continue
if not v: continue
if not v:
continue
if isinstance(v, list):
bqry.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):

View File

@ -1,14 +1,21 @@
from beartype.claw import beartype_this_package
beartype_this_package() # <-- raise exceptions in your code
import importlib.metadata
__version__ = importlib.metadata.version("ragflow_sdk")
from .ragflow import RAGFlow
from .modules.dataset import DataSet
from .modules.chat import Chat
from .modules.session import Session
from .modules.document import Document
from .modules.chunk import Chunk
from .modules.agent import Agent
from .modules.agent import Agent
__version__ = importlib.metadata.version("ragflow_sdk")
__all__ = [
"RAGFlow",
"DataSet",
"Chat",
"Session",
"Document",
"Chunk",
"Agent"
]

View File

@ -29,7 +29,7 @@ class Session(Base):
raise Exception(json_data["message"])
if line.startswith("data:"):
json_data = json.loads(line[5:])
if json_data["data"] != True:
if not json_data["data"]:
answer = json_data["data"]["answer"]
reference = json_data["data"]["reference"]
temp_dict = {

View File

@ -1,5 +1,3 @@
import string
import random
import os
import pytest
import requests

View File

@ -39,7 +39,6 @@ def update_dataset(auth, json_req):
def upload_file(auth, dataset_id, path):
authorization = {"Authorization": auth}
url = f"{HOST_ADDRESS}/v1/document/upload"
base_name = os.path.basename(path)
json_req = {
"kb_id": dataset_id,
}

View File

@ -1,3 +1,3 @@
def test_get_email(get_email):
print(f"\nEmail account:",flush=True)
print("\nEmail account:",flush=True)
print(f"{get_email}\n",flush=True)

View File

@ -13,14 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from common import HOST_ADDRESS, create_dataset, list_dataset, rm_dataset, update_dataset, upload_file, DATASET_NAME_LIMIT
from common import create_dataset, list_dataset, rm_dataset, upload_file
from common import list_document, get_docs_info, parse_docs
from time import sleep
from timeit import default_timer as timer
import re
import pytest
import random
import string
def test_parse_txt_document(get_auth):

View File

@ -1,6 +1,5 @@
from common import HOST_ADDRESS, create_dataset, list_dataset, rm_dataset, update_dataset, DATASET_NAME_LIMIT
from common import create_dataset, list_dataset, rm_dataset, update_dataset, DATASET_NAME_LIMIT
import re
import pytest
import random
import string
@ -33,8 +32,6 @@ def test_dataset(get_auth):
def test_dataset_1k_dataset(get_auth):
# create dataset
authorization = {"Authorization": get_auth}
url = f"{HOST_ADDRESS}/v1/kb/create"
for i in range(1000):
res = create_dataset(get_auth, f"test_create_dataset_{i}")
assert res.get("code") == 0, f"{res.get('message')}"
@ -76,7 +73,7 @@ def test_duplicated_name_dataset(get_auth):
dataset_id = item.get("id")
dataset_list.append(dataset_id)
match = re.match(pattern, dataset_name)
assert match != None
assert match is not None
for dataset_id in dataset_list:
res = rm_dataset(get_auth, dataset_id)

View File

@ -1,3 +1,3 @@
def test_get_email(get_email):
print(f"\nEmail account:",flush=True)
print("\nEmail account:",flush=True)
print(f"{get_email}\n",flush=True)

View File

@ -1,4 +1,4 @@
from ragflow_sdk import RAGFlow,Agent
from ragflow_sdk import RAGFlow
from common import HOST_ADDRESS
import pytest