diff --git a/rag/nlp/rag_tokenizer.py b/rag/nlp/rag_tokenizer.py index 6f2686851..c3394971e 100644 --- a/rag/nlp/rag_tokenizer.py +++ b/rag/nlp/rag_tokenizer.py @@ -116,55 +116,86 @@ class RagTokenizer: def _tradi2simp(self, line): return HanziConv.toSimplified(line) - def dfs_(self, chars, s, preTks, tkslist): + def dfs_(self, chars, s, preTks, tkslist, _depth=0, _memo=None): + if _memo is None: + _memo = {} + MAX_DEPTH = 10 + if _depth > MAX_DEPTH: + if s < len(chars): + copy_pretks = copy.deepcopy(preTks) + remaining = "".join(chars[s:]) + copy_pretks.append((remaining, (-12, ''))) + tkslist.append(copy_pretks) + return s + + state_key = (s, tuple(tk[0] for tk in preTks)) if preTks else (s, None) + if state_key in _memo: + return _memo[state_key] + res = s - if len(tkslist) >= 2048: - return res - # if s > MAX_L or s>= len(chars): if s >= len(chars): tkslist.append(preTks) - return res - - # pruning + _memo[state_key] = s + return s + if s < len(chars) - 4: + is_repetitive = True + char_to_check = chars[s] + for i in range(1, 5): + if s + i >= len(chars) or chars[s + i] != char_to_check: + is_repetitive = False + break + if is_repetitive: + end = s + while end < len(chars) and chars[end] == char_to_check: + end += 1 + mid = s + min(10, end - s) + t = "".join(chars[s:mid]) + k = self.key_(t) + copy_pretks = copy.deepcopy(preTks) + if k in self.trie_: + copy_pretks.append((t, self.trie_[k])) + else: + copy_pretks.append((t, (-12, ''))) + next_res = self.dfs_(chars, mid, copy_pretks, tkslist, _depth + 1, _memo) + res = max(res, next_res) + _memo[state_key] = res + return res + S = s + 1 if s + 2 <= len(chars): - t1, t2 = "".join(chars[s:s + 1]), "".join(chars[s:s + 2]) - if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix( - self.key_(t2)): + t1 = "".join(chars[s:s + 1]) + t2 = "".join(chars[s:s + 2]) + if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(self.key_(t2)): S = s + 2 - if len(preTks) > 2 and len( - preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1: + if len(preTks) > 2 and len(preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1: t1 = preTks[-1][0] + "".join(chars[s:s + 1]) if self.trie_.has_keys_with_prefix(self.key_(t1)): S = s + 2 - - ################ + for e in range(S, len(chars) + 1): t = "".join(chars[s:e]) k = self.key_(t) - if e > s + 1 and not self.trie_.has_keys_with_prefix(k): break - if k in self.trie_: pretks = copy.deepcopy(preTks) - if k in self.trie_: - pretks.append((t, self.trie_[k])) - else: - pretks.append((t, (-12, ''))) - res = max(res, self.dfs_(chars, e, pretks, tkslist)) - + pretks.append((t, self.trie_[k])) + res = max(res, self.dfs_(chars, e, pretks, tkslist, _depth + 1, _memo)) + if res > s: + _memo[state_key] = res return res - + t = "".join(chars[s:s + 1]) k = self.key_(t) + copy_pretks = copy.deepcopy(preTks) if k in self.trie_: - preTks.append((t, self.trie_[k])) + copy_pretks.append((t, self.trie_[k])) else: - preTks.append((t, (-12, ''))) - - return self.dfs_(chars, s + 1, preTks, tkslist) + copy_pretks.append((t, (-12, ''))) + result = self.dfs_(chars, s + 1, copy_pretks, tkslist, _depth + 1, _memo) + _memo[state_key] = result + return result def freq(self, tk): k = self.key_(tk)