140 lines
6.2 KiB
Python
140 lines
6.2 KiB
Python
from nltk.tokenize import word_tokenize
|
|
import re
|
|
from datetime import datetime
|
|
from lm.lm import KenLM
|
|
import unidecode
|
|
import editdistance
|
|
|
|
with open('./checkpoint/common-vietnamese-syllables.txt') as f:
|
|
dict = f.read()
|
|
|
|
dict_not_tone = unidecode.unidecode(dict)
|
|
dict_ls = [i for i in word_tokenize(dict)]
|
|
|
|
def preprocessing(text):
|
|
text = text.lower()
|
|
text = text.replace('\n', '')
|
|
text = text.strip()
|
|
return text
|
|
|
|
def check_real_word(word):
|
|
if word in dict:
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def similarities_by_edit_distance(word, distance=1):
|
|
"""
|
|
generate a list of candidates which have edit distance <= 2 with the non-word error
|
|
:param word: non-word error
|
|
:param distance: edit distance {1, 2, 3, 4, ...}
|
|
:return: a list of candidates
|
|
"""
|
|
ls_temp = []
|
|
for i in dict_ls:
|
|
if editdistance.eval(i, word) <= distance:
|
|
ls_temp.append(i)
|
|
# print(ls_temp)
|
|
return ls_temp
|
|
|
|
class Predict:
|
|
def __init__(self, wlm_path):
|
|
print('Loading language model ...')
|
|
self.wlm = KenLM(wlm_path)
|
|
def beam_lm(self, predicted_seq, k=20):
|
|
# replace non-word errors with '*'
|
|
sentence = preprocessing(predicted_seq)
|
|
words = word_tokenize(sentence)
|
|
for index, word in enumerate(words):
|
|
if not check_real_word(word):
|
|
# print('Non-Word-Error:', word)
|
|
words[index] = '*'
|
|
predicted_seq_copy = ' '.join(words)
|
|
def beam_lm_(predicted_seq, predicted_seq_uncertain):
|
|
uncertainties = [m.span() for m in re.finditer('\\*+', predicted_seq_uncertain)]
|
|
if len(uncertainties)==0:
|
|
return predicted_seq_uncertain
|
|
ls_words_predict_seq = word_tokenize(predicted_seq)
|
|
ls_words_uncertain = word_tokenize(predicted_seq_uncertain)
|
|
topk_fwd = [predicted_seq_uncertain[0:uncertainties[0][0]]]
|
|
for i, v in enumerate(ls_words_uncertain):
|
|
if v != '*':
|
|
continue
|
|
# generate list of similar words
|
|
c = ls_words_predict_seq[i]
|
|
c_list = similarities_by_edit_distance(c)
|
|
# return old text if not have any similar words
|
|
if len(c_list) == 0:
|
|
return predicted_seq_uncertain
|
|
# left context
|
|
left_context = ' '.join(ls_words_uncertain[0:i])
|
|
if i == 0:
|
|
left_context = ''
|
|
# right context, remain context
|
|
remain_context = ''
|
|
right_context = ''
|
|
if i < len(ls_words_uncertain)-1:
|
|
ls_words_right_context = []
|
|
for index, w in enumerate(ls_words_uncertain[i+1:]):
|
|
if w == '*':
|
|
if index<len(ls_words_uncertain)-1:
|
|
remain_context = ' '.join(ls_words_uncertain[index+i+1:])
|
|
break
|
|
ls_words_right_context.append(w)
|
|
right_context = ' '.join(ls_words_right_context)
|
|
# get score of sentences with replacing similar words
|
|
candidates = []
|
|
for ch in c_list:
|
|
candidate = left_context + ' ' + ch + ' ' + right_context
|
|
score = self.score(candidate)
|
|
candidates.append({'candidate': candidate, 'score': score})
|
|
candidates = sorted(candidates, key = lambda i: i['score'], reverse=True)
|
|
best_candidate = candidates[0]['candidate'] + ' ' + remain_context
|
|
return beam_lm_(predicted_seq, best_candidate)
|
|
|
|
return beam_lm_(sentence, predicted_seq_copy)
|
|
def score(self, candidate):
|
|
return self.wlm.score(candidate)
|
|
start = datetime.now()
|
|
# wlm_path = "/home/minh/projects/aivivn-tone/lm/corpus-wplm-4g-v2.binary"
|
|
wlm_path = "/home/minh/projects/lla/kenlm/build/3_gram_kenlm.binary"
|
|
# wlm_path = "/home/minh/projects/lla/kenlm/build/2_gram_kenlm.binary"
|
|
predict = Predict(wlm_path=wlm_path)
|
|
if __name__ == '__main__':
|
|
with open('./data/data_test_v2/correction_test.txt', 'r') as f1:
|
|
text_correction = f1.readlines()
|
|
with open('./data/data_test_v2/non_word_error_test_v2.txt') as f2:
|
|
non_word_errors = f2.readlines()
|
|
with open('./data/data_test_v2/index_correction_non_word_error_v2.txt', 'r') as f:
|
|
index_co_nos = f.readlines()
|
|
count_right = 0
|
|
count_wrong = 0
|
|
for index, text in enumerate(non_word_errors):
|
|
if count_right + count_wrong == 200000:
|
|
break
|
|
print(index, count_right, count_wrong)
|
|
predicted_seq = preprocessing(predict.beam_lm(predicted_seq=text))
|
|
text_correction_sentence = preprocessing(text_correction[index])
|
|
index_co_no = index_co_nos[index]
|
|
index_no = word_tokenize(index_co_no)[0]
|
|
print(predicted_seq)
|
|
print(text_correction_sentence)
|
|
print(word_tokenize(predicted_seq)[int(index_no)])
|
|
print(word_tokenize(text_correction_sentence)[int(index_no)])
|
|
if word_tokenize(predicted_seq)[int(index_no)]==word_tokenize(text_correction_sentence)[int(index_no)]:
|
|
count_right += 1
|
|
continue
|
|
count_wrong +=1
|
|
|
|
|
|
print('Duration:', datetime.now()-start)
|
|
# s1 = """để sản xuất nước hoa giả các thương hiệu nổi tiếng đối tượng tú anh đã mua các hợp chất có nguồn gốc từ trung quốc về để ủ làm tăng độ thơm rồi dùng các máy móc đóng gói mua vỏ nhãn hiệu của các thương hiệu nổi tiếng để dán lên sau đó đưa ra thị trường tiêu thụ"""
|
|
# s2 = """để sản xuất nước hoa giả các thương hiệu nổi tiếng đối tượng tú anh đã mua các hóa chất có nguồn gốc từ trung quốc về để ủ làm tăng độ thơm rồi dùng các máy móc đóng gói mua vỏ nhãn hiệu của các thương hiệu nổi tiếng để dán lên sau đó đưa ra thị trường tiêu thụ"""
|
|
# ls1 = [c for c in s1]
|
|
# ls2 = [c for c in s2]
|
|
# for index, c in enumerate(ls1):
|
|
# if c != ls2[index]:
|
|
# print(c, ls2[index])
|
|
# if s1 == s2:
|
|
# print('w')
|
|
# print(editdistance.eval(s1, s2)) |