spell_correction_n_gram/test_2.py
2019-12-21 12:30:20 +07:00

140 lines
6.2 KiB
Python

from nltk.tokenize import word_tokenize
import re
from datetime import datetime
from lm.lm import KenLM
import unidecode
import editdistance
with open('./checkpoint/common-vietnamese-syllables.txt') as f:
dict = f.read()
dict_not_tone = unidecode.unidecode(dict)
dict_ls = [i for i in word_tokenize(dict)]
def preprocessing(text):
text = text.lower()
text = text.replace('\n', '')
text = text.strip()
return text
def check_real_word(word):
if word in dict:
return True
else:
return False
def similarities_by_edit_distance(word, distance=1):
"""
generate a list of candidates which have edit distance <= 2 with the non-word error
:param word: non-word error
:param distance: edit distance {1, 2, 3, 4, ...}
:return: a list of candidates
"""
ls_temp = []
for i in dict_ls:
if editdistance.eval(i, word) <= distance:
ls_temp.append(i)
# print(ls_temp)
return ls_temp
class Predict:
def __init__(self, wlm_path):
print('Loading language model ...')
self.wlm = KenLM(wlm_path)
def beam_lm(self, predicted_seq, k=20):
# replace non-word errors with '*'
sentence = preprocessing(predicted_seq)
words = word_tokenize(sentence)
for index, word in enumerate(words):
if not check_real_word(word):
# print('Non-Word-Error:', word)
words[index] = '*'
predicted_seq_copy = ' '.join(words)
def beam_lm_(predicted_seq, predicted_seq_uncertain):
uncertainties = [m.span() for m in re.finditer('\\*+', predicted_seq_uncertain)]
if len(uncertainties)==0:
return predicted_seq_uncertain
ls_words_predict_seq = word_tokenize(predicted_seq)
ls_words_uncertain = word_tokenize(predicted_seq_uncertain)
topk_fwd = [predicted_seq_uncertain[0:uncertainties[0][0]]]
for i, v in enumerate(ls_words_uncertain):
if v != '*':
continue
# generate list of similar words
c = ls_words_predict_seq[i]
c_list = similarities_by_edit_distance(c)
# return old text if not have any similar words
if len(c_list) == 0:
return predicted_seq_uncertain
# left context
left_context = ' '.join(ls_words_uncertain[0:i])
if i == 0:
left_context = ''
# right context, remain context
remain_context = ''
right_context = ''
if i < len(ls_words_uncertain)-1:
ls_words_right_context = []
for index, w in enumerate(ls_words_uncertain[i+1:]):
if w == '*':
if index<len(ls_words_uncertain)-1:
remain_context = ' '.join(ls_words_uncertain[index+i+1:])
break
ls_words_right_context.append(w)
right_context = ' '.join(ls_words_right_context)
# get score of sentences with replacing similar words
candidates = []
for ch in c_list:
candidate = left_context + ' ' + ch + ' ' + right_context
score = self.score(candidate)
candidates.append({'candidate': candidate, 'score': score})
candidates = sorted(candidates, key = lambda i: i['score'], reverse=True)
best_candidate = candidates[0]['candidate'] + ' ' + remain_context
return beam_lm_(predicted_seq, best_candidate)
return beam_lm_(sentence, predicted_seq_copy)
def score(self, candidate):
return self.wlm.score(candidate)
start = datetime.now()
# wlm_path = "/home/minh/projects/aivivn-tone/lm/corpus-wplm-4g-v2.binary"
wlm_path = "/home/minh/projects/lla/kenlm/build/3_gram_kenlm.binary"
# wlm_path = "/home/minh/projects/lla/kenlm/build/2_gram_kenlm.binary"
predict = Predict(wlm_path=wlm_path)
if __name__ == '__main__':
with open('./data/data_test_v2/correction_test.txt', 'r') as f1:
text_correction = f1.readlines()
with open('./data/data_test_v2/non_word_error_test_v2.txt') as f2:
non_word_errors = f2.readlines()
with open('./data/data_test_v2/index_correction_non_word_error_v2.txt', 'r') as f:
index_co_nos = f.readlines()
count_right = 0
count_wrong = 0
for index, text in enumerate(non_word_errors):
if count_right + count_wrong == 200000:
break
print(index, count_right, count_wrong)
predicted_seq = preprocessing(predict.beam_lm(predicted_seq=text))
text_correction_sentence = preprocessing(text_correction[index])
index_co_no = index_co_nos[index]
index_no = word_tokenize(index_co_no)[0]
print(predicted_seq)
print(text_correction_sentence)
print(word_tokenize(predicted_seq)[int(index_no)])
print(word_tokenize(text_correction_sentence)[int(index_no)])
if word_tokenize(predicted_seq)[int(index_no)]==word_tokenize(text_correction_sentence)[int(index_no)]:
count_right += 1
continue
count_wrong +=1
print('Duration:', datetime.now()-start)
# s1 = """để sản xuất nước hoa giả các thương hiệu nổi tiếng đối tượng tú anh đã mua các hợp chất có nguồn gốc từ trung quốc về để ủ làm tăng độ thơm rồi dùng các máy móc đóng gói mua vỏ nhãn hiệu của các thương hiệu nổi tiếng để dán lên sau đó đưa ra thị trường tiêu thụ"""
# s2 = """để sản xuất nước hoa giả các thương hiệu nổi tiếng đối tượng tú anh đã mua các hóa chất có nguồn gốc từ trung quốc về để ủ làm tăng độ thơm rồi dùng các máy móc đóng gói mua vỏ nhãn hiệu của các thương hiệu nổi tiếng để dán lên sau đó đưa ra thị trường tiêu thụ"""
# ls1 = [c for c in s1]
# ls2 = [c for c in s2]
# for index, c in enumerate(ls1):
# if c != ls2[index]:
# print(c, ls2[index])
# if s1 == s2:
# print('w')
# print(editdistance.eval(s1, s2))