88 lines
3.5 KiB
Python
88 lines
3.5 KiB
Python
from flask import Flask
|
|
from flask_restplus import Resource, Api
|
|
from datetime import datetime
|
|
import dill
|
|
import torch
|
|
from torchtext.data import BucketIterator
|
|
from model import Seq2SeqConcat, Encoder, Decoder
|
|
from dataset import Seq2SeqDataset, PAD
|
|
from alphabet import LEGAL, PUNCT, A_LIST, O_LIST, U_LIST, E_LIST, D_LIST, I_LIST, Y_LIST
|
|
from unidecode import unidecode
|
|
|
|
app = Flask(__name__)
|
|
api = Api(app)
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
|
|
class Minh:
|
|
def __init__(self, src_vocab_path, tgt_vocab_path, model_path,
|
|
max_len=1000, hidden_size=300, n_layers=2):
|
|
with open(src_vocab_path, 'rb') as f:
|
|
self.src = dill.load(f)
|
|
with open(tgt_vocab_path, 'rb') as f:
|
|
self.tgt = dill.load(f)
|
|
|
|
encoder = Encoder(self.src, max_len, hidden_size, n_layers)
|
|
decoder = Decoder(self.tgt, max_len, hidden_size*2, n_layers)
|
|
reverse_decoder = Decoder(self.tgt, max_len, hidden_size*2, n_layers, reverse=True)
|
|
self.model = Seq2SeqConcat(encoder, decoder, reverse_decoder, self.src.stoi[PAD])
|
|
|
|
if torch.cuda.is_available():
|
|
checkpoint = torch.load(model_path)
|
|
else:
|
|
checkpoint = torch.load(model_path, map_location='cpu')
|
|
|
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
self.model.to(device)
|
|
self.model.eval()
|
|
|
|
def predict(self, text):
|
|
lines_raw = []
|
|
lines_prep = []
|
|
line_seq = text
|
|
lines_raw.append(line_seq)
|
|
lines_prep.append(self.preprocess(line_seq))
|
|
test = Seq2SeqDataset.from_list(lines_prep)
|
|
test.src_field.vocab = self.src
|
|
|
|
test_iterator = BucketIterator(dataset=test, batch_size=1, train=False,
|
|
sort=False, sort_within_batch=False,
|
|
shuffle=False, device=device)
|
|
|
|
with torch.no_grad():
|
|
for i, batch in enumerate(test_iterator):
|
|
_, _, output = self.model(batch, has_targets=False, mask_softmax=1.0, teacher_forcing=1.0)
|
|
print(type(output))
|
|
predicted_values, predicted_indices = torch.max(output, dim=-1)
|
|
predicted_seq = [self.tgt.itos[c] for c in predicted_indices.squeeze(0).tolist()]
|
|
return ''.join(predicted_seq[1:-1])
|
|
|
|
def preprocess(self, line):
|
|
line = line.strip().lower()
|
|
line = unidecode(line)
|
|
line = ''.join(c if c not in PUNCT else '-' for c in line) # replace all punctuations with '-'
|
|
line = ''.join(c if c in LEGAL else '?' for c in line) # replace unknown characters with '?'
|
|
return line
|
|
|
|
src_vocab_path = "checkpoint/vocab.src"
|
|
tgt_vocab_path = "checkpoint/vocab.tgt"
|
|
model_path = "checkpoint/aivivn_tone.model.ep25"
|
|
wlm_path = "lm/corpus-wplm-4g-v2.binary"
|
|
predictor = Minh(src_vocab_path, tgt_vocab_path, model_path)
|
|
predictor.predict('nghieng nuoc nghieng thanh')
|
|
@api.route('/spell/<string:text>')
|
|
class Spell(Resource):
|
|
def get(self, text):
|
|
start = datetime.now()
|
|
test_string = "data/test.txt"
|
|
test_cleaned_string = "data/test_cleaned.txt"
|
|
out_path = "data/submission.txt"
|
|
result = predictor.predict(text)
|
|
# result = result.encode('utf-8')
|
|
duration = datetime.now() - start
|
|
dic = {'duration': str(duration),'result': result}
|
|
print(dic)
|
|
return dic
|
|
if __name__ == '__main__':
|
|
# reload(sys)
|
|
# sys.setdefaultencoding('utf-8')
|
|
app.run(debug=True, port=5000) |