Fix infinite recursion in RagTokenizer when processing repetitive characters (#6109)

### What problem does this PR solve?
fix #6085 
RagTokenizer's dfs_() function falls into infinite recursion when
processing text with repetitive Chinese characters (e.g.,
"一一一一一十一十一十一..." or "一一一一一一十十十十十十十二十二十二..."), causing memory leaks.
### Type of change
Implemented three optimizations to the dfs_() function:
1.Added memoization with _memo dictionary to cache computed results
2.Added recursion depth limiting with _depth parameter (max 10 levels)
3.Implemented special handling for repetitive character sequences
- [x] Bug Fix (non-breaking change which fixes an issue)

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
kaiyuan Zhang 2025-04-01 13:59:52 +08:00 committed by GitHub
parent 58e6e7b668
commit ead5f7aba9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -116,55 +116,86 @@ class RagTokenizer:
def _tradi2simp(self, line): def _tradi2simp(self, line):
return HanziConv.toSimplified(line) return HanziConv.toSimplified(line)
def dfs_(self, chars, s, preTks, tkslist): def dfs_(self, chars, s, preTks, tkslist, _depth=0, _memo=None):
if _memo is None:
_memo = {}
MAX_DEPTH = 10
if _depth > MAX_DEPTH:
if s < len(chars):
copy_pretks = copy.deepcopy(preTks)
remaining = "".join(chars[s:])
copy_pretks.append((remaining, (-12, '')))
tkslist.append(copy_pretks)
return s
state_key = (s, tuple(tk[0] for tk in preTks)) if preTks else (s, None)
if state_key in _memo:
return _memo[state_key]
res = s res = s
if len(tkslist) >= 2048:
return res
# if s > MAX_L or s>= len(chars):
if s >= len(chars): if s >= len(chars):
tkslist.append(preTks) tkslist.append(preTks)
return res _memo[state_key] = s
return s
if s < len(chars) - 4:
is_repetitive = True
char_to_check = chars[s]
for i in range(1, 5):
if s + i >= len(chars) or chars[s + i] != char_to_check:
is_repetitive = False
break
if is_repetitive:
end = s
while end < len(chars) and chars[end] == char_to_check:
end += 1
mid = s + min(10, end - s)
t = "".join(chars[s:mid])
k = self.key_(t)
copy_pretks = copy.deepcopy(preTks)
if k in self.trie_:
copy_pretks.append((t, self.trie_[k]))
else:
copy_pretks.append((t, (-12, '')))
next_res = self.dfs_(chars, mid, copy_pretks, tkslist, _depth + 1, _memo)
res = max(res, next_res)
_memo[state_key] = res
return res
# pruning
S = s + 1 S = s + 1
if s + 2 <= len(chars): if s + 2 <= len(chars):
t1, t2 = "".join(chars[s:s + 1]), "".join(chars[s:s + 2]) t1 = "".join(chars[s:s + 1])
if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix( t2 = "".join(chars[s:s + 2])
self.key_(t2)): if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(self.key_(t2)):
S = s + 2 S = s + 2
if len(preTks) > 2 and len( if len(preTks) > 2 and len(preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
t1 = preTks[-1][0] + "".join(chars[s:s + 1]) t1 = preTks[-1][0] + "".join(chars[s:s + 1])
if self.trie_.has_keys_with_prefix(self.key_(t1)): if self.trie_.has_keys_with_prefix(self.key_(t1)):
S = s + 2 S = s + 2
################
for e in range(S, len(chars) + 1): for e in range(S, len(chars) + 1):
t = "".join(chars[s:e]) t = "".join(chars[s:e])
k = self.key_(t) k = self.key_(t)
if e > s + 1 and not self.trie_.has_keys_with_prefix(k): if e > s + 1 and not self.trie_.has_keys_with_prefix(k):
break break
if k in self.trie_: if k in self.trie_:
pretks = copy.deepcopy(preTks) pretks = copy.deepcopy(preTks)
if k in self.trie_: pretks.append((t, self.trie_[k]))
pretks.append((t, self.trie_[k])) res = max(res, self.dfs_(chars, e, pretks, tkslist, _depth + 1, _memo))
else:
pretks.append((t, (-12, '')))
res = max(res, self.dfs_(chars, e, pretks, tkslist))
if res > s: if res > s:
_memo[state_key] = res
return res return res
t = "".join(chars[s:s + 1]) t = "".join(chars[s:s + 1])
k = self.key_(t) k = self.key_(t)
copy_pretks = copy.deepcopy(preTks)
if k in self.trie_: if k in self.trie_:
preTks.append((t, self.trie_[k])) copy_pretks.append((t, self.trie_[k]))
else: else:
preTks.append((t, (-12, ''))) copy_pretks.append((t, (-12, '')))
result = self.dfs_(chars, s + 1, copy_pretks, tkslist, _depth + 1, _memo)
return self.dfs_(chars, s + 1, preTks, tkslist) _memo[state_key] = result
return result
def freq(self, tk): def freq(self, tk):
k = self.key_(tk) k = self.key_(tk)