mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-07-20 09:44:25 +08:00
Fix errors detected by Ruff (#3918)
### What problem does this PR solve? Fix errors detected by Ruff ### Type of change - [x] Refactoring
This commit is contained in:
parent
e267a026f3
commit
0d68a6cd1b
@ -133,7 +133,8 @@ class Canvas(ABC):
|
|||||||
"components": {}
|
"components": {}
|
||||||
}
|
}
|
||||||
for k in self.dsl.keys():
|
for k in self.dsl.keys():
|
||||||
if k in ["components"]:continue
|
if k in ["components"]:
|
||||||
|
continue
|
||||||
dsl[k] = deepcopy(self.dsl[k])
|
dsl[k] = deepcopy(self.dsl[k])
|
||||||
|
|
||||||
for k, cpn in self.components.items():
|
for k, cpn in self.components.items():
|
||||||
@ -158,7 +159,8 @@ class Canvas(ABC):
|
|||||||
|
|
||||||
def get_compnent_name(self, cid):
|
def get_compnent_name(self, cid):
|
||||||
for n in self.dsl["graph"]["nodes"]:
|
for n in self.dsl["graph"]["nodes"]:
|
||||||
if cid == n["id"]: return n["data"]["name"]
|
if cid == n["id"]:
|
||||||
|
return n["data"]["name"]
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def run(self, **kwargs):
|
def run(self, **kwargs):
|
||||||
@ -173,7 +175,8 @@ class Canvas(ABC):
|
|||||||
if kwargs.get("stream"):
|
if kwargs.get("stream"):
|
||||||
for an in ans():
|
for an in ans():
|
||||||
yield an
|
yield an
|
||||||
else: yield ans
|
else:
|
||||||
|
yield ans
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self.path:
|
if not self.path:
|
||||||
@ -188,7 +191,8 @@ class Canvas(ABC):
|
|||||||
def prepare2run(cpns):
|
def prepare2run(cpns):
|
||||||
nonlocal ran, ans
|
nonlocal ran, ans
|
||||||
for c in cpns:
|
for c in cpns:
|
||||||
if self.path[-1] and c == self.path[-1][-1]: continue
|
if self.path[-1] and c == self.path[-1][-1]:
|
||||||
|
continue
|
||||||
cpn = self.components[c]["obj"]
|
cpn = self.components[c]["obj"]
|
||||||
if cpn.component_name == "Answer":
|
if cpn.component_name == "Answer":
|
||||||
self.answer.append(c)
|
self.answer.append(c)
|
||||||
@ -197,7 +201,8 @@ class Canvas(ABC):
|
|||||||
if c not in without_dependent_checking:
|
if c not in without_dependent_checking:
|
||||||
cpids = cpn.get_dependent_components()
|
cpids = cpn.get_dependent_components()
|
||||||
if any([cc not in self.path[-1] for cc in cpids]):
|
if any([cc not in self.path[-1] for cc in cpids]):
|
||||||
if c not in waiting: waiting.append(c)
|
if c not in waiting:
|
||||||
|
waiting.append(c)
|
||||||
continue
|
continue
|
||||||
yield "*'{}'* is running...🕞".format(self.get_compnent_name(c))
|
yield "*'{}'* is running...🕞".format(self.get_compnent_name(c))
|
||||||
ans = cpn.run(self.history, **kwargs)
|
ans = cpn.run(self.history, **kwargs)
|
||||||
@ -211,10 +216,12 @@ class Canvas(ABC):
|
|||||||
logging.debug(f"Canvas.run: {ran} {self.path}")
|
logging.debug(f"Canvas.run: {ran} {self.path}")
|
||||||
cpn_id = self.path[-1][ran]
|
cpn_id = self.path[-1][ran]
|
||||||
cpn = self.get_component(cpn_id)
|
cpn = self.get_component(cpn_id)
|
||||||
if not cpn["downstream"]: break
|
if not cpn["downstream"]:
|
||||||
|
break
|
||||||
|
|
||||||
loop = self._find_loop()
|
loop = self._find_loop()
|
||||||
if loop: raise OverflowError(f"Too much loops: {loop}")
|
if loop:
|
||||||
|
raise OverflowError(f"Too much loops: {loop}")
|
||||||
|
|
||||||
if cpn["obj"].component_name.lower() in ["switch", "categorize", "relevant"]:
|
if cpn["obj"].component_name.lower() in ["switch", "categorize", "relevant"]:
|
||||||
switch_out = cpn["obj"].output()[1].iloc[0, 0]
|
switch_out = cpn["obj"].output()[1].iloc[0, 0]
|
||||||
@ -283,19 +290,22 @@ class Canvas(ABC):
|
|||||||
|
|
||||||
def _find_loop(self, max_loops=6):
|
def _find_loop(self, max_loops=6):
|
||||||
path = self.path[-1][::-1]
|
path = self.path[-1][::-1]
|
||||||
if len(path) < 2: return False
|
if len(path) < 2:
|
||||||
|
return False
|
||||||
|
|
||||||
for i in range(len(path)):
|
for i in range(len(path)):
|
||||||
if path[i].lower().find("answer") >= 0:
|
if path[i].lower().find("answer") >= 0:
|
||||||
path = path[:i]
|
path = path[:i]
|
||||||
break
|
break
|
||||||
|
|
||||||
if len(path) < 2: return False
|
if len(path) < 2:
|
||||||
|
return False
|
||||||
|
|
||||||
for l in range(2, len(path) // 2):
|
for loc in range(2, len(path) // 2):
|
||||||
pat = ",".join(path[0:l])
|
pat = ",".join(path[0:loc])
|
||||||
path_str = ",".join(path)
|
path_str = ",".join(path)
|
||||||
if len(pat) >= len(path_str): return False
|
if len(pat) >= len(path_str):
|
||||||
|
return False
|
||||||
loop = max_loops
|
loop = max_loops
|
||||||
while path_str.find(pat) == 0 and loop >= 0:
|
while path_str.find(pat) == 0 and loop >= 0:
|
||||||
loop -= 1
|
loop -= 1
|
||||||
@ -303,7 +313,7 @@ class Canvas(ABC):
|
|||||||
return False
|
return False
|
||||||
path_str = path_str[len(pat)+1:]
|
path_str = path_str[len(pat)+1:]
|
||||||
if loop < 0:
|
if loop < 0:
|
||||||
pat = " => ".join([p.split(":")[0] for p in path[0:l]])
|
pat = " => ".join([p.split(":")[0] for p in path[0:loc]])
|
||||||
return pat + " => " + pat
|
return pat + " => " + pat
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
@ -39,3 +39,73 @@ def component_class(class_name):
|
|||||||
m = importlib.import_module("agent.component")
|
m = importlib.import_module("agent.component")
|
||||||
c = getattr(m, class_name)
|
c = getattr(m, class_name)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Begin",
|
||||||
|
"BeginParam",
|
||||||
|
"Generate",
|
||||||
|
"GenerateParam",
|
||||||
|
"Retrieval",
|
||||||
|
"RetrievalParam",
|
||||||
|
"Answer",
|
||||||
|
"AnswerParam",
|
||||||
|
"Categorize",
|
||||||
|
"CategorizeParam",
|
||||||
|
"Switch",
|
||||||
|
"SwitchParam",
|
||||||
|
"Relevant",
|
||||||
|
"RelevantParam",
|
||||||
|
"Message",
|
||||||
|
"MessageParam",
|
||||||
|
"RewriteQuestion",
|
||||||
|
"RewriteQuestionParam",
|
||||||
|
"KeywordExtract",
|
||||||
|
"KeywordExtractParam",
|
||||||
|
"Concentrator",
|
||||||
|
"ConcentratorParam",
|
||||||
|
"Baidu",
|
||||||
|
"BaiduParam",
|
||||||
|
"DuckDuckGo",
|
||||||
|
"DuckDuckGoParam",
|
||||||
|
"Wikipedia",
|
||||||
|
"WikipediaParam",
|
||||||
|
"PubMed",
|
||||||
|
"PubMedParam",
|
||||||
|
"ArXiv",
|
||||||
|
"ArXivParam",
|
||||||
|
"Google",
|
||||||
|
"GoogleParam",
|
||||||
|
"Bing",
|
||||||
|
"BingParam",
|
||||||
|
"GoogleScholar",
|
||||||
|
"GoogleScholarParam",
|
||||||
|
"DeepL",
|
||||||
|
"DeepLParam",
|
||||||
|
"GitHub",
|
||||||
|
"GitHubParam",
|
||||||
|
"BaiduFanyi",
|
||||||
|
"BaiduFanyiParam",
|
||||||
|
"QWeather",
|
||||||
|
"QWeatherParam",
|
||||||
|
"ExeSQL",
|
||||||
|
"ExeSQLParam",
|
||||||
|
"YahooFinance",
|
||||||
|
"YahooFinanceParam",
|
||||||
|
"WenCai",
|
||||||
|
"WenCaiParam",
|
||||||
|
"Jin10",
|
||||||
|
"Jin10Param",
|
||||||
|
"TuShare",
|
||||||
|
"TuShareParam",
|
||||||
|
"AkShare",
|
||||||
|
"AkShareParam",
|
||||||
|
"Crawler",
|
||||||
|
"CrawlerParam",
|
||||||
|
"Invoke",
|
||||||
|
"InvokeParam",
|
||||||
|
"Template",
|
||||||
|
"TemplateParam",
|
||||||
|
"Email",
|
||||||
|
"EmailParam",
|
||||||
|
"component_class"
|
||||||
|
]
|
||||||
|
@ -428,7 +428,8 @@ class ComponentBase(ABC):
|
|||||||
def output(self, allow_partial=True) -> Tuple[str, Union[pd.DataFrame, partial]]:
|
def output(self, allow_partial=True) -> Tuple[str, Union[pd.DataFrame, partial]]:
|
||||||
o = getattr(self._param, self._param.output_var_name)
|
o = getattr(self._param, self._param.output_var_name)
|
||||||
if not isinstance(o, partial) and not isinstance(o, pd.DataFrame):
|
if not isinstance(o, partial) and not isinstance(o, pd.DataFrame):
|
||||||
if not isinstance(o, list): o = [o]
|
if not isinstance(o, list):
|
||||||
|
o = [o]
|
||||||
o = pd.DataFrame(o)
|
o = pd.DataFrame(o)
|
||||||
|
|
||||||
if allow_partial or not isinstance(o, partial):
|
if allow_partial or not isinstance(o, partial):
|
||||||
@ -440,7 +441,8 @@ class ComponentBase(ABC):
|
|||||||
for oo in o():
|
for oo in o():
|
||||||
if not isinstance(oo, pd.DataFrame):
|
if not isinstance(oo, pd.DataFrame):
|
||||||
outs = pd.DataFrame(oo if isinstance(oo, list) else [oo])
|
outs = pd.DataFrame(oo if isinstance(oo, list) else [oo])
|
||||||
else: outs = oo
|
else:
|
||||||
|
outs = oo
|
||||||
return self._param.output_var_name, outs
|
return self._param.output_var_name, outs
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
@ -482,13 +484,15 @@ class ComponentBase(ABC):
|
|||||||
outs.append(pd.DataFrame([{"content": q["value"]}]))
|
outs.append(pd.DataFrame([{"content": q["value"]}]))
|
||||||
if outs:
|
if outs:
|
||||||
df = pd.concat(outs, ignore_index=True)
|
df = pd.concat(outs, ignore_index=True)
|
||||||
if "content" in df: df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
|
if "content" in df:
|
||||||
|
df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
|
||||||
return df
|
return df
|
||||||
|
|
||||||
upstream_outs = []
|
upstream_outs = []
|
||||||
|
|
||||||
for u in reversed_cpnts[::-1]:
|
for u in reversed_cpnts[::-1]:
|
||||||
if self.get_component_name(u) in ["switch", "concentrator"]: continue
|
if self.get_component_name(u) in ["switch", "concentrator"]:
|
||||||
|
continue
|
||||||
if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
|
if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
|
||||||
o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
|
o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
|
||||||
if o is not None:
|
if o is not None:
|
||||||
@ -532,7 +536,8 @@ class ComponentBase(ABC):
|
|||||||
reversed_cpnts.extend(self._canvas.path[-1])
|
reversed_cpnts.extend(self._canvas.path[-1])
|
||||||
|
|
||||||
for u in reversed_cpnts[::-1]:
|
for u in reversed_cpnts[::-1]:
|
||||||
if self.get_component_name(u) in ["switch", "answer"]: continue
|
if self.get_component_name(u) in ["switch", "answer"]:
|
||||||
|
continue
|
||||||
return self._canvas.get_component(u)["obj"].output()[1]
|
return self._canvas.get_component(u)["obj"].output()[1]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -34,15 +34,18 @@ class CategorizeParam(GenerateParam):
|
|||||||
super().check()
|
super().check()
|
||||||
self.check_empty(self.category_description, "[Categorize] Category examples")
|
self.check_empty(self.category_description, "[Categorize] Category examples")
|
||||||
for k, v in self.category_description.items():
|
for k, v in self.category_description.items():
|
||||||
if not k: raise ValueError("[Categorize] Category name can not be empty!")
|
if not k:
|
||||||
if not v.get("to"): raise ValueError(f"[Categorize] 'To' of category {k} can not be empty!")
|
raise ValueError("[Categorize] Category name can not be empty!")
|
||||||
|
if not v.get("to"):
|
||||||
|
raise ValueError(f"[Categorize] 'To' of category {k} can not be empty!")
|
||||||
|
|
||||||
def get_prompt(self):
|
def get_prompt(self):
|
||||||
cate_lines = []
|
cate_lines = []
|
||||||
for c, desc in self.category_description.items():
|
for c, desc in self.category_description.items():
|
||||||
for l in desc.get("examples", "").split("\n"):
|
for line in desc.get("examples", "").split("\n"):
|
||||||
if not l: continue
|
if not line:
|
||||||
cate_lines.append("Question: {}\tCategory: {}".format(l, c))
|
continue
|
||||||
|
cate_lines.append("Question: {}\tCategory: {}".format(line, c))
|
||||||
descriptions = []
|
descriptions = []
|
||||||
for c, desc in self.category_description.items():
|
for c, desc in self.category_description.items():
|
||||||
if desc.get("description"):
|
if desc.get("description"):
|
||||||
|
@ -14,7 +14,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
import re
|
|
||||||
from agent.component.base import ComponentBase, ComponentParamBase
|
from agent.component.base import ComponentBase, ComponentParamBase
|
||||||
import deepl
|
import deepl
|
||||||
|
|
||||||
|
@ -46,8 +46,10 @@ class ExeSQLParam(ComponentParamBase):
|
|||||||
self.check_empty(self.password, "Database password")
|
self.check_empty(self.password, "Database password")
|
||||||
self.check_positive_integer(self.top_n, "Number of records")
|
self.check_positive_integer(self.top_n, "Number of records")
|
||||||
if self.database == "rag_flow":
|
if self.database == "rag_flow":
|
||||||
if self.host == "ragflow-mysql": raise ValueError("The host is not accessible.")
|
if self.host == "ragflow-mysql":
|
||||||
if self.password == "infini_rag_flow": raise ValueError("The host is not accessible.")
|
raise ValueError("The host is not accessible.")
|
||||||
|
if self.password == "infini_rag_flow":
|
||||||
|
raise ValueError("The host is not accessible.")
|
||||||
|
|
||||||
|
|
||||||
class ExeSQL(ComponentBase, ABC):
|
class ExeSQL(ComponentBase, ABC):
|
||||||
|
@ -51,11 +51,16 @@ class GenerateParam(ComponentParamBase):
|
|||||||
|
|
||||||
def gen_conf(self):
|
def gen_conf(self):
|
||||||
conf = {}
|
conf = {}
|
||||||
if self.max_tokens > 0: conf["max_tokens"] = self.max_tokens
|
if self.max_tokens > 0:
|
||||||
if self.temperature > 0: conf["temperature"] = self.temperature
|
conf["max_tokens"] = self.max_tokens
|
||||||
if self.top_p > 0: conf["top_p"] = self.top_p
|
if self.temperature > 0:
|
||||||
if self.presence_penalty > 0: conf["presence_penalty"] = self.presence_penalty
|
conf["temperature"] = self.temperature
|
||||||
if self.frequency_penalty > 0: conf["frequency_penalty"] = self.frequency_penalty
|
if self.top_p > 0:
|
||||||
|
conf["top_p"] = self.top_p
|
||||||
|
if self.presence_penalty > 0:
|
||||||
|
conf["presence_penalty"] = self.presence_penalty
|
||||||
|
if self.frequency_penalty > 0:
|
||||||
|
conf["frequency_penalty"] = self.frequency_penalty
|
||||||
return conf
|
return conf
|
||||||
|
|
||||||
|
|
||||||
@ -83,7 +88,8 @@ class Generate(ComponentBase):
|
|||||||
recall_docs = []
|
recall_docs = []
|
||||||
for i in idx:
|
for i in idx:
|
||||||
did = retrieval_res.loc[int(i), "doc_id"]
|
did = retrieval_res.loc[int(i), "doc_id"]
|
||||||
if did in doc_ids: continue
|
if did in doc_ids:
|
||||||
|
continue
|
||||||
doc_ids.add(did)
|
doc_ids.add(did)
|
||||||
recall_docs.append({"doc_id": did, "doc_name": retrieval_res.loc[int(i), "docnm_kwd"]})
|
recall_docs.append({"doc_id": did, "doc_name": retrieval_res.loc[int(i), "docnm_kwd"]})
|
||||||
|
|
||||||
@ -108,7 +114,8 @@ class Generate(ComponentBase):
|
|||||||
retrieval_res = []
|
retrieval_res = []
|
||||||
self._param.inputs = []
|
self._param.inputs = []
|
||||||
for para in self._param.parameters:
|
for para in self._param.parameters:
|
||||||
if not para.get("component_id"): continue
|
if not para.get("component_id"):
|
||||||
|
continue
|
||||||
component_id = para["component_id"].split("@")[0]
|
component_id = para["component_id"].split("@")[0]
|
||||||
if para["component_id"].lower().find("@") >= 0:
|
if para["component_id"].lower().find("@") >= 0:
|
||||||
cpn_id, key = para["component_id"].split("@")
|
cpn_id, key = para["component_id"].split("@")
|
||||||
@ -142,7 +149,8 @@ class Generate(ComponentBase):
|
|||||||
|
|
||||||
if retrieval_res:
|
if retrieval_res:
|
||||||
retrieval_res = pd.concat(retrieval_res, ignore_index=True)
|
retrieval_res = pd.concat(retrieval_res, ignore_index=True)
|
||||||
else: retrieval_res = pd.DataFrame([])
|
else:
|
||||||
|
retrieval_res = pd.DataFrame([])
|
||||||
|
|
||||||
for n, v in kwargs.items():
|
for n, v in kwargs.items():
|
||||||
prompt = re.sub(r"\{%s\}" % re.escape(n), str(v).replace("\\", " "), prompt)
|
prompt = re.sub(r"\{%s\}" % re.escape(n), str(v).replace("\\", " "), prompt)
|
||||||
@ -164,9 +172,11 @@ class Generate(ComponentBase):
|
|||||||
return pd.DataFrame([res])
|
return pd.DataFrame([res])
|
||||||
|
|
||||||
msg = self._canvas.get_history(self._param.message_history_window_size)
|
msg = self._canvas.get_history(self._param.message_history_window_size)
|
||||||
if len(msg) < 1: msg.append({"role": "user", "content": ""})
|
if len(msg) < 1:
|
||||||
|
msg.append({"role": "user", "content": ""})
|
||||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
|
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
|
||||||
if len(msg) < 2: msg.append({"role": "user", "content": ""})
|
if len(msg) < 2:
|
||||||
|
msg.append({"role": "user", "content": ""})
|
||||||
ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf())
|
ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf())
|
||||||
|
|
||||||
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
|
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
|
||||||
@ -185,9 +195,11 @@ class Generate(ComponentBase):
|
|||||||
return
|
return
|
||||||
|
|
||||||
msg = self._canvas.get_history(self._param.message_history_window_size)
|
msg = self._canvas.get_history(self._param.message_history_window_size)
|
||||||
if len(msg) < 1: msg.append({"role": "user", "content": ""})
|
if len(msg) < 1:
|
||||||
|
msg.append({"role": "user", "content": ""})
|
||||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
|
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
|
||||||
if len(msg) < 2: msg.append({"role": "user", "content": ""})
|
if len(msg) < 2:
|
||||||
|
msg.append({"role": "user", "content": ""})
|
||||||
answer = ""
|
answer = ""
|
||||||
for ans in chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf()):
|
for ans in chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf()):
|
||||||
res = {"content": ans, "reference": []}
|
res = {"content": ans, "reference": []}
|
||||||
|
@ -95,7 +95,8 @@ class RewriteQuestion(Generate, ABC):
|
|||||||
hist = self._canvas.get_history(4)
|
hist = self._canvas.get_history(4)
|
||||||
conv = []
|
conv = []
|
||||||
for m in hist:
|
for m in hist:
|
||||||
if m["role"] not in ["user", "assistant"]: continue
|
if m["role"] not in ["user", "assistant"]:
|
||||||
|
continue
|
||||||
conv.append("{}: {}".format(m["role"].upper(), m["content"]))
|
conv.append("{}: {}".format(m["role"].upper(), m["content"]))
|
||||||
conv = "\n".join(conv)
|
conv = "\n".join(conv)
|
||||||
|
|
||||||
|
@ -41,7 +41,8 @@ class SwitchParam(ComponentParamBase):
|
|||||||
def check(self):
|
def check(self):
|
||||||
self.check_empty(self.conditions, "[Switch] conditions")
|
self.check_empty(self.conditions, "[Switch] conditions")
|
||||||
for cond in self.conditions:
|
for cond in self.conditions:
|
||||||
if not cond["to"]: raise ValueError(f"[Switch] 'To' can not be empty!")
|
if not cond["to"]:
|
||||||
|
raise ValueError("[Switch] 'To' can not be empty!")
|
||||||
|
|
||||||
|
|
||||||
class Switch(ComponentBase, ABC):
|
class Switch(ComponentBase, ABC):
|
||||||
@ -51,7 +52,8 @@ class Switch(ComponentBase, ABC):
|
|||||||
res = []
|
res = []
|
||||||
for cond in self._param.conditions:
|
for cond in self._param.conditions:
|
||||||
for item in cond["items"]:
|
for item in cond["items"]:
|
||||||
if not item["cpn_id"]: continue
|
if not item["cpn_id"]:
|
||||||
|
continue
|
||||||
if item["cpn_id"].find("begin") >= 0:
|
if item["cpn_id"].find("begin") >= 0:
|
||||||
continue
|
continue
|
||||||
cid = item["cpn_id"].split("@")[0]
|
cid = item["cpn_id"].split("@")[0]
|
||||||
@ -63,7 +65,8 @@ class Switch(ComponentBase, ABC):
|
|||||||
for cond in self._param.conditions:
|
for cond in self._param.conditions:
|
||||||
res = []
|
res = []
|
||||||
for item in cond["items"]:
|
for item in cond["items"]:
|
||||||
if not item["cpn_id"]:continue
|
if not item["cpn_id"]:
|
||||||
|
continue
|
||||||
cid = item["cpn_id"].split("@")[0]
|
cid = item["cpn_id"].split("@")[0]
|
||||||
if item["cpn_id"].find("@") > 0:
|
if item["cpn_id"].find("@") > 0:
|
||||||
cpn_id, key = item["cpn_id"].split("@")
|
cpn_id, key = item["cpn_id"].split("@")
|
||||||
@ -107,22 +110,22 @@ class Switch(ComponentBase, ABC):
|
|||||||
elif operator == ">":
|
elif operator == ">":
|
||||||
try:
|
try:
|
||||||
return True if float(input) > float(value) else False
|
return True if float(input) > float(value) else False
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return True if input > value else False
|
return True if input > value else False
|
||||||
elif operator == "<":
|
elif operator == "<":
|
||||||
try:
|
try:
|
||||||
return True if float(input) < float(value) else False
|
return True if float(input) < float(value) else False
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return True if input < value else False
|
return True if input < value else False
|
||||||
elif operator == "≥":
|
elif operator == "≥":
|
||||||
try:
|
try:
|
||||||
return True if float(input) >= float(value) else False
|
return True if float(input) >= float(value) else False
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return True if input >= value else False
|
return True if input >= value else False
|
||||||
elif operator == "≤":
|
elif operator == "≤":
|
||||||
try:
|
try:
|
||||||
return True if float(input) <= float(value) else False
|
return True if float(input) <= float(value) else False
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return True if input <= value else False
|
return True if input <= value else False
|
||||||
|
|
||||||
raise ValueError('Not supported operator' + operator)
|
raise ValueError('Not supported operator' + operator)
|
@ -47,7 +47,8 @@ class Template(ComponentBase):
|
|||||||
|
|
||||||
self._param.inputs = []
|
self._param.inputs = []
|
||||||
for para in self._param.parameters:
|
for para in self._param.parameters:
|
||||||
if not para.get("component_id"): continue
|
if not para.get("component_id"):
|
||||||
|
continue
|
||||||
component_id = para["component_id"].split("@")[0]
|
component_id = para["component_id"].split("@")[0]
|
||||||
if para["component_id"].lower().find("@") >= 0:
|
if para["component_id"].lower().find("@") >= 0:
|
||||||
cpn_id, key = para["component_id"].split("@")
|
cpn_id, key = para["component_id"].split("@")
|
||||||
|
@ -43,6 +43,7 @@ if __name__ == '__main__':
|
|||||||
else:
|
else:
|
||||||
print(ans["content"])
|
print(ans["content"])
|
||||||
|
|
||||||
if DEBUG: print(canvas.path)
|
if DEBUG:
|
||||||
|
print(canvas.path)
|
||||||
question = input("\n==================== User =====================\n> ")
|
question = input("\n==================== User =====================\n> ")
|
||||||
canvas.add_user_input(question)
|
canvas.add_user_input(question)
|
||||||
|
@ -142,7 +142,6 @@ def set_conversation():
|
|||||||
if not objs:
|
if not objs:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
req = request.json
|
|
||||||
try:
|
try:
|
||||||
if objs[0].source == "agent":
|
if objs[0].source == "agent":
|
||||||
e, cvs = UserCanvasService.get_by_id(objs[0].dialog_id)
|
e, cvs = UserCanvasService.get_by_id(objs[0].dialog_id)
|
||||||
@ -188,7 +187,8 @@ def completion():
|
|||||||
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Conversation not found!")
|
return get_data_error_result(message="Conversation not found!")
|
||||||
if "quote" not in req: req["quote"] = False
|
if "quote" not in req:
|
||||||
|
req["quote"] = False
|
||||||
|
|
||||||
msg = []
|
msg = []
|
||||||
for m in req["messages"]:
|
for m in req["messages"]:
|
||||||
@ -197,7 +197,8 @@ def completion():
|
|||||||
if m["role"] == "assistant" and not msg:
|
if m["role"] == "assistant" and not msg:
|
||||||
continue
|
continue
|
||||||
msg.append(m)
|
msg.append(m)
|
||||||
if not msg[-1].get("id"): msg[-1]["id"] = get_uuid()
|
if not msg[-1].get("id"):
|
||||||
|
msg[-1]["id"] = get_uuid()
|
||||||
message_id = msg[-1]["id"]
|
message_id = msg[-1]["id"]
|
||||||
|
|
||||||
def fillin_conv(ans):
|
def fillin_conv(ans):
|
||||||
@ -674,11 +675,13 @@ def completion_faq():
|
|||||||
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Conversation not found!")
|
return get_data_error_result(message="Conversation not found!")
|
||||||
if "quote" not in req: req["quote"] = True
|
if "quote" not in req:
|
||||||
|
req["quote"] = True
|
||||||
|
|
||||||
msg = []
|
msg = []
|
||||||
msg.append({"role": "user", "content": req["word"]})
|
msg.append({"role": "user", "content": req["word"]})
|
||||||
if not msg[-1].get("id"): msg[-1]["id"] = get_uuid()
|
if not msg[-1].get("id"):
|
||||||
|
msg[-1]["id"] = get_uuid()
|
||||||
message_id = msg[-1]["id"]
|
message_id = msg[-1]["id"]
|
||||||
|
|
||||||
def fillin_conv(ans):
|
def fillin_conv(ans):
|
||||||
|
@ -13,10 +13,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import logging
|
|
||||||
import json
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
from functools import partial
|
|
||||||
from flask import request, Response
|
from flask import request, Response
|
||||||
from flask_login import login_required, current_user
|
from flask_login import login_required, current_user
|
||||||
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
|
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
|
||||||
@ -60,7 +58,8 @@ def rm():
|
|||||||
def save():
|
def save():
|
||||||
req = request.json
|
req = request.json
|
||||||
req["user_id"] = current_user.id
|
req["user_id"] = current_user.id
|
||||||
if not isinstance(req["dsl"], str): req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
if not isinstance(req["dsl"], str):
|
||||||
|
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
||||||
|
|
||||||
req["dsl"] = json.loads(req["dsl"])
|
req["dsl"] = json.loads(req["dsl"])
|
||||||
if "id" not in req:
|
if "id" not in req:
|
||||||
@ -153,7 +152,8 @@ def run():
|
|||||||
return resp
|
return resp
|
||||||
|
|
||||||
for answer in canvas.run(stream=False):
|
for answer in canvas.run(stream=False):
|
||||||
if answer.get("running_status"): continue
|
if answer.get("running_status"):
|
||||||
|
continue
|
||||||
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
|
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
|
||||||
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
|
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
|
||||||
if final_ans.get("reference"):
|
if final_ans.get("reference"):
|
||||||
|
@ -237,7 +237,8 @@ def create():
|
|||||||
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Knowledgebase not found!")
|
return get_data_error_result(message="Knowledgebase not found!")
|
||||||
if kb.pagerank: d["pagerank_fea"] = kb.pagerank
|
if kb.pagerank:
|
||||||
|
d["pagerank_fea"] = kb.pagerank
|
||||||
|
|
||||||
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
||||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
|
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
|
||||||
|
@ -281,10 +281,12 @@ def thumbup():
|
|||||||
if req["message_id"] == msg.get("id", "") and msg.get("role", "") == "assistant":
|
if req["message_id"] == msg.get("id", "") and msg.get("role", "") == "assistant":
|
||||||
if up_down:
|
if up_down:
|
||||||
msg["thumbup"] = True
|
msg["thumbup"] = True
|
||||||
if "feedback" in msg: del msg["feedback"]
|
if "feedback" in msg:
|
||||||
|
del msg["feedback"]
|
||||||
else:
|
else:
|
||||||
msg["thumbup"] = False
|
msg["thumbup"] = False
|
||||||
if feedback: msg["feedback"] = feedback
|
if feedback:
|
||||||
|
msg["feedback"] = feedback
|
||||||
break
|
break
|
||||||
|
|
||||||
ConversationService.update_by_id(conv["id"], conv)
|
ConversationService.update_by_id(conv["id"], conv)
|
||||||
|
@ -37,10 +37,12 @@ def set_dialog():
|
|||||||
top_n = req.get("top_n", 6)
|
top_n = req.get("top_n", 6)
|
||||||
top_k = req.get("top_k", 1024)
|
top_k = req.get("top_k", 1024)
|
||||||
rerank_id = req.get("rerank_id", "")
|
rerank_id = req.get("rerank_id", "")
|
||||||
if not rerank_id: req["rerank_id"] = ""
|
if not rerank_id:
|
||||||
|
req["rerank_id"] = ""
|
||||||
similarity_threshold = req.get("similarity_threshold", 0.1)
|
similarity_threshold = req.get("similarity_threshold", 0.1)
|
||||||
vector_similarity_weight = req.get("vector_similarity_weight", 0.3)
|
vector_similarity_weight = req.get("vector_similarity_weight", 0.3)
|
||||||
if vector_similarity_weight is None: vector_similarity_weight = 0.3
|
if vector_similarity_weight is None:
|
||||||
|
vector_similarity_weight = 0.3
|
||||||
llm_setting = req.get("llm_setting", {})
|
llm_setting = req.get("llm_setting", {})
|
||||||
default_prompt = {
|
default_prompt = {
|
||||||
"system": """你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答。当所有知识库内容都与问题无关时,你的回答必须包括“知识库中未找到您要的答案!”这句话。回答需要考虑聊天历史。
|
"system": """你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答。当所有知识库内容都与问题无关时,你的回答必须包括“知识库中未找到您要的答案!”这句话。回答需要考虑聊天历史。
|
||||||
|
@ -13,7 +13,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License
|
# limitations under the License
|
||||||
#
|
#
|
||||||
import json
|
|
||||||
import os.path
|
import os.path
|
||||||
import pathlib
|
import pathlib
|
||||||
import re
|
import re
|
||||||
@ -90,7 +89,8 @@ def web_crawl():
|
|||||||
raise LookupError("Can't find this knowledgebase!")
|
raise LookupError("Can't find this knowledgebase!")
|
||||||
|
|
||||||
blob = html2pdf(url)
|
blob = html2pdf(url)
|
||||||
if not blob: return server_error_response(ValueError("Download failure."))
|
if not blob:
|
||||||
|
return server_error_response(ValueError("Download failure."))
|
||||||
|
|
||||||
root_folder = FileService.get_root_folder(current_user.id)
|
root_folder = FileService.get_root_folder(current_user.id)
|
||||||
pf_id = root_folder["id"]
|
pf_id = root_folder["id"]
|
||||||
@ -290,7 +290,8 @@ def change_status():
|
|||||||
def rm():
|
def rm():
|
||||||
req = request.json
|
req = request.json
|
||||||
doc_ids = req["doc_id"]
|
doc_ids = req["doc_id"]
|
||||||
if isinstance(doc_ids, str): doc_ids = [doc_ids]
|
if isinstance(doc_ids, str):
|
||||||
|
doc_ids = [doc_ids]
|
||||||
|
|
||||||
for doc_id in doc_ids:
|
for doc_id in doc_ids:
|
||||||
if not DocumentService.accessible4deletion(doc_id, current_user.id):
|
if not DocumentService.accessible4deletion(doc_id, current_user.id):
|
||||||
|
@ -351,8 +351,10 @@ def list_app():
|
|||||||
|
|
||||||
llm_set = set([m["llm_name"] + "@" + m["fid"] for m in llms])
|
llm_set = set([m["llm_name"] + "@" + m["fid"] for m in llms])
|
||||||
for o in objs:
|
for o in objs:
|
||||||
if not o.api_key: continue
|
if not o.api_key:
|
||||||
if o.llm_name + "@" + o.llm_factory in llm_set: continue
|
continue
|
||||||
|
if o.llm_name + "@" + o.llm_factory in llm_set:
|
||||||
|
continue
|
||||||
llms.append({"llm_name": o.llm_name, "model_type": o.model_type, "fid": o.llm_factory, "available": True})
|
llms.append({"llm_name": o.llm_name, "model_type": o.model_type, "fid": o.llm_factory, "available": True})
|
||||||
|
|
||||||
res = {}
|
res = {}
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
|
from api.db.services.canvas_service import UserCanvasService
|
||||||
from api.utils.api_utils import get_error_data_result, token_required
|
from api.utils.api_utils import get_error_data_result, token_required
|
||||||
from api.utils.api_utils import get_result
|
from api.utils.api_utils import get_result
|
||||||
from flask import request
|
from flask import request
|
||||||
|
@ -41,7 +41,6 @@ from api.utils.api_utils import construct_json_result, get_parser_config
|
|||||||
from rag.nlp import search
|
from rag.nlp import search
|
||||||
from rag.utils import rmSpace
|
from rag.utils import rmSpace
|
||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
import os
|
|
||||||
|
|
||||||
MAXIMUM_OF_UPLOADING_FILES = 256
|
MAXIMUM_OF_UPLOADING_FILES = 256
|
||||||
|
|
||||||
@ -976,12 +975,12 @@ def add_chunk(tenant_id, dataset_id, document_id):
|
|||||||
if not req.get("content"):
|
if not req.get("content"):
|
||||||
return get_error_data_result(message="`content` is required")
|
return get_error_data_result(message="`content` is required")
|
||||||
if "important_keywords" in req:
|
if "important_keywords" in req:
|
||||||
if type(req["important_keywords"]) != list:
|
if not isinstance(req["important_keywords"], list):
|
||||||
return get_error_data_result(
|
return get_error_data_result(
|
||||||
"`important_keywords` is required to be a list"
|
"`important_keywords` is required to be a list"
|
||||||
)
|
)
|
||||||
if "questions" in req:
|
if "questions" in req:
|
||||||
if type(req["questions"]) != list:
|
if not isinstance(req["questions"], list):
|
||||||
return get_error_data_result(
|
return get_error_data_result(
|
||||||
"`questions` is required to be a list"
|
"`questions` is required to be a list"
|
||||||
)
|
)
|
||||||
|
@ -143,8 +143,10 @@ def completion(tenant_id, chat_id):
|
|||||||
}
|
}
|
||||||
conv.message.append(question)
|
conv.message.append(question)
|
||||||
for m in conv.message:
|
for m in conv.message:
|
||||||
if m["role"] == "system": continue
|
if m["role"] == "system":
|
||||||
if m["role"] == "assistant" and not msg: continue
|
continue
|
||||||
|
if m["role"] == "assistant" and not msg:
|
||||||
|
continue
|
||||||
msg.append(m)
|
msg.append(m)
|
||||||
message_id = msg[-1].get("id")
|
message_id = msg[-1].get("id")
|
||||||
e, dia = DialogService.get_by_id(conv.dialog_id)
|
e, dia = DialogService.get_by_id(conv.dialog_id)
|
||||||
@ -267,7 +269,8 @@ def agent_completion(tenant_id, agent_id):
|
|||||||
if m["role"] == "assistant" and not msg:
|
if m["role"] == "assistant" and not msg:
|
||||||
continue
|
continue
|
||||||
msg.append(m)
|
msg.append(m)
|
||||||
if not msg[-1].get("id"): msg[-1]["id"] = get_uuid()
|
if not msg[-1].get("id"):
|
||||||
|
msg[-1]["id"] = get_uuid()
|
||||||
message_id = msg[-1]["id"]
|
message_id = msg[-1]["id"]
|
||||||
|
|
||||||
stream = req.get("stream", True)
|
stream = req.get("stream", True)
|
||||||
@ -361,7 +364,8 @@ def agent_completion(tenant_id, agent_id):
|
|||||||
return resp
|
return resp
|
||||||
|
|
||||||
for answer in canvas.run(stream=False):
|
for answer in canvas.run(stream=False):
|
||||||
if answer.get("running_status"): continue
|
if answer.get("running_status"):
|
||||||
|
continue
|
||||||
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
|
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
|
||||||
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
|
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
|
||||||
if final_ans.get("reference"):
|
if final_ans.get("reference"):
|
||||||
|
@ -330,7 +330,7 @@ def user_info_from_github(access_token):
|
|||||||
headers=headers,
|
headers=headers,
|
||||||
).json()
|
).json()
|
||||||
user_info["email"] = next(
|
user_info["email"] = next(
|
||||||
(email for email in email_info if email["primary"] == True), None
|
(email for email in email_info if email["primary"]), None
|
||||||
)["email"]
|
)["email"]
|
||||||
return user_info
|
return user_info
|
||||||
|
|
||||||
|
@ -130,7 +130,7 @@ def is_continuous_field(cls: typing.Type) -> bool:
|
|||||||
for p in cls.__bases__:
|
for p in cls.__bases__:
|
||||||
if p in CONTINUOUS_FIELD_TYPE:
|
if p in CONTINUOUS_FIELD_TYPE:
|
||||||
return True
|
return True
|
||||||
elif p != Field and p != object:
|
elif p is not Field and p is not object:
|
||||||
if is_continuous_field(p):
|
if is_continuous_field(p):
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
|
@ -170,7 +170,7 @@ def add_graph_templates():
|
|||||||
cnvs = json.load(open(os.path.join(dir, fnm), "r"))
|
cnvs = json.load(open(os.path.join(dir, fnm), "r"))
|
||||||
try:
|
try:
|
||||||
CanvasTemplateService.save(**cnvs)
|
CanvasTemplateService.save(**cnvs)
|
||||||
except:
|
except Exception:
|
||||||
CanvasTemplateService.update_by_id(cnvs["id"], cnvs)
|
CanvasTemplateService.update_by_id(cnvs["id"], cnvs)
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception("Add graph templates error: ")
|
logging.exception("Add graph templates error: ")
|
||||||
|
@ -15,13 +15,14 @@
|
|||||||
#
|
#
|
||||||
import pathlib
|
import pathlib
|
||||||
import re
|
import re
|
||||||
from .user_service import UserService
|
from .user_service import UserService as UserService
|
||||||
|
|
||||||
|
|
||||||
def duplicate_name(query_func, **kwargs):
|
def duplicate_name(query_func, **kwargs):
|
||||||
fnm = kwargs["name"]
|
fnm = kwargs["name"]
|
||||||
objs = query_func(**kwargs)
|
objs = query_func(**kwargs)
|
||||||
if not objs: return fnm
|
if not objs:
|
||||||
|
return fnm
|
||||||
ext = pathlib.Path(fnm).suffix #.jpg
|
ext = pathlib.Path(fnm).suffix #.jpg
|
||||||
nm = re.sub(r"%s$"%ext, "", fnm)
|
nm = re.sub(r"%s$"%ext, "", fnm)
|
||||||
r = re.search(r"\(([0-9]+)\)$", nm)
|
r = re.search(r"\(([0-9]+)\)$", nm)
|
||||||
@ -31,8 +32,8 @@ def duplicate_name(query_func, **kwargs):
|
|||||||
nm = re.sub(r"\([0-9]+\)$", "", nm)
|
nm = re.sub(r"\([0-9]+\)$", "", nm)
|
||||||
c += 1
|
c += 1
|
||||||
nm = f"{nm}({c})"
|
nm = f"{nm}({c})"
|
||||||
if ext: nm += f"{ext}"
|
if ext:
|
||||||
|
nm += f"{ext}"
|
||||||
|
|
||||||
kwargs["name"] = nm
|
kwargs["name"] = nm
|
||||||
return duplicate_name(query_func, **kwargs)
|
return duplicate_name(query_func, **kwargs)
|
||||||
|
|
||||||
|
@ -64,7 +64,8 @@ class API4ConversationService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def stats(cls, tenant_id, from_date, to_date, source=None):
|
def stats(cls, tenant_id, from_date, to_date, source=None):
|
||||||
if len(to_date) == 10: to_date += " 23:59:59"
|
if len(to_date) == 10:
|
||||||
|
to_date += " 23:59:59"
|
||||||
return cls.model.select(
|
return cls.model.select(
|
||||||
cls.model.create_date.truncate("day").alias("dt"),
|
cls.model.create_date.truncate("day").alias("dt"),
|
||||||
peewee.fn.COUNT(
|
peewee.fn.COUNT(
|
||||||
|
@ -13,9 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
from datetime import datetime
|
from api.db.db_models import DB, CanvasTemplate, UserCanvas
|
||||||
import peewee
|
|
||||||
from api.db.db_models import DB, API4Conversation, APIToken, Dialog, CanvasTemplate, UserCanvas
|
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
|
|
||||||
|
|
||||||
|
@ -115,7 +115,7 @@ class CommonService:
|
|||||||
try:
|
try:
|
||||||
obj = cls.model.query(id=pid)[0]
|
obj = cls.model.query(id=pid)[0]
|
||||||
return True, obj
|
return True, obj
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -106,15 +106,15 @@ def message_fit_in(msg, max_length=4000):
|
|||||||
return c, msg
|
return c, msg
|
||||||
|
|
||||||
ll = num_tokens_from_string(msg_[0]["content"])
|
ll = num_tokens_from_string(msg_[0]["content"])
|
||||||
l = num_tokens_from_string(msg_[-1]["content"])
|
ll2 = num_tokens_from_string(msg_[-1]["content"])
|
||||||
if ll / (ll + l) > 0.8:
|
if ll / (ll + ll2) > 0.8:
|
||||||
m = msg_[0]["content"]
|
m = msg_[0]["content"]
|
||||||
m = encoder.decode(encoder.encode(m)[:max_length - l])
|
m = encoder.decode(encoder.encode(m)[:max_length - ll2])
|
||||||
msg[0]["content"] = m
|
msg[0]["content"] = m
|
||||||
return max_length, msg
|
return max_length, msg
|
||||||
|
|
||||||
m = msg_[1]["content"]
|
m = msg_[1]["content"]
|
||||||
m = encoder.decode(encoder.encode(m)[:max_length - l])
|
m = encoder.decode(encoder.encode(m)[:max_length - ll2])
|
||||||
msg[1]["content"] = m
|
msg[1]["content"] = m
|
||||||
return max_length, msg
|
return max_length, msg
|
||||||
|
|
||||||
@ -257,7 +257,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||||||
recall_docs = [
|
recall_docs = [
|
||||||
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||||
if not recall_docs: recall_docs = kbinfos["doc_aggs"]
|
if not recall_docs:
|
||||||
|
recall_docs = kbinfos["doc_aggs"]
|
||||||
kbinfos["doc_aggs"] = recall_docs
|
kbinfos["doc_aggs"] = recall_docs
|
||||||
|
|
||||||
refs = deepcopy(kbinfos)
|
refs = deepcopy(kbinfos)
|
||||||
@ -433,13 +434,15 @@ def relevant(tenant_id, llm_id, question, contents: list):
|
|||||||
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.
|
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.
|
||||||
No other words needed except 'yes' or 'no'.
|
No other words needed except 'yes' or 'no'.
|
||||||
"""
|
"""
|
||||||
if not contents:return False
|
if not contents:
|
||||||
|
return False
|
||||||
contents = "Documents: \n" + " - ".join(contents)
|
contents = "Documents: \n" + " - ".join(contents)
|
||||||
contents = f"Question: {question}\n" + contents
|
contents = f"Question: {question}\n" + contents
|
||||||
if num_tokens_from_string(contents) >= chat_mdl.max_length - 4:
|
if num_tokens_from_string(contents) >= chat_mdl.max_length - 4:
|
||||||
contents = encoder.decode(encoder.encode(contents)[:chat_mdl.max_length - 4])
|
contents = encoder.decode(encoder.encode(contents)[:chat_mdl.max_length - 4])
|
||||||
ans = chat_mdl.chat(prompt, [{"role": "user", "content": contents}], {"temperature": 0.01})
|
ans = chat_mdl.chat(prompt, [{"role": "user", "content": contents}], {"temperature": 0.01})
|
||||||
if ans.lower().find("yes") >= 0: return True
|
if ans.lower().find("yes") >= 0:
|
||||||
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@ -481,8 +484,10 @@ Requirements:
|
|||||||
]
|
]
|
||||||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||||||
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
|
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
|
||||||
if isinstance(kwd, tuple): kwd = kwd[0]
|
if isinstance(kwd, tuple):
|
||||||
if kwd.find("**ERROR**") >=0: return ""
|
kwd = kwd[0]
|
||||||
|
if kwd.find("**ERROR**") >=0:
|
||||||
|
return ""
|
||||||
return kwd
|
return kwd
|
||||||
|
|
||||||
|
|
||||||
@ -508,8 +513,10 @@ Requirements:
|
|||||||
]
|
]
|
||||||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||||||
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
|
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
|
||||||
if isinstance(kwd, tuple): kwd = kwd[0]
|
if isinstance(kwd, tuple):
|
||||||
if kwd.find("**ERROR**") >= 0: return ""
|
kwd = kwd[0]
|
||||||
|
if kwd.find("**ERROR**") >= 0:
|
||||||
|
return ""
|
||||||
return kwd
|
return kwd
|
||||||
|
|
||||||
|
|
||||||
@ -520,7 +527,8 @@ def full_question(tenant_id, llm_id, messages):
|
|||||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
||||||
conv = []
|
conv = []
|
||||||
for m in messages:
|
for m in messages:
|
||||||
if m["role"] not in ["user", "assistant"]: continue
|
if m["role"] not in ["user", "assistant"]:
|
||||||
|
continue
|
||||||
conv.append("{}: {}".format(m["role"].upper(), m["content"]))
|
conv.append("{}: {}".format(m["role"].upper(), m["content"]))
|
||||||
conv = "\n".join(conv)
|
conv = "\n".join(conv)
|
||||||
today = datetime.date.today().isoformat()
|
today = datetime.date.today().isoformat()
|
||||||
@ -581,7 +589,8 @@ Output: What's the weather in Rochester on {tomorrow}?
|
|||||||
|
|
||||||
|
|
||||||
def tts(tts_mdl, text):
|
def tts(tts_mdl, text):
|
||||||
if not tts_mdl or not text: return
|
if not tts_mdl or not text:
|
||||||
|
return
|
||||||
bin = b""
|
bin = b""
|
||||||
for chunk in tts_mdl.tts(text):
|
for chunk in tts_mdl.tts(text):
|
||||||
bin += chunk
|
bin += chunk
|
||||||
@ -641,7 +650,8 @@ def ask(question, kb_ids, tenant_id):
|
|||||||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||||||
recall_docs = [
|
recall_docs = [
|
||||||
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||||
if not recall_docs: recall_docs = kbinfos["doc_aggs"]
|
if not recall_docs:
|
||||||
|
recall_docs = kbinfos["doc_aggs"]
|
||||||
kbinfos["doc_aggs"] = recall_docs
|
kbinfos["doc_aggs"] = recall_docs
|
||||||
refs = deepcopy(kbinfos)
|
refs = deepcopy(kbinfos)
|
||||||
for c in refs["chunks"]:
|
for c in refs["chunks"]:
|
||||||
|
@ -532,7 +532,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|||||||
try:
|
try:
|
||||||
mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output,
|
mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output,
|
||||||
ensure_ascii=False, indent=2)
|
ensure_ascii=False, indent=2)
|
||||||
if len(mind_map) < 32: raise Exception("Few content: " + mind_map)
|
if len(mind_map) < 32:
|
||||||
|
raise Exception("Few content: " + mind_map)
|
||||||
cks.append({
|
cks.append({
|
||||||
"id": get_uuid(),
|
"id": get_uuid(),
|
||||||
"doc_id": doc_id,
|
"doc_id": doc_id,
|
||||||
|
@ -20,7 +20,7 @@ from api.db.db_models import DB
|
|||||||
from api.db.db_models import File, File2Document
|
from api.db.db_models import File, File2Document
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.utils import current_timestamp, datetime_format, get_uuid
|
from api.utils import current_timestamp, datetime_format
|
||||||
|
|
||||||
|
|
||||||
class File2DocumentService(CommonService):
|
class File2DocumentService(CommonService):
|
||||||
@ -63,7 +63,7 @@ class File2DocumentService(CommonService):
|
|||||||
def update_by_file_id(cls, file_id, obj):
|
def update_by_file_id(cls, file_id, obj):
|
||||||
obj["update_time"] = current_timestamp()
|
obj["update_time"] = current_timestamp()
|
||||||
obj["update_date"] = datetime_format(datetime.now())
|
obj["update_date"] = datetime_format(datetime.now())
|
||||||
num = cls.model.update(obj).where(cls.model.id == file_id).execute()
|
# num = cls.model.update(obj).where(cls.model.id == file_id).execute()
|
||||||
e, obj = cls.get_by_id(cls.model.id)
|
e, obj = cls.get_by_id(cls.model.id)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
@ -85,7 +85,8 @@ class FileService(CommonService):
|
|||||||
.join(Document, on=(File2Document.document_id == Document.id))
|
.join(Document, on=(File2Document.document_id == Document.id))
|
||||||
.join(Knowledgebase, on=(Knowledgebase.id == Document.kb_id))
|
.join(Knowledgebase, on=(Knowledgebase.id == Document.kb_id))
|
||||||
.where(cls.model.id == file_id))
|
.where(cls.model.id == file_id))
|
||||||
if not kbs: return []
|
if not kbs:
|
||||||
|
return []
|
||||||
kbs_info_list = []
|
kbs_info_list = []
|
||||||
for kb in list(kbs.dicts()):
|
for kb in list(kbs.dicts()):
|
||||||
kbs_info_list.append({"kb_id": kb['id'], "kb_name": kb['name']})
|
kbs_info_list.append({"kb_id": kb['id'], "kb_name": kb['name']})
|
||||||
@ -304,7 +305,8 @@ class FileService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def add_file_from_kb(cls, doc, kb_folder_id, tenant_id):
|
def add_file_from_kb(cls, doc, kb_folder_id, tenant_id):
|
||||||
for _ in File2DocumentService.get_by_document_id(doc["id"]): return
|
for _ in File2DocumentService.get_by_document_id(doc["id"]):
|
||||||
|
return
|
||||||
file = {
|
file = {
|
||||||
"id": get_uuid(),
|
"id": get_uuid(),
|
||||||
"parent_id": kb_folder_id,
|
"parent_id": kb_folder_id,
|
||||||
|
@ -107,7 +107,8 @@ class TenantLLMService(CommonService):
|
|||||||
|
|
||||||
model_config = cls.get_api_key(tenant_id, mdlnm)
|
model_config = cls.get_api_key(tenant_id, mdlnm)
|
||||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
|
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
|
||||||
if model_config: model_config = model_config.to_dict()
|
if model_config:
|
||||||
|
model_config = model_config.to_dict()
|
||||||
if not model_config:
|
if not model_config:
|
||||||
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
||||||
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
||||||
|
@ -57,28 +57,33 @@ class TaskService(CommonService):
|
|||||||
Tenant.img2txt_id,
|
Tenant.img2txt_id,
|
||||||
Tenant.asr_id,
|
Tenant.asr_id,
|
||||||
Tenant.llm_id,
|
Tenant.llm_id,
|
||||||
cls.model.update_time]
|
cls.model.update_time,
|
||||||
docs = cls.model.select(*fields) \
|
]
|
||||||
.join(Document, on=(cls.model.doc_id == Document.id)) \
|
docs = (
|
||||||
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id)) \
|
cls.model.select(*fields)
|
||||||
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \
|
.join(Document, on=(cls.model.doc_id == Document.id))
|
||||||
|
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id))
|
||||||
|
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
|
||||||
.where(cls.model.id == task_id)
|
.where(cls.model.id == task_id)
|
||||||
|
)
|
||||||
docs = list(docs.dicts())
|
docs = list(docs.dicts())
|
||||||
if not docs: return None
|
if not docs:
|
||||||
|
return None
|
||||||
|
|
||||||
msg = "\nTask has been received."
|
msg = "\nTask has been received."
|
||||||
prog = random.random() / 10.
|
prog = random.random() / 10.0
|
||||||
if docs[0]["retry_count"] >= 3:
|
if docs[0]["retry_count"] >= 3:
|
||||||
msg = "\nERROR: Task is abandoned after 3 times attempts."
|
msg = "\nERROR: Task is abandoned after 3 times attempts."
|
||||||
prog = -1
|
prog = -1
|
||||||
|
|
||||||
cls.model.update(progress_msg=cls.model.progress_msg + msg,
|
cls.model.update(
|
||||||
progress=prog,
|
progress_msg=cls.model.progress_msg + msg,
|
||||||
retry_count=docs[0]["retry_count"]+1
|
progress=prog,
|
||||||
).where(
|
retry_count=docs[0]["retry_count"] + 1,
|
||||||
cls.model.id == docs[0]["id"]).execute()
|
).where(cls.model.id == docs[0]["id"]).execute()
|
||||||
|
|
||||||
if docs[0]["retry_count"] >= 3: return None
|
if docs[0]["retry_count"] >= 3:
|
||||||
|
return None
|
||||||
|
|
||||||
return docs[0]
|
return docs[0]
|
||||||
|
|
||||||
@ -86,21 +91,44 @@ class TaskService(CommonService):
|
|||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_ongoing_doc_name(cls):
|
def get_ongoing_doc_name(cls):
|
||||||
with DB.lock("get_task", -1):
|
with DB.lock("get_task", -1):
|
||||||
docs = cls.model.select(*[Document.id, Document.kb_id, Document.location, File.parent_id]) \
|
docs = (
|
||||||
.join(Document, on=(cls.model.doc_id == Document.id)) \
|
cls.model.select(
|
||||||
.join(File2Document, on=(File2Document.document_id == Document.id), join_type=JOIN.LEFT_OUTER) \
|
*[Document.id, Document.kb_id, Document.location, File.parent_id]
|
||||||
.join(File, on=(File2Document.file_id == File.id), join_type=JOIN.LEFT_OUTER) \
|
)
|
||||||
|
.join(Document, on=(cls.model.doc_id == Document.id))
|
||||||
|
.join(
|
||||||
|
File2Document,
|
||||||
|
on=(File2Document.document_id == Document.id),
|
||||||
|
join_type=JOIN.LEFT_OUTER,
|
||||||
|
)
|
||||||
|
.join(
|
||||||
|
File,
|
||||||
|
on=(File2Document.file_id == File.id),
|
||||||
|
join_type=JOIN.LEFT_OUTER,
|
||||||
|
)
|
||||||
.where(
|
.where(
|
||||||
Document.status == StatusEnum.VALID.value,
|
Document.status == StatusEnum.VALID.value,
|
||||||
Document.run == TaskStatus.RUNNING.value,
|
Document.run == TaskStatus.RUNNING.value,
|
||||||
~(Document.type == FileType.VIRTUAL.value),
|
~(Document.type == FileType.VIRTUAL.value),
|
||||||
cls.model.progress < 1,
|
cls.model.progress < 1,
|
||||||
cls.model.create_time >= current_timestamp() - 1000 * 600
|
cls.model.create_time >= current_timestamp() - 1000 * 600,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
docs = list(docs.dicts())
|
docs = list(docs.dicts())
|
||||||
if not docs: return []
|
if not docs:
|
||||||
|
return []
|
||||||
|
|
||||||
return list(set([(d["parent_id"] if d["parent_id"] else d["kb_id"], d["location"]) for d in docs]))
|
return list(
|
||||||
|
set(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
d["parent_id"] if d["parent_id"] else d["kb_id"],
|
||||||
|
d["location"],
|
||||||
|
)
|
||||||
|
for d in docs
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
@ -118,28 +146,30 @@ class TaskService(CommonService):
|
|||||||
def update_progress(cls, id, info):
|
def update_progress(cls, id, info):
|
||||||
if os.environ.get("MACOS"):
|
if os.environ.get("MACOS"):
|
||||||
if info["progress_msg"]:
|
if info["progress_msg"]:
|
||||||
cls.model.update(progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]).where(
|
cls.model.update(
|
||||||
cls.model.id == id).execute()
|
progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]
|
||||||
|
).where(cls.model.id == id).execute()
|
||||||
if "progress" in info:
|
if "progress" in info:
|
||||||
cls.model.update(progress=info["progress"]).where(
|
cls.model.update(progress=info["progress"]).where(
|
||||||
cls.model.id == id).execute()
|
cls.model.id == id
|
||||||
|
).execute()
|
||||||
return
|
return
|
||||||
|
|
||||||
with DB.lock("update_progress", -1):
|
with DB.lock("update_progress", -1):
|
||||||
if info["progress_msg"]:
|
if info["progress_msg"]:
|
||||||
cls.model.update(progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]).where(
|
cls.model.update(
|
||||||
cls.model.id == id).execute()
|
progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]
|
||||||
|
).where(cls.model.id == id).execute()
|
||||||
if "progress" in info:
|
if "progress" in info:
|
||||||
cls.model.update(progress=info["progress"]).where(
|
cls.model.update(progress=info["progress"]).where(
|
||||||
cls.model.id == id).execute()
|
cls.model.id == id
|
||||||
|
).execute()
|
||||||
|
|
||||||
|
|
||||||
def queue_tasks(doc: dict, bucket: str, name: str):
|
def queue_tasks(doc: dict, bucket: str, name: str):
|
||||||
def new_task():
|
def new_task():
|
||||||
return {
|
return {"id": get_uuid(), "doc_id": doc["id"]}
|
||||||
"id": get_uuid(),
|
|
||||||
"doc_id": doc["id"]
|
|
||||||
}
|
|
||||||
tsks = []
|
tsks = []
|
||||||
|
|
||||||
if doc["type"] == FileType.PDF.value:
|
if doc["type"] == FileType.PDF.value:
|
||||||
@ -150,8 +180,8 @@ def queue_tasks(doc: dict, bucket: str, name: str):
|
|||||||
if doc["parser_id"] == "paper":
|
if doc["parser_id"] == "paper":
|
||||||
page_size = doc["parser_config"].get("task_page_size", 22)
|
page_size = doc["parser_config"].get("task_page_size", 22)
|
||||||
if doc["parser_id"] in ["one", "knowledge_graph"] or not do_layout:
|
if doc["parser_id"] in ["one", "knowledge_graph"] or not do_layout:
|
||||||
page_size = 10 ** 9
|
page_size = 10**9
|
||||||
page_ranges = doc["parser_config"].get("pages") or [(1, 10 ** 5)]
|
page_ranges = doc["parser_config"].get("pages") or [(1, 10**5)]
|
||||||
for s, e in page_ranges:
|
for s, e in page_ranges:
|
||||||
s -= 1
|
s -= 1
|
||||||
s = max(0, s)
|
s = max(0, s)
|
||||||
@ -177,4 +207,6 @@ def queue_tasks(doc: dict, bucket: str, name: str):
|
|||||||
DocumentService.begin2parse(doc["id"])
|
DocumentService.begin2parse(doc["id"])
|
||||||
|
|
||||||
for t in tsks:
|
for t in tsks:
|
||||||
assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=t), "Can't access Redis. Please check the Redis' status."
|
assert REDIS_CONN.queue_product(
|
||||||
|
SVR_QUEUE_NAME, message=t
|
||||||
|
), "Can't access Redis. Please check the Redis' status."
|
||||||
|
@ -22,7 +22,7 @@ from api.db import UserTenantRole
|
|||||||
from api.db.db_models import DB, UserTenant
|
from api.db.db_models import DB, UserTenant
|
||||||
from api.db.db_models import User, Tenant
|
from api.db.db_models import User, Tenant
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
from api.utils import get_uuid, get_format_time, current_timestamp, datetime_format
|
from api.utils import get_uuid, current_timestamp, datetime_format
|
||||||
from api.db import StatusEnum
|
from api.db import StatusEnum
|
||||||
|
|
||||||
|
|
||||||
|
@ -21,10 +21,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from api.utils.log_utils import initRootLogger
|
from api.utils.log_utils import initRootLogger
|
||||||
LOG_LEVELS = os.environ.get("LOG_LEVELS", "")
|
|
||||||
initRootLogger("ragflow_server", LOG_LEVELS)
|
|
||||||
|
|
||||||
import os
|
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
@ -44,6 +41,9 @@ from api.versions import get_ragflow_version
|
|||||||
from api.utils import show_configs
|
from api.utils import show_configs
|
||||||
from rag.settings import print_rag_settings
|
from rag.settings import print_rag_settings
|
||||||
|
|
||||||
|
LOG_LEVELS = os.environ.get("LOG_LEVELS", "")
|
||||||
|
initRootLogger("ragflow_server", LOG_LEVELS)
|
||||||
|
|
||||||
|
|
||||||
def update_progress():
|
def update_progress():
|
||||||
while True:
|
while True:
|
||||||
|
@ -36,7 +36,6 @@ from werkzeug.http import HTTP_STATUS_CODES
|
|||||||
from api.db.db_models import APIToken
|
from api.db.db_models import APIToken
|
||||||
from api import settings
|
from api import settings
|
||||||
|
|
||||||
from api import settings
|
|
||||||
from api.utils import CustomJSONEncoder, get_uuid
|
from api.utils import CustomJSONEncoder, get_uuid
|
||||||
from api.utils import json_dumps
|
from api.utils import json_dumps
|
||||||
from api.constants import REQUEST_WAIT_SEC, REQUEST_MAX_WAIT_SEC
|
from api.constants import REQUEST_WAIT_SEC, REQUEST_MAX_WAIT_SEC
|
||||||
|
@ -45,5 +45,5 @@ try:
|
|||||||
pool = Pool(processes=1)
|
pool = Pool(processes=1)
|
||||||
thread = pool.apply_async(download_nltk_data)
|
thread = pool.apply_async(download_nltk_data)
|
||||||
binary = thread.get(timeout=60)
|
binary = thread.get(timeout=60)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print('\x1b[6;37;41m WARNING \x1b[0m' + "Downloading NLTK data failure.", flush=True)
|
print('\x1b[6;37;41m WARNING \x1b[0m' + "Downloading NLTK data failure.", flush=True)
|
||||||
|
@ -18,4 +18,16 @@ from .ppt_parser import RAGFlowPptParser as PptParser
|
|||||||
from .html_parser import RAGFlowHtmlParser as HtmlParser
|
from .html_parser import RAGFlowHtmlParser as HtmlParser
|
||||||
from .json_parser import RAGFlowJsonParser as JsonParser
|
from .json_parser import RAGFlowJsonParser as JsonParser
|
||||||
from .markdown_parser import RAGFlowMarkdownParser as MarkdownParser
|
from .markdown_parser import RAGFlowMarkdownParser as MarkdownParser
|
||||||
from .txt_parser import RAGFlowTxtParser as TxtParser
|
from .txt_parser import RAGFlowTxtParser as TxtParser
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PdfParser",
|
||||||
|
"PlainParser",
|
||||||
|
"DocxParser",
|
||||||
|
"ExcelParser",
|
||||||
|
"PptParser",
|
||||||
|
"HtmlParser",
|
||||||
|
"JsonParser",
|
||||||
|
"MarkdownParser",
|
||||||
|
"TxtParser",
|
||||||
|
]
|
@ -29,7 +29,8 @@ class RAGFlowExcelParser:
|
|||||||
for sheetname in wb.sheetnames:
|
for sheetname in wb.sheetnames:
|
||||||
ws = wb[sheetname]
|
ws = wb[sheetname]
|
||||||
rows = list(ws.rows)
|
rows = list(ws.rows)
|
||||||
if not rows: continue
|
if not rows:
|
||||||
|
continue
|
||||||
|
|
||||||
tb_rows_0 = "<tr>"
|
tb_rows_0 = "<tr>"
|
||||||
for t in list(rows[0]):
|
for t in list(rows[0]):
|
||||||
@ -40,7 +41,9 @@ class RAGFlowExcelParser:
|
|||||||
tb = ""
|
tb = ""
|
||||||
tb += f"<table><caption>{sheetname}</caption>"
|
tb += f"<table><caption>{sheetname}</caption>"
|
||||||
tb += tb_rows_0
|
tb += tb_rows_0
|
||||||
for r in list(rows[1 + chunk_i * chunk_rows:1 + (chunk_i + 1) * chunk_rows]):
|
for r in list(
|
||||||
|
rows[1 + chunk_i * chunk_rows : 1 + (chunk_i + 1) * chunk_rows]
|
||||||
|
):
|
||||||
tb += "<tr>"
|
tb += "<tr>"
|
||||||
for i, c in enumerate(r):
|
for i, c in enumerate(r):
|
||||||
if c.value is None:
|
if c.value is None:
|
||||||
@ -62,20 +65,21 @@ class RAGFlowExcelParser:
|
|||||||
for sheetname in wb.sheetnames:
|
for sheetname in wb.sheetnames:
|
||||||
ws = wb[sheetname]
|
ws = wb[sheetname]
|
||||||
rows = list(ws.rows)
|
rows = list(ws.rows)
|
||||||
if not rows:continue
|
if not rows:
|
||||||
|
continue
|
||||||
ti = list(rows[0])
|
ti = list(rows[0])
|
||||||
for r in list(rows[1:]):
|
for r in list(rows[1:]):
|
||||||
l = []
|
fields = []
|
||||||
for i, c in enumerate(r):
|
for i, c in enumerate(r):
|
||||||
if not c.value:
|
if not c.value:
|
||||||
continue
|
continue
|
||||||
t = str(ti[i].value) if i < len(ti) else ""
|
t = str(ti[i].value) if i < len(ti) else ""
|
||||||
t += (":" if t else "") + str(c.value)
|
t += (":" if t else "") + str(c.value)
|
||||||
l.append(t)
|
fields.append(t)
|
||||||
l = "; ".join(l)
|
line = "; ".join(fields)
|
||||||
if sheetname.lower().find("sheet") < 0:
|
if sheetname.lower().find("sheet") < 0:
|
||||||
l += " ——" + sheetname
|
line += " ——" + sheetname
|
||||||
res.append(l)
|
res.append(line)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -36,7 +36,7 @@ class RAGFlowHtmlParser:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parser_txt(cls, txt):
|
def parser_txt(cls, txt):
|
||||||
if type(txt) != str:
|
if not isinstance(txt, str):
|
||||||
raise TypeError("txt type should be str!")
|
raise TypeError("txt type should be str!")
|
||||||
html_doc = readability.Document(txt)
|
html_doc = readability.Document(txt)
|
||||||
title = html_doc.title()
|
title = html_doc.title()
|
||||||
|
@ -22,7 +22,7 @@ class RAGFlowJsonParser:
|
|||||||
txt = binary.decode(encoding, errors="ignore")
|
txt = binary.decode(encoding, errors="ignore")
|
||||||
json_data = json.loads(txt)
|
json_data = json.loads(txt)
|
||||||
chunks = self.split_json(json_data, True)
|
chunks = self.split_json(json_data, True)
|
||||||
sections = [json.dumps(l, ensure_ascii=False) for l in chunks if l]
|
sections = [json.dumps(line, ensure_ascii=False) for line in chunks if line]
|
||||||
return sections
|
return sections
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -752,7 +752,7 @@ class RAGFlowPdfParser:
|
|||||||
"x1": np.max([b["x1"] for b in bxs]),
|
"x1": np.max([b["x1"] for b in bxs]),
|
||||||
"bottom": np.max([b["bottom"] for b in bxs]) - ht
|
"bottom": np.max([b["bottom"] for b in bxs]) - ht
|
||||||
}
|
}
|
||||||
louts = [l for l in self.page_layout[pn] if l["type"] == ltype]
|
louts = [layout for layout in self.page_layout[pn] if layout["type"] == ltype]
|
||||||
ii = Recognizer.find_overlapped(b, louts, naive=True)
|
ii = Recognizer.find_overlapped(b, louts, naive=True)
|
||||||
if ii is not None:
|
if ii is not None:
|
||||||
b = louts[ii]
|
b = louts[ii]
|
||||||
@ -763,7 +763,8 @@ class RAGFlowPdfParser:
|
|||||||
"layoutno", "")))
|
"layoutno", "")))
|
||||||
|
|
||||||
left, top, right, bott = b["x0"], b["top"], b["x1"], b["bottom"]
|
left, top, right, bott = b["x0"], b["top"], b["x1"], b["bottom"]
|
||||||
if right < left: right = left + 1
|
if right < left:
|
||||||
|
right = left + 1
|
||||||
poss.append((pn + self.page_from, left, right, top, bott))
|
poss.append((pn + self.page_from, left, right, top, bott))
|
||||||
return self.page_images[pn] \
|
return self.page_images[pn] \
|
||||||
.crop((left * ZM, top * ZM,
|
.crop((left * ZM, top * ZM,
|
||||||
@ -845,7 +846,8 @@ class RAGFlowPdfParser:
|
|||||||
top = bx["top"] - self.page_cum_height[pn[0] - 1]
|
top = bx["top"] - self.page_cum_height[pn[0] - 1]
|
||||||
bott = bx["bottom"] - self.page_cum_height[pn[0] - 1]
|
bott = bx["bottom"] - self.page_cum_height[pn[0] - 1]
|
||||||
page_images_cnt = len(self.page_images)
|
page_images_cnt = len(self.page_images)
|
||||||
if pn[-1] - 1 >= page_images_cnt: return ""
|
if pn[-1] - 1 >= page_images_cnt:
|
||||||
|
return ""
|
||||||
while bott * ZM > self.page_images[pn[-1] - 1].size[1]:
|
while bott * ZM > self.page_images[pn[-1] - 1].size[1]:
|
||||||
bott -= self.page_images[pn[-1] - 1].size[1] / ZM
|
bott -= self.page_images[pn[-1] - 1].size[1] / ZM
|
||||||
pn.append(pn[-1] + 1)
|
pn.append(pn[-1] + 1)
|
||||||
@ -889,7 +891,6 @@ class RAGFlowPdfParser:
|
|||||||
nonlocal mh, pw, lines, widths
|
nonlocal mh, pw, lines, widths
|
||||||
lines.append(line)
|
lines.append(line)
|
||||||
widths.append(width(line))
|
widths.append(width(line))
|
||||||
width_mean = np.mean(widths)
|
|
||||||
mmj = self.proj_match(
|
mmj = self.proj_match(
|
||||||
line["text"]) or line.get(
|
line["text"]) or line.get(
|
||||||
"layout_type",
|
"layout_type",
|
||||||
@ -994,7 +995,7 @@ class RAGFlowPdfParser:
|
|||||||
else:
|
else:
|
||||||
self.is_english = False
|
self.is_english = False
|
||||||
|
|
||||||
st = timer()
|
# st = timer()
|
||||||
for i, img in enumerate(self.page_images_x2):
|
for i, img in enumerate(self.page_images_x2):
|
||||||
chars = self.page_chars[i] if not self.is_english else []
|
chars = self.page_chars[i] if not self.is_english else []
|
||||||
self.mean_height.append(
|
self.mean_height.append(
|
||||||
@ -1028,8 +1029,8 @@ class RAGFlowPdfParser:
|
|||||||
|
|
||||||
self.page_cum_height = np.cumsum(self.page_cum_height)
|
self.page_cum_height = np.cumsum(self.page_cum_height)
|
||||||
assert len(self.page_cum_height) == len(self.page_images) + 1
|
assert len(self.page_cum_height) == len(self.page_images) + 1
|
||||||
if len(self.boxes) == 0 and zoomin < 9: self.__images__(fnm, zoomin * 3, page_from,
|
if len(self.boxes) == 0 and zoomin < 9:
|
||||||
page_to, callback)
|
self.__images__(fnm, zoomin * 3, page_from, page_to, callback)
|
||||||
|
|
||||||
def __call__(self, fnm, need_image=True, zoomin=3, return_html=False):
|
def __call__(self, fnm, need_image=True, zoomin=3, return_html=False):
|
||||||
self.__images__(fnm, zoomin)
|
self.__images__(fnm, zoomin)
|
||||||
@ -1168,7 +1169,7 @@ class PlainParser(object):
|
|||||||
if not self.outlines:
|
if not self.outlines:
|
||||||
logging.warning("Miss outlines")
|
logging.warning("Miss outlines")
|
||||||
|
|
||||||
return [(l, "") for l in lines], []
|
return [(line, "") for line in lines], []
|
||||||
|
|
||||||
def crop(self, ck, need_position):
|
def crop(self, ck, need_position):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -15,21 +15,42 @@ import datetime
|
|||||||
|
|
||||||
|
|
||||||
def refactor(cv):
|
def refactor(cv):
|
||||||
for n in ["raw_txt", "parser_name", "inference", "ori_text", "use_time", "time_stat"]:
|
for n in [
|
||||||
if n in cv and cv[n] is not None: del cv[n]
|
"raw_txt",
|
||||||
|
"parser_name",
|
||||||
|
"inference",
|
||||||
|
"ori_text",
|
||||||
|
"use_time",
|
||||||
|
"time_stat",
|
||||||
|
]:
|
||||||
|
if n in cv and cv[n] is not None:
|
||||||
|
del cv[n]
|
||||||
cv["is_deleted"] = 0
|
cv["is_deleted"] = 0
|
||||||
if "basic" not in cv: cv["basic"] = {}
|
if "basic" not in cv:
|
||||||
if cv["basic"].get("photo2"): del cv["basic"]["photo2"]
|
cv["basic"] = {}
|
||||||
|
if cv["basic"].get("photo2"):
|
||||||
|
del cv["basic"]["photo2"]
|
||||||
|
|
||||||
for n in ["education", "work", "certificate", "project", "language", "skill", "training"]:
|
for n in [
|
||||||
if n not in cv or cv[n] is None: continue
|
"education",
|
||||||
if type(cv[n]) == type({}): cv[n] = [v for _, v in cv[n].items()]
|
"work",
|
||||||
if type(cv[n]) != type([]):
|
"certificate",
|
||||||
|
"project",
|
||||||
|
"language",
|
||||||
|
"skill",
|
||||||
|
"training",
|
||||||
|
]:
|
||||||
|
if n not in cv or cv[n] is None:
|
||||||
|
continue
|
||||||
|
if isinstance(cv[n], dict):
|
||||||
|
cv[n] = [v for _, v in cv[n].items()]
|
||||||
|
if not isinstance(cv[n], list):
|
||||||
del cv[n]
|
del cv[n]
|
||||||
continue
|
continue
|
||||||
vv = []
|
vv = []
|
||||||
for v in cv[n]:
|
for v in cv[n]:
|
||||||
if "external" in v and v["external"] is not None: del v["external"]
|
if "external" in v and v["external"] is not None:
|
||||||
|
del v["external"]
|
||||||
vv.append(v)
|
vv.append(v)
|
||||||
cv[n] = {str(i): vv[i] for i in range(len(vv))}
|
cv[n] = {str(i): vv[i] for i in range(len(vv))}
|
||||||
|
|
||||||
@ -42,24 +63,44 @@ def refactor(cv):
|
|||||||
cv["basic"][t] = cv["basic"][n]
|
cv["basic"][t] = cv["basic"][n]
|
||||||
del cv["basic"][n]
|
del cv["basic"][n]
|
||||||
|
|
||||||
work = sorted([v for _, v in cv.get("work", {}).items()], key=lambda x: x.get("start_time", ""))
|
work = sorted(
|
||||||
edu = sorted([v for _, v in cv.get("education", {}).items()], key=lambda x: x.get("start_time", ""))
|
[v for _, v in cv.get("work", {}).items()],
|
||||||
|
key=lambda x: x.get("start_time", ""),
|
||||||
|
)
|
||||||
|
edu = sorted(
|
||||||
|
[v for _, v in cv.get("education", {}).items()],
|
||||||
|
key=lambda x: x.get("start_time", ""),
|
||||||
|
)
|
||||||
|
|
||||||
if work:
|
if work:
|
||||||
cv["basic"]["work_start_time"] = work[0].get("start_time", "")
|
cv["basic"]["work_start_time"] = work[0].get("start_time", "")
|
||||||
cv["basic"]["management_experience"] = 'Y' if any(
|
cv["basic"]["management_experience"] = (
|
||||||
[w.get("management_experience", '') == 'Y' for w in work]) else 'N'
|
"Y"
|
||||||
|
if any([w.get("management_experience", "") == "Y" for w in work])
|
||||||
|
else "N"
|
||||||
|
)
|
||||||
cv["basic"]["annual_salary"] = work[-1].get("annual_salary_from", "0")
|
cv["basic"]["annual_salary"] = work[-1].get("annual_salary_from", "0")
|
||||||
|
|
||||||
for n in ["annual_salary_from", "annual_salary_to", "industry_name", "position_name", "responsibilities",
|
for n in [
|
||||||
"corporation_type", "scale", "corporation_name"]:
|
"annual_salary_from",
|
||||||
|
"annual_salary_to",
|
||||||
|
"industry_name",
|
||||||
|
"position_name",
|
||||||
|
"responsibilities",
|
||||||
|
"corporation_type",
|
||||||
|
"scale",
|
||||||
|
"corporation_name",
|
||||||
|
]:
|
||||||
cv["basic"][n] = work[-1].get(n, "")
|
cv["basic"][n] = work[-1].get(n, "")
|
||||||
|
|
||||||
if edu:
|
if edu:
|
||||||
for n in ["school_name", "discipline_name"]:
|
for n in ["school_name", "discipline_name"]:
|
||||||
if n in edu[-1]: cv["basic"][n] = edu[-1][n]
|
if n in edu[-1]:
|
||||||
|
cv["basic"][n] = edu[-1][n]
|
||||||
|
|
||||||
cv["basic"]["updated_at"] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
cv["basic"]["updated_at"] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
if "contact" not in cv: cv["contact"] = {}
|
if "contact" not in cv:
|
||||||
if not cv["contact"].get("name"): cv["contact"]["name"] = cv["basic"].get("name", "")
|
cv["contact"] = {}
|
||||||
return cv
|
if not cv["contact"].get("name"):
|
||||||
|
cv["contact"]["name"] = cv["basic"].get("name", "")
|
||||||
|
return cv
|
||||||
|
@ -21,13 +21,18 @@ from . import regions
|
|||||||
|
|
||||||
|
|
||||||
current_file_path = os.path.dirname(os.path.abspath(__file__))
|
current_file_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
GOODS = pd.read_csv(os.path.join(current_file_path, "res/corp_baike_len.csv"), sep="\t", header=0).fillna(0)
|
GOODS = pd.read_csv(
|
||||||
|
os.path.join(current_file_path, "res/corp_baike_len.csv"), sep="\t", header=0
|
||||||
|
).fillna(0)
|
||||||
GOODS["cid"] = GOODS["cid"].astype(str)
|
GOODS["cid"] = GOODS["cid"].astype(str)
|
||||||
GOODS = GOODS.set_index(["cid"])
|
GOODS = GOODS.set_index(["cid"])
|
||||||
CORP_TKS = json.load(open(os.path.join(current_file_path, "res/corp.tks.freq.json"), "r"))
|
CORP_TKS = json.load(
|
||||||
|
open(os.path.join(current_file_path, "res/corp.tks.freq.json"), "r")
|
||||||
|
)
|
||||||
GOOD_CORP = json.load(open(os.path.join(current_file_path, "res/good_corp.json"), "r"))
|
GOOD_CORP = json.load(open(os.path.join(current_file_path, "res/good_corp.json"), "r"))
|
||||||
CORP_TAG = json.load(open(os.path.join(current_file_path, "res/corp_tag.json"), "r"))
|
CORP_TAG = json.load(open(os.path.join(current_file_path, "res/corp_tag.json"), "r"))
|
||||||
|
|
||||||
|
|
||||||
def baike(cid, default_v=0):
|
def baike(cid, default_v=0):
|
||||||
global GOODS
|
global GOODS
|
||||||
try:
|
try:
|
||||||
@ -39,27 +44,41 @@ def baike(cid, default_v=0):
|
|||||||
|
|
||||||
def corpNorm(nm, add_region=True):
|
def corpNorm(nm, add_region=True):
|
||||||
global CORP_TKS
|
global CORP_TKS
|
||||||
if not nm or type(nm)!=type(""):return ""
|
if not nm or isinstance(nm, str):
|
||||||
|
return ""
|
||||||
nm = rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(nm)).lower()
|
nm = rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(nm)).lower()
|
||||||
nm = re.sub(r"&", "&", nm)
|
nm = re.sub(r"&", "&", nm)
|
||||||
nm = re.sub(r"[\(\)()\+'\"\t \*\\【】-]+", " ", nm)
|
nm = re.sub(r"[\(\)()\+'\"\t \*\\【】-]+", " ", nm)
|
||||||
nm = re.sub(r"([—-]+.*| +co\..*|corp\..*| +inc\..*| +ltd.*)", "", nm, 10000, re.IGNORECASE)
|
nm = re.sub(
|
||||||
nm = re.sub(r"(计算机|技术|(技术|科技|网络)*有限公司|公司|有限|研发中心|中国|总部)$", "", nm, 10000, re.IGNORECASE)
|
r"([—-]+.*| +co\..*|corp\..*| +inc\..*| +ltd.*)", "", nm, 10000, re.IGNORECASE
|
||||||
if not nm or (len(nm)<5 and not regions.isName(nm[0:2])):return nm
|
)
|
||||||
|
nm = re.sub(
|
||||||
|
r"(计算机|技术|(技术|科技|网络)*有限公司|公司|有限|研发中心|中国|总部)$",
|
||||||
|
"",
|
||||||
|
nm,
|
||||||
|
10000,
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
if not nm or (len(nm) < 5 and not regions.isName(nm[0:2])):
|
||||||
|
return nm
|
||||||
|
|
||||||
tks = rag_tokenizer.tokenize(nm).split()
|
tks = rag_tokenizer.tokenize(nm).split()
|
||||||
reg = [t for i,t in enumerate(tks) if regions.isName(t) and (t != "中国" or i > 0)]
|
reg = [t for i, t in enumerate(tks) if regions.isName(t) and (t != "中国" or i > 0)]
|
||||||
nm = ""
|
nm = ""
|
||||||
for t in tks:
|
for t in tks:
|
||||||
if regions.isName(t) or t in CORP_TKS:continue
|
if regions.isName(t) or t in CORP_TKS:
|
||||||
if re.match(r"[0-9a-zA-Z\\,.]+", t) and re.match(r".*[0-9a-zA-Z\,.]+$", nm):nm += " "
|
continue
|
||||||
|
if re.match(r"[0-9a-zA-Z\\,.]+", t) and re.match(r".*[0-9a-zA-Z\,.]+$", nm):
|
||||||
|
nm += " "
|
||||||
nm += t
|
nm += t
|
||||||
|
|
||||||
r = re.search(r"^([^a-z0-9 \(\)&]{2,})[a-z ]{4,}$", nm.strip())
|
r = re.search(r"^([^a-z0-9 \(\)&]{2,})[a-z ]{4,}$", nm.strip())
|
||||||
if r:nm = r.group(1)
|
if r:
|
||||||
|
nm = r.group(1)
|
||||||
r = re.search(r"^([a-z ]{3,})[^a-z0-9 \(\)&]{2,}$", nm.strip())
|
r = re.search(r"^([a-z ]{3,})[^a-z0-9 \(\)&]{2,}$", nm.strip())
|
||||||
if r:nm = r.group(1)
|
if r:
|
||||||
return nm.strip() + (("" if not reg else "(%s)"%reg[0]) if add_region else "")
|
nm = r.group(1)
|
||||||
|
return nm.strip() + (("" if not reg else "(%s)" % reg[0]) if add_region else "")
|
||||||
|
|
||||||
|
|
||||||
def rmNoise(n):
|
def rmNoise(n):
|
||||||
@ -67,33 +86,40 @@ def rmNoise(n):
|
|||||||
n = re.sub(r"[,. &()()]+", "", n)
|
n = re.sub(r"[,. &()()]+", "", n)
|
||||||
return n
|
return n
|
||||||
|
|
||||||
|
|
||||||
GOOD_CORP = set([corpNorm(rmNoise(c), False) for c in GOOD_CORP])
|
GOOD_CORP = set([corpNorm(rmNoise(c), False) for c in GOOD_CORP])
|
||||||
for c,v in CORP_TAG.items():
|
for c, v in CORP_TAG.items():
|
||||||
cc = corpNorm(rmNoise(c), False)
|
cc = corpNorm(rmNoise(c), False)
|
||||||
if not cc:
|
if not cc:
|
||||||
logging.debug(c)
|
logging.debug(c)
|
||||||
CORP_TAG = {corpNorm(rmNoise(c), False):v for c,v in CORP_TAG.items()}
|
CORP_TAG = {corpNorm(rmNoise(c), False): v for c, v in CORP_TAG.items()}
|
||||||
|
|
||||||
|
|
||||||
def is_good(nm):
|
def is_good(nm):
|
||||||
global GOOD_CORP
|
global GOOD_CORP
|
||||||
if nm.find("外派")>=0:return False
|
if nm.find("外派") >= 0:
|
||||||
|
return False
|
||||||
nm = rmNoise(nm)
|
nm = rmNoise(nm)
|
||||||
nm = corpNorm(nm, False)
|
nm = corpNorm(nm, False)
|
||||||
for n in GOOD_CORP:
|
for n in GOOD_CORP:
|
||||||
if re.match(r"[0-9a-zA-Z]+$", n):
|
if re.match(r"[0-9a-zA-Z]+$", n):
|
||||||
if n == nm: return True
|
if n == nm:
|
||||||
elif nm.find(n)>=0:return True
|
return True
|
||||||
|
elif nm.find(n) >= 0:
|
||||||
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def corp_tag(nm):
|
def corp_tag(nm):
|
||||||
global CORP_TAG
|
global CORP_TAG
|
||||||
nm = rmNoise(nm)
|
nm = rmNoise(nm)
|
||||||
nm = corpNorm(nm, False)
|
nm = corpNorm(nm, False)
|
||||||
for n in CORP_TAG.keys():
|
for n in CORP_TAG.keys():
|
||||||
if re.match(r"[0-9a-zA-Z., ]+$", n):
|
if re.match(r"[0-9a-zA-Z., ]+$", n):
|
||||||
if n == nm: return CORP_TAG[n]
|
if n == nm:
|
||||||
elif nm.find(n)>=0:
|
return CORP_TAG[n]
|
||||||
if len(n)<3 and len(nm)/len(n)>=2:continue
|
elif nm.find(n) >= 0:
|
||||||
|
if len(n) < 3 and len(nm) / len(n) >= 2:
|
||||||
|
continue
|
||||||
return CORP_TAG[n]
|
return CORP_TAG[n]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
@ -11,27 +11,31 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
TBL = {"94":"EMBA",
|
TBL = {
|
||||||
"6":"MBA",
|
"94": "EMBA",
|
||||||
"95":"MPA",
|
"6": "MBA",
|
||||||
"92":"专升本",
|
"95": "MPA",
|
||||||
"4":"专科",
|
"92": "专升本",
|
||||||
"90":"中专",
|
"4": "专科",
|
||||||
"91":"中技",
|
"90": "中专",
|
||||||
"86":"初中",
|
"91": "中技",
|
||||||
"3":"博士",
|
"86": "初中",
|
||||||
"10":"博士后",
|
"3": "博士",
|
||||||
"1":"本科",
|
"10": "博士后",
|
||||||
"2":"硕士",
|
"1": "本科",
|
||||||
"87":"职高",
|
"2": "硕士",
|
||||||
"89":"高中"
|
"87": "职高",
|
||||||
|
"89": "高中",
|
||||||
}
|
}
|
||||||
|
|
||||||
TBL_ = {v:k for k,v in TBL.items()}
|
TBL_ = {v: k for k, v in TBL.items()}
|
||||||
|
|
||||||
|
|
||||||
def get_name(id):
|
def get_name(id):
|
||||||
return TBL.get(str(id), "")
|
return TBL.get(str(id), "")
|
||||||
|
|
||||||
|
|
||||||
def get_id(nm):
|
def get_id(nm):
|
||||||
if not nm:return ""
|
if not nm:
|
||||||
|
return ""
|
||||||
return TBL_.get(nm.upper().strip(), "")
|
return TBL_.get(nm.upper().strip(), "")
|
||||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -16,8 +16,11 @@ import json
|
|||||||
import re
|
import re
|
||||||
import copy
|
import copy
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
current_file_path = os.path.dirname(os.path.abspath(__file__))
|
current_file_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
TBL = pd.read_csv(os.path.join(current_file_path, "res/schools.csv"), sep="\t", header=0).fillna("")
|
TBL = pd.read_csv(
|
||||||
|
os.path.join(current_file_path, "res/schools.csv"), sep="\t", header=0
|
||||||
|
).fillna("")
|
||||||
TBL["name_en"] = TBL["name_en"].map(lambda x: x.lower().strip())
|
TBL["name_en"] = TBL["name_en"].map(lambda x: x.lower().strip())
|
||||||
GOOD_SCH = json.load(open(os.path.join(current_file_path, "res/good_sch.json"), "r"))
|
GOOD_SCH = json.load(open(os.path.join(current_file_path, "res/good_sch.json"), "r"))
|
||||||
GOOD_SCH = set([re.sub(r"[,. &()()]+", "", c) for c in GOOD_SCH])
|
GOOD_SCH = set([re.sub(r"[,. &()()]+", "", c) for c in GOOD_SCH])
|
||||||
@ -26,14 +29,15 @@ GOOD_SCH = set([re.sub(r"[,. &()()]+", "", c) for c in GOOD_SCH])
|
|||||||
def loadRank(fnm):
|
def loadRank(fnm):
|
||||||
global TBL
|
global TBL
|
||||||
TBL["rank"] = 1000000
|
TBL["rank"] = 1000000
|
||||||
with open(fnm, "r", encoding='utf-8') as f:
|
with open(fnm, "r", encoding="utf-8") as f:
|
||||||
while True:
|
while True:
|
||||||
l = f.readline()
|
line = f.readline()
|
||||||
if not l:break
|
if not line:
|
||||||
l = l.strip("\n").split(",")
|
break
|
||||||
|
line = line.strip("\n").split(",")
|
||||||
try:
|
try:
|
||||||
nm,rk = l[0].strip(),int(l[1])
|
nm, rk = line[0].strip(), int(line[1])
|
||||||
#assert len(TBL[((TBL.name_cn == nm) | (TBL.name_en == nm))]),f"<{nm}>"
|
# assert len(TBL[((TBL.name_cn == nm) | (TBL.name_en == nm))]),f"<{nm}>"
|
||||||
TBL.loc[((TBL.name_cn == nm) | (TBL.name_en == nm)), "rank"] = rk
|
TBL.loc[((TBL.name_cn == nm) | (TBL.name_en == nm)), "rank"] = rk
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
@ -44,27 +48,35 @@ loadRank(os.path.join(current_file_path, "res/school.rank.csv"))
|
|||||||
|
|
||||||
def split(txt):
|
def split(txt):
|
||||||
tks = []
|
tks = []
|
||||||
for t in re.sub(r"[ \t]+", " ",txt).split():
|
for t in re.sub(r"[ \t]+", " ", txt).split():
|
||||||
if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and \
|
if (
|
||||||
re.match(r"[a-zA-Z]", t) and tks:
|
tks
|
||||||
|
and re.match(r".*[a-zA-Z]$", tks[-1])
|
||||||
|
and re.match(r"[a-zA-Z]", t)
|
||||||
|
and tks
|
||||||
|
):
|
||||||
tks[-1] = tks[-1] + " " + t
|
tks[-1] = tks[-1] + " " + t
|
||||||
else:tks.append(t)
|
else:
|
||||||
|
tks.append(t)
|
||||||
return tks
|
return tks
|
||||||
|
|
||||||
|
|
||||||
def select(nm):
|
def select(nm):
|
||||||
global TBL
|
global TBL
|
||||||
if not nm:return
|
if not nm:
|
||||||
if isinstance(nm, list):nm = str(nm[0])
|
return
|
||||||
|
if isinstance(nm, list):
|
||||||
|
nm = str(nm[0])
|
||||||
nm = split(nm)[0]
|
nm = split(nm)[0]
|
||||||
nm = str(nm).lower().strip()
|
nm = str(nm).lower().strip()
|
||||||
nm = re.sub(r"[((][^()()]+[))]", "", nm.lower())
|
nm = re.sub(r"[((][^()()]+[))]", "", nm.lower())
|
||||||
nm = re.sub(r"(^the |[,.&()();;·]+|^(英国|美国|瑞士))", "", nm)
|
nm = re.sub(r"(^the |[,.&()();;·]+|^(英国|美国|瑞士))", "", nm)
|
||||||
nm = re.sub(r"大学.*学院", "大学", nm)
|
nm = re.sub(r"大学.*学院", "大学", nm)
|
||||||
tbl = copy.deepcopy(TBL)
|
tbl = copy.deepcopy(TBL)
|
||||||
tbl["hit_alias"] = tbl["alias"].map(lambda x:nm in set(x.split("+")))
|
tbl["hit_alias"] = tbl["alias"].map(lambda x: nm in set(x.split("+")))
|
||||||
res = tbl[((tbl.name_cn == nm) | (tbl.name_en == nm) | (tbl.hit_alias == True))]
|
res = tbl[((tbl.name_cn == nm) | (tbl.name_en == nm) | tbl.hit_alias)]
|
||||||
if res.empty:return
|
if res.empty:
|
||||||
|
return
|
||||||
|
|
||||||
return json.loads(res.to_json(orient="records"))[0]
|
return json.loads(res.to_json(orient="records"))[0]
|
||||||
|
|
||||||
@ -74,4 +86,3 @@ def is_good(nm):
|
|||||||
nm = re.sub(r"[((][^()()]+[))]", "", nm.lower())
|
nm = re.sub(r"[((][^()()]+[))]", "", nm.lower())
|
||||||
nm = re.sub(r"[''`‘’“”,. &()();;]+", "", nm)
|
nm = re.sub(r"[''`‘’“”,. &()();;]+", "", nm)
|
||||||
return nm in GOOD_SCH
|
return nm in GOOD_SCH
|
||||||
|
|
||||||
|
@ -25,7 +25,8 @@ from xpinyin import Pinyin
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
|
||||||
class TimeoutException(Exception): pass
|
class TimeoutException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@ -50,8 +51,10 @@ def rmHtmlTag(line):
|
|||||||
|
|
||||||
|
|
||||||
def highest_degree(dg):
|
def highest_degree(dg):
|
||||||
if not dg: return ""
|
if not dg:
|
||||||
if type(dg) == type(""): dg = [dg]
|
return ""
|
||||||
|
if isinstance(dg, str):
|
||||||
|
dg = [dg]
|
||||||
m = {"初中": 0, "高中": 1, "中专": 2, "大专": 3, "专升本": 4, "本科": 5, "硕士": 6, "博士": 7, "博士后": 8}
|
m = {"初中": 0, "高中": 1, "中专": 2, "大专": 3, "专升本": 4, "本科": 5, "硕士": 6, "博士": 7, "博士后": 8}
|
||||||
return sorted([(d, m.get(d, -1)) for d in dg], key=lambda x: x[1] * -1)[0][0]
|
return sorted([(d, m.get(d, -1)) for d in dg], key=lambda x: x[1] * -1)[0][0]
|
||||||
|
|
||||||
@ -68,10 +71,12 @@ def forEdu(cv):
|
|||||||
for ii, n in enumerate(sorted(cv["education_obj"], key=lambda x: x.get("start_time", "3"))):
|
for ii, n in enumerate(sorted(cv["education_obj"], key=lambda x: x.get("start_time", "3"))):
|
||||||
e = {}
|
e = {}
|
||||||
if n.get("end_time"):
|
if n.get("end_time"):
|
||||||
if n["end_time"] > edu_end_dt: edu_end_dt = n["end_time"]
|
if n["end_time"] > edu_end_dt:
|
||||||
|
edu_end_dt = n["end_time"]
|
||||||
try:
|
try:
|
||||||
dt = n["end_time"]
|
dt = n["end_time"]
|
||||||
if re.match(r"[0-9]{9,}", dt): dt = turnTm2Dt(dt)
|
if re.match(r"[0-9]{9,}", dt):
|
||||||
|
dt = turnTm2Dt(dt)
|
||||||
y, m, d = getYMD(dt)
|
y, m, d = getYMD(dt)
|
||||||
ed_dt.append(str(y))
|
ed_dt.append(str(y))
|
||||||
e["end_dt_kwd"] = str(y)
|
e["end_dt_kwd"] = str(y)
|
||||||
@ -80,7 +85,8 @@ def forEdu(cv):
|
|||||||
if n.get("start_time"):
|
if n.get("start_time"):
|
||||||
try:
|
try:
|
||||||
dt = n["start_time"]
|
dt = n["start_time"]
|
||||||
if re.match(r"[0-9]{9,}", dt): dt = turnTm2Dt(dt)
|
if re.match(r"[0-9]{9,}", dt):
|
||||||
|
dt = turnTm2Dt(dt)
|
||||||
y, m, d = getYMD(dt)
|
y, m, d = getYMD(dt)
|
||||||
st_dt.append(str(y))
|
st_dt.append(str(y))
|
||||||
e["start_dt_kwd"] = str(y)
|
e["start_dt_kwd"] = str(y)
|
||||||
@ -89,13 +95,20 @@ def forEdu(cv):
|
|||||||
|
|
||||||
r = schools.select(n.get("school_name", ""))
|
r = schools.select(n.get("school_name", ""))
|
||||||
if r:
|
if r:
|
||||||
if str(r.get("type", "")) == "1": fea.append("211")
|
if str(r.get("type", "")) == "1":
|
||||||
if str(r.get("type", "")) == "2": fea.append("211")
|
fea.append("211")
|
||||||
if str(r.get("is_abroad", "")) == "1": fea.append("留学")
|
if str(r.get("type", "")) == "2":
|
||||||
if str(r.get("is_double_first", "")) == "1": fea.append("双一流")
|
fea.append("211")
|
||||||
if str(r.get("is_985", "")) == "1": fea.append("985")
|
if str(r.get("is_abroad", "")) == "1":
|
||||||
if str(r.get("is_world_known", "")) == "1": fea.append("海外知名")
|
fea.append("留学")
|
||||||
if r.get("rank") and cv["school_rank_int"] > r["rank"]: cv["school_rank_int"] = r["rank"]
|
if str(r.get("is_double_first", "")) == "1":
|
||||||
|
fea.append("双一流")
|
||||||
|
if str(r.get("is_985", "")) == "1":
|
||||||
|
fea.append("985")
|
||||||
|
if str(r.get("is_world_known", "")) == "1":
|
||||||
|
fea.append("海外知名")
|
||||||
|
if r.get("rank") and cv["school_rank_int"] > r["rank"]:
|
||||||
|
cv["school_rank_int"] = r["rank"]
|
||||||
|
|
||||||
if n.get("school_name") and isinstance(n["school_name"], str):
|
if n.get("school_name") and isinstance(n["school_name"], str):
|
||||||
sch.append(re.sub(r"(211|985|重点大学|[,&;;-])", "", n["school_name"]))
|
sch.append(re.sub(r"(211|985|重点大学|[,&;;-])", "", n["school_name"]))
|
||||||
@ -106,22 +119,25 @@ def forEdu(cv):
|
|||||||
maj.append(n["discipline_name"])
|
maj.append(n["discipline_name"])
|
||||||
e["major_kwd"] = n["discipline_name"]
|
e["major_kwd"] = n["discipline_name"]
|
||||||
|
|
||||||
if not n.get("degree") and "985" in fea and not first_fea: n["degree"] = "1"
|
if not n.get("degree") and "985" in fea and not first_fea:
|
||||||
|
n["degree"] = "1"
|
||||||
|
|
||||||
if n.get("degree"):
|
if n.get("degree"):
|
||||||
d = degrees.get_name(n["degree"])
|
d = degrees.get_name(n["degree"])
|
||||||
if d: e["degree_kwd"] = d
|
if d:
|
||||||
if d == "本科" and ("专科" in deg or "专升本" in deg or "中专" in deg or "大专" in deg or re.search(r"(成人|自考|自学考试)",
|
e["degree_kwd"] = d
|
||||||
n.get(
|
if d == "本科" and ("专科" in deg or "专升本" in deg or "中专" in deg or "大专" in deg or re.search(r"(成人|自考|自学考试)", n.get("school_name",""))):
|
||||||
"school_name",
|
d = "专升本"
|
||||||
""))): d = "专升本"
|
if d:
|
||||||
if d: deg.append(d)
|
deg.append(d)
|
||||||
|
|
||||||
# for first degree
|
# for first degree
|
||||||
if not fdeg and d in ["中专", "专升本", "专科", "本科", "大专"]:
|
if not fdeg and d in ["中专", "专升本", "专科", "本科", "大专"]:
|
||||||
fdeg = [d]
|
fdeg = [d]
|
||||||
if n.get("school_name"): fsch = [n["school_name"]]
|
if n.get("school_name"):
|
||||||
if n.get("discipline_name"): fmaj = [n["discipline_name"]]
|
fsch = [n["school_name"]]
|
||||||
|
if n.get("discipline_name"):
|
||||||
|
fmaj = [n["discipline_name"]]
|
||||||
first_fea = copy.deepcopy(fea)
|
first_fea = copy.deepcopy(fea)
|
||||||
|
|
||||||
edu_nst.append(e)
|
edu_nst.append(e)
|
||||||
@ -140,16 +156,26 @@ def forEdu(cv):
|
|||||||
else:
|
else:
|
||||||
cv["sch_rank_kwd"].append("一般学校")
|
cv["sch_rank_kwd"].append("一般学校")
|
||||||
|
|
||||||
if edu_nst: cv["edu_nst"] = edu_nst
|
if edu_nst:
|
||||||
if fea: cv["edu_fea_kwd"] = list(set(fea))
|
cv["edu_nst"] = edu_nst
|
||||||
if first_fea: cv["edu_first_fea_kwd"] = list(set(first_fea))
|
if fea:
|
||||||
if maj: cv["major_kwd"] = maj
|
cv["edu_fea_kwd"] = list(set(fea))
|
||||||
if fsch: cv["first_school_name_kwd"] = fsch
|
if first_fea:
|
||||||
if fdeg: cv["first_degree_kwd"] = fdeg
|
cv["edu_first_fea_kwd"] = list(set(first_fea))
|
||||||
if fmaj: cv["first_major_kwd"] = fmaj
|
if maj:
|
||||||
if st_dt: cv["edu_start_kwd"] = st_dt
|
cv["major_kwd"] = maj
|
||||||
if ed_dt: cv["edu_end_kwd"] = ed_dt
|
if fsch:
|
||||||
if ed_dt: cv["edu_end_int"] = max([int(t) for t in ed_dt])
|
cv["first_school_name_kwd"] = fsch
|
||||||
|
if fdeg:
|
||||||
|
cv["first_degree_kwd"] = fdeg
|
||||||
|
if fmaj:
|
||||||
|
cv["first_major_kwd"] = fmaj
|
||||||
|
if st_dt:
|
||||||
|
cv["edu_start_kwd"] = st_dt
|
||||||
|
if ed_dt:
|
||||||
|
cv["edu_end_kwd"] = ed_dt
|
||||||
|
if ed_dt:
|
||||||
|
cv["edu_end_int"] = max([int(t) for t in ed_dt])
|
||||||
if deg:
|
if deg:
|
||||||
if "本科" in deg and "专科" in deg:
|
if "本科" in deg and "专科" in deg:
|
||||||
deg.append("专升本")
|
deg.append("专升本")
|
||||||
@ -158,8 +184,10 @@ def forEdu(cv):
|
|||||||
cv["highest_degree_kwd"] = highest_degree(deg)
|
cv["highest_degree_kwd"] = highest_degree(deg)
|
||||||
if edu_end_dt:
|
if edu_end_dt:
|
||||||
try:
|
try:
|
||||||
if re.match(r"[0-9]{9,}", edu_end_dt): edu_end_dt = turnTm2Dt(edu_end_dt)
|
if re.match(r"[0-9]{9,}", edu_end_dt):
|
||||||
if edu_end_dt.strip("\n") == "至今": edu_end_dt = cv.get("updated_at_dt", str(datetime.date.today()))
|
edu_end_dt = turnTm2Dt(edu_end_dt)
|
||||||
|
if edu_end_dt.strip("\n") == "至今":
|
||||||
|
edu_end_dt = cv.get("updated_at_dt", str(datetime.date.today()))
|
||||||
y, m, d = getYMD(edu_end_dt)
|
y, m, d = getYMD(edu_end_dt)
|
||||||
cv["work_exp_flt"] = min(int(str(datetime.date.today())[0:4]) - int(y), cv.get("work_exp_flt", 1000))
|
cv["work_exp_flt"] = min(int(str(datetime.date.today())[0:4]) - int(y), cv.get("work_exp_flt", 1000))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -171,7 +199,8 @@ def forEdu(cv):
|
|||||||
or not cv.get("degree_kwd"):
|
or not cv.get("degree_kwd"):
|
||||||
for c in sch:
|
for c in sch:
|
||||||
if schools.is_good(c):
|
if schools.is_good(c):
|
||||||
if "tag_kwd" not in cv: cv["tag_kwd"] = []
|
if "tag_kwd" not in cv:
|
||||||
|
cv["tag_kwd"] = []
|
||||||
cv["tag_kwd"].append("好学校")
|
cv["tag_kwd"].append("好学校")
|
||||||
cv["tag_kwd"].append("好学历")
|
cv["tag_kwd"].append("好学历")
|
||||||
break
|
break
|
||||||
@ -180,28 +209,39 @@ def forEdu(cv):
|
|||||||
any([d.lower() in ["硕士", "博士", "mba", "博士"] for d in cv.get("degree_kwd", [])])) \
|
any([d.lower() in ["硕士", "博士", "mba", "博士"] for d in cv.get("degree_kwd", [])])) \
|
||||||
or all([d.lower() in ["硕士", "博士", "mba", "博士后"] for d in cv.get("degree_kwd", [])]) \
|
or all([d.lower() in ["硕士", "博士", "mba", "博士后"] for d in cv.get("degree_kwd", [])]) \
|
||||||
or any([d in ["mba", "emba", "博士后"] for d in cv.get("degree_kwd", [])]):
|
or any([d in ["mba", "emba", "博士后"] for d in cv.get("degree_kwd", [])]):
|
||||||
if "tag_kwd" not in cv: cv["tag_kwd"] = []
|
if "tag_kwd" not in cv:
|
||||||
if "好学历" not in cv["tag_kwd"]: cv["tag_kwd"].append("好学历")
|
cv["tag_kwd"] = []
|
||||||
|
if "好学历" not in cv["tag_kwd"]:
|
||||||
|
cv["tag_kwd"].append("好学历")
|
||||||
|
|
||||||
if cv.get("major_kwd"): cv["major_tks"] = rag_tokenizer.tokenize(" ".join(maj))
|
if cv.get("major_kwd"):
|
||||||
if cv.get("school_name_kwd"): cv["school_name_tks"] = rag_tokenizer.tokenize(" ".join(sch))
|
cv["major_tks"] = rag_tokenizer.tokenize(" ".join(maj))
|
||||||
if cv.get("first_school_name_kwd"): cv["first_school_name_tks"] = rag_tokenizer.tokenize(" ".join(fsch))
|
if cv.get("school_name_kwd"):
|
||||||
if cv.get("first_major_kwd"): cv["first_major_tks"] = rag_tokenizer.tokenize(" ".join(fmaj))
|
cv["school_name_tks"] = rag_tokenizer.tokenize(" ".join(sch))
|
||||||
|
if cv.get("first_school_name_kwd"):
|
||||||
|
cv["first_school_name_tks"] = rag_tokenizer.tokenize(" ".join(fsch))
|
||||||
|
if cv.get("first_major_kwd"):
|
||||||
|
cv["first_major_tks"] = rag_tokenizer.tokenize(" ".join(fmaj))
|
||||||
|
|
||||||
return cv
|
return cv
|
||||||
|
|
||||||
|
|
||||||
def forProj(cv):
|
def forProj(cv):
|
||||||
if not cv.get("project_obj"): return cv
|
if not cv.get("project_obj"):
|
||||||
|
return cv
|
||||||
|
|
||||||
pro_nms, desc = [], []
|
pro_nms, desc = [], []
|
||||||
for i, n in enumerate(
|
for i, n in enumerate(
|
||||||
sorted(cv.get("project_obj", []), key=lambda x: str(x.get("updated_at", "")) if type(x) == type({}) else "",
|
sorted(cv.get("project_obj", []), key=lambda x: str(x.get("updated_at", "")) if isinstance(x, dict) else "",
|
||||||
reverse=True)):
|
reverse=True)):
|
||||||
if n.get("name"): pro_nms.append(n["name"])
|
if n.get("name"):
|
||||||
if n.get("describe"): desc.append(str(n["describe"]))
|
pro_nms.append(n["name"])
|
||||||
if n.get("responsibilities"): desc.append(str(n["responsibilities"]))
|
if n.get("describe"):
|
||||||
if n.get("achivement"): desc.append(str(n["achivement"]))
|
desc.append(str(n["describe"]))
|
||||||
|
if n.get("responsibilities"):
|
||||||
|
desc.append(str(n["responsibilities"]))
|
||||||
|
if n.get("achivement"):
|
||||||
|
desc.append(str(n["achivement"]))
|
||||||
|
|
||||||
if pro_nms:
|
if pro_nms:
|
||||||
# cv["pro_nms_tks"] = rag_tokenizer.tokenize(" ".join(pro_nms))
|
# cv["pro_nms_tks"] = rag_tokenizer.tokenize(" ".join(pro_nms))
|
||||||
@ -233,15 +273,16 @@ def forWork(cv):
|
|||||||
work_st_tm = ""
|
work_st_tm = ""
|
||||||
corp_tags = []
|
corp_tags = []
|
||||||
for i, n in enumerate(
|
for i, n in enumerate(
|
||||||
sorted(cv.get("work_obj", []), key=lambda x: str(x.get("start_time", "")) if type(x) == type({}) else "",
|
sorted(cv.get("work_obj", []), key=lambda x: str(x.get("start_time", "")) if isinstance(x, dict) else "",
|
||||||
reverse=True)):
|
reverse=True)):
|
||||||
if type(n) == type(""):
|
if isinstance(n, str):
|
||||||
try:
|
try:
|
||||||
n = json_loads(n)
|
n = json_loads(n)
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if n.get("start_time") and (not work_st_tm or n["start_time"] < work_st_tm): work_st_tm = n["start_time"]
|
if n.get("start_time") and (not work_st_tm or n["start_time"] < work_st_tm):
|
||||||
|
work_st_tm = n["start_time"]
|
||||||
for c in flds:
|
for c in flds:
|
||||||
if not n.get(c) or str(n[c]) == '0':
|
if not n.get(c) or str(n[c]) == '0':
|
||||||
fea[c].append("")
|
fea[c].append("")
|
||||||
@ -262,14 +303,18 @@ def forWork(cv):
|
|||||||
fea[c].append(rmHtmlTag(str(n[c]).lower()))
|
fea[c].append(rmHtmlTag(str(n[c]).lower()))
|
||||||
|
|
||||||
y, m, d = getYMD(n.get("start_time"))
|
y, m, d = getYMD(n.get("start_time"))
|
||||||
if not y or not m: continue
|
if not y or not m:
|
||||||
|
continue
|
||||||
st = "%s-%02d-%02d" % (y, int(m), int(d))
|
st = "%s-%02d-%02d" % (y, int(m), int(d))
|
||||||
latest_job_tm = st
|
latest_job_tm = st
|
||||||
|
|
||||||
y, m, d = getYMD(n.get("end_time"))
|
y, m, d = getYMD(n.get("end_time"))
|
||||||
if (not y or not m) and i > 0: continue
|
if (not y or not m) and i > 0:
|
||||||
if not y or not m or int(y) > 2022: y, m, d = getYMD(str(n.get("updated_at", "")))
|
continue
|
||||||
if not y or not m: continue
|
if not y or not m or int(y) > 2022:
|
||||||
|
y, m, d = getYMD(str(n.get("updated_at", "")))
|
||||||
|
if not y or not m:
|
||||||
|
continue
|
||||||
ed = "%s-%02d-%02d" % (y, int(m), int(d))
|
ed = "%s-%02d-%02d" % (y, int(m), int(d))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -279,22 +324,28 @@ def forWork(cv):
|
|||||||
|
|
||||||
if n.get("scale"):
|
if n.get("scale"):
|
||||||
r = re.search(r"^([0-9]+)", str(n["scale"]))
|
r = re.search(r"^([0-9]+)", str(n["scale"]))
|
||||||
if r: scales.append(int(r.group(1)))
|
if r:
|
||||||
|
scales.append(int(r.group(1)))
|
||||||
|
|
||||||
if goodcorp:
|
if goodcorp:
|
||||||
if "tag_kwd" not in cv: cv["tag_kwd"] = []
|
if "tag_kwd" not in cv:
|
||||||
|
cv["tag_kwd"] = []
|
||||||
cv["tag_kwd"].append("好公司")
|
cv["tag_kwd"].append("好公司")
|
||||||
if goodcorp_:
|
if goodcorp_:
|
||||||
if "tag_kwd" not in cv: cv["tag_kwd"] = []
|
if "tag_kwd" not in cv:
|
||||||
|
cv["tag_kwd"] = []
|
||||||
cv["tag_kwd"].append("好公司(曾)")
|
cv["tag_kwd"].append("好公司(曾)")
|
||||||
|
|
||||||
if corp_tags:
|
if corp_tags:
|
||||||
if "tag_kwd" not in cv: cv["tag_kwd"] = []
|
if "tag_kwd" not in cv:
|
||||||
|
cv["tag_kwd"] = []
|
||||||
cv["tag_kwd"].extend(corp_tags)
|
cv["tag_kwd"].extend(corp_tags)
|
||||||
cv["corp_tag_kwd"] = [c for c in corp_tags if re.match(r"(综合|行业)", c)]
|
cv["corp_tag_kwd"] = [c for c in corp_tags if re.match(r"(综合|行业)", c)]
|
||||||
|
|
||||||
if latest_job_tm: cv["latest_job_dt"] = latest_job_tm
|
if latest_job_tm:
|
||||||
if fea["corporation_id"]: cv["corporation_id"] = fea["corporation_id"]
|
cv["latest_job_dt"] = latest_job_tm
|
||||||
|
if fea["corporation_id"]:
|
||||||
|
cv["corporation_id"] = fea["corporation_id"]
|
||||||
|
|
||||||
if fea["position_name"]:
|
if fea["position_name"]:
|
||||||
cv["position_name_tks"] = rag_tokenizer.tokenize(fea["position_name"][0])
|
cv["position_name_tks"] = rag_tokenizer.tokenize(fea["position_name"][0])
|
||||||
@ -317,18 +368,23 @@ def forWork(cv):
|
|||||||
cv["responsibilities_ltks"] = rag_tokenizer.tokenize(fea["responsibilities"][0])
|
cv["responsibilities_ltks"] = rag_tokenizer.tokenize(fea["responsibilities"][0])
|
||||||
cv["resp_ltks"] = rag_tokenizer.tokenize(" ".join(fea["responsibilities"][1:]))
|
cv["resp_ltks"] = rag_tokenizer.tokenize(" ".join(fea["responsibilities"][1:]))
|
||||||
|
|
||||||
if fea["subordinates_count"]: fea["subordinates_count"] = [int(i) for i in fea["subordinates_count"] if
|
if fea["subordinates_count"]:
|
||||||
|
fea["subordinates_count"] = [int(i) for i in fea["subordinates_count"] if
|
||||||
re.match(r"[^0-9]+$", str(i))]
|
re.match(r"[^0-9]+$", str(i))]
|
||||||
if fea["subordinates_count"]: cv["max_sub_cnt_int"] = np.max(fea["subordinates_count"])
|
if fea["subordinates_count"]:
|
||||||
|
cv["max_sub_cnt_int"] = np.max(fea["subordinates_count"])
|
||||||
|
|
||||||
if type(cv.get("corporation_id")) == type(1): cv["corporation_id"] = [str(cv["corporation_id"])]
|
if isinstance(cv.get("corporation_id"), int):
|
||||||
if not cv.get("corporation_id"): cv["corporation_id"] = []
|
cv["corporation_id"] = [str(cv["corporation_id"])]
|
||||||
|
if not cv.get("corporation_id"):
|
||||||
|
cv["corporation_id"] = []
|
||||||
for i in cv.get("corporation_id", []):
|
for i in cv.get("corporation_id", []):
|
||||||
cv["baike_flt"] = max(corporations.baike(i), cv["baike_flt"] if "baike_flt" in cv else 0)
|
cv["baike_flt"] = max(corporations.baike(i), cv["baike_flt"] if "baike_flt" in cv else 0)
|
||||||
|
|
||||||
if work_st_tm:
|
if work_st_tm:
|
||||||
try:
|
try:
|
||||||
if re.match(r"[0-9]{9,}", work_st_tm): work_st_tm = turnTm2Dt(work_st_tm)
|
if re.match(r"[0-9]{9,}", work_st_tm):
|
||||||
|
work_st_tm = turnTm2Dt(work_st_tm)
|
||||||
y, m, d = getYMD(work_st_tm)
|
y, m, d = getYMD(work_st_tm)
|
||||||
cv["work_exp_flt"] = min(int(str(datetime.date.today())[0:4]) - int(y), cv.get("work_exp_flt", 1000))
|
cv["work_exp_flt"] = min(int(str(datetime.date.today())[0:4]) - int(y), cv.get("work_exp_flt", 1000))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -339,28 +395,37 @@ def forWork(cv):
|
|||||||
cv["dua_flt"] = np.mean(duas)
|
cv["dua_flt"] = np.mean(duas)
|
||||||
cv["cur_dua_int"] = duas[0]
|
cv["cur_dua_int"] = duas[0]
|
||||||
cv["job_num_int"] = len(duas)
|
cv["job_num_int"] = len(duas)
|
||||||
if scales: cv["scale_flt"] = np.max(scales)
|
if scales:
|
||||||
|
cv["scale_flt"] = np.max(scales)
|
||||||
return cv
|
return cv
|
||||||
|
|
||||||
|
|
||||||
def turnTm2Dt(b):
|
def turnTm2Dt(b):
|
||||||
if not b: return
|
if not b:
|
||||||
|
return
|
||||||
b = str(b).strip()
|
b = str(b).strip()
|
||||||
if re.match(r"[0-9]{10,}", b): b = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(b[:10])))
|
if re.match(r"[0-9]{10,}", b):
|
||||||
|
b = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(b[:10])))
|
||||||
return b
|
return b
|
||||||
|
|
||||||
|
|
||||||
def getYMD(b):
|
def getYMD(b):
|
||||||
y, m, d = "", "", "01"
|
y, m, d = "", "", "01"
|
||||||
if not b: return (y, m, d)
|
if not b:
|
||||||
|
return (y, m, d)
|
||||||
b = turnTm2Dt(b)
|
b = turnTm2Dt(b)
|
||||||
if re.match(r"[0-9]{4}", b): y = int(b[:4])
|
if re.match(r"[0-9]{4}", b):
|
||||||
|
y = int(b[:4])
|
||||||
r = re.search(r"[0-9]{4}.?([0-9]{1,2})", b)
|
r = re.search(r"[0-9]{4}.?([0-9]{1,2})", b)
|
||||||
if r: m = r.group(1)
|
if r:
|
||||||
|
m = r.group(1)
|
||||||
r = re.search(r"[0-9]{4}.?[0-9]{,2}.?([0-9]{1,2})", b)
|
r = re.search(r"[0-9]{4}.?[0-9]{,2}.?([0-9]{1,2})", b)
|
||||||
if r: d = r.group(1)
|
if r:
|
||||||
if not d or int(d) == 0 or int(d) > 31: d = "1"
|
d = r.group(1)
|
||||||
if not m or int(m) > 12 or int(m) < 1: m = "1"
|
if not d or int(d) == 0 or int(d) > 31:
|
||||||
|
d = "1"
|
||||||
|
if not m or int(m) > 12 or int(m) < 1:
|
||||||
|
m = "1"
|
||||||
return (y, m, d)
|
return (y, m, d)
|
||||||
|
|
||||||
|
|
||||||
@ -369,7 +434,8 @@ def birth(cv):
|
|||||||
cv["integerity_flt"] *= 0.9
|
cv["integerity_flt"] *= 0.9
|
||||||
return cv
|
return cv
|
||||||
y, m, d = getYMD(cv["birth"])
|
y, m, d = getYMD(cv["birth"])
|
||||||
if not m or not y: return cv
|
if not m or not y:
|
||||||
|
return cv
|
||||||
b = "%s-%02d-%02d" % (y, int(m), int(d))
|
b = "%s-%02d-%02d" % (y, int(m), int(d))
|
||||||
cv["birth_dt"] = b
|
cv["birth_dt"] = b
|
||||||
cv["birthday_kwd"] = "%02d%02d" % (int(m), int(d))
|
cv["birthday_kwd"] = "%02d%02d" % (int(m), int(d))
|
||||||
@ -380,7 +446,8 @@ def birth(cv):
|
|||||||
|
|
||||||
def parse(cv):
|
def parse(cv):
|
||||||
for k in cv.keys():
|
for k in cv.keys():
|
||||||
if cv[k] == '\\N': cv[k] = ''
|
if cv[k] == '\\N':
|
||||||
|
cv[k] = ''
|
||||||
# cv = cv.asDict()
|
# cv = cv.asDict()
|
||||||
tks_fld = ["address", "corporation_name", "discipline_name", "email", "expect_city_names",
|
tks_fld = ["address", "corporation_name", "discipline_name", "email", "expect_city_names",
|
||||||
"expect_industry_name", "expect_position_name", "industry_name", "industry_names", "name",
|
"expect_industry_name", "expect_position_name", "industry_name", "industry_names", "name",
|
||||||
@ -402,9 +469,12 @@ def parse(cv):
|
|||||||
|
|
||||||
rmkeys = []
|
rmkeys = []
|
||||||
for k in cv.keys():
|
for k in cv.keys():
|
||||||
if cv[k] is None: rmkeys.append(k)
|
if cv[k] is None:
|
||||||
if (type(cv[k]) == type([]) or type(cv[k]) == type("")) and len(cv[k]) == 0: rmkeys.append(k)
|
rmkeys.append(k)
|
||||||
for k in rmkeys: del cv[k]
|
if (isinstance(cv[k], list) or isinstance(cv[k], str)) and len(cv[k]) == 0:
|
||||||
|
rmkeys.append(k)
|
||||||
|
for k in rmkeys:
|
||||||
|
del cv[k]
|
||||||
|
|
||||||
integerity = 0.
|
integerity = 0.
|
||||||
flds_num = 0.
|
flds_num = 0.
|
||||||
@ -414,7 +484,8 @@ def parse(cv):
|
|||||||
flds_num += len(flds)
|
flds_num += len(flds)
|
||||||
for f in flds:
|
for f in flds:
|
||||||
v = str(cv.get(f, ""))
|
v = str(cv.get(f, ""))
|
||||||
if len(v) > 0 and v != '0' and v != '[]': integerity += 1
|
if len(v) > 0 and v != '0' and v != '[]':
|
||||||
|
integerity += 1
|
||||||
|
|
||||||
hasValues(tks_fld)
|
hasValues(tks_fld)
|
||||||
hasValues(small_tks_fld)
|
hasValues(small_tks_fld)
|
||||||
@ -433,7 +504,8 @@ def parse(cv):
|
|||||||
(r"[ ()\(\)人/·0-9-]+", ""),
|
(r"[ ()\(\)人/·0-9-]+", ""),
|
||||||
(r".*(元|规模|于|=|北京|上海|至今|中国|工资|州|shanghai|强|餐饮|融资|职).*", "")]:
|
(r".*(元|规模|于|=|北京|上海|至今|中国|工资|州|shanghai|强|餐饮|融资|职).*", "")]:
|
||||||
cv["corporation_type"] = re.sub(p, r, cv["corporation_type"], 1000, re.IGNORECASE)
|
cv["corporation_type"] = re.sub(p, r, cv["corporation_type"], 1000, re.IGNORECASE)
|
||||||
if len(cv["corporation_type"]) < 2: del cv["corporation_type"]
|
if len(cv["corporation_type"]) < 2:
|
||||||
|
del cv["corporation_type"]
|
||||||
|
|
||||||
if cv.get("political_status"):
|
if cv.get("political_status"):
|
||||||
for p, r in [
|
for p, r in [
|
||||||
@ -441,9 +513,11 @@ def parse(cv):
|
|||||||
(r".*(无党派|公民).*", "群众"),
|
(r".*(无党派|公民).*", "群众"),
|
||||||
(r".*团员.*", "团员")]:
|
(r".*团员.*", "团员")]:
|
||||||
cv["political_status"] = re.sub(p, r, cv["political_status"])
|
cv["political_status"] = re.sub(p, r, cv["political_status"])
|
||||||
if not re.search(r"[党团群]", cv["political_status"]): del cv["political_status"]
|
if not re.search(r"[党团群]", cv["political_status"]):
|
||||||
|
del cv["political_status"]
|
||||||
|
|
||||||
if cv.get("phone"): cv["phone"] = re.sub(r"^0*86([0-9]{11})", r"\1", re.sub(r"[^0-9]+", "", cv["phone"]))
|
if cv.get("phone"):
|
||||||
|
cv["phone"] = re.sub(r"^0*86([0-9]{11})", r"\1", re.sub(r"[^0-9]+", "", cv["phone"]))
|
||||||
|
|
||||||
keys = list(cv.keys())
|
keys = list(cv.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
@ -454,9 +528,11 @@ def parse(cv):
|
|||||||
cv[k] = [a for _, a in cv[k].items()]
|
cv[k] = [a for _, a in cv[k].items()]
|
||||||
nms = []
|
nms = []
|
||||||
for n in cv[k]:
|
for n in cv[k]:
|
||||||
if type(n) != type({}) or "name" not in n or not n.get("name"): continue
|
if not isinstance(n, dict) or "name" not in n or not n.get("name"):
|
||||||
|
continue
|
||||||
n["name"] = re.sub(r"((442)|\t )", "", n["name"]).strip().lower()
|
n["name"] = re.sub(r"((442)|\t )", "", n["name"]).strip().lower()
|
||||||
if not n["name"]: continue
|
if not n["name"]:
|
||||||
|
continue
|
||||||
nms.append(n["name"])
|
nms.append(n["name"])
|
||||||
if nms:
|
if nms:
|
||||||
t = k[:-4]
|
t = k[:-4]
|
||||||
@ -469,15 +545,18 @@ def parse(cv):
|
|||||||
# tokenize fields
|
# tokenize fields
|
||||||
if k in tks_fld:
|
if k in tks_fld:
|
||||||
cv[f"{k}_tks"] = rag_tokenizer.tokenize(cv[k])
|
cv[f"{k}_tks"] = rag_tokenizer.tokenize(cv[k])
|
||||||
if k in small_tks_fld: cv[f"{k}_sm_tks"] = rag_tokenizer.tokenize(cv[f"{k}_tks"])
|
if k in small_tks_fld:
|
||||||
|
cv[f"{k}_sm_tks"] = rag_tokenizer.tokenize(cv[f"{k}_tks"])
|
||||||
|
|
||||||
# keyword fields
|
# keyword fields
|
||||||
if k in kwd_fld: cv[f"{k}_kwd"] = [n.lower()
|
if k in kwd_fld:
|
||||||
|
cv[f"{k}_kwd"] = [n.lower()
|
||||||
for n in re.split(r"[\t,,;;. ]",
|
for n in re.split(r"[\t,,;;. ]",
|
||||||
re.sub(r"([^a-zA-Z])[ ]+([^a-zA-Z ])", r"\1,\2", cv[k])
|
re.sub(r"([^a-zA-Z])[ ]+([^a-zA-Z ])", r"\1,\2", cv[k])
|
||||||
) if n]
|
) if n]
|
||||||
|
|
||||||
if k in num_fld and cv.get(k): cv[f"{k}_int"] = cv[k]
|
if k in num_fld and cv.get(k):
|
||||||
|
cv[f"{k}_int"] = cv[k]
|
||||||
|
|
||||||
cv["email_kwd"] = cv.get("email_tks", "").replace(" ", "")
|
cv["email_kwd"] = cv.get("email_tks", "").replace(" ", "")
|
||||||
# for name field
|
# for name field
|
||||||
@ -501,10 +580,12 @@ def parse(cv):
|
|||||||
cv["name_py_pref0_tks"] = ""
|
cv["name_py_pref0_tks"] = ""
|
||||||
cv["name_py_pref_tks"] = ""
|
cv["name_py_pref_tks"] = ""
|
||||||
for py in PY.get_pinyins(nm[:20], ''):
|
for py in PY.get_pinyins(nm[:20], ''):
|
||||||
for i in range(2, len(py) + 1): cv["name_py_pref_tks"] += " " + py[:i]
|
for i in range(2, len(py) + 1):
|
||||||
|
cv["name_py_pref_tks"] += " " + py[:i]
|
||||||
for py in PY.get_pinyins(nm[:20], ' '):
|
for py in PY.get_pinyins(nm[:20], ' '):
|
||||||
py = py.split()
|
py = py.split()
|
||||||
for i in range(1, len(py) + 1): cv["name_py_pref0_tks"] += " " + "".join(py[:i])
|
for i in range(1, len(py) + 1):
|
||||||
|
cv["name_py_pref0_tks"] += " " + "".join(py[:i])
|
||||||
|
|
||||||
cv["name_kwd"] = name
|
cv["name_kwd"] = name
|
||||||
cv["name_pinyin_kwd"] = PY.get_pinyins(nm[:20], ' ')[:3]
|
cv["name_pinyin_kwd"] = PY.get_pinyins(nm[:20], ' ')[:3]
|
||||||
@ -526,22 +607,30 @@ def parse(cv):
|
|||||||
cv["updated_at_dt"] = cv["updated_at"].strftime('%Y-%m-%d %H:%M:%S')
|
cv["updated_at_dt"] = cv["updated_at"].strftime('%Y-%m-%d %H:%M:%S')
|
||||||
else:
|
else:
|
||||||
y, m, d = getYMD(str(cv.get("updated_at", "")))
|
y, m, d = getYMD(str(cv.get("updated_at", "")))
|
||||||
if not y: y = "2012"
|
if not y:
|
||||||
if not m: m = "01"
|
y = "2012"
|
||||||
if not d: d = "01"
|
if not m:
|
||||||
|
m = "01"
|
||||||
|
if not d:
|
||||||
|
d = "01"
|
||||||
cv["updated_at_dt"] = "%s-%02d-%02d 00:00:00" % (y, int(m), int(d))
|
cv["updated_at_dt"] = "%s-%02d-%02d 00:00:00" % (y, int(m), int(d))
|
||||||
# long text tokenize
|
# long text tokenize
|
||||||
|
|
||||||
if cv.get("responsibilities"): cv["responsibilities_ltks"] = rag_tokenizer.tokenize(rmHtmlTag(cv["responsibilities"]))
|
if cv.get("responsibilities"):
|
||||||
|
cv["responsibilities_ltks"] = rag_tokenizer.tokenize(rmHtmlTag(cv["responsibilities"]))
|
||||||
|
|
||||||
# for yes or no field
|
# for yes or no field
|
||||||
fea = []
|
fea = []
|
||||||
for f, y, n in is_fld:
|
for f, y, n in is_fld:
|
||||||
if f not in cv: continue
|
if f not in cv:
|
||||||
if cv[f] == '是': fea.append(y)
|
continue
|
||||||
if cv[f] == '否': fea.append(n)
|
if cv[f] == '是':
|
||||||
|
fea.append(y)
|
||||||
|
if cv[f] == '否':
|
||||||
|
fea.append(n)
|
||||||
|
|
||||||
if fea: cv["tag_kwd"] = fea
|
if fea:
|
||||||
|
cv["tag_kwd"] = fea
|
||||||
|
|
||||||
cv = forEdu(cv)
|
cv = forEdu(cv)
|
||||||
cv = forProj(cv)
|
cv = forProj(cv)
|
||||||
@ -550,9 +639,11 @@ def parse(cv):
|
|||||||
|
|
||||||
cv["corp_proj_sch_deg_kwd"] = [c for c in cv.get("corp_tag_kwd", [])]
|
cv["corp_proj_sch_deg_kwd"] = [c for c in cv.get("corp_tag_kwd", [])]
|
||||||
for i in range(len(cv["corp_proj_sch_deg_kwd"])):
|
for i in range(len(cv["corp_proj_sch_deg_kwd"])):
|
||||||
for j in cv.get("sch_rank_kwd", []): cv["corp_proj_sch_deg_kwd"][i] += "+" + j
|
for j in cv.get("sch_rank_kwd", []):
|
||||||
|
cv["corp_proj_sch_deg_kwd"][i] += "+" + j
|
||||||
for i in range(len(cv["corp_proj_sch_deg_kwd"])):
|
for i in range(len(cv["corp_proj_sch_deg_kwd"])):
|
||||||
if cv.get("highest_degree_kwd"): cv["corp_proj_sch_deg_kwd"][i] += "+" + cv["highest_degree_kwd"]
|
if cv.get("highest_degree_kwd"):
|
||||||
|
cv["corp_proj_sch_deg_kwd"][i] += "+" + cv["highest_degree_kwd"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not cv.get("work_exp_flt") and cv.get("work_start_time"):
|
if not cv.get("work_exp_flt") and cv.get("work_start_time"):
|
||||||
@ -565,17 +656,21 @@ def parse(cv):
|
|||||||
cv["work_exp_flt"] = int(str(datetime.date.today())[0:4]) - int(y)
|
cv["work_exp_flt"] = int(str(datetime.date.today())[0:4]) - int(y)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception("parse {} ==> {}".format(e, cv.get("work_start_time")))
|
logging.exception("parse {} ==> {}".format(e, cv.get("work_start_time")))
|
||||||
if "work_exp_flt" not in cv and cv.get("work_experience", 0): cv["work_exp_flt"] = int(cv["work_experience"]) / 12.
|
if "work_exp_flt" not in cv and cv.get("work_experience", 0):
|
||||||
|
cv["work_exp_flt"] = int(cv["work_experience"]) / 12.
|
||||||
|
|
||||||
keys = list(cv.keys())
|
keys = list(cv.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
if not re.search(r"_(fea|tks|nst|dt|int|flt|ltks|kwd|id)$", k): del cv[k]
|
if not re.search(r"_(fea|tks|nst|dt|int|flt|ltks|kwd|id)$", k):
|
||||||
|
del cv[k]
|
||||||
for k in cv.keys():
|
for k in cv.keys():
|
||||||
if not re.search("_(kwd|id)$", k) or type(cv[k]) != type([]): continue
|
if not re.search("_(kwd|id)$", k) or not isinstance(cv[k], list):
|
||||||
|
continue
|
||||||
cv[k] = list(set([re.sub("(市)$", "", str(n)) for n in cv[k] if n not in ['中国', '0']]))
|
cv[k] = list(set([re.sub("(市)$", "", str(n)) for n in cv[k] if n not in ['中国', '0']]))
|
||||||
keys = [k for k in cv.keys() if re.search(r"_feas*$", k)]
|
keys = [k for k in cv.keys() if re.search(r"_feas*$", k)]
|
||||||
for k in keys:
|
for k in keys:
|
||||||
if cv[k] <= 0: del cv[k]
|
if cv[k] <= 0:
|
||||||
|
del cv[k]
|
||||||
|
|
||||||
cv["tob_resume_id"] = str(cv["tob_resume_id"])
|
cv["tob_resume_id"] = str(cv["tob_resume_id"])
|
||||||
cv["id"] = cv["tob_resume_id"]
|
cv["id"] = cv["tob_resume_id"]
|
||||||
@ -592,5 +687,6 @@ def dealWithInt64(d):
|
|||||||
if isinstance(d, list):
|
if isinstance(d, list):
|
||||||
d = [dealWithInt64(t) for t in d]
|
d = [dealWithInt64(t) for t in d]
|
||||||
|
|
||||||
if isinstance(d, np.integer): d = int(d)
|
if isinstance(d, np.integer):
|
||||||
|
d = int(d)
|
||||||
return d
|
return d
|
||||||
|
@ -51,6 +51,7 @@ class RAGFlowTxtParser:
|
|||||||
dels = [d for d in dels if d]
|
dels = [d for d in dels if d]
|
||||||
dels = "|".join(dels)
|
dels = "|".join(dels)
|
||||||
secs = re.split(r"(%s)" % dels, txt)
|
secs = re.split(r"(%s)" % dels, txt)
|
||||||
for sec in secs: add_chunk(sec)
|
for sec in secs:
|
||||||
|
add_chunk(sec)
|
||||||
|
|
||||||
return [[c, ""] for c in cks]
|
return [[c, ""] for c in cks]
|
||||||
|
@ -18,7 +18,6 @@ from .recognizer import Recognizer
|
|||||||
from .layout_recognizer import LayoutRecognizer
|
from .layout_recognizer import LayoutRecognizer
|
||||||
from .table_structure_recognizer import TableStructureRecognizer
|
from .table_structure_recognizer import TableStructureRecognizer
|
||||||
|
|
||||||
|
|
||||||
def init_in_out(args):
|
def init_in_out(args):
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import os
|
import os
|
||||||
@ -47,7 +46,7 @@ def init_in_out(args):
|
|||||||
try:
|
try:
|
||||||
images.append(Image.open(fnm))
|
images.append(Image.open(fnm))
|
||||||
outputs.append(os.path.split(fnm)[-1])
|
outputs.append(os.path.split(fnm)[-1])
|
||||||
except Exception as e:
|
except Exception:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
if os.path.isdir(args.inputs):
|
if os.path.isdir(args.inputs):
|
||||||
@ -56,6 +55,16 @@ def init_in_out(args):
|
|||||||
else:
|
else:
|
||||||
images_and_outputs(args.inputs)
|
images_and_outputs(args.inputs)
|
||||||
|
|
||||||
for i in range(len(outputs)): outputs[i] = os.path.join(args.output_dir, outputs[i])
|
for i in range(len(outputs)):
|
||||||
|
outputs[i] = os.path.join(args.output_dir, outputs[i])
|
||||||
|
|
||||||
return images, outputs
|
return images, outputs
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"OCR",
|
||||||
|
"Recognizer",
|
||||||
|
"LayoutRecognizer",
|
||||||
|
"TableStructureRecognizer",
|
||||||
|
"init_in_out",
|
||||||
|
]
|
||||||
|
@ -42,7 +42,7 @@ class LayoutRecognizer(Recognizer):
|
|||||||
get_project_base_directory(),
|
get_project_base_directory(),
|
||||||
"rag/res/deepdoc")
|
"rag/res/deepdoc")
|
||||||
super().__init__(self.labels, domain, model_dir)
|
super().__init__(self.labels, domain, model_dir)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
|
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
|
||||||
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
||||||
local_dir_use_symlinks=False)
|
local_dir_use_symlinks=False)
|
||||||
@ -77,7 +77,7 @@ class LayoutRecognizer(Recognizer):
|
|||||||
"page_number": pn,
|
"page_number": pn,
|
||||||
} for b in lts if float(b["score"]) >= 0.8 or b["type"] not in self.garbage_layouts]
|
} for b in lts if float(b["score"]) >= 0.8 or b["type"] not in self.garbage_layouts]
|
||||||
lts = self.sort_Y_firstly(lts, np.mean(
|
lts = self.sort_Y_firstly(lts, np.mean(
|
||||||
[l["bottom"] - l["top"] for l in lts]) / 2)
|
[lt["bottom"] - lt["top"] for lt in lts]) / 2)
|
||||||
lts = self.layouts_cleanup(bxs, lts)
|
lts = self.layouts_cleanup(bxs, lts)
|
||||||
page_layout.append(lts)
|
page_layout.append(lts)
|
||||||
|
|
||||||
|
@ -19,7 +19,9 @@ from huggingface_hub import snapshot_download
|
|||||||
|
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
from .operators import *
|
from .operators import *
|
||||||
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import cv2
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
|
|
||||||
from .postprocess import build_post_process
|
from .postprocess import build_post_process
|
||||||
@ -484,7 +486,7 @@ class OCR(object):
|
|||||||
"rag/res/deepdoc")
|
"rag/res/deepdoc")
|
||||||
self.text_detector = TextDetector(model_dir)
|
self.text_detector = TextDetector(model_dir)
|
||||||
self.text_recognizer = TextRecognizer(model_dir)
|
self.text_recognizer = TextRecognizer(model_dir)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
|
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
|
||||||
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
||||||
local_dir_use_symlinks=False)
|
local_dir_use_symlinks=False)
|
||||||
|
@ -232,7 +232,7 @@ class LinearResize(object):
|
|||||||
"""
|
"""
|
||||||
assert len(self.target_size) == 2
|
assert len(self.target_size) == 2
|
||||||
assert self.target_size[0] > 0 and self.target_size[1] > 0
|
assert self.target_size[0] > 0 and self.target_size[1] > 0
|
||||||
im_channel = im.shape[2]
|
_im_channel = im.shape[2]
|
||||||
im_scale_y, im_scale_x = self.generate_scale(im)
|
im_scale_y, im_scale_x = self.generate_scale(im)
|
||||||
im = cv2.resize(
|
im = cv2.resize(
|
||||||
im,
|
im,
|
||||||
@ -255,7 +255,7 @@ class LinearResize(object):
|
|||||||
im_scale_y: the resize ratio of Y
|
im_scale_y: the resize ratio of Y
|
||||||
"""
|
"""
|
||||||
origin_shape = im.shape[:2]
|
origin_shape = im.shape[:2]
|
||||||
im_c = im.shape[2]
|
_im_c = im.shape[2]
|
||||||
if self.keep_ratio:
|
if self.keep_ratio:
|
||||||
im_size_min = np.min(origin_shape)
|
im_size_min = np.min(origin_shape)
|
||||||
im_size_max = np.max(origin_shape)
|
im_size_max = np.max(origin_shape)
|
||||||
@ -581,7 +581,7 @@ class SRResize(object):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
images_HR = data["image_hr"]
|
images_HR = data["image_hr"]
|
||||||
label_strs = data["label"]
|
_label_strs = data["label"]
|
||||||
transform = ResizeNormalize((imgW, imgH))
|
transform = ResizeNormalize((imgW, imgH))
|
||||||
images_HR = transform(images_HR)
|
images_HR = transform(images_HR)
|
||||||
data["img_hr"] = images_HR
|
data["img_hr"] = images_HR
|
||||||
|
@ -121,7 +121,7 @@ class DBPostProcess(object):
|
|||||||
outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
|
outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
|
||||||
cv2.CHAIN_APPROX_SIMPLE)
|
cv2.CHAIN_APPROX_SIMPLE)
|
||||||
if len(outs) == 3:
|
if len(outs) == 3:
|
||||||
img, contours, _ = outs[0], outs[1], outs[2]
|
_img, contours, _ = outs[0], outs[1], outs[2]
|
||||||
elif len(outs) == 2:
|
elif len(outs) == 2:
|
||||||
contours, _ = outs[0], outs[1]
|
contours, _ = outs[0], outs[1]
|
||||||
|
|
||||||
|
@ -13,15 +13,18 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
|
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
from .operators import *
|
from .operators import *
|
||||||
|
|
||||||
|
|
||||||
class Recognizer(object):
|
class Recognizer(object):
|
||||||
def __init__(self, label_list, task_name, model_dir=None):
|
def __init__(self, label_list, task_name, model_dir=None):
|
||||||
"""
|
"""
|
||||||
@ -277,7 +280,8 @@ class Recognizer(object):
|
|||||||
return
|
return
|
||||||
min_dis, min_i = 1000000, None
|
min_dis, min_i = 1000000, None
|
||||||
for i,b in enumerate(boxes):
|
for i,b in enumerate(boxes):
|
||||||
if box.get("layoutno", "0") != b.get("layoutno", "0"): continue
|
if box.get("layoutno", "0") != b.get("layoutno", "0"):
|
||||||
|
continue
|
||||||
dis = min(abs(box["x0"] - b["x0"]), abs(box["x1"] - b["x1"]), abs(box["x0"]+box["x1"] - b["x1"] - b["x0"])/2)
|
dis = min(abs(box["x0"] - b["x0"]), abs(box["x1"] - b["x1"]), abs(box["x0"]+box["x1"] - b["x1"] - b["x0"])/2)
|
||||||
if dis < min_dis:
|
if dis < min_dis:
|
||||||
min_i = i
|
min_i = i
|
||||||
@ -402,7 +406,8 @@ class Recognizer(object):
|
|||||||
scores = np.max(boxes[:, 4:], axis=1)
|
scores = np.max(boxes[:, 4:], axis=1)
|
||||||
boxes = boxes[scores > thr, :]
|
boxes = boxes[scores > thr, :]
|
||||||
scores = scores[scores > thr]
|
scores = scores[scores > thr]
|
||||||
if len(boxes) == 0: return []
|
if len(boxes) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
# Get the class with the highest confidence
|
# Get the class with the highest confidence
|
||||||
class_ids = np.argmax(boxes[:, 4:], axis=1)
|
class_ids = np.argmax(boxes[:, 4:], axis=1)
|
||||||
@ -432,7 +437,8 @@ class Recognizer(object):
|
|||||||
for i in range(len(image_list)):
|
for i in range(len(image_list)):
|
||||||
if not isinstance(image_list[i], np.ndarray):
|
if not isinstance(image_list[i], np.ndarray):
|
||||||
imgs.append(np.array(image_list[i]))
|
imgs.append(np.array(image_list[i]))
|
||||||
else: imgs.append(image_list[i])
|
else:
|
||||||
|
imgs.append(image_list[i])
|
||||||
|
|
||||||
batch_loop_cnt = math.ceil(float(len(imgs)) / batch_size)
|
batch_loop_cnt = math.ceil(float(len(imgs)) / batch_size)
|
||||||
for i in range(batch_loop_cnt):
|
for i in range(batch_loop_cnt):
|
||||||
|
@ -88,7 +88,8 @@ class CommunityReportsExtractor:
|
|||||||
("findings", list),
|
("findings", list),
|
||||||
("rating", float),
|
("rating", float),
|
||||||
("rating_explanation", str),
|
("rating_explanation", str),
|
||||||
]): continue
|
]):
|
||||||
|
continue
|
||||||
response["weight"] = weight
|
response["weight"] = weight
|
||||||
response["entities"] = ents
|
response["entities"] = ents
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -100,7 +101,8 @@ class CommunityReportsExtractor:
|
|||||||
res_str.append(self._get_text_output(response))
|
res_str.append(self._get_text_output(response))
|
||||||
res_dict.append(response)
|
res_dict.append(response)
|
||||||
over += 1
|
over += 1
|
||||||
if callback: callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}")
|
if callback:
|
||||||
|
callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}")
|
||||||
|
|
||||||
return CommunityReportsResult(
|
return CommunityReportsResult(
|
||||||
structured_output=res_dict,
|
structured_output=res_dict,
|
||||||
|
@ -8,6 +8,7 @@ Reference:
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
|
from dataclasses import dataclass
|
||||||
from graphrag.leiden import stable_largest_connected_component
|
from graphrag.leiden import stable_largest_connected_component
|
||||||
|
|
||||||
|
|
||||||
|
@ -129,9 +129,11 @@ class GraphExtractor:
|
|||||||
source_doc_map[doc_index] = text
|
source_doc_map[doc_index] = text
|
||||||
all_records[doc_index] = result
|
all_records[doc_index] = result
|
||||||
total_token_count += token_count
|
total_token_count += token_count
|
||||||
if callback: callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}")
|
if callback:
|
||||||
|
callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if callback: callback(msg="Knowledge graph extraction error:{}".format(str(e)))
|
if callback:
|
||||||
|
callback(msg="Knowledge graph extraction error:{}".format(str(e)))
|
||||||
logging.exception("error extracting graph")
|
logging.exception("error extracting graph")
|
||||||
self._on_error(
|
self._on_error(
|
||||||
e,
|
e,
|
||||||
@ -164,7 +166,8 @@ class GraphExtractor:
|
|||||||
text = perform_variable_replacements(self._extraction_prompt, variables=variables)
|
text = perform_variable_replacements(self._extraction_prompt, variables=variables)
|
||||||
gen_conf = {"temperature": 0.3}
|
gen_conf = {"temperature": 0.3}
|
||||||
response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
|
response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
|
||||||
if response.find("**ERROR**") >= 0: raise Exception(response)
|
if response.find("**ERROR**") >= 0:
|
||||||
|
raise Exception(response)
|
||||||
token_count = num_tokens_from_string(text + response)
|
token_count = num_tokens_from_string(text + response)
|
||||||
|
|
||||||
results = response or ""
|
results = response or ""
|
||||||
@ -175,7 +178,8 @@ class GraphExtractor:
|
|||||||
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
|
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
|
||||||
history.append({"role": "user", "content": text})
|
history.append({"role": "user", "content": text})
|
||||||
response = self._llm.chat("", history, gen_conf)
|
response = self._llm.chat("", history, gen_conf)
|
||||||
if response.find("**ERROR**") >=0: raise Exception(response)
|
if response.find("**ERROR**") >=0:
|
||||||
|
raise Exception(response)
|
||||||
results += response or ""
|
results += response or ""
|
||||||
|
|
||||||
# if this is the final glean, don't bother updating the continuation flag
|
# if this is the final glean, don't bother updating the continuation flag
|
||||||
|
@ -134,7 +134,8 @@ def build_knowledge_graph_chunks(tenant_id: str, chunks: list[str], callback, en
|
|||||||
callback(0.75, "Extracting mind graph.")
|
callback(0.75, "Extracting mind graph.")
|
||||||
mindmap = MindMapExtractor(llm_bdl)
|
mindmap = MindMapExtractor(llm_bdl)
|
||||||
mg = mindmap(_chunks).output
|
mg = mindmap(_chunks).output
|
||||||
if not len(mg.keys()): return chunks
|
if not len(mg.keys()):
|
||||||
|
return chunks
|
||||||
|
|
||||||
logging.debug(json.dumps(mg, ensure_ascii=False, indent=2))
|
logging.debug(json.dumps(mg, ensure_ascii=False, indent=2))
|
||||||
chunks.append(
|
chunks.append(
|
||||||
|
@ -78,7 +78,8 @@ def _compute_leiden_communities(
|
|||||||
) -> dict[int, dict[str, int]]:
|
) -> dict[int, dict[str, int]]:
|
||||||
"""Return Leiden root communities."""
|
"""Return Leiden root communities."""
|
||||||
results: dict[int, dict[str, int]] = {}
|
results: dict[int, dict[str, int]] = {}
|
||||||
if is_empty(graph): return results
|
if is_empty(graph):
|
||||||
|
return results
|
||||||
if use_lcc:
|
if use_lcc:
|
||||||
graph = stable_largest_connected_component(graph)
|
graph = stable_largest_connected_component(graph)
|
||||||
|
|
||||||
@ -100,7 +101,8 @@ def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]:
|
|||||||
logging.debug(
|
logging.debug(
|
||||||
"Running leiden with max_cluster_size=%s, lcc=%s", max_cluster_size, use_lcc
|
"Running leiden with max_cluster_size=%s, lcc=%s", max_cluster_size, use_lcc
|
||||||
)
|
)
|
||||||
if not graph.nodes(): return {}
|
if not graph.nodes():
|
||||||
|
return {}
|
||||||
|
|
||||||
node_id_to_community_map = _compute_leiden_communities(
|
node_id_to_community_map = _compute_leiden_communities(
|
||||||
graph=graph,
|
graph=graph,
|
||||||
@ -125,9 +127,11 @@ def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]:
|
|||||||
result[community_id]["nodes"].append(node_id)
|
result[community_id]["nodes"].append(node_id)
|
||||||
result[community_id]["weight"] += graph.nodes[node_id].get("rank", 0) * graph.nodes[node_id].get("weight", 1)
|
result[community_id]["weight"] += graph.nodes[node_id].get("rank", 0) * graph.nodes[node_id].get("weight", 1)
|
||||||
weights = [comm["weight"] for _, comm in result.items()]
|
weights = [comm["weight"] for _, comm in result.items()]
|
||||||
if not weights:continue
|
if not weights:
|
||||||
|
continue
|
||||||
max_weight = max(weights)
|
max_weight = max(weights)
|
||||||
for _, comm in result.items(): comm["weight"] /= max_weight
|
for _, comm in result.items():
|
||||||
|
comm["weight"] /= max_weight
|
||||||
|
|
||||||
return results_by_level
|
return results_by_level
|
||||||
|
|
||||||
|
@ -1 +1,5 @@
|
|||||||
from .ragflow_chat import *
|
from .ragflow_chat import RAGFlowChat
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"RAGFlowChat"
|
||||||
|
]
|
||||||
|
@ -2,7 +2,6 @@ import logging
|
|||||||
import requests
|
import requests
|
||||||
from bridge.context import ContextType # Import Context, ContextType
|
from bridge.context import ContextType # Import Context, ContextType
|
||||||
from bridge.reply import Reply, ReplyType # Import Reply, ReplyType
|
from bridge.reply import Reply, ReplyType # Import Reply, ReplyType
|
||||||
from bridge import *
|
|
||||||
from plugins import Plugin, register # Import Plugin and register
|
from plugins import Plugin, register # Import Plugin and register
|
||||||
from plugins.event import Event, EventContext, EventAction # Import event-related classes
|
from plugins.event import Event, EventContext, EventAction # Import event-related classes
|
||||||
|
|
||||||
|
@ -94,7 +94,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|||||||
callback(0.1, "Start to parse.")
|
callback(0.1, "Start to parse.")
|
||||||
txt = get_text(filename, binary)
|
txt = get_text(filename, binary)
|
||||||
sections = txt.split("\n")
|
sections = txt.split("\n")
|
||||||
sections = [(l, "") for l in sections if l]
|
sections = [(line, "") for line in sections if line]
|
||||||
remove_contents_table(sections, eng=is_english(
|
remove_contents_table(sections, eng=is_english(
|
||||||
random_choices([t for t, _ in sections], k=200)))
|
random_choices([t for t, _ in sections], k=200)))
|
||||||
callback(0.8, "Finish parsing.")
|
callback(0.8, "Finish parsing.")
|
||||||
@ -102,7 +102,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|||||||
elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE):
|
elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE):
|
||||||
callback(0.1, "Start to parse.")
|
callback(0.1, "Start to parse.")
|
||||||
sections = HtmlParser()(filename, binary)
|
sections = HtmlParser()(filename, binary)
|
||||||
sections = [(l, "") for l in sections if l]
|
sections = [(line, "") for line in sections if line]
|
||||||
remove_contents_table(sections, eng=is_english(
|
remove_contents_table(sections, eng=is_english(
|
||||||
random_choices([t for t, _ in sections], k=200)))
|
random_choices([t for t, _ in sections], k=200)))
|
||||||
callback(0.8, "Finish parsing.")
|
callback(0.8, "Finish parsing.")
|
||||||
@ -112,7 +112,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|||||||
binary = BytesIO(binary)
|
binary = BytesIO(binary)
|
||||||
doc_parsed = parser.from_buffer(binary)
|
doc_parsed = parser.from_buffer(binary)
|
||||||
sections = doc_parsed['content'].split('\n')
|
sections = doc_parsed['content'].split('\n')
|
||||||
sections = [(l, "") for l in sections if l]
|
sections = [(line, "") for line in sections if line]
|
||||||
remove_contents_table(sections, eng=is_english(
|
remove_contents_table(sections, eng=is_english(
|
||||||
random_choices([t for t, _ in sections], k=200)))
|
random_choices([t for t, _ in sections], k=200)))
|
||||||
callback(0.8, "Finish parsing.")
|
callback(0.8, "Finish parsing.")
|
||||||
|
@ -75,7 +75,7 @@ def chunk(
|
|||||||
_add_content(msg, msg.get_content_type())
|
_add_content(msg, msg.get_content_type())
|
||||||
|
|
||||||
sections = TxtParser.parser_txt("\n".join(text_txt)) + [
|
sections = TxtParser.parser_txt("\n".join(text_txt)) + [
|
||||||
(l, "") for l in HtmlParser.parser_txt("\n".join(html_txt)) if l
|
(line, "") for line in HtmlParser.parser_txt("\n".join(html_txt)) if line
|
||||||
]
|
]
|
||||||
|
|
||||||
st = timer()
|
st = timer()
|
||||||
|
@ -18,7 +18,8 @@ def chunk(filename, binary, tenant_id, from_page=0, to_page=100000,
|
|||||||
chunks = build_knowledge_graph_chunks(tenant_id, sections, callback,
|
chunks = build_knowledge_graph_chunks(tenant_id, sections, callback,
|
||||||
parser_config.get("entity_types", ["organization", "person", "location", "event", "time"])
|
parser_config.get("entity_types", ["organization", "person", "location", "event", "time"])
|
||||||
)
|
)
|
||||||
for c in chunks: c["docnm_kwd"] = filename
|
for c in chunks:
|
||||||
|
c["docnm_kwd"] = filename
|
||||||
|
|
||||||
doc = {
|
doc = {
|
||||||
"docnm_kwd": filename,
|
"docnm_kwd": filename,
|
||||||
|
@ -48,7 +48,7 @@ class Docx(DocxParser):
|
|||||||
continue
|
continue
|
||||||
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
|
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
|
||||||
pn += 1
|
pn += 1
|
||||||
return [l for l in lines if l]
|
return [line for line in lines if line]
|
||||||
|
|
||||||
def __call__(self, filename, binary=None, from_page=0, to_page=100000):
|
def __call__(self, filename, binary=None, from_page=0, to_page=100000):
|
||||||
self.doc = Document(
|
self.doc = Document(
|
||||||
@ -60,7 +60,8 @@ class Docx(DocxParser):
|
|||||||
if pn > to_page:
|
if pn > to_page:
|
||||||
break
|
break
|
||||||
question_level, p_text = docx_question_level(p, bull)
|
question_level, p_text = docx_question_level(p, bull)
|
||||||
if not p_text.strip("\n"):continue
|
if not p_text.strip("\n"):
|
||||||
|
continue
|
||||||
lines.append((question_level, p_text))
|
lines.append((question_level, p_text))
|
||||||
|
|
||||||
for run in p.runs:
|
for run in p.runs:
|
||||||
@ -78,19 +79,21 @@ class Docx(DocxParser):
|
|||||||
if lines[e][0] <= lines[s][0]:
|
if lines[e][0] <= lines[s][0]:
|
||||||
break
|
break
|
||||||
e += 1
|
e += 1
|
||||||
if e - s == 1 and visit[s]: continue
|
if e - s == 1 and visit[s]:
|
||||||
|
continue
|
||||||
sec = []
|
sec = []
|
||||||
next_level = lines[s][0] + 1
|
next_level = lines[s][0] + 1
|
||||||
while not sec and next_level < 22:
|
while not sec and next_level < 22:
|
||||||
for i in range(s+1, e):
|
for i in range(s+1, e):
|
||||||
if lines[i][0] != next_level: continue
|
if lines[i][0] != next_level:
|
||||||
|
continue
|
||||||
sec.append(lines[i][1])
|
sec.append(lines[i][1])
|
||||||
visit[i] = True
|
visit[i] = True
|
||||||
next_level += 1
|
next_level += 1
|
||||||
sec.insert(0, lines[s][1])
|
sec.insert(0, lines[s][1])
|
||||||
|
|
||||||
sections.append("\n".join(sec))
|
sections.append("\n".join(sec))
|
||||||
return [l for l in sections if l]
|
return [s for s in sections if s]
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return f'''
|
return f'''
|
||||||
@ -168,13 +171,13 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|||||||
callback(0.1, "Start to parse.")
|
callback(0.1, "Start to parse.")
|
||||||
txt = get_text(filename, binary)
|
txt = get_text(filename, binary)
|
||||||
sections = txt.split("\n")
|
sections = txt.split("\n")
|
||||||
sections = [l for l in sections if l]
|
sections = [s for s in sections if s]
|
||||||
callback(0.8, "Finish parsing.")
|
callback(0.8, "Finish parsing.")
|
||||||
|
|
||||||
elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE):
|
elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE):
|
||||||
callback(0.1, "Start to parse.")
|
callback(0.1, "Start to parse.")
|
||||||
sections = HtmlParser()(filename, binary)
|
sections = HtmlParser()(filename, binary)
|
||||||
sections = [l for l in sections if l]
|
sections = [s for s in sections if s]
|
||||||
callback(0.8, "Finish parsing.")
|
callback(0.8, "Finish parsing.")
|
||||||
|
|
||||||
elif re.search(r"\.doc$", filename, re.IGNORECASE):
|
elif re.search(r"\.doc$", filename, re.IGNORECASE):
|
||||||
@ -182,7 +185,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|||||||
binary = BytesIO(binary)
|
binary = BytesIO(binary)
|
||||||
doc_parsed = parser.from_buffer(binary)
|
doc_parsed = parser.from_buffer(binary)
|
||||||
sections = doc_parsed['content'].split('\n')
|
sections = doc_parsed['content'].split('\n')
|
||||||
sections = [l for l in sections if l]
|
sections = [s for s in sections if s]
|
||||||
callback(0.8, "Finish parsing.")
|
callback(0.8, "Finish parsing.")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -190,7 +190,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|||||||
sections, tbls = pdf_parser(filename if not binary else binary,
|
sections, tbls = pdf_parser(filename if not binary else binary,
|
||||||
from_page=from_page, to_page=to_page, callback=callback)
|
from_page=from_page, to_page=to_page, callback=callback)
|
||||||
if sections and len(sections[0]) < 3:
|
if sections and len(sections[0]) < 3:
|
||||||
sections = [(t, l, [[0] * 5]) for t, l in sections]
|
sections = [(t, lvl, [[0] * 5]) for t, lvl in sections]
|
||||||
# set pivot using the most frequent type of title,
|
# set pivot using the most frequent type of title,
|
||||||
# then merge between 2 pivot
|
# then merge between 2 pivot
|
||||||
if len(sections) > 0 and len(pdf_parser.outlines) / len(sections) > 0.1:
|
if len(sections) > 0 and len(pdf_parser.outlines) / len(sections) > 0.1:
|
||||||
@ -211,7 +211,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|||||||
else:
|
else:
|
||||||
bull = bullets_category([txt for txt, _, _ in sections])
|
bull = bullets_category([txt for txt, _, _ in sections])
|
||||||
most_level, levels = title_frequency(
|
most_level, levels = title_frequency(
|
||||||
bull, [(txt, l) for txt, l, poss in sections])
|
bull, [(txt, lvl) for txt, lvl, _ in sections])
|
||||||
|
|
||||||
assert len(sections) == len(levels)
|
assert len(sections) == len(levels)
|
||||||
sec_ids = []
|
sec_ids = []
|
||||||
@ -225,7 +225,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|||||||
sections = [(txt, sec_ids[i], poss)
|
sections = [(txt, sec_ids[i], poss)
|
||||||
for i, (txt, _, poss) in enumerate(sections)]
|
for i, (txt, _, poss) in enumerate(sections)]
|
||||||
for (img, rows), poss in tbls:
|
for (img, rows), poss in tbls:
|
||||||
if not rows: continue
|
if not rows:
|
||||||
|
continue
|
||||||
sections.append((rows if isinstance(rows, str) else rows[0], -1,
|
sections.append((rows if isinstance(rows, str) else rows[0], -1,
|
||||||
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
|
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
|
||||||
|
|
||||||
|
@ -54,7 +54,8 @@ class Pdf(PdfParser):
|
|||||||
sections = [(b["text"], self.get_position(b, zoomin))
|
sections = [(b["text"], self.get_position(b, zoomin))
|
||||||
for i, b in enumerate(self.boxes)]
|
for i, b in enumerate(self.boxes)]
|
||||||
for (img, rows), poss in tbls:
|
for (img, rows), poss in tbls:
|
||||||
if not rows:continue
|
if not rows:
|
||||||
|
continue
|
||||||
sections.append((rows if isinstance(rows, str) else rows[0],
|
sections.append((rows if isinstance(rows, str) else rows[0],
|
||||||
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
|
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
|
||||||
return [(txt, "") for txt, _ in sorted(sections, key=lambda x: (
|
return [(txt, "") for txt, _ in sorted(sections, key=lambda x: (
|
||||||
@ -109,7 +110,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|||||||
binary = BytesIO(binary)
|
binary = BytesIO(binary)
|
||||||
doc_parsed = parser.from_buffer(binary)
|
doc_parsed = parser.from_buffer(binary)
|
||||||
sections = doc_parsed['content'].split('\n')
|
sections = doc_parsed['content'].split('\n')
|
||||||
sections = [l for l in sections if l]
|
sections = [s for s in sections if s]
|
||||||
callback(0.8, "Finish parsing.")
|
callback(0.8, "Finish parsing.")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -171,7 +171,7 @@ class Pdf(PdfParser):
|
|||||||
tbl_bottom = tbls[tbl_index][1][0][4]
|
tbl_bottom = tbls[tbl_index][1][0][4]
|
||||||
tbl_tag = "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##" \
|
tbl_tag = "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##" \
|
||||||
.format(tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom)
|
.format(tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom)
|
||||||
tbl_text = ''.join(tbls[tbl_index][0][1])
|
_tbl_text = ''.join(tbls[tbl_index][0][1])
|
||||||
return tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag,
|
return tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag,
|
||||||
|
|
||||||
|
|
||||||
@ -325,9 +325,11 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
|||||||
txt = get_text(filename, binary)
|
txt = get_text(filename, binary)
|
||||||
lines = txt.split("\n")
|
lines = txt.split("\n")
|
||||||
comma, tab = 0, 0
|
comma, tab = 0, 0
|
||||||
for l in lines:
|
for line in lines:
|
||||||
if len(l.split(",")) == 2: comma += 1
|
if len(line.split(",")) == 2:
|
||||||
if len(l.split("\t")) == 2: tab += 1
|
comma += 1
|
||||||
|
if len(line.split("\t")) == 2:
|
||||||
|
tab += 1
|
||||||
delimiter = "\t" if tab >= comma else ","
|
delimiter = "\t" if tab >= comma else ","
|
||||||
|
|
||||||
fails = []
|
fails = []
|
||||||
@ -336,18 +338,21 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
|||||||
while i < len(lines):
|
while i < len(lines):
|
||||||
arr = lines[i].split(delimiter)
|
arr = lines[i].split(delimiter)
|
||||||
if len(arr) != 2:
|
if len(arr) != 2:
|
||||||
if question: answer += "\n" + lines[i]
|
if question:
|
||||||
|
answer += "\n" + lines[i]
|
||||||
else:
|
else:
|
||||||
fails.append(str(i+1))
|
fails.append(str(i+1))
|
||||||
elif len(arr) == 2:
|
elif len(arr) == 2:
|
||||||
if question and answer: res.append(beAdoc(deepcopy(doc), question, answer, eng))
|
if question and answer:
|
||||||
|
res.append(beAdoc(deepcopy(doc), question, answer, eng))
|
||||||
question, answer = arr
|
question, answer = arr
|
||||||
i += 1
|
i += 1
|
||||||
if len(res) % 999 == 0:
|
if len(res) % 999 == 0:
|
||||||
callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + (
|
callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + (
|
||||||
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||||
|
|
||||||
if question: res.append(beAdoc(deepcopy(doc), question, answer, eng))
|
if question:
|
||||||
|
res.append(beAdoc(deepcopy(doc), question, answer, eng))
|
||||||
|
|
||||||
callback(0.6, ("Extract Q&A: {}".format(len(res)) + (
|
callback(0.6, ("Extract Q&A: {}".format(len(res)) + (
|
||||||
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||||
@ -367,19 +372,18 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
|||||||
callback(0.1, "Start to parse.")
|
callback(0.1, "Start to parse.")
|
||||||
txt = get_text(filename, binary)
|
txt = get_text(filename, binary)
|
||||||
lines = txt.split("\n")
|
lines = txt.split("\n")
|
||||||
last_question, last_answer = "", ""
|
_last_question, last_answer = "", ""
|
||||||
question_stack, level_stack = [], []
|
question_stack, level_stack = [], []
|
||||||
code_block = False
|
code_block = False
|
||||||
level_index = [-1] * 7
|
for index, line in enumerate(lines):
|
||||||
for index, l in enumerate(lines):
|
if line.strip().startswith('```'):
|
||||||
if l.strip().startswith('```'):
|
|
||||||
code_block = not code_block
|
code_block = not code_block
|
||||||
question_level, question = 0, ''
|
question_level, question = 0, ''
|
||||||
if not code_block:
|
if not code_block:
|
||||||
question_level, question = mdQuestionLevel(l)
|
question_level, question = mdQuestionLevel(line)
|
||||||
|
|
||||||
if not question_level or question_level > 6: # not a question
|
if not question_level or question_level > 6: # not a question
|
||||||
last_answer = f'{last_answer}\n{l}'
|
last_answer = f'{last_answer}\n{line}'
|
||||||
else: # is a question
|
else: # is a question
|
||||||
if last_answer.strip():
|
if last_answer.strip():
|
||||||
sum_question = '\n'.join(question_stack)
|
sum_question = '\n'.join(question_stack)
|
||||||
|
@ -41,14 +41,16 @@ class Excel(ExcelParser):
|
|||||||
for sheetname in wb.sheetnames:
|
for sheetname in wb.sheetnames:
|
||||||
ws = wb[sheetname]
|
ws = wb[sheetname]
|
||||||
rows = list(ws.rows)
|
rows = list(ws.rows)
|
||||||
if not rows:continue
|
if not rows:
|
||||||
|
continue
|
||||||
headers = [cell.value for cell in rows[0]]
|
headers = [cell.value for cell in rows[0]]
|
||||||
missed = set([i for i, h in enumerate(headers) if h is None])
|
missed = set([i for i, h in enumerate(headers) if h is None])
|
||||||
headers = [
|
headers = [
|
||||||
cell.value for i,
|
cell.value for i,
|
||||||
cell in enumerate(
|
cell in enumerate(
|
||||||
rows[0]) if i not in missed]
|
rows[0]) if i not in missed]
|
||||||
if not headers:continue
|
if not headers:
|
||||||
|
continue
|
||||||
data = []
|
data = []
|
||||||
for i, r in enumerate(rows[1:]):
|
for i, r in enumerate(rows[1:]):
|
||||||
rn += 1
|
rn += 1
|
||||||
@ -88,7 +90,6 @@ def trans_bool(s):
|
|||||||
|
|
||||||
def column_data_type(arr):
|
def column_data_type(arr):
|
||||||
arr = list(arr)
|
arr = list(arr)
|
||||||
uni = len(set([a for a in arr if a is not None]))
|
|
||||||
counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0}
|
counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0}
|
||||||
trans = {t: f for f, t in
|
trans = {t: f for f, t in
|
||||||
[(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]}
|
[(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]}
|
||||||
@ -157,7 +158,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000,
|
|||||||
continue
|
continue
|
||||||
if i >= to_page:
|
if i >= to_page:
|
||||||
break
|
break
|
||||||
row = [l for l in line.split(kwargs.get("delimiter", "\t"))]
|
row = [field for field in line.split(kwargs.get("delimiter", "\t"))]
|
||||||
if len(row) != len(headers):
|
if len(row) != len(headers):
|
||||||
fails.append(str(i))
|
fails.append(str(i))
|
||||||
continue
|
continue
|
||||||
|
@ -13,12 +13,124 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
from .embedding_model import *
|
from .embedding_model import (
|
||||||
from .chat_model import *
|
OllamaEmbed,
|
||||||
from .cv_model import *
|
LocalAIEmbed,
|
||||||
from .rerank_model import *
|
OpenAIEmbed,
|
||||||
from .sequence2txt_model import *
|
AzureEmbed,
|
||||||
from .tts_model import *
|
XinferenceEmbed,
|
||||||
|
QWenEmbed,
|
||||||
|
ZhipuEmbed,
|
||||||
|
FastEmbed,
|
||||||
|
YoudaoEmbed,
|
||||||
|
BaiChuanEmbed,
|
||||||
|
JinaEmbed,
|
||||||
|
DefaultEmbedding,
|
||||||
|
MistralEmbed,
|
||||||
|
BedrockEmbed,
|
||||||
|
GeminiEmbed,
|
||||||
|
NvidiaEmbed,
|
||||||
|
LmStudioEmbed,
|
||||||
|
OpenAI_APIEmbed,
|
||||||
|
CoHereEmbed,
|
||||||
|
TogetherAIEmbed,
|
||||||
|
PerfXCloudEmbed,
|
||||||
|
UpstageEmbed,
|
||||||
|
SILICONFLOWEmbed,
|
||||||
|
ReplicateEmbed,
|
||||||
|
BaiduYiyanEmbed,
|
||||||
|
VoyageEmbed,
|
||||||
|
HuggingFaceEmbed,
|
||||||
|
VolcEngineEmbed,
|
||||||
|
)
|
||||||
|
from .chat_model import (
|
||||||
|
GptTurbo,
|
||||||
|
AzureChat,
|
||||||
|
ZhipuChat,
|
||||||
|
QWenChat,
|
||||||
|
OllamaChat,
|
||||||
|
LocalAIChat,
|
||||||
|
XinferenceChat,
|
||||||
|
MoonshotChat,
|
||||||
|
DeepSeekChat,
|
||||||
|
VolcEngineChat,
|
||||||
|
BaiChuanChat,
|
||||||
|
MiniMaxChat,
|
||||||
|
MistralChat,
|
||||||
|
GeminiChat,
|
||||||
|
BedrockChat,
|
||||||
|
GroqChat,
|
||||||
|
OpenRouterChat,
|
||||||
|
StepFunChat,
|
||||||
|
NvidiaChat,
|
||||||
|
LmStudioChat,
|
||||||
|
OpenAI_APIChat,
|
||||||
|
CoHereChat,
|
||||||
|
LeptonAIChat,
|
||||||
|
TogetherAIChat,
|
||||||
|
PerfXCloudChat,
|
||||||
|
UpstageChat,
|
||||||
|
NovitaAIChat,
|
||||||
|
SILICONFLOWChat,
|
||||||
|
YiChat,
|
||||||
|
ReplicateChat,
|
||||||
|
HunyuanChat,
|
||||||
|
SparkChat,
|
||||||
|
BaiduYiyanChat,
|
||||||
|
AnthropicChat,
|
||||||
|
GoogleChat,
|
||||||
|
HuggingFaceChat,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .cv_model import (
|
||||||
|
GptV4,
|
||||||
|
AzureGptV4,
|
||||||
|
OllamaCV,
|
||||||
|
XinferenceCV,
|
||||||
|
QWenCV,
|
||||||
|
Zhipu4V,
|
||||||
|
LocalCV,
|
||||||
|
GeminiCV,
|
||||||
|
OpenRouterCV,
|
||||||
|
LocalAICV,
|
||||||
|
NvidiaCV,
|
||||||
|
LmStudioCV,
|
||||||
|
StepFunCV,
|
||||||
|
OpenAI_APICV,
|
||||||
|
TogetherAICV,
|
||||||
|
YiCV,
|
||||||
|
HunyuanCV,
|
||||||
|
)
|
||||||
|
from .rerank_model import (
|
||||||
|
LocalAIRerank,
|
||||||
|
DefaultRerank,
|
||||||
|
JinaRerank,
|
||||||
|
YoudaoRerank,
|
||||||
|
XInferenceRerank,
|
||||||
|
NvidiaRerank,
|
||||||
|
LmStudioRerank,
|
||||||
|
OpenAI_APIRerank,
|
||||||
|
CoHereRerank,
|
||||||
|
TogetherAIRerank,
|
||||||
|
SILICONFLOWRerank,
|
||||||
|
BaiduYiyanRerank,
|
||||||
|
VoyageRerank,
|
||||||
|
QWenRerank,
|
||||||
|
)
|
||||||
|
from .sequence2txt_model import (
|
||||||
|
GPTSeq2txt,
|
||||||
|
QWenSeq2txt,
|
||||||
|
AzureSeq2txt,
|
||||||
|
XinferenceSeq2txt,
|
||||||
|
TencentCloudSeq2txt,
|
||||||
|
)
|
||||||
|
from .tts_model import (
|
||||||
|
FishAudioTTS,
|
||||||
|
QwenTTS,
|
||||||
|
OpenAITTS,
|
||||||
|
SparkTTS,
|
||||||
|
XinferenceTTS,
|
||||||
|
)
|
||||||
|
|
||||||
EmbeddingModel = {
|
EmbeddingModel = {
|
||||||
"Ollama": OllamaEmbed,
|
"Ollama": OllamaEmbed,
|
||||||
@ -48,7 +160,7 @@ EmbeddingModel = {
|
|||||||
"BaiduYiyan": BaiduYiyanEmbed,
|
"BaiduYiyan": BaiduYiyanEmbed,
|
||||||
"Voyage AI": VoyageEmbed,
|
"Voyage AI": VoyageEmbed,
|
||||||
"HuggingFace": HuggingFaceEmbed,
|
"HuggingFace": HuggingFaceEmbed,
|
||||||
"VolcEngine":VolcEngineEmbed,
|
"VolcEngine": VolcEngineEmbed,
|
||||||
}
|
}
|
||||||
|
|
||||||
CvModel = {
|
CvModel = {
|
||||||
@ -68,7 +180,7 @@ CvModel = {
|
|||||||
"OpenAI-API-Compatible": OpenAI_APICV,
|
"OpenAI-API-Compatible": OpenAI_APICV,
|
||||||
"TogetherAI": TogetherAICV,
|
"TogetherAI": TogetherAICV,
|
||||||
"01.AI": YiCV,
|
"01.AI": YiCV,
|
||||||
"Tencent Hunyuan": HunyuanCV
|
"Tencent Hunyuan": HunyuanCV,
|
||||||
}
|
}
|
||||||
|
|
||||||
ChatModel = {
|
ChatModel = {
|
||||||
@ -111,7 +223,7 @@ ChatModel = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
RerankModel = {
|
RerankModel = {
|
||||||
"LocalAI":LocalAIRerank,
|
"LocalAI": LocalAIRerank,
|
||||||
"BAAI": DefaultRerank,
|
"BAAI": DefaultRerank,
|
||||||
"Jina": JinaRerank,
|
"Jina": JinaRerank,
|
||||||
"Youdao": YoudaoRerank,
|
"Youdao": YoudaoRerank,
|
||||||
@ -132,7 +244,7 @@ Seq2txtModel = {
|
|||||||
"Tongyi-Qianwen": QWenSeq2txt,
|
"Tongyi-Qianwen": QWenSeq2txt,
|
||||||
"Azure-OpenAI": AzureSeq2txt,
|
"Azure-OpenAI": AzureSeq2txt,
|
||||||
"Xinference": XinferenceSeq2txt,
|
"Xinference": XinferenceSeq2txt,
|
||||||
"Tencent Cloud": TencentCloudSeq2txt
|
"Tencent Cloud": TencentCloudSeq2txt,
|
||||||
}
|
}
|
||||||
|
|
||||||
TTSModel = {
|
TTSModel = {
|
||||||
|
@ -69,7 +69,8 @@ class Base(ABC):
|
|||||||
stream=True,
|
stream=True,
|
||||||
**gen_conf)
|
**gen_conf)
|
||||||
for resp in response:
|
for resp in response:
|
||||||
if not resp.choices: continue
|
if not resp.choices:
|
||||||
|
continue
|
||||||
if not resp.choices[0].delta.content:
|
if not resp.choices[0].delta.content:
|
||||||
resp.choices[0].delta.content = ""
|
resp.choices[0].delta.content = ""
|
||||||
ans += resp.choices[0].delta.content
|
ans += resp.choices[0].delta.content
|
||||||
@ -81,7 +82,8 @@ class Base(ABC):
|
|||||||
)
|
)
|
||||||
elif isinstance(resp.usage, dict):
|
elif isinstance(resp.usage, dict):
|
||||||
total_tokens = resp.usage.get("total_tokens", total_tokens)
|
total_tokens = resp.usage.get("total_tokens", total_tokens)
|
||||||
else: total_tokens = resp.usage.total_tokens
|
else:
|
||||||
|
total_tokens = resp.usage.total_tokens
|
||||||
|
|
||||||
if resp.choices[0].finish_reason == "length":
|
if resp.choices[0].finish_reason == "length":
|
||||||
if is_chinese(ans):
|
if is_chinese(ans):
|
||||||
@ -98,13 +100,15 @@ class Base(ABC):
|
|||||||
|
|
||||||
class GptTurbo(Base):
|
class GptTurbo(Base):
|
||||||
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
|
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
|
||||||
if not base_url: base_url = "https://api.openai.com/v1"
|
if not base_url:
|
||||||
|
base_url = "https://api.openai.com/v1"
|
||||||
super().__init__(key, model_name, base_url)
|
super().__init__(key, model_name, base_url)
|
||||||
|
|
||||||
|
|
||||||
class MoonshotChat(Base):
|
class MoonshotChat(Base):
|
||||||
def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
|
def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
|
||||||
if not base_url: base_url = "https://api.moonshot.cn/v1"
|
if not base_url:
|
||||||
|
base_url = "https://api.moonshot.cn/v1"
|
||||||
super().__init__(key, model_name, base_url)
|
super().__init__(key, model_name, base_url)
|
||||||
|
|
||||||
|
|
||||||
@ -128,7 +132,8 @@ class HuggingFaceChat(Base):
|
|||||||
|
|
||||||
class DeepSeekChat(Base):
|
class DeepSeekChat(Base):
|
||||||
def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"):
|
def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"):
|
||||||
if not base_url: base_url = "https://api.deepseek.com/v1"
|
if not base_url:
|
||||||
|
base_url = "https://api.deepseek.com/v1"
|
||||||
super().__init__(key, model_name, base_url)
|
super().__init__(key, model_name, base_url)
|
||||||
|
|
||||||
|
|
||||||
@ -202,7 +207,8 @@ class BaiChuanChat(Base):
|
|||||||
stream=True,
|
stream=True,
|
||||||
**self._format_params(gen_conf))
|
**self._format_params(gen_conf))
|
||||||
for resp in response:
|
for resp in response:
|
||||||
if not resp.choices: continue
|
if not resp.choices:
|
||||||
|
continue
|
||||||
if not resp.choices[0].delta.content:
|
if not resp.choices[0].delta.content:
|
||||||
resp.choices[0].delta.content = ""
|
resp.choices[0].delta.content = ""
|
||||||
ans += resp.choices[0].delta.content
|
ans += resp.choices[0].delta.content
|
||||||
@ -313,8 +319,10 @@ class ZhipuChat(Base):
|
|||||||
if system:
|
if system:
|
||||||
history.insert(0, {"role": "system", "content": system})
|
history.insert(0, {"role": "system", "content": system})
|
||||||
try:
|
try:
|
||||||
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
|
if "presence_penalty" in gen_conf:
|
||||||
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
|
del gen_conf["presence_penalty"]
|
||||||
|
if "frequency_penalty" in gen_conf:
|
||||||
|
del gen_conf["frequency_penalty"]
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
messages=history,
|
messages=history,
|
||||||
@ -333,8 +341,10 @@ class ZhipuChat(Base):
|
|||||||
def chat_streamly(self, system, history, gen_conf):
|
def chat_streamly(self, system, history, gen_conf):
|
||||||
if system:
|
if system:
|
||||||
history.insert(0, {"role": "system", "content": system})
|
history.insert(0, {"role": "system", "content": system})
|
||||||
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
|
if "presence_penalty" in gen_conf:
|
||||||
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
|
del gen_conf["presence_penalty"]
|
||||||
|
if "frequency_penalty" in gen_conf:
|
||||||
|
del gen_conf["frequency_penalty"]
|
||||||
ans = ""
|
ans = ""
|
||||||
tk_count = 0
|
tk_count = 0
|
||||||
try:
|
try:
|
||||||
@ -345,7 +355,8 @@ class ZhipuChat(Base):
|
|||||||
**gen_conf
|
**gen_conf
|
||||||
)
|
)
|
||||||
for resp in response:
|
for resp in response:
|
||||||
if not resp.choices[0].delta.content: continue
|
if not resp.choices[0].delta.content:
|
||||||
|
continue
|
||||||
delta = resp.choices[0].delta.content
|
delta = resp.choices[0].delta.content
|
||||||
ans += delta
|
ans += delta
|
||||||
if resp.choices[0].finish_reason == "length":
|
if resp.choices[0].finish_reason == "length":
|
||||||
@ -354,7 +365,8 @@ class ZhipuChat(Base):
|
|||||||
else:
|
else:
|
||||||
ans += LENGTH_NOTIFICATION_EN
|
ans += LENGTH_NOTIFICATION_EN
|
||||||
tk_count = resp.usage.total_tokens
|
tk_count = resp.usage.total_tokens
|
||||||
if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
|
if resp.choices[0].finish_reason == "stop":
|
||||||
|
tk_count = resp.usage.total_tokens
|
||||||
yield ans
|
yield ans
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield ans + "\n**ERROR**: " + str(e)
|
yield ans + "\n**ERROR**: " + str(e)
|
||||||
@ -372,11 +384,16 @@ class OllamaChat(Base):
|
|||||||
history.insert(0, {"role": "system", "content": system})
|
history.insert(0, {"role": "system", "content": system})
|
||||||
try:
|
try:
|
||||||
options = {}
|
options = {}
|
||||||
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
|
if "temperature" in gen_conf:
|
||||||
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
|
options["temperature"] = gen_conf["temperature"]
|
||||||
if "top_p" in gen_conf: options["top_p"] = gen_conf["top_p"]
|
if "max_tokens" in gen_conf:
|
||||||
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
|
options["num_predict"] = gen_conf["max_tokens"]
|
||||||
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
if "top_p" in gen_conf:
|
||||||
|
options["top_p"] = gen_conf["top_p"]
|
||||||
|
if "presence_penalty" in gen_conf:
|
||||||
|
options["presence_penalty"] = gen_conf["presence_penalty"]
|
||||||
|
if "frequency_penalty" in gen_conf:
|
||||||
|
options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
||||||
response = self.client.chat(
|
response = self.client.chat(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
messages=history,
|
messages=history,
|
||||||
@ -392,11 +409,16 @@ class OllamaChat(Base):
|
|||||||
if system:
|
if system:
|
||||||
history.insert(0, {"role": "system", "content": system})
|
history.insert(0, {"role": "system", "content": system})
|
||||||
options = {}
|
options = {}
|
||||||
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
|
if "temperature" in gen_conf:
|
||||||
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
|
options["temperature"] = gen_conf["temperature"]
|
||||||
if "top_p" in gen_conf: options["top_p"] = gen_conf["top_p"]
|
if "max_tokens" in gen_conf:
|
||||||
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
|
options["num_predict"] = gen_conf["max_tokens"]
|
||||||
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
if "top_p" in gen_conf:
|
||||||
|
options["top_p"] = gen_conf["top_p"]
|
||||||
|
if "presence_penalty" in gen_conf:
|
||||||
|
options["presence_penalty"] = gen_conf["presence_penalty"]
|
||||||
|
if "frequency_penalty" in gen_conf:
|
||||||
|
options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
||||||
ans = ""
|
ans = ""
|
||||||
try:
|
try:
|
||||||
response = self.client.chat(
|
response = self.client.chat(
|
||||||
@ -636,7 +658,8 @@ class MistralChat(Base):
|
|||||||
messages=history,
|
messages=history,
|
||||||
**gen_conf)
|
**gen_conf)
|
||||||
for resp in response:
|
for resp in response:
|
||||||
if not resp.choices or not resp.choices[0].delta.content: continue
|
if not resp.choices or not resp.choices[0].delta.content:
|
||||||
|
continue
|
||||||
ans += resp.choices[0].delta.content
|
ans += resp.choices[0].delta.content
|
||||||
total_tokens += 1
|
total_tokens += 1
|
||||||
if resp.choices[0].finish_reason == "length":
|
if resp.choices[0].finish_reason == "length":
|
||||||
@ -1196,7 +1219,8 @@ class SparkChat(Base):
|
|||||||
assert model_name in model2version or model_name in version2model, f"The given model name is not supported yet. Support: {list(model2version.keys())}"
|
assert model_name in model2version or model_name in version2model, f"The given model name is not supported yet. Support: {list(model2version.keys())}"
|
||||||
if model_name in model2version:
|
if model_name in model2version:
|
||||||
model_version = model2version[model_name]
|
model_version = model2version[model_name]
|
||||||
else: model_version = model_name
|
else:
|
||||||
|
model_version = model_name
|
||||||
super().__init__(key, model_version, base_url)
|
super().__init__(key, model_version, base_url)
|
||||||
|
|
||||||
|
|
||||||
@ -1281,8 +1305,10 @@ class AnthropicChat(Base):
|
|||||||
self.system = system
|
self.system = system
|
||||||
if "max_tokens" not in gen_conf:
|
if "max_tokens" not in gen_conf:
|
||||||
gen_conf["max_tokens"] = 4096
|
gen_conf["max_tokens"] = 4096
|
||||||
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
|
if "presence_penalty" in gen_conf:
|
||||||
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
|
del gen_conf["presence_penalty"]
|
||||||
|
if "frequency_penalty" in gen_conf:
|
||||||
|
del gen_conf["frequency_penalty"]
|
||||||
|
|
||||||
ans = ""
|
ans = ""
|
||||||
try:
|
try:
|
||||||
@ -1312,8 +1338,10 @@ class AnthropicChat(Base):
|
|||||||
self.system = system
|
self.system = system
|
||||||
if "max_tokens" not in gen_conf:
|
if "max_tokens" not in gen_conf:
|
||||||
gen_conf["max_tokens"] = 4096
|
gen_conf["max_tokens"] = 4096
|
||||||
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
|
if "presence_penalty" in gen_conf:
|
||||||
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
|
del gen_conf["presence_penalty"]
|
||||||
|
if "frequency_penalty" in gen_conf:
|
||||||
|
del gen_conf["frequency_penalty"]
|
||||||
|
|
||||||
ans = ""
|
ans = ""
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
|
@ -25,6 +25,7 @@ import base64
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import json
|
import json
|
||||||
import requests
|
import requests
|
||||||
|
from transformers import GenerationConfig
|
||||||
|
|
||||||
from rag.nlp import is_english
|
from rag.nlp import is_english
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
@ -77,14 +78,16 @@ class Base(ABC):
|
|||||||
stream=True
|
stream=True
|
||||||
)
|
)
|
||||||
for resp in response:
|
for resp in response:
|
||||||
if not resp.choices[0].delta.content: continue
|
if not resp.choices[0].delta.content:
|
||||||
|
continue
|
||||||
delta = resp.choices[0].delta.content
|
delta = resp.choices[0].delta.content
|
||||||
ans += delta
|
ans += delta
|
||||||
if resp.choices[0].finish_reason == "length":
|
if resp.choices[0].finish_reason == "length":
|
||||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||||
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||||
tk_count = resp.usage.total_tokens
|
tk_count = resp.usage.total_tokens
|
||||||
if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
|
if resp.choices[0].finish_reason == "stop":
|
||||||
|
tk_count = resp.usage.total_tokens
|
||||||
yield ans
|
yield ans
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield ans + "\n**ERROR**: " + str(e)
|
yield ans + "\n**ERROR**: " + str(e)
|
||||||
@ -99,7 +102,7 @@ class Base(ABC):
|
|||||||
buffered = BytesIO()
|
buffered = BytesIO()
|
||||||
try:
|
try:
|
||||||
image.save(buffered, format="JPEG")
|
image.save(buffered, format="JPEG")
|
||||||
except Exception as e:
|
except Exception:
|
||||||
image.save(buffered, format="PNG")
|
image.save(buffered, format="PNG")
|
||||||
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
@ -139,7 +142,8 @@ class Base(ABC):
|
|||||||
|
|
||||||
class GptV4(Base):
|
class GptV4(Base):
|
||||||
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
|
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
|
||||||
if not base_url: base_url="https://api.openai.com/v1"
|
if not base_url:
|
||||||
|
base_url="https://api.openai.com/v1"
|
||||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.lang = lang
|
self.lang = lang
|
||||||
@ -149,7 +153,8 @@ class GptV4(Base):
|
|||||||
prompt = self.prompt(b64)
|
prompt = self.prompt(b64)
|
||||||
for i in range(len(prompt)):
|
for i in range(len(prompt)):
|
||||||
for c in prompt[i]["content"]:
|
for c in prompt[i]["content"]:
|
||||||
if "text" in c: c["type"] = "text"
|
if "text" in c:
|
||||||
|
c["type"] = "text"
|
||||||
|
|
||||||
res = self.client.chat.completions.create(
|
res = self.client.chat.completions.create(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
@ -171,7 +176,8 @@ class AzureGptV4(Base):
|
|||||||
prompt = self.prompt(b64)
|
prompt = self.prompt(b64)
|
||||||
for i in range(len(prompt)):
|
for i in range(len(prompt)):
|
||||||
for c in prompt[i]["content"]:
|
for c in prompt[i]["content"]:
|
||||||
if "text" in c: c["type"] = "text"
|
if "text" in c:
|
||||||
|
c["type"] = "text"
|
||||||
|
|
||||||
res = self.client.chat.completions.create(
|
res = self.client.chat.completions.create(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
@ -344,14 +350,16 @@ class Zhipu4V(Base):
|
|||||||
stream=True
|
stream=True
|
||||||
)
|
)
|
||||||
for resp in response:
|
for resp in response:
|
||||||
if not resp.choices[0].delta.content: continue
|
if not resp.choices[0].delta.content:
|
||||||
|
continue
|
||||||
delta = resp.choices[0].delta.content
|
delta = resp.choices[0].delta.content
|
||||||
ans += delta
|
ans += delta
|
||||||
if resp.choices[0].finish_reason == "length":
|
if resp.choices[0].finish_reason == "length":
|
||||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||||
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||||
tk_count = resp.usage.total_tokens
|
tk_count = resp.usage.total_tokens
|
||||||
if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
|
if resp.choices[0].finish_reason == "stop":
|
||||||
|
tk_count = resp.usage.total_tokens
|
||||||
yield ans
|
yield ans
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield ans + "\n**ERROR**: " + str(e)
|
yield ans + "\n**ERROR**: " + str(e)
|
||||||
@ -389,11 +397,16 @@ class OllamaCV(Base):
|
|||||||
if his["role"] == "user":
|
if his["role"] == "user":
|
||||||
his["images"] = [image]
|
his["images"] = [image]
|
||||||
options = {}
|
options = {}
|
||||||
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
|
if "temperature" in gen_conf:
|
||||||
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
|
options["temperature"] = gen_conf["temperature"]
|
||||||
if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
|
if "max_tokens" in gen_conf:
|
||||||
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
|
options["num_predict"] = gen_conf["max_tokens"]
|
||||||
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
if "top_p" in gen_conf:
|
||||||
|
options["top_k"] = gen_conf["top_p"]
|
||||||
|
if "presence_penalty" in gen_conf:
|
||||||
|
options["presence_penalty"] = gen_conf["presence_penalty"]
|
||||||
|
if "frequency_penalty" in gen_conf:
|
||||||
|
options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
||||||
response = self.client.chat(
|
response = self.client.chat(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
messages=history,
|
messages=history,
|
||||||
@ -414,11 +427,16 @@ class OllamaCV(Base):
|
|||||||
if his["role"] == "user":
|
if his["role"] == "user":
|
||||||
his["images"] = [image]
|
his["images"] = [image]
|
||||||
options = {}
|
options = {}
|
||||||
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
|
if "temperature" in gen_conf:
|
||||||
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
|
options["temperature"] = gen_conf["temperature"]
|
||||||
if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
|
if "max_tokens" in gen_conf:
|
||||||
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
|
options["num_predict"] = gen_conf["max_tokens"]
|
||||||
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
if "top_p" in gen_conf:
|
||||||
|
options["top_k"] = gen_conf["top_p"]
|
||||||
|
if "presence_penalty" in gen_conf:
|
||||||
|
options["presence_penalty"] = gen_conf["presence_penalty"]
|
||||||
|
if "frequency_penalty" in gen_conf:
|
||||||
|
options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
||||||
ans = ""
|
ans = ""
|
||||||
try:
|
try:
|
||||||
response = self.client.chat(
|
response = self.client.chat(
|
||||||
@ -469,7 +487,7 @@ class XinferenceCV(Base):
|
|||||||
|
|
||||||
class GeminiCV(Base):
|
class GeminiCV(Base):
|
||||||
def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
|
def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
|
||||||
from google.generativeai import client, GenerativeModel, GenerationConfig
|
from google.generativeai import client, GenerativeModel
|
||||||
client.configure(api_key=key)
|
client.configure(api_key=key)
|
||||||
_client = client.get_default_generative_client()
|
_client = client.get_default_generative_client()
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
@ -503,7 +521,7 @@ class GeminiCV(Base):
|
|||||||
if his["role"] == "user":
|
if his["role"] == "user":
|
||||||
his["parts"] = [his["content"]]
|
his["parts"] = [his["content"]]
|
||||||
his.pop("content")
|
his.pop("content")
|
||||||
history[-1]["parts"].append(f"data:image/jpeg;base64," + image)
|
history[-1]["parts"].append("data:image/jpeg;base64," + image)
|
||||||
|
|
||||||
response = self.model.generate_content(history, generation_config=GenerationConfig(
|
response = self.model.generate_content(history, generation_config=GenerationConfig(
|
||||||
max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3),
|
max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3),
|
||||||
@ -519,7 +537,6 @@ class GeminiCV(Base):
|
|||||||
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||||
|
|
||||||
ans = ""
|
ans = ""
|
||||||
tk_count = 0
|
|
||||||
try:
|
try:
|
||||||
for his in history:
|
for his in history:
|
||||||
if his["role"] == "assistant":
|
if his["role"] == "assistant":
|
||||||
@ -529,14 +546,15 @@ class GeminiCV(Base):
|
|||||||
if his["role"] == "user":
|
if his["role"] == "user":
|
||||||
his["parts"] = [his["content"]]
|
his["parts"] = [his["content"]]
|
||||||
his.pop("content")
|
his.pop("content")
|
||||||
history[-1]["parts"].append(f"data:image/jpeg;base64," + image)
|
history[-1]["parts"].append("data:image/jpeg;base64," + image)
|
||||||
|
|
||||||
response = self.model.generate_content(history, generation_config=GenerationConfig(
|
response = self.model.generate_content(history, generation_config=GenerationConfig(
|
||||||
max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3),
|
max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3),
|
||||||
top_p=gen_conf.get("top_p", 0.7)), stream=True)
|
top_p=gen_conf.get("top_p", 0.7)), stream=True)
|
||||||
|
|
||||||
for resp in response:
|
for resp in response:
|
||||||
if not resp.text: continue
|
if not resp.text:
|
||||||
|
continue
|
||||||
ans += resp.text
|
ans += resp.text
|
||||||
yield ans
|
yield ans
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -632,7 +650,8 @@ class NvidiaCV(Base):
|
|||||||
|
|
||||||
class StepFunCV(GptV4):
|
class StepFunCV(GptV4):
|
||||||
def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"):
|
def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"):
|
||||||
if not base_url: base_url="https://api.stepfun.com/v1"
|
if not base_url:
|
||||||
|
base_url="https://api.stepfun.com/v1"
|
||||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.lang = lang
|
self.lang = lang
|
||||||
|
@ -15,12 +15,9 @@
|
|||||||
#
|
#
|
||||||
import requests
|
import requests
|
||||||
from openai.lib.azure import AzureOpenAI
|
from openai.lib.azure import AzureOpenAI
|
||||||
from zhipuai import ZhipuAI
|
|
||||||
import io
|
import io
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from ollama import Client
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
from rag.utils import num_tokens_from_string
|
from rag.utils import num_tokens_from_string
|
||||||
import base64
|
import base64
|
||||||
@ -49,7 +46,8 @@ class Base(ABC):
|
|||||||
|
|
||||||
class GPTSeq2txt(Base):
|
class GPTSeq2txt(Base):
|
||||||
def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1"):
|
def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1"):
|
||||||
if not base_url: base_url = "https://api.openai.com/v1"
|
if not base_url:
|
||||||
|
base_url = "https://api.openai.com/v1"
|
||||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
|
@ -16,7 +16,6 @@
|
|||||||
|
|
||||||
import _thread as thread
|
import _thread as thread
|
||||||
import base64
|
import base64
|
||||||
import datetime
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import hmac
|
import hmac
|
||||||
import json
|
import json
|
||||||
@ -175,7 +174,8 @@ class QwenTTS(Base):
|
|||||||
|
|
||||||
class OpenAITTS(Base):
|
class OpenAITTS(Base):
|
||||||
def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
|
def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
|
||||||
if not base_url: base_url = "https://api.openai.com/v1"
|
if not base_url:
|
||||||
|
base_url = "https://api.openai.com/v1"
|
||||||
self.api_key = key
|
self.api_key = key
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
|
@ -222,7 +222,8 @@ def bullets_category(sections):
|
|||||||
|
|
||||||
def is_english(texts):
|
def is_english(texts):
|
||||||
eng = 0
|
eng = 0
|
||||||
if not texts: return False
|
if not texts:
|
||||||
|
return False
|
||||||
for t in texts:
|
for t in texts:
|
||||||
if re.match(r"[ `a-zA-Z.,':;/\"?<>!\(\)-]", t.strip()):
|
if re.match(r"[ `a-zA-Z.,':;/\"?<>!\(\)-]", t.strip()):
|
||||||
eng += 1
|
eng += 1
|
||||||
@ -250,7 +251,8 @@ def tokenize_chunks(chunks, doc, eng, pdf_parser=None):
|
|||||||
res = []
|
res = []
|
||||||
# wrap up as es documents
|
# wrap up as es documents
|
||||||
for ck in chunks:
|
for ck in chunks:
|
||||||
if len(ck.strip()) == 0:continue
|
if len(ck.strip()) == 0:
|
||||||
|
continue
|
||||||
logging.debug("-- {}".format(ck))
|
logging.debug("-- {}".format(ck))
|
||||||
d = copy.deepcopy(doc)
|
d = copy.deepcopy(doc)
|
||||||
if pdf_parser:
|
if pdf_parser:
|
||||||
@ -269,7 +271,8 @@ def tokenize_chunks_docx(chunks, doc, eng, images):
|
|||||||
res = []
|
res = []
|
||||||
# wrap up as es documents
|
# wrap up as es documents
|
||||||
for ck, image in zip(chunks, images):
|
for ck, image in zip(chunks, images):
|
||||||
if len(ck.strip()) == 0:continue
|
if len(ck.strip()) == 0:
|
||||||
|
continue
|
||||||
logging.debug("-- {}".format(ck))
|
logging.debug("-- {}".format(ck))
|
||||||
d = copy.deepcopy(doc)
|
d = copy.deepcopy(doc)
|
||||||
d["image"] = image
|
d["image"] = image
|
||||||
@ -288,8 +291,10 @@ def tokenize_table(tbls, doc, eng, batch_size=10):
|
|||||||
d = copy.deepcopy(doc)
|
d = copy.deepcopy(doc)
|
||||||
tokenize(d, rows, eng)
|
tokenize(d, rows, eng)
|
||||||
d["content_with_weight"] = rows
|
d["content_with_weight"] = rows
|
||||||
if img: d["image"] = img
|
if img:
|
||||||
if poss: add_positions(d, poss)
|
d["image"] = img
|
||||||
|
if poss:
|
||||||
|
add_positions(d, poss)
|
||||||
res.append(d)
|
res.append(d)
|
||||||
continue
|
continue
|
||||||
de = "; " if eng else "; "
|
de = "; " if eng else "; "
|
||||||
@ -387,9 +392,9 @@ def title_frequency(bull, sections):
|
|||||||
if re.search(r"(title|head)", layout) and not not_title(txt.split("@")[0]):
|
if re.search(r"(title|head)", layout) and not not_title(txt.split("@")[0]):
|
||||||
levels[i] = bullets_size
|
levels[i] = bullets_size
|
||||||
most_level = bullets_size+1
|
most_level = bullets_size+1
|
||||||
for l, c in sorted(Counter(levels).items(), key=lambda x:x[1]*-1):
|
for level, c in sorted(Counter(levels).items(), key=lambda x:x[1]*-1):
|
||||||
if l <= bullets_size:
|
if level <= bullets_size:
|
||||||
most_level = l
|
most_level = level
|
||||||
break
|
break
|
||||||
return most_level, levels
|
return most_level, levels
|
||||||
|
|
||||||
@ -504,7 +509,8 @@ def naive_merge(sections, chunk_token_num=128, delimiter="\n。;!?"):
|
|||||||
def add_chunk(t, pos):
|
def add_chunk(t, pos):
|
||||||
nonlocal cks, tk_nums, delimiter
|
nonlocal cks, tk_nums, delimiter
|
||||||
tnum = num_tokens_from_string(t)
|
tnum = num_tokens_from_string(t)
|
||||||
if not pos: pos = ""
|
if not pos:
|
||||||
|
pos = ""
|
||||||
if tnum < 8:
|
if tnum < 8:
|
||||||
pos = ""
|
pos = ""
|
||||||
# Ensure that the length of the merged chunk does not exceed chunk_token_num
|
# Ensure that the length of the merged chunk does not exceed chunk_token_num
|
||||||
|
@ -121,7 +121,8 @@ class FulltextQueryer:
|
|||||||
keywords.append(tt)
|
keywords.append(tt)
|
||||||
twts = self.tw.weights([tt])
|
twts = self.tw.weights([tt])
|
||||||
syns = self.syn.lookup(tt)
|
syns = self.syn.lookup(tt)
|
||||||
if syns and len(keywords) < 32: keywords.extend(syns)
|
if syns and len(keywords) < 32:
|
||||||
|
keywords.extend(syns)
|
||||||
logging.debug(json.dumps(twts, ensure_ascii=False))
|
logging.debug(json.dumps(twts, ensure_ascii=False))
|
||||||
tms = []
|
tms = []
|
||||||
for tk, w in sorted(twts, key=lambda x: x[1] * -1):
|
for tk, w in sorted(twts, key=lambda x: x[1] * -1):
|
||||||
@ -147,7 +148,8 @@ class FulltextQueryer:
|
|||||||
|
|
||||||
tk_syns = self.syn.lookup(tk)
|
tk_syns = self.syn.lookup(tk)
|
||||||
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
|
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
|
||||||
if len(keywords) < 32: keywords.extend([s for s in tk_syns if s])
|
if len(keywords) < 32:
|
||||||
|
keywords.extend([s for s in tk_syns if s])
|
||||||
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
|
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
|
||||||
tk_syns = [f"\"{s}\"" if s.find(" ")>0 else s for s in tk_syns]
|
tk_syns = [f"\"{s}\"" if s.find(" ")>0 else s for s in tk_syns]
|
||||||
|
|
||||||
|
@ -104,7 +104,6 @@ class RagTokenizer:
|
|||||||
return HanziConv.toSimplified(line)
|
return HanziConv.toSimplified(line)
|
||||||
|
|
||||||
def dfs_(self, chars, s, preTks, tkslist):
|
def dfs_(self, chars, s, preTks, tkslist):
|
||||||
MAX_L = 10
|
|
||||||
res = s
|
res = s
|
||||||
# if s > MAX_L or s>= len(chars):
|
# if s > MAX_L or s>= len(chars):
|
||||||
if s >= len(chars):
|
if s >= len(chars):
|
||||||
@ -184,12 +183,6 @@ class RagTokenizer:
|
|||||||
return sorted(res, key=lambda x: x[1], reverse=True)
|
return sorted(res, key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
def merge_(self, tks):
|
def merge_(self, tks):
|
||||||
patts = [
|
|
||||||
(r"[ ]+", " "),
|
|
||||||
(r"([0-9\+\.,%\*=-]) ([0-9\+\.,%\*=-])", r"\1\2"),
|
|
||||||
]
|
|
||||||
# for p,s in patts: tks = re.sub(p, s, tks)
|
|
||||||
|
|
||||||
# if split chars is part of token
|
# if split chars is part of token
|
||||||
res = []
|
res = []
|
||||||
tks = re.sub(r"[ ]+", " ", tks).split()
|
tks = re.sub(r"[ ]+", " ", tks).split()
|
||||||
@ -284,7 +277,8 @@ class RagTokenizer:
|
|||||||
same = 0
|
same = 0
|
||||||
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
|
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
|
||||||
same += 1
|
same += 1
|
||||||
if same > 0: res.append(" ".join(tks[j: j + same]))
|
if same > 0:
|
||||||
|
res.append(" ".join(tks[j: j + same]))
|
||||||
_i = i + same
|
_i = i + same
|
||||||
_j = j + same
|
_j = j + same
|
||||||
j = _j + 1
|
j = _j + 1
|
||||||
|
@ -62,10 +62,10 @@ class Dealer:
|
|||||||
res = {}
|
res = {}
|
||||||
f = open(fnm, "r")
|
f = open(fnm, "r")
|
||||||
while True:
|
while True:
|
||||||
l = f.readline()
|
line = f.readline()
|
||||||
if not l:
|
if not line:
|
||||||
break
|
break
|
||||||
arr = l.replace("\n", "").split("\t")
|
arr = line.replace("\n", "").split("\t")
|
||||||
if len(arr) < 2:
|
if len(arr) < 2:
|
||||||
res[arr[0]] = 0
|
res[arr[0]] = 0
|
||||||
else:
|
else:
|
||||||
|
@ -47,7 +47,8 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
|||||||
def __call__(self, chunks, random_state, callback=None):
|
def __call__(self, chunks, random_state, callback=None):
|
||||||
layers = [(0, len(chunks))]
|
layers = [(0, len(chunks))]
|
||||||
start, end = 0, len(chunks)
|
start, end = 0, len(chunks)
|
||||||
if len(chunks) <= 1: return
|
if len(chunks) <= 1:
|
||||||
|
return
|
||||||
chunks = [(s, a) for s, a in chunks if len(a) > 0]
|
chunks = [(s, a) for s, a in chunks if len(a) > 0]
|
||||||
|
|
||||||
def summarize(ck_idx, lock):
|
def summarize(ck_idx, lock):
|
||||||
@ -66,7 +67,8 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
|||||||
logging.debug(f"SUM: {cnt}")
|
logging.debug(f"SUM: {cnt}")
|
||||||
embds, _ = self._embd_model.encode([cnt])
|
embds, _ = self._embd_model.encode([cnt])
|
||||||
with lock:
|
with lock:
|
||||||
if not len(embds[0]): return
|
if not len(embds[0]):
|
||||||
|
return
|
||||||
chunks.append((cnt, embds[0]))
|
chunks.append((cnt, embds[0]))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception("summarize got exception")
|
logging.exception("summarize got exception")
|
||||||
|
@ -33,14 +33,16 @@ def collect():
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
locations = collect()
|
locations = collect()
|
||||||
if not locations:return
|
if not locations:
|
||||||
|
return
|
||||||
logging.info(f"TASKS: {len(locations)}")
|
logging.info(f"TASKS: {len(locations)}")
|
||||||
for kb_id, loc in locations:
|
for kb_id, loc in locations:
|
||||||
try:
|
try:
|
||||||
if REDIS_CONN.is_alive():
|
if REDIS_CONN.is_alive():
|
||||||
try:
|
try:
|
||||||
key = "{}/{}".format(kb_id, loc)
|
key = "{}/{}".format(kb_id, loc)
|
||||||
if REDIS_CONN.exist(key):continue
|
if REDIS_CONN.exist(key):
|
||||||
|
continue
|
||||||
file_bin = STORAGE_IMPL.get(kb_id, loc)
|
file_bin = STORAGE_IMPL.get(kb_id, loc)
|
||||||
REDIS_CONN.transaction(key, file_bin, 12 * 60)
|
REDIS_CONN.transaction(key, file_bin, 12 * 60)
|
||||||
logging.info("CACHE: {}".format(loc))
|
logging.info("CACHE: {}".format(loc))
|
||||||
|
@ -23,18 +23,12 @@ import os
|
|||||||
|
|
||||||
from api.utils.log_utils import initRootLogger
|
from api.utils.log_utils import initRootLogger
|
||||||
|
|
||||||
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
|
|
||||||
CONSUMER_NAME = "task_executor_" + CONSUMER_NO
|
|
||||||
LOG_LEVELS = os.environ.get("LOG_LEVELS", "")
|
|
||||||
initRootLogger(CONSUMER_NAME, LOG_LEVELS)
|
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import copy
|
import copy
|
||||||
import re
|
import re
|
||||||
import sys
|
|
||||||
import time
|
import time
|
||||||
import threading
|
import threading
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@ -63,6 +57,11 @@ from rag.utils import rmSpace, num_tokens_from_string
|
|||||||
from rag.utils.redis_conn import REDIS_CONN, Payload
|
from rag.utils.redis_conn import REDIS_CONN, Payload
|
||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
|
|
||||||
|
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
|
||||||
|
CONSUMER_NAME = "task_executor_" + CONSUMER_NO
|
||||||
|
LOG_LEVELS = os.environ.get("LOG_LEVELS", "")
|
||||||
|
initRootLogger(CONSUMER_NAME, LOG_LEVELS)
|
||||||
|
|
||||||
BATCH_SIZE = 64
|
BATCH_SIZE = 64
|
||||||
|
|
||||||
FACTORY = {
|
FACTORY = {
|
||||||
@ -201,7 +200,8 @@ def build_chunks(task, progress_callback):
|
|||||||
"doc_id": task["doc_id"],
|
"doc_id": task["doc_id"],
|
||||||
"kb_id": str(task["kb_id"])
|
"kb_id": str(task["kb_id"])
|
||||||
}
|
}
|
||||||
if task["pagerank"]: doc["pagerank_fea"] = int(task["pagerank"])
|
if task["pagerank"]:
|
||||||
|
doc["pagerank_fea"] = int(task["pagerank"])
|
||||||
el = 0
|
el = 0
|
||||||
for ck in cks:
|
for ck in cks:
|
||||||
d = copy.deepcopy(doc)
|
d = copy.deepcopy(doc)
|
||||||
@ -342,7 +342,8 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
|
|||||||
"docnm_kwd": row["name"],
|
"docnm_kwd": row["name"],
|
||||||
"title_tks": rag_tokenizer.tokenize(row["name"])
|
"title_tks": rag_tokenizer.tokenize(row["name"])
|
||||||
}
|
}
|
||||||
if row["pagerank"]: doc["pagerank_fea"] = int(row["pagerank"])
|
if row["pagerank"]:
|
||||||
|
doc["pagerank_fea"] = int(row["pagerank"])
|
||||||
res = []
|
res = []
|
||||||
tk_count = 0
|
tk_count = 0
|
||||||
for content, vctr in chunks[original_length:]:
|
for content, vctr in chunks[original_length:]:
|
||||||
|
@ -41,15 +41,15 @@ def findMaxDt(fnm):
|
|||||||
try:
|
try:
|
||||||
with open(fnm, "r") as f:
|
with open(fnm, "r") as f:
|
||||||
while True:
|
while True:
|
||||||
l = f.readline()
|
line = f.readline()
|
||||||
if not l:
|
if not line:
|
||||||
break
|
break
|
||||||
l = l.strip("\n")
|
line = line.strip("\n")
|
||||||
if l == 'nan':
|
if line == 'nan':
|
||||||
continue
|
continue
|
||||||
if l > m:
|
if line > m:
|
||||||
m = l
|
m = line
|
||||||
except Exception as e:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return m
|
return m
|
||||||
|
|
||||||
@ -59,15 +59,15 @@ def findMaxTm(fnm):
|
|||||||
try:
|
try:
|
||||||
with open(fnm, "r") as f:
|
with open(fnm, "r") as f:
|
||||||
while True:
|
while True:
|
||||||
l = f.readline()
|
line = f.readline()
|
||||||
if not l:
|
if not line:
|
||||||
break
|
break
|
||||||
l = l.strip("\n")
|
line = line.strip("\n")
|
||||||
if l == 'nan':
|
if line == 'nan':
|
||||||
continue
|
continue
|
||||||
if int(l) > m:
|
if int(line) > m:
|
||||||
m = int(l)
|
m = int(line)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return m
|
return m
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ class RAGFlowAzureSasBlob(object):
|
|||||||
self.conn = None
|
self.conn = None
|
||||||
|
|
||||||
def health(self):
|
def health(self):
|
||||||
bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
|
_bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
|
||||||
return self.conn.upload_blob(name=fnm, data=BytesIO(binary), length=len(binary))
|
return self.conn.upload_blob(name=fnm, data=BytesIO(binary), length=len(binary))
|
||||||
|
|
||||||
def put(self, bucket, fnm, binary):
|
def put(self, bucket, fnm, binary):
|
||||||
|
@ -36,7 +36,7 @@ class RAGFlowAzureSpnBlob(object):
|
|||||||
self.conn = None
|
self.conn = None
|
||||||
|
|
||||||
def health(self):
|
def health(self):
|
||||||
bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
|
_bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
|
||||||
f = self.conn.create_file(fnm)
|
f = self.conn.create_file(fnm)
|
||||||
f.append_data(binary, offset=0, length=len(binary))
|
f.append_data(binary, offset=0, length=len(binary))
|
||||||
return f.flush_data(len(binary))
|
return f.flush_data(len(binary))
|
||||||
|
@ -132,7 +132,8 @@ class ESConnection(DocStoreConnection):
|
|||||||
bqry.filter.append(
|
bqry.filter.append(
|
||||||
Q("bool", must_not=Q("range", available_int={"lt": 1})))
|
Q("bool", must_not=Q("range", available_int={"lt": 1})))
|
||||||
continue
|
continue
|
||||||
if not v: continue
|
if not v:
|
||||||
|
continue
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
bqry.filter.append(Q("terms", **{k: v}))
|
bqry.filter.append(Q("terms", **{k: v}))
|
||||||
elif isinstance(v, str) or isinstance(v, int):
|
elif isinstance(v, str) or isinstance(v, int):
|
||||||
|
@ -1,14 +1,21 @@
|
|||||||
from beartype.claw import beartype_this_package
|
|
||||||
beartype_this_package() # <-- raise exceptions in your code
|
|
||||||
|
|
||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
|
|
||||||
__version__ = importlib.metadata.version("ragflow_sdk")
|
|
||||||
|
|
||||||
from .ragflow import RAGFlow
|
from .ragflow import RAGFlow
|
||||||
from .modules.dataset import DataSet
|
from .modules.dataset import DataSet
|
||||||
from .modules.chat import Chat
|
from .modules.chat import Chat
|
||||||
from .modules.session import Session
|
from .modules.session import Session
|
||||||
from .modules.document import Document
|
from .modules.document import Document
|
||||||
from .modules.chunk import Chunk
|
from .modules.chunk import Chunk
|
||||||
from .modules.agent import Agent
|
from .modules.agent import Agent
|
||||||
|
|
||||||
|
__version__ = importlib.metadata.version("ragflow_sdk")
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"RAGFlow",
|
||||||
|
"DataSet",
|
||||||
|
"Chat",
|
||||||
|
"Session",
|
||||||
|
"Document",
|
||||||
|
"Chunk",
|
||||||
|
"Agent"
|
||||||
|
]
|
@ -29,7 +29,7 @@ class Session(Base):
|
|||||||
raise Exception(json_data["message"])
|
raise Exception(json_data["message"])
|
||||||
if line.startswith("data:"):
|
if line.startswith("data:"):
|
||||||
json_data = json.loads(line[5:])
|
json_data = json.loads(line[5:])
|
||||||
if json_data["data"] != True:
|
if not json_data["data"]:
|
||||||
answer = json_data["data"]["answer"]
|
answer = json_data["data"]["answer"]
|
||||||
reference = json_data["data"]["reference"]
|
reference = json_data["data"]["reference"]
|
||||||
temp_dict = {
|
temp_dict = {
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
import string
|
|
||||||
import random
|
|
||||||
import os
|
import os
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
@ -39,7 +39,6 @@ def update_dataset(auth, json_req):
|
|||||||
def upload_file(auth, dataset_id, path):
|
def upload_file(auth, dataset_id, path):
|
||||||
authorization = {"Authorization": auth}
|
authorization = {"Authorization": auth}
|
||||||
url = f"{HOST_ADDRESS}/v1/document/upload"
|
url = f"{HOST_ADDRESS}/v1/document/upload"
|
||||||
base_name = os.path.basename(path)
|
|
||||||
json_req = {
|
json_req = {
|
||||||
"kb_id": dataset_id,
|
"kb_id": dataset_id,
|
||||||
}
|
}
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
def test_get_email(get_email):
|
def test_get_email(get_email):
|
||||||
print(f"\nEmail account:",flush=True)
|
print("\nEmail account:",flush=True)
|
||||||
print(f"{get_email}\n",flush=True)
|
print(f"{get_email}\n",flush=True)
|
@ -13,14 +13,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from common import HOST_ADDRESS, create_dataset, list_dataset, rm_dataset, update_dataset, upload_file, DATASET_NAME_LIMIT
|
from common import create_dataset, list_dataset, rm_dataset, upload_file
|
||||||
from common import list_document, get_docs_info, parse_docs
|
from common import list_document, get_docs_info, parse_docs
|
||||||
from time import sleep
|
from time import sleep
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
import re
|
|
||||||
import pytest
|
|
||||||
import random
|
|
||||||
import string
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_txt_document(get_auth):
|
def test_parse_txt_document(get_auth):
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
from common import HOST_ADDRESS, create_dataset, list_dataset, rm_dataset, update_dataset, DATASET_NAME_LIMIT
|
from common import create_dataset, list_dataset, rm_dataset, update_dataset, DATASET_NAME_LIMIT
|
||||||
import re
|
import re
|
||||||
import pytest
|
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
|
|
||||||
@ -33,8 +32,6 @@ def test_dataset(get_auth):
|
|||||||
|
|
||||||
def test_dataset_1k_dataset(get_auth):
|
def test_dataset_1k_dataset(get_auth):
|
||||||
# create dataset
|
# create dataset
|
||||||
authorization = {"Authorization": get_auth}
|
|
||||||
url = f"{HOST_ADDRESS}/v1/kb/create"
|
|
||||||
for i in range(1000):
|
for i in range(1000):
|
||||||
res = create_dataset(get_auth, f"test_create_dataset_{i}")
|
res = create_dataset(get_auth, f"test_create_dataset_{i}")
|
||||||
assert res.get("code") == 0, f"{res.get('message')}"
|
assert res.get("code") == 0, f"{res.get('message')}"
|
||||||
@ -76,7 +73,7 @@ def test_duplicated_name_dataset(get_auth):
|
|||||||
dataset_id = item.get("id")
|
dataset_id = item.get("id")
|
||||||
dataset_list.append(dataset_id)
|
dataset_list.append(dataset_id)
|
||||||
match = re.match(pattern, dataset_name)
|
match = re.match(pattern, dataset_name)
|
||||||
assert match != None
|
assert match is not None
|
||||||
|
|
||||||
for dataset_id in dataset_list:
|
for dataset_id in dataset_list:
|
||||||
res = rm_dataset(get_auth, dataset_id)
|
res = rm_dataset(get_auth, dataset_id)
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
def test_get_email(get_email):
|
def test_get_email(get_email):
|
||||||
print(f"\nEmail account:",flush=True)
|
print("\nEmail account:",flush=True)
|
||||||
print(f"{get_email}\n",flush=True)
|
print(f"{get_email}\n",flush=True)
|
@ -1,4 +1,4 @@
|
|||||||
from ragflow_sdk import RAGFlow,Agent
|
from ragflow_sdk import RAGFlow
|
||||||
from common import HOST_ADDRESS
|
from common import HOST_ADDRESS
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user