ragflow/rag/nlp/__init__.py
KevinHuSh c6b6c748ae
fix file encoding detection bug (#653)
### What problem does this PR solve?

#651 

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2024-05-07 10:01:24 +08:00

392 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import random
from collections import Counter
from rag.utils import num_tokens_from_string
from . import rag_tokenizer
import re
import copy
all_codecs = [
'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs',
'cp037', 'cp273', 'cp424', 'cp437',
'cp500', 'cp720', 'cp737', 'cp775', 'cp850', 'cp852', 'cp855', 'cp856', 'cp857',
'cp858', 'cp860', 'cp861', 'cp862', 'cp863', 'cp864', 'cp865', 'cp866', 'cp869',
'cp874', 'cp875', 'cp932', 'cp949', 'cp950', 'cp1006', 'cp1026', 'cp1125',
'cp1140', 'cp1250', 'cp1251', 'cp1252', 'cp1253', 'cp1254', 'cp1255', 'cp1256',
'cp1257', 'cp1258', 'euc_jp', 'euc_jis_2004', 'euc_jisx0213', 'euc_kr',
'gb2312', 'gb18030', 'hz', 'iso2022_jp', 'iso2022_jp_1', 'iso2022_jp_2',
'iso2022_jp_2004', 'iso2022_jp_3', 'iso2022_jp_ext', 'iso2022_kr', 'latin_1',
'iso8859_2', 'iso8859_3', 'iso8859_4', 'iso8859_5', 'iso8859_6', 'iso8859_7',
'iso8859_8', 'iso8859_9', 'iso8859_10', 'iso8859_11', 'iso8859_13',
'iso8859_14', 'iso8859_15', 'iso8859_16', 'johab', 'koi8_r', 'koi8_t', 'koi8_u',
'kz1048', 'mac_cyrillic', 'mac_greek', 'mac_iceland', 'mac_latin2', 'mac_roman',
'mac_turkish', 'ptcp154', 'shift_jis', 'shift_jis_2004', 'shift_jisx0213',
'utf_32', 'utf_32_be', 'utf_32_le''utf_16_be', 'utf_16_le', 'utf_7'
]
def find_codec(blob):
global all_codecs
for c in all_codecs:
try:
blob[:1024].decode(c)
return c
except Exception as e:
pass
try:
blob.decode(c)
return c
except Exception as e:
pass
return "utf-8"
BULLET_PATTERN = [[
r"第[零一二三四五六七八九十百0-9]+(分?编|部分)",
r"第[零一二三四五六七八九十百0-9]+章",
r"第[零一二三四五六七八九十百0-9]+节",
r"第[零一二三四五六七八九十百0-9]+条",
r"[\(][零一二三四五六七八九十百]+[\)]",
], [
r"第[0-9]+章",
r"第[0-9]+节",
r"[0-9]{,2}[\. 、]",
r"[0-9]{,2}\.[0-9]{,2}[^a-zA-Z/%~-]",
r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}",
r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}",
], [
r"第[零一二三四五六七八九十百0-9]+章",
r"第[零一二三四五六七八九十百0-9]+节",
r"[零一二三四五六七八九十百]+[ 、]",
r"[\(][零一二三四五六七八九十百]+[\)]",
r"[\(][0-9]{,2}[\)]",
], [
r"PART (ONE|TWO|THREE|FOUR|FIVE|SIX|SEVEN|EIGHT|NINE|TEN)",
r"Chapter (I+V?|VI*|XI|IX|X)",
r"Section [0-9]+",
r"Article [0-9]+"
]
]
def random_choices(arr, k):
k = min(len(arr), k)
return random.choices(arr, k=k)
def not_bullet(line):
patt = [
r"0", r"[0-9]+ +[0-9~个只-]", r"[0-9]+\.{2,}"
]
return any([re.match(r, line) for r in patt])
def bullets_category(sections):
global BULLET_PATTERN
hits = [0] * len(BULLET_PATTERN)
for i, pro in enumerate(BULLET_PATTERN):
for sec in sections:
for p in pro:
if re.match(p, sec) and not not_bullet(sec):
hits[i] += 1
break
maxium = 0
res = -1
for i, h in enumerate(hits):
if h <= maxium:
continue
res = i
maxium = h
return res
def is_english(texts):
eng = 0
if not texts: return False
for t in texts:
if re.match(r"[a-zA-Z]{2,}", t.strip()):
eng += 1
if eng / len(texts) > 0.8:
return True
return False
def tokenize(d, t, eng):
d["content_with_weight"] = t
t = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", t)
d["content_ltks"] = rag_tokenizer.tokenize(t)
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
def tokenize_chunks(chunks, doc, eng, pdf_parser):
res = []
# wrap up as es documents
for ck in chunks:
if len(ck.strip()) == 0:continue
print("--", ck)
d = copy.deepcopy(doc)
if pdf_parser:
try:
d["image"], poss = pdf_parser.crop(ck, need_position=True)
add_positions(d, poss)
ck = pdf_parser.remove_tag(ck)
except NotImplementedError as e:
pass
tokenize(d, ck, eng)
res.append(d)
return res
def tokenize_table(tbls, doc, eng, batch_size=10):
res = []
# add tables
for (img, rows), poss in tbls:
if not rows:
continue
if isinstance(rows, str):
d = copy.deepcopy(doc)
tokenize(d, rows, eng)
d["content_with_weight"] = rows
if img: d["image"] = img
if poss: add_positions(d, poss)
res.append(d)
continue
de = "; " if eng else " "
for i in range(0, len(rows), batch_size):
d = copy.deepcopy(doc)
r = de.join(rows[i:i + batch_size])
tokenize(d, r, eng)
d["image"] = img
add_positions(d, poss)
res.append(d)
return res
def add_positions(d, poss):
if not poss:
return
d["page_num_int"] = []
d["position_int"] = []
d["top_int"] = []
for pn, left, right, top, bottom in poss:
d["page_num_int"].append(int(pn + 1))
d["top_int"].append(int(top))
d["position_int"].append((int(pn + 1), int(left), int(right), int(top), int(bottom)))
def remove_contents_table(sections, eng=False):
i = 0
while i < len(sections):
def get(i):
nonlocal sections
return (sections[i] if isinstance(sections[i],
type("")) else sections[i][0]).strip()
if not re.match(r"(contents|目录|目次|table of contents|致谢|acknowledge)$",
re.sub(r"( | |\u3000)+", "", get(i).split("@@")[0], re.IGNORECASE)):
i += 1
continue
sections.pop(i)
if i >= len(sections):
break
prefix = get(i)[:3] if not eng else " ".join(get(i).split(" ")[:2])
while not prefix:
sections.pop(i)
if i >= len(sections):
break
prefix = get(i)[:3] if not eng else " ".join(get(i).split(" ")[:2])
sections.pop(i)
if i >= len(sections) or not prefix:
break
for j in range(i, min(i + 128, len(sections))):
if not re.match(prefix, get(j)):
continue
for _ in range(i, j):
sections.pop(i)
break
def make_colon_as_title(sections):
if not sections:
return []
if isinstance(sections[0], type("")):
return sections
i = 0
while i < len(sections):
txt, layout = sections[i]
i += 1
txt = txt.split("@")[0].strip()
if not txt:
continue
if txt[-1] not in ":":
continue
txt = txt[::-1]
arr = re.split(r"([。?!!?;]| .)", txt)
if len(arr) < 2 or len(arr[1]) < 32:
continue
sections.insert(i - 1, (arr[0][::-1], "title"))
i += 1
def title_frequency(bull, sections):
bullets_size = len(BULLET_PATTERN[bull])
levels = [bullets_size+1 for _ in range(len(sections))]
if not sections or bull < 0:
return bullets_size+1, levels
for i, (txt, layout) in enumerate(sections):
for j, p in enumerate(BULLET_PATTERN[bull]):
if re.match(p, txt.strip()) and not not_bullet(txt):
levels[i] = j
break
else:
if re.search(r"(title|head)", layout) and not not_title(txt.split("@")[0]):
levels[i] = bullets_size
most_level = bullets_size+1
for l, c in sorted(Counter(levels).items(), key=lambda x:x[1]*-1):
if l <= bullets_size:
most_level = l
break
return most_level, levels
def not_title(txt):
if re.match(r"第[零一二三四五六七八九十百0-9]+条", txt):
return False
if len(txt.split(" ")) > 12 or (txt.find(" ") < 0 and len(txt) >= 32):
return True
return re.search(r"[,;,。;!!]", txt)
def hierarchical_merge(bull, sections, depth):
if not sections or bull < 0:
return []
if isinstance(sections[0], type("")):
sections = [(s, "") for s in sections]
sections = [(t, o) for t, o in sections if
t and len(t.split("@")[0].strip()) > 1 and not re.match(r"[0-9]+$", t.split("@")[0].strip())]
bullets_size = len(BULLET_PATTERN[bull])
levels = [[] for _ in range(bullets_size + 2)]
for i, (txt, layout) in enumerate(sections):
for j, p in enumerate(BULLET_PATTERN[bull]):
if re.match(p, txt.strip()):
levels[j].append(i)
break
else:
if re.search(r"(title|head)", layout) and not not_title(txt):
levels[bullets_size].append(i)
else:
levels[bullets_size + 1].append(i)
sections = [t for t, _ in sections]
# for s in sections: print("--", s)
def binary_search(arr, target):
if not arr:
return -1
if target > arr[-1]:
return len(arr) - 1
if target < arr[0]:
return -1
s, e = 0, len(arr)
while e - s > 1:
i = (e + s) // 2
if target > arr[i]:
s = i
continue
elif target < arr[i]:
e = i
continue
else:
assert False
return s
cks = []
readed = [False] * len(sections)
levels = levels[::-1]
for i, arr in enumerate(levels[:depth]):
for j in arr:
if readed[j]:
continue
readed[j] = True
cks.append([j])
if i + 1 == len(levels) - 1:
continue
for ii in range(i + 1, len(levels)):
jj = binary_search(levels[ii], j)
if jj < 0:
continue
if jj > cks[-1][-1]:
cks[-1].pop(-1)
cks[-1].append(levels[ii][jj])
for ii in cks[-1]:
readed[ii] = True
if not cks:
return cks
for i in range(len(cks)):
cks[i] = [sections[j] for j in cks[i][::-1]]
print("--------------\n", "\n* ".join(cks[i]))
res = [[]]
num = [0]
for ck in cks:
if len(ck) == 1:
n = num_tokens_from_string(re.sub(r"@@[0-9]+.*", "", ck[0]))
if n + num[-1] < 218:
res[-1].append(ck[0])
num[-1] += n
continue
res.append(ck)
num.append(n)
continue
res.append(ck)
num.append(218)
return res
def naive_merge(sections, chunk_token_num=128, delimiter="\n。;!?"):
if not sections:
return []
if isinstance(sections[0], type("")):
sections = [(s, "") for s in sections]
cks = [""]
tk_nums = [0]
def add_chunk(t, pos):
nonlocal cks, tk_nums, delimiter
tnum = num_tokens_from_string(t)
if tnum < 8:
pos = ""
if tk_nums[-1] > chunk_token_num:
if t.find(pos) < 0:
t += pos
cks.append(t)
tk_nums.append(tnum)
else:
if cks[-1].find(pos) < 0:
t += pos
cks[-1] += t
tk_nums[-1] += tnum
for sec, pos in sections:
add_chunk(sec, pos)
continue
s, e = 0, 1
while e < len(sec):
if sec[e] in delimiter:
add_chunk(sec[s: e + 1], pos)
s = e + 1
e = s + 1
else:
e += 1
if s < e:
add_chunk(sec[s: e], pos)
return cks