diff --git a/rag/nlp/rag_tokenizer.py b/rag/nlp/rag_tokenizer.py index 0b78c4e09..c262ca9b4 100644 --- a/rag/nlp/rag_tokenizer.py +++ b/rag/nlp/rag_tokenizer.py @@ -66,7 +66,7 @@ class RagTokenizer: self.stemmer = PorterStemmer() self.lemmatizer = WordNetLemmatizer() - self.SPLIT_CHAR = r"([ ,\.<>/?;:'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-z\.-]+|[0-9,\.-]+)" + self.SPLIT_CHAR = r"([ ,\.<>/?;:'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-zA-Z0-9,\.-]+)" trie_file_name = self.DIR_ + ".txt.trie" # check if trie file existence @@ -263,22 +263,44 @@ class RagTokenizer: def english_normalize_(self, tks): return [self.stemmer.stem(self.lemmatizer.lemmatize(t)) if re.match(r"[a-zA-Z_-]+$", t) else t for t in tks] + def _split_by_lang(self, line): + txt_lang_pairs = [] + arr = re.split(self.SPLIT_CHAR, line) + for a in arr: + if not a: + continue + s = 0 + e = s + 1 + zh = is_chinese(a[s]) + while e < len(a): + _zh = is_chinese(a[e]) + if _zh == zh: + e += 1 + continue + txt_lang_pairs.append((a[s: e], zh)) + s = e + e = s + 1 + zh = _zh + if s >= len(a): + continue + txt_lang_pairs.append((a[s: e], zh)) + return txt_lang_pairs + def tokenize(self, line): line = re.sub(r"\W+", " ", line) line = self._strQ2B(line).lower() line = self._tradi2simp(line) - zh_num = len([1 for c in line if is_chinese(c)]) - if zh_num == 0: - return " ".join([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(line)]) - arr = re.split(self.SPLIT_CHAR, line) + arr = self._split_by_lang(line) res = [] - for L in arr: + for L,lang in arr: + if not lang: + res.extend([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(L)]) + continue if len(L) < 2 or re.match( r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L): res.append(L) continue - # print(L) # use maxforward for the first time tks, s = self.maxForward_(L) @@ -332,7 +354,7 @@ class RagTokenizer: self.dfs_("".join(tks[_j:]), 0, [], tkslist) res.append(" ".join(self.sortTks_(tkslist)[0][0])) - res = " ".join(self.english_normalize_(res)) + res = " ".join(res) logging.debug("[TKS] {}".format(self.merge_(res))) return self.merge_(res)