diff --git a/rag/nlp/rag_tokenizer.py b/rag/nlp/rag_tokenizer.py index eee013e61..bfe4887a6 100644 --- a/rag/nlp/rag_tokenizer.py +++ b/rag/nlp/rag_tokenizer.py @@ -281,34 +281,49 @@ class RagTokenizer: print("[FW]", tks, s) print("[BW]", tks1, s1) - diff = [0 for _ in range(max(len(tks1), len(tks)))] - for i in range(min(len(tks1), len(tks))): - if tks[i] != tks1[i]: - diff[i] = 1 + i, j, _i, _j = 0, 0, 0, 0 + same = 0 + while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]: + same += 1 + if same > 0: res.append(" ".join(tks[j: j + same])) + _i = i + same + _j = j + same + j = _j + 1 + i = _i + 1 - if s1 > s: - tks = tks1 - - i = 0 - while i < len(tks): - s = i - while s < len(tks) and diff[s] == 0: - s += 1 - if s == len(tks): - res.append(" ".join(tks[i:])) - break - if s > i: - res.append(" ".join(tks[i:s])) - - e = s - while e < len(tks) and e - s < 5 and diff[e] == 1: - e += 1 + while i < len(tks1) and j < len(tks): + tk1, tk = "".join(tks1[_i:i]), "".join(tks[_j:j]) + if tk1 != tk: + if len(tk1) > len(tk): + j += 1 + else: + i += 1 + continue + if tks1[i] != tks[j]: + i += 1 + j += 1 + continue + # backward tokens from_i to i are different from forward tokens from _j to j. tkslist = [] - self.dfs_("".join(tks[s:e + 1]), 0, [], tkslist) + self.dfs_("".join(tks[_j:j]), 0, [], tkslist) res.append(" ".join(self.sortTks_(tkslist)[0][0])) - i = e + 1 + same = 1 + while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]: + same += 1 + res.append(" ".join(tks[j: j + same])) + _i = i + same + _j = j + same + j = _j + 1 + i = _i + 1 + + if _i < len(tks1): + assert _j < len(tks) + assert "".join(tks1[_i:]) == "".join(tks[_j:]) + tkslist = [] + self.dfs_("".join(tks[_j:]), 0, [], tkslist) + res.append(" ".join(self.sortTks_(tkslist)[0][0])) res = " ".join(self.english_normalize_(res)) if self.DEBUG: