add_tone_seq2seq/app.py
2019-12-27 17:19:22 +07:00

88 lines
3.5 KiB
Python

from flask import Flask
from flask_restplus import Resource, Api
from datetime import datetime
import dill
import torch
from torchtext.data import BucketIterator
from model import Seq2SeqConcat, Encoder, Decoder
from dataset import Seq2SeqDataset, PAD
from alphabet import LEGAL, PUNCT, A_LIST, O_LIST, U_LIST, E_LIST, D_LIST, I_LIST, Y_LIST
from unidecode import unidecode
app = Flask(__name__)
api = Api(app)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
class Minh:
def __init__(self, src_vocab_path, tgt_vocab_path, model_path,
max_len=1000, hidden_size=300, n_layers=2):
with open(src_vocab_path, 'rb') as f:
self.src = dill.load(f)
with open(tgt_vocab_path, 'rb') as f:
self.tgt = dill.load(f)
encoder = Encoder(self.src, max_len, hidden_size, n_layers)
decoder = Decoder(self.tgt, max_len, hidden_size*2, n_layers)
reverse_decoder = Decoder(self.tgt, max_len, hidden_size*2, n_layers, reverse=True)
self.model = Seq2SeqConcat(encoder, decoder, reverse_decoder, self.src.stoi[PAD])
if torch.cuda.is_available():
checkpoint = torch.load(model_path)
else:
checkpoint = torch.load(model_path, map_location='cpu')
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.to(device)
self.model.eval()
def predict(self, text):
lines_raw = []
lines_prep = []
line_seq = text
lines_raw.append(line_seq)
lines_prep.append(self.preprocess(line_seq))
test = Seq2SeqDataset.from_list(lines_prep)
test.src_field.vocab = self.src
test_iterator = BucketIterator(dataset=test, batch_size=1, train=False,
sort=False, sort_within_batch=False,
shuffle=False, device=device)
with torch.no_grad():
for i, batch in enumerate(test_iterator):
_, _, output = self.model(batch, has_targets=False, mask_softmax=1.0, teacher_forcing=1.0)
print(type(output))
predicted_values, predicted_indices = torch.max(output, dim=-1)
predicted_seq = [self.tgt.itos[c] for c in predicted_indices.squeeze(0).tolist()]
return ''.join(predicted_seq[1:-1])
def preprocess(self, line):
line = line.strip().lower()
line = unidecode(line)
line = ''.join(c if c not in PUNCT else '-' for c in line) # replace all punctuations with '-'
line = ''.join(c if c in LEGAL else '?' for c in line) # replace unknown characters with '?'
return line
src_vocab_path = "checkpoint/vocab.src"
tgt_vocab_path = "checkpoint/vocab.tgt"
model_path = "checkpoint/aivivn_tone.model.ep25"
wlm_path = "lm/corpus-wplm-4g-v2.binary"
predictor = Minh(src_vocab_path, tgt_vocab_path, model_path)
predictor.predict('nghieng nuoc nghieng thanh')
@api.route('/spell/<string:text>')
class Spell(Resource):
def get(self, text):
start = datetime.now()
test_string = "data/test.txt"
test_cleaned_string = "data/test_cleaned.txt"
out_path = "data/submission.txt"
result = predictor.predict(text)
# result = result.encode('utf-8')
duration = datetime.now() - start
dic = {'duration': str(duration),'result': result}
print(dic)
return dic
if __name__ == '__main__':
# reload(sys)
# sys.setdefaultencoding('utf-8')
app.run(debug=True, port=5000)