diff --git a/rag/nlp/__init__.py b/rag/nlp/__init__.py index ae932be85..47a044669 100644 --- a/rag/nlp/__init__.py +++ b/rag/nlp/__init__.py @@ -19,7 +19,6 @@ import random from collections import Counter from common.token_utils import num_tokens_from_string -from . import rag_tokenizer import re import copy import roman_numbers as r @@ -29,6 +28,8 @@ from PIL import Image import chardet +__all__ = ['rag_tokenizer'] + all_codecs = [ 'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs', 'cp037', 'cp273', 'cp424', 'cp437', diff --git a/rag/nlp/rag_tokenizer.py b/rag/nlp/rag_tokenizer.py index 234566f7f..2ad0bfd3f 100644 --- a/rag/nlp/rag_tokenizer.py +++ b/rag/nlp/rag_tokenizer.py @@ -14,455 +14,23 @@ # limitations under the License. # -import logging -import copy -import datrie -import math -import os -import re -import string -import sys -from hanziconv import HanziConv -from nltk import word_tokenize -from nltk.stem import PorterStemmer, WordNetLemmatizer -from common.file_utils import get_project_base_directory +from infinity import rag_tokenizer from common import settings -class RagTokenizer: - def key_(self, line): - return str(line.lower().encode("utf-8"))[2:-1] - - def rkey_(self, line): - return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1] - - def _load_dict(self, fnm): - logging.info(f"[HUQIE]:Build trie from {fnm}") - try: - of = open(fnm, "r", encoding="utf-8") - while True: - line = of.readline() - if not line: - break - line = re.sub(r"[\r\n]+", "", line) - line = re.split(r"[ \t]", line) - k = self.key_(line[0]) - F = int(math.log(float(line[1]) / self.DENOMINATOR) + 0.5) - if k not in self.trie_ or self.trie_[k][0] < F: - self.trie_[self.key_(line[0])] = (F, line[2]) - self.trie_[self.rkey_(line[0])] = 1 - - dict_file_cache = fnm + ".trie" - logging.info(f"[HUQIE]:Build trie cache to {dict_file_cache}") - self.trie_.save(dict_file_cache) - of.close() - except Exception: - logging.exception(f"[HUQIE]:Build trie {fnm} failed") - - def __init__(self, debug=False): - self.DEBUG = debug - self.DENOMINATOR = 1000000 - self.DIR_ = os.path.join(get_project_base_directory(), "rag/res", "huqie") - - self.stemmer = PorterStemmer() - self.lemmatizer = WordNetLemmatizer() - - self.SPLIT_CHAR = r"([ ,\.<>/?;:'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-zA-Z0-9,\.-]+)" - - trie_file_name = self.DIR_ + ".txt.trie" - # check if trie file existence - if os.path.exists(trie_file_name): - try: - # load trie from file - self.trie_ = datrie.Trie.load(trie_file_name) - return - except Exception: - # fail to load trie from file, build default trie - logging.exception(f"[HUQIE]:Fail to load trie file {trie_file_name}, build the default trie file") - self.trie_ = datrie.Trie(string.printable) - else: - # file not exist, build default trie - logging.info(f"[HUQIE]:Trie file {trie_file_name} not found, build the default trie file") - self.trie_ = datrie.Trie(string.printable) - - # load data from dict file and save to trie file - self._load_dict(self.DIR_ + ".txt") - - def load_user_dict(self, fnm): - try: - self.trie_ = datrie.Trie.load(fnm + ".trie") - return - except Exception: - self.trie_ = datrie.Trie(string.printable) - self._load_dict(fnm) - - def add_user_dict(self, fnm): - self._load_dict(fnm) - - def _strQ2B(self, ustring): - """Convert full-width characters to half-width characters""" - rstring = "" - for uchar in ustring: - inside_code = ord(uchar) - if inside_code == 0x3000: - inside_code = 0x0020 - else: - inside_code -= 0xFEE0 - if inside_code < 0x0020 or inside_code > 0x7E: # After the conversion, if it's not a half-width character, return the original character. - rstring += uchar - else: - rstring += chr(inside_code) - return rstring - - def _tradi2simp(self, line): - return HanziConv.toSimplified(line) - - def dfs_(self, chars, s, preTks, tkslist, _depth=0, _memo=None): - if _memo is None: - _memo = {} - MAX_DEPTH = 10 - if _depth > MAX_DEPTH: - if s < len(chars): - copy_pretks = copy.deepcopy(preTks) - remaining = "".join(chars[s:]) - copy_pretks.append((remaining, (-12, ""))) - tkslist.append(copy_pretks) - return s - - state_key = (s, tuple(tk[0] for tk in preTks)) if preTks else (s, None) - if state_key in _memo: - return _memo[state_key] - - res = s - if s >= len(chars): - tkslist.append(preTks) - _memo[state_key] = s - return s - if s < len(chars) - 4: - is_repetitive = True - char_to_check = chars[s] - for i in range(1, 5): - if s + i >= len(chars) or chars[s + i] != char_to_check: - is_repetitive = False - break - if is_repetitive: - end = s - while end < len(chars) and chars[end] == char_to_check: - end += 1 - mid = s + min(10, end - s) - t = "".join(chars[s:mid]) - k = self.key_(t) - copy_pretks = copy.deepcopy(preTks) - if k in self.trie_: - copy_pretks.append((t, self.trie_[k])) - else: - copy_pretks.append((t, (-12, ""))) - next_res = self.dfs_(chars, mid, copy_pretks, tkslist, _depth + 1, _memo) - res = max(res, next_res) - _memo[state_key] = res - return res - - S = s + 1 - if s + 2 <= len(chars): - t1 = "".join(chars[s : s + 1]) - t2 = "".join(chars[s : s + 2]) - if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(self.key_(t2)): - S = s + 2 - if len(preTks) > 2 and len(preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1: - t1 = preTks[-1][0] + "".join(chars[s : s + 1]) - if self.trie_.has_keys_with_prefix(self.key_(t1)): - S = s + 2 - - for e in range(S, len(chars) + 1): - t = "".join(chars[s:e]) - k = self.key_(t) - if e > s + 1 and not self.trie_.has_keys_with_prefix(k): - break - if k in self.trie_: - pretks = copy.deepcopy(preTks) - pretks.append((t, self.trie_[k])) - res = max(res, self.dfs_(chars, e, pretks, tkslist, _depth + 1, _memo)) - - if res > s: - _memo[state_key] = res - return res - - t = "".join(chars[s : s + 1]) - k = self.key_(t) - copy_pretks = copy.deepcopy(preTks) - if k in self.trie_: - copy_pretks.append((t, self.trie_[k])) - else: - copy_pretks.append((t, (-12, ""))) - result = self.dfs_(chars, s + 1, copy_pretks, tkslist, _depth + 1, _memo) - _memo[state_key] = result - return result - - def freq(self, tk): - k = self.key_(tk) - if k not in self.trie_: - return 0 - return int(math.exp(self.trie_[k][0]) * self.DENOMINATOR + 0.5) - - def tag(self, tk): - k = self.key_(tk) - if k not in self.trie_: - return "" - return self.trie_[k][1] - - def score_(self, tfts): - B = 30 - F, L, tks = 0, 0, [] - for tk, (freq, tag) in tfts: - F += freq - L += 0 if len(tk) < 2 else 1 - tks.append(tk) - # F /= len(tks) - L /= len(tks) - logging.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F)) - return tks, B / len(tks) + L + F - - def _sort_tokens(self, tkslist): - res = [] - for tfts in tkslist: - tks, s = self.score_(tfts) - res.append((tks, s)) - return sorted(res, key=lambda x: x[1], reverse=True) - - def merge_(self, tks): - # if split chars is part of token - res = [] - tks = re.sub(r"[ ]+", " ", tks).split() - s = 0 - while True: - if s >= len(tks): - break - E = s + 1 - for e in range(s + 2, min(len(tks) + 2, s + 6)): - tk = "".join(tks[s:e]) - if re.search(self.SPLIT_CHAR, tk) and self.freq(tk): - E = e - res.append("".join(tks[s:E])) - s = E - - return " ".join(res) - - def _max_forward(self, line): - res = [] - s = 0 - while s < len(line): - e = s + 1 - t = line[s:e] - while e < len(line) and self.trie_.has_keys_with_prefix(self.key_(t)): - e += 1 - t = line[s:e] - - while e - 1 > s and self.key_(t) not in self.trie_: - e -= 1 - t = line[s:e] - - if self.key_(t) in self.trie_: - res.append((t, self.trie_[self.key_(t)])) - else: - res.append((t, (0, ""))) - - s = e - - return self.score_(res) - - def _max_backward(self, line): - res = [] - s = len(line) - 1 - while s >= 0: - e = s + 1 - t = line[s:e] - while s > 0 and self.trie_.has_keys_with_prefix(self.rkey_(t)): - s -= 1 - t = line[s:e] - - while s + 1 < e and self.key_(t) not in self.trie_: - s += 1 - t = line[s:e] - - if self.key_(t) in self.trie_: - res.append((t, self.trie_[self.key_(t)])) - else: - res.append((t, (0, ""))) - - s -= 1 - - return self.score_(res[::-1]) - - def english_normalize_(self, tks): - return [self.stemmer.stem(self.lemmatizer.lemmatize(t)) if re.match(r"[a-zA-Z_-]+$", t) else t for t in tks] - - def _split_by_lang(self, line): - txt_lang_pairs = [] - arr = re.split(self.SPLIT_CHAR, line) - for a in arr: - if not a: - continue - s = 0 - e = s + 1 - zh = is_chinese(a[s]) - while e < len(a): - _zh = is_chinese(a[e]) - if _zh == zh: - e += 1 - continue - txt_lang_pairs.append((a[s:e], zh)) - s = e - e = s + 1 - zh = _zh - if s >= len(a): - continue - txt_lang_pairs.append((a[s:e], zh)) - return txt_lang_pairs +class RagTokenizer(rag_tokenizer.RagTokenizer): def tokenize(self, line: str) -> str: if settings.DOC_ENGINE_INFINITY: return line - line = re.sub(r"\W+", " ", line) - line = self._strQ2B(line).lower() - line = self._tradi2simp(line) - - arr = self._split_by_lang(line) - res = [] - for L, lang in arr: - if not lang: - res.extend([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(L)]) - continue - if len(L) < 2 or re.match(r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L): - res.append(L) - continue - - # use maxforward for the first time - tks, s = self._max_forward(L) - tks1, s1 = self._max_backward(L) - if self.DEBUG: - logging.debug("[FW] {} {}".format(tks, s)) - logging.debug("[BW] {} {}".format(tks1, s1)) - - i, j, _i, _j = 0, 0, 0, 0 - same = 0 - while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]: - same += 1 - if same > 0: - res.append(" ".join(tks[j : j + same])) - _i = i + same - _j = j + same - j = _j + 1 - i = _i + 1 - - while i < len(tks1) and j < len(tks): - tk1, tk = "".join(tks1[_i:i]), "".join(tks[_j:j]) - if tk1 != tk: - if len(tk1) > len(tk): - j += 1 - else: - i += 1 - continue - - if tks1[i] != tks[j]: - i += 1 - j += 1 - continue - # backward tokens from_i to i are different from forward tokens from _j to j. - tkslist = [] - self.dfs_("".join(tks[_j:j]), 0, [], tkslist) - res.append(" ".join(self._sort_tokens(tkslist)[0][0])) - - same = 1 - while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]: - same += 1 - res.append(" ".join(tks[j : j + same])) - _i = i + same - _j = j + same - j = _j + 1 - i = _i + 1 - - if _i < len(tks1): - assert _j < len(tks) - assert "".join(tks1[_i:]) == "".join(tks[_j:]) - tkslist = [] - self.dfs_("".join(tks[_j:]), 0, [], tkslist) - res.append(" ".join(self._sort_tokens(tkslist)[0][0])) - - res = " ".join(res) - logging.debug("[TKS] {}".format(self.merge_(res))) - return self.merge_(res) + else: + return super().tokenize(line) def fine_grained_tokenize(self, tks: str) -> str: if settings.DOC_ENGINE_INFINITY: return tks - tks = tks.split() - zh_num = len([1 for c in tks if c and is_chinese(c[0])]) - if zh_num < len(tks) * 0.2: - res = [] - for tk in tks: - res.extend(tk.split("/")) - return " ".join(res) - - res = [] - for tk in tks: - if len(tk) < 3 or re.match(r"[0-9,\.-]+$", tk): - res.append(tk) - continue - tkslist = [] - if len(tk) > 10: - tkslist.append(tk) - else: - self.dfs_(tk, 0, [], tkslist) - if len(tkslist) < 2: - res.append(tk) - continue - stk = self._sort_tokens(tkslist)[1][0] - if len(stk) == len(tk): - stk = tk - else: - if re.match(r"[a-z\.-]+$", tk): - for t in stk: - if len(t) < 3: - stk = tk - break - else: - stk = " ".join(stk) - else: - stk = " ".join(stk) - - res.append(stk) - - return " ".join(self.english_normalize_(res)) - - -def is_chinese(s): - if s >= "\u4e00" and s <= "\u9fa5": - return True - else: - return False - - -def is_number(s): - if s >= "\u0030" and s <= "\u0039": - return True - else: - return False - - -def is_alphabet(s): - if ("\u0041" <= s <= "\u005a") or ("\u0061" <= s <= "\u007a"): - return True - else: - return False - - -def naive_qie(txt): - tks = [] - for t in txt.split(): - if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and re.match(r".*[a-zA-Z]$", t): - tks.append(" ") - tks.append(t) - return tks + else: + return super().fine_grained_tokenize(tks) tokenizer = RagTokenizer() @@ -470,40 +38,5 @@ tokenize = tokenizer.tokenize fine_grained_tokenize = tokenizer.fine_grained_tokenize tag = tokenizer.tag freq = tokenizer.freq -load_user_dict = tokenizer.load_user_dict -add_user_dict = tokenizer.add_user_dict tradi2simp = tokenizer._tradi2simp strQ2B = tokenizer._strQ2B - -if __name__ == "__main__": - tknzr = RagTokenizer(debug=True) - # huqie.add_user_dict("/tmp/tmp.new.tks.dict") - texts = [ - "over_the_past.pdf", - "哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈", - "公开征求意见稿提出,境外投资者可使用自有人民币或外汇投资。使用外汇投资的,可通过债券持有人在香港人民币业务清算行及香港地区经批准可进入境内银行间外汇市场进行交易的境外人民币业务参加行(以下统称香港结算行)办理外汇资金兑换。香港结算行由此所产生的头寸可到境内银行间外汇市场平盘。使用外汇投资的,在其投资的债券到期或卖出后,原则上应兑换回外汇。", - "多校划片就是一个小区对应多个小学初中,让买了学区房的家庭也不确定到底能上哪个学校。目的是通过这种方式为学区房降温,把就近入学落到实处。南京市长江大桥", - "实际上当时他们已经将业务中心偏移到安全部门和针对政府企业的部门 Scripts are compiled and cached aaaaaaaaa", - "虽然我不怎么玩", - "蓝月亮如何在外资夹击中生存,那是全宇宙最有意思的", - "涡轮增压发动机num最大功率,不像别的共享买车锁电子化的手段,我们接过来是否有意义,黄黄爱美食,不过,今天阿奇要讲到的这家农贸市场,说实话,还真蛮有特色的!不仅环境好,还打出了", - "这周日你去吗?这周日你有空吗?", - "Unity3D开发经验 测试开发工程师 c++双11双11 985 211 ", - "数据分析项目经理|数据分析挖掘|数据分析方向|商品数据分析|搜索数据分析 sql python hive tableau Cocos2d-", - ] - for text in texts: - print(text) - tks1 = tknzr.tokenize(text) - tks2 = tknzr.fine_grained_tokenize(tks1) - print(tks1) - print(tks2) - if len(sys.argv) < 2: - sys.exit() - tknzr.load_user_dict(sys.argv[1]) - of = open(sys.argv[2], "r") - while True: - line = of.readline() - if not line: - break - print(tknzr.tokenize(line)) - of.close()