mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-22 14:10:01 +08:00
Fix infinite recursion in RagTokenizer when processing repetitive characters (#6109)
### What problem does this PR solve? fix #6085 RagTokenizer's dfs_() function falls into infinite recursion when processing text with repetitive Chinese characters (e.g., "一一一一一十一十一十一..." or "一一一一一一十十十十十十十二十二十二..."), causing memory leaks. ### Type of change Implemented three optimizations to the dfs_() function: 1.Added memoization with _memo dictionary to cache computed results 2.Added recursion depth limiting with _depth parameter (max 10 levels) 3.Implemented special handling for repetitive character sequences - [x] Bug Fix (non-breaking change which fixes an issue) Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
parent
58e6e7b668
commit
ead5f7aba9
@ -116,55 +116,86 @@ class RagTokenizer:
|
|||||||
def _tradi2simp(self, line):
|
def _tradi2simp(self, line):
|
||||||
return HanziConv.toSimplified(line)
|
return HanziConv.toSimplified(line)
|
||||||
|
|
||||||
def dfs_(self, chars, s, preTks, tkslist):
|
def dfs_(self, chars, s, preTks, tkslist, _depth=0, _memo=None):
|
||||||
|
if _memo is None:
|
||||||
|
_memo = {}
|
||||||
|
MAX_DEPTH = 10
|
||||||
|
if _depth > MAX_DEPTH:
|
||||||
|
if s < len(chars):
|
||||||
|
copy_pretks = copy.deepcopy(preTks)
|
||||||
|
remaining = "".join(chars[s:])
|
||||||
|
copy_pretks.append((remaining, (-12, '')))
|
||||||
|
tkslist.append(copy_pretks)
|
||||||
|
return s
|
||||||
|
|
||||||
|
state_key = (s, tuple(tk[0] for tk in preTks)) if preTks else (s, None)
|
||||||
|
if state_key in _memo:
|
||||||
|
return _memo[state_key]
|
||||||
|
|
||||||
res = s
|
res = s
|
||||||
if len(tkslist) >= 2048:
|
|
||||||
return res
|
|
||||||
# if s > MAX_L or s>= len(chars):
|
|
||||||
if s >= len(chars):
|
if s >= len(chars):
|
||||||
tkslist.append(preTks)
|
tkslist.append(preTks)
|
||||||
return res
|
_memo[state_key] = s
|
||||||
|
return s
|
||||||
|
if s < len(chars) - 4:
|
||||||
|
is_repetitive = True
|
||||||
|
char_to_check = chars[s]
|
||||||
|
for i in range(1, 5):
|
||||||
|
if s + i >= len(chars) or chars[s + i] != char_to_check:
|
||||||
|
is_repetitive = False
|
||||||
|
break
|
||||||
|
if is_repetitive:
|
||||||
|
end = s
|
||||||
|
while end < len(chars) and chars[end] == char_to_check:
|
||||||
|
end += 1
|
||||||
|
mid = s + min(10, end - s)
|
||||||
|
t = "".join(chars[s:mid])
|
||||||
|
k = self.key_(t)
|
||||||
|
copy_pretks = copy.deepcopy(preTks)
|
||||||
|
if k in self.trie_:
|
||||||
|
copy_pretks.append((t, self.trie_[k]))
|
||||||
|
else:
|
||||||
|
copy_pretks.append((t, (-12, '')))
|
||||||
|
next_res = self.dfs_(chars, mid, copy_pretks, tkslist, _depth + 1, _memo)
|
||||||
|
res = max(res, next_res)
|
||||||
|
_memo[state_key] = res
|
||||||
|
return res
|
||||||
|
|
||||||
# pruning
|
|
||||||
S = s + 1
|
S = s + 1
|
||||||
if s + 2 <= len(chars):
|
if s + 2 <= len(chars):
|
||||||
t1, t2 = "".join(chars[s:s + 1]), "".join(chars[s:s + 2])
|
t1 = "".join(chars[s:s + 1])
|
||||||
if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(
|
t2 = "".join(chars[s:s + 2])
|
||||||
self.key_(t2)):
|
if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(self.key_(t2)):
|
||||||
S = s + 2
|
S = s + 2
|
||||||
if len(preTks) > 2 and len(
|
if len(preTks) > 2 and len(preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
|
||||||
preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
|
|
||||||
t1 = preTks[-1][0] + "".join(chars[s:s + 1])
|
t1 = preTks[-1][0] + "".join(chars[s:s + 1])
|
||||||
if self.trie_.has_keys_with_prefix(self.key_(t1)):
|
if self.trie_.has_keys_with_prefix(self.key_(t1)):
|
||||||
S = s + 2
|
S = s + 2
|
||||||
|
|
||||||
################
|
|
||||||
for e in range(S, len(chars) + 1):
|
for e in range(S, len(chars) + 1):
|
||||||
t = "".join(chars[s:e])
|
t = "".join(chars[s:e])
|
||||||
k = self.key_(t)
|
k = self.key_(t)
|
||||||
|
|
||||||
if e > s + 1 and not self.trie_.has_keys_with_prefix(k):
|
if e > s + 1 and not self.trie_.has_keys_with_prefix(k):
|
||||||
break
|
break
|
||||||
|
|
||||||
if k in self.trie_:
|
if k in self.trie_:
|
||||||
pretks = copy.deepcopy(preTks)
|
pretks = copy.deepcopy(preTks)
|
||||||
if k in self.trie_:
|
pretks.append((t, self.trie_[k]))
|
||||||
pretks.append((t, self.trie_[k]))
|
res = max(res, self.dfs_(chars, e, pretks, tkslist, _depth + 1, _memo))
|
||||||
else:
|
|
||||||
pretks.append((t, (-12, '')))
|
|
||||||
res = max(res, self.dfs_(chars, e, pretks, tkslist))
|
|
||||||
|
|
||||||
if res > s:
|
if res > s:
|
||||||
|
_memo[state_key] = res
|
||||||
return res
|
return res
|
||||||
|
|
||||||
t = "".join(chars[s:s + 1])
|
t = "".join(chars[s:s + 1])
|
||||||
k = self.key_(t)
|
k = self.key_(t)
|
||||||
|
copy_pretks = copy.deepcopy(preTks)
|
||||||
if k in self.trie_:
|
if k in self.trie_:
|
||||||
preTks.append((t, self.trie_[k]))
|
copy_pretks.append((t, self.trie_[k]))
|
||||||
else:
|
else:
|
||||||
preTks.append((t, (-12, '')))
|
copy_pretks.append((t, (-12, '')))
|
||||||
|
result = self.dfs_(chars, s + 1, copy_pretks, tkslist, _depth + 1, _memo)
|
||||||
return self.dfs_(chars, s + 1, preTks, tkslist)
|
_memo[state_key] = result
|
||||||
|
return result
|
||||||
|
|
||||||
def freq(self, tk):
|
def freq(self, tk):
|
||||||
k = self.key_(tk)
|
k = self.key_(tk)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user