├── README.md ├── dataloader.py ├── embedding.py ├── machine-translation ├── sample.enfr.zip └── sample.enko.zip ├── machine-transliteration └── generate.py ├── parameters.py ├── predict.py ├── prepare.py ├── rnn_decoder.py ├── rnn_encoder.py ├── rnn_encoder_decoder.py ├── search.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # RNN Encoder-Decoder in PyTorch 2 | 3 | A minimal PyTorch implementation of RNN Encoder-Decoder for sequence to sequence learning. 4 | 5 | Supported features: 6 | - Mini-batch training with CUDA 7 | - Lookup, CNNs, RNNs and/or self-attentive encoding in the embedding layer 8 | - Input feeding (Luong et al 2015) 9 | - Attention mechanism (Bahdanau et al 2014, Luong et al 2015) 10 | - CopyNet, copying mechanism (Gu et al 2016) 11 | - Beam search decoding 12 | - Attention visualization 13 | 14 | ## Usage 15 | 16 | Training data should be formatted as below: 17 | ``` 18 | source_sequence \t target_sequence 19 | source_sequence \t target_sequence 20 | ... 21 | ``` 22 | 23 | To prepare data: 24 | ``` 25 | python3 prepare.py training_data 26 | ``` 27 | 28 | To train: 29 | ``` 30 | python3 train.py model vocab.src vocab.tgt training_data.csv num_epoch 31 | ``` 32 | 33 | To predict: 34 | ``` 35 | python3 predict.py model.epochN vocab.src vocab.tgt test_data 36 | ``` 37 | 38 | ## References 39 | 40 | Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio. 2015. [Neural Machine Translation by Jointly Learning to Align and Translate.](https://arxiv.org/abs/1409.0473) arXiv:1409.0473. 41 | 42 | Denny Britz, Anna Goldie, Minh-Thang Luong, Quoc Le. 2017. [Massive Exploration of Neural Machine Translation Architectures.](https://arxiv.org/abs/1703.03906) arXiv:1703.03906. 43 | 44 | Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, Yoshua Bengio. 2014. [Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation.](https://arxiv.org/abs/1406.1078) arXiv:1406.1078. 45 | 46 | Jiatao Gu, Zhengdong Lu, Hang Li, Victor O.K. Li. 2016. [Incorporating Copying Mechanism in Sequence-to-Sequence Learning.](https://arxiv.org/abs/1603.06393) arXiv:1603.06393. 47 | 48 | Jiwei Li. 2017. [Teaching Machines to Converse.](https://github.com/jiweil/Jiwei-Thesis/blob/master/thesis.pdf) Doctoral dissertation. Stanford University. 49 | 50 | Junyang Lin, Xu Sun, Xuancheng Ren, Muyu Li, Qi Su. 2018. [Learning When to Concentrate or Divert Attention: Self-Adaptive Attention Temperature for Neural Machine Translation.](https://arxiv.org/abs/1808.07374) arXiv:1808.07374. 51 | 52 | Minh-Thang Luong, Hieu Pham, Christopher D. Manning. 2015. [Effective Approaches to Attention-based Neural Machine Translation.](https://arxiv.org/abs/1508.04025) arXiv:1507.04025. 53 | 54 | Chan Young Park, Yulia Tsvetkov. [Learning to Generate Word- and Phrase-Embeddings for Efficient Phrase-Based Neural Machine Translation.](https://www.aclweb.org/anthology/D19-5626.pdf) In Proceedings of the 3rd Workshop on Neural Generation and Translation. 55 | 56 | Sam Wiseman, Alexander M. Rush. [Sequence-to-Sequence Learning as Beam-Search Optimization.](https://arxiv.org/abs/1606.02960) arXiv:1606.02960. 57 | 58 | Yonghui Wu, Mike Schuster, Zhifeng Chen, Quoc V. Le, Mohammad Norouzi, Wolfgang Macherey, Maxim Krikun, Yuan Cao, Qin Gao, Klaus Macherey, Jeff Klingner, Apurva Shah, Melvin Johnson, Xiaobing Liu, Łukasz Kaiser, Stephan Gouws, Yoshikiyo Kato, Taku Kudo, Hideto Kazawa, Keith Stevens, George Kurian, Nishant Patil, Wei Wang, Cliff Young, Jason Smith, Jason Riesa, Alex Rudnick, Oriol Vinyals, Greg Corrado, Macduff Hughes, Jeffrey Dean. 2016. [Google's Neural Machine Translation System: Bridging the Gap between Human and Machine Translation.](https://arxiv.org/abs/1609.08144) arXiv:1609.08144. 59 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | 3 | class dataset(): 4 | 5 | _vars = ("x0", "x1", "xc", "xw", "y0") 6 | 7 | def __init__(self): 8 | 9 | self.idx = None # input index 10 | self.x0 = [] # text input, raw 11 | self.x1 = [] # text input, tokenized 12 | self.xc = [] # indexed input, character-level 13 | self.xw = [] # indexed input, word-level 14 | self.y0 = [] # actual output 15 | self.y1 = None # predicted output 16 | self.lens = None # sequence lengths (for HRE) 17 | self.prob = None # output probabilities 18 | self.attn = None # attention weights 19 | self.copy = None # copy weights 20 | 21 | def sort(self): # HRE = False 22 | 23 | self.idx = list(range(len(self.xw))) 24 | self.idx.sort(key = lambda x: -len(self.xw[x])) 25 | xc = [self.xc[i] for i in self.idx] 26 | xw = [self.xw[i] for i in self.idx] 27 | y0 = [self.y0[i] for i in self.idx] 28 | lens = list(map(len, xw)) 29 | return xc, xw, y0, lens 30 | 31 | def unsort(self): 32 | 33 | self.idx = sorted(range(len(self.x0)), key = lambda x: self.idx[x]) 34 | self.y1 = [self.y1[i] for i in self.idx] 35 | if self.prob: 36 | self.prob = [self.prob[i] for i in self.idx] 37 | if self.attn: 38 | self.attn = [self.attn[i] for i in self.idx] 39 | 40 | class dataloader(dataset): 41 | 42 | def __init__(self, batch_first = False, hre = False): 43 | 44 | super().__init__() 45 | self.batch_first = batch_first 46 | self.hre = hre # hierarchical recurrent encoding 47 | 48 | def append_row(self): 49 | 50 | for x in self._vars: 51 | getattr(self, x).append([]) 52 | 53 | def append_item(self, **kwargs): 54 | 55 | for k, v in kwargs.items(): 56 | getattr(self, k)[-1].append(v) 57 | 58 | def clone_row(self): 59 | 60 | for x in self._vars: 61 | getattr(self, x).append(getattr(self, x)[-1]) 62 | 63 | def flatten(self, x): # [Ld, Ls, Lw] -> [Ld * Ls, Lw] 64 | 65 | if self.hre: 66 | return [list(x) for x in x for x in x] 67 | try: 68 | return [x if type(x[0]) == str else list(*x) for x in x] 69 | except: 70 | return [x for x in x for x in x] 71 | 72 | def batchify(self, batch_size): 73 | 74 | if self.hre: 75 | self.x0 = [[x] for x in self.x0] 76 | self.y0 = [[[y[0] if y else None for y in y]] for y in self.y0] 77 | 78 | for i in range(0, len(self.y0), batch_size): 79 | batch = dataset() 80 | j = i + min(batch_size, len(self.x0) - i) 81 | batch.lens = list(map(len, self.xw[i:j])) 82 | for x in self._vars: 83 | setattr(batch, x, self.flatten(getattr(self, x)[i:j])) 84 | yield batch 85 | 86 | def to_tensor(self, bc = None, bw = None, lens = None, sos = False, eos = False): 87 | 88 | p, s, e = [PAD_IDX], [SOS_IDX], [EOS_IDX] 89 | 90 | if self.hre and lens: 91 | dl = max(lens) # document length (Ld) 92 | i, _bc, _bw = 0, [], [] 93 | for j in lens: 94 | if bc: 95 | if sos: _bc.append([[]]) 96 | _bc += bc[i:i + j] + [[[]] for _ in range(dl - j)] 97 | if eos: _bc.append([[]]) 98 | if bw: 99 | if sos: _bw.append([]) 100 | _bw += bw[i:i + j] + [[] for _ in range(dl - j)] 101 | if eos: _bw.append([]) 102 | i += j 103 | bc, bw = _bc, _bw # [B * Ld, ...] 104 | 105 | if bw: 106 | sl = max(map(len, bw)) # sentence length (Ls) 107 | bw = [s * sos + x + e * eos + p * (sl - len(x)) for x in bw] 108 | bw = LongTensor(bw) # [B * Ld, Ls] 109 | if not self.batch_first: 110 | bw.transpose_(0, 1) 111 | 112 | if bc: 113 | wl = max(max(map(len, x)) for x in bc) # word length (Lw) 114 | wp = [p * (wl + 2)] 115 | bc = [[s + x + e + p * (wl - len(x)) for x in x] for x in bc] 116 | bc = [wp * sos + x + wp * (sl - len(x) + eos) for x in bc] 117 | bc = LongTensor(bc) # [B * Ld, Ls, Lw] 118 | if not self.batch_first: 119 | bc.transpose_(0, 1) 120 | 121 | return bc, bw 122 | -------------------------------------------------------------------------------- /embedding.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | 3 | class embed(nn.Module): 4 | 5 | def __init__(self, ls, cti, wti, batch_first = False, hre = False): 6 | 7 | super().__init__() 8 | self.dim = sum(ls.values()) 9 | self.batch_first = batch_first 10 | 11 | # architecture 12 | self.char_embed = None 13 | self.word_embed = None 14 | self.sent_embed = None 15 | 16 | for model, dim in ls.items(): 17 | assert model in ("lookup", "cnn", "rnn", "sae") 18 | if model in ("cnn", "rnn"): 19 | self.char_embed = getattr(self, model)(len(cti), dim) 20 | if model in ("lookup", "sae"): 21 | self.word_embed = getattr(self, model)(len(wti), dim) 22 | 23 | if hre: 24 | self.sent_embed = self.rnn(self.dim, self.dim, hre = True) 25 | 26 | self = self.cuda() if CUDA else self 27 | 28 | def forward(self, b, xc, xw): 29 | 30 | hc, hw = None, None 31 | 32 | if self.char_embed: 33 | hc = self.char_embed(xc) # [Ls, B * Ld, Lw] -> [Ls, B * Ld, Hc] 34 | if self.word_embed: 35 | hw = self.word_embed(xw) # [Ls, B * Ld] -> [Ls, B * Ld, Hw] 36 | 37 | h = torch.cat([h for h in [hc, hw] if type(h) == torch.Tensor], 2) 38 | 39 | if self.sent_embed: 40 | if self.batch_first: 41 | h.transpose_(0, 1) 42 | h = self.sent_embed(h) # [Lw, B * Ld, H] -> [1, B * Ld, H] 43 | h = h.view(b, -1, h.size(2)) # [B, Ld, H] 44 | if not self.batch_first: 45 | h.transpose_(0, 1) 46 | 47 | return h 48 | 49 | class lookup(nn.Module): 50 | 51 | def __init__(self, vocab_size, embed_size): 52 | 53 | super().__init__() 54 | self.embed = nn.Embedding(vocab_size, embed_size, padding_idx = PAD_IDX) 55 | 56 | def forward(self, x): 57 | 58 | return self.embed(x) # [Ls, B * Ld, H] 59 | 60 | class cnn(nn.Module): 61 | 62 | def __init__(self, vocab_size, embed_size): 63 | 64 | super().__init__() 65 | dim = 50 66 | num_featmaps = 50 # feature maps generated by each kernel 67 | kernel_sizes = [3] 68 | 69 | # architecture 70 | self.embed = nn.Embedding(vocab_size, dim, padding_idx = PAD_IDX) 71 | self.conv = nn.ModuleList([nn.Conv2d( 72 | in_channels = 1, # Ci 73 | out_channels = num_featmaps, # Co 74 | kernel_size = (i, dim) # height, width 75 | ) for i in kernel_sizes]) # num_kernels (K) 76 | self.dropout = nn.Dropout(DROPOUT) 77 | self.fc = nn.Linear(len(kernel_sizes) * num_featmaps, embed_size) 78 | 79 | def forward(self, x): 80 | 81 | b = x.size(1) # B' = B * Ld 82 | x = x.reshape(-1, x.size(2)) # [B' * Ls, Lw] 83 | x = self.embed(x).unsqueeze(1) # [B' * Ls, Ci = 1, Lw, dim] 84 | h = [conv(x) for conv in self.conv] # [B' * Ls, Co, Lw, 1] * K 85 | h = [F.relu(k).squeeze(3) for k in h] # [B' * Ls, Co, Lw] * K 86 | h = [F.max_pool1d(k, k.size(2)).squeeze(2) for k in h] # [B' * Ls, Co] * K 87 | h = torch.cat(h, 1) # [B' * Ls, Co * K] 88 | h = self.dropout(h) 89 | h = self.fc(h) # fully connected layer [B' * Ls, H] 90 | h = h.view(-1, b, h.size(1)) # [Ls, B', H] 91 | 92 | return h 93 | 94 | class rnn(nn.Module): 95 | 96 | def __init__(self, vocab_size, embed_size, hre = False): 97 | 98 | super().__init__() 99 | self.dim = embed_size 100 | self.rnn_type = "GRU" # LSTM, GRU 101 | self.num_dirs = 2 # unidirectional: 1, bidirectional: 2 102 | self.num_layers = 2 103 | self.hre = hre 104 | 105 | # architecture 106 | self.embed = nn.Embedding(vocab_size, embed_size, padding_idx = PAD_IDX) 107 | self.rnn = getattr(nn, self.rnn_type)( 108 | input_size = self.dim, 109 | hidden_size = self.dim // self.num_dirs, 110 | num_layers = self.num_layers, 111 | bias = True, 112 | dropout = DROPOUT, 113 | bidirectional = (self.num_dirs == 2) 114 | ) 115 | 116 | def init_state(self, b): # initialize RNN states 117 | 118 | n = self.num_layers * self.num_dirs 119 | h = self.dim // self.num_dirs 120 | hs = zeros(n, b, h) # hidden state 121 | if self.rnn_type == "GRU": 122 | return hs 123 | cs = zeros(n, b, h) # LSTM cell state 124 | return (hs, cs) 125 | 126 | def forward(self, x): 127 | 128 | b = x.size(1) # B' = B * Ld 129 | s = self.init_state(b * (1 if self.hre else x.size(0))) 130 | if not self.hre: # [Ls, B', Lw] -> [Lw, B' * Ls, H] 131 | x = x.reshape(-1, x.size(2)).transpose(0, 1) 132 | x = self.embed(x) 133 | 134 | h, s = self.rnn(x, s) 135 | h = s if self.rnn_type == "GRU" else s[-1] 136 | h = torch.cat([x for x in h[-self.num_dirs:]], 1) # final hidden state 137 | h = h.view(-1, b, h.size(1)) # [Ls, B', H] 138 | 139 | return h 140 | 141 | class sae(nn.Module): # self-attentive encoder 142 | 143 | def __init__(self, vocab_size, embed_size = 512): 144 | 145 | super().__init__() 146 | dim = embed_size 147 | num_layers = 1 148 | 149 | # architecture 150 | self.embed = nn.Embedding(vocab_size, dim, padding_idx = PAD_IDX) 151 | self.pe = self.pos_encoding(dim) 152 | self.layers = nn.ModuleList([self.layer(dim) for _ in range(num_layers)]) 153 | 154 | def forward(self, x): 155 | 156 | mask = x.eq(PAD_IDX).view(x.size(0), 1, 1, -1) 157 | x = self.embed(x) 158 | h = x + self.pe[:x.size(1)] 159 | for layer in self.layers: 160 | h = layer(h, mask) 161 | 162 | return h 163 | 164 | def pos_encoding(self, dim, maxlen = 1000): # positional encoding 165 | 166 | pe = Tensor(maxlen, dim) 167 | pos = torch.arange(0, maxlen, 1.).unsqueeze(1) 168 | k = torch.exp(-np.log(10000) * torch.arange(0, dim, 2.) / dim) 169 | pe[:, 0::2] = torch.sin(pos * k) 170 | pe[:, 1::2] = torch.cos(pos * k) 171 | 172 | return pe 173 | 174 | class layer(nn.Module): # encoder layer 175 | 176 | def __init__(self, dim): 177 | 178 | super().__init__() 179 | 180 | # architecture 181 | self.attn = embed.sae.mh_attn(dim) 182 | self.ffn = embed.sae.ffn(dim) 183 | 184 | def forward(self, x, mask): 185 | 186 | z = self.attn(x, x, x, mask) 187 | z = self.ffn(z) 188 | 189 | return z 190 | 191 | class mh_attn(nn.Module): # multi-head attention 192 | 193 | def __init__(self, dim): 194 | 195 | super().__init__() 196 | self.D = dim # dimension of model 197 | self.H = 8 # number of heads 198 | self.Dk = self.D // self.H # dimension of key 199 | self.Dv = self.D // self.H # dimension of value 200 | 201 | # architecture 202 | self.Wq = nn.Linear(self.D, self.H * self.Dk) # query 203 | self.Wk = nn.Linear(self.D, self.H * self.Dk) # key 204 | self.Wv = nn.Linear(self.D, self.H * self.Dv) # value 205 | self.Wo = nn.Linear(self.H * self.Dv, self.D) 206 | self.dropout = nn.Dropout(DROPOUT) 207 | self.norm = nn.LayerNorm(self.D) 208 | 209 | def sdp_attn(self, q, k, v, mask): # scaled dot-product attention 210 | 211 | c = np.sqrt(self.Dk) 212 | a = torch.matmul(q, k.transpose(2, 3)) / c 213 | a = a.masked_fill(mask, -10000) 214 | a = F.softmax(a, 2) 215 | a = torch.matmul(a, v) 216 | 217 | return a # attention weights 218 | 219 | def forward(self, q, k, v, mask): 220 | 221 | b = q.size(0) 222 | x = q 223 | q = self.Wq(q).view(b, -1, self.H, self.Dk).transpose(1, 2) 224 | k = self.Wk(k).view(b, -1, self.H, self.Dk).transpose(1, 2) 225 | v = self.Wv(v).view(b, -1, self.H, self.Dv).transpose(1, 2) 226 | z = self.sdp_attn(q, k, v, mask) 227 | z = z.transpose(1, 2).contiguous().view(b, -1, self.H * self.Dv) 228 | z = self.Wo(z) 229 | z = self.norm(x + self.dropout(z)) # residual connection and dropout 230 | 231 | return z 232 | 233 | class ffn(nn.Module): # position-wise feed-forward networks 234 | 235 | def __init__(self, dim): 236 | 237 | super().__init__() 238 | dim_ffn = 2048 239 | 240 | # architecture 241 | self.layers = nn.Sequential( 242 | nn.Linear(dim, dim_ffn), 243 | nn.ReLU(), 244 | nn.Dropout(DROPOUT), 245 | nn.Linear(dim_ffn, dim) 246 | ) 247 | self.norm = nn.LayerNorm(dim) 248 | 249 | def forward(self, x): 250 | 251 | z = x + self.layers(x) # residual connection 252 | z = self.norm(z) # layer normalization 253 | 254 | return z 255 | -------------------------------------------------------------------------------- /machine-translation/sample.enfr.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threelittlemonkeys/rnn-encoder-decoder-pytorch/ba2bf25156b4513093aee4f9c3ab90a97fddc489/machine-translation/sample.enfr.zip -------------------------------------------------------------------------------- /machine-translation/sample.enko.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threelittlemonkeys/rnn-encoder-decoder-pytorch/ba2bf25156b4513093aee4f9c3ab90a97fddc489/machine-translation/sample.enko.zip -------------------------------------------------------------------------------- /machine-transliteration/generate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import random 3 | 4 | chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" 5 | 6 | minlen = 5 7 | maxlen = 20 8 | data_size = int(sys.argv[1]) 9 | 10 | def transliterate(x): 11 | 12 | if "A" <= x <= "Z": 13 | return chr(ord(x) + 32) 14 | 15 | if "a" <= x <= "z": 16 | return chr(ord(x) - 32) 17 | 18 | return x 19 | 20 | for _ in range(data_size): 21 | 22 | xs = [random.choice(chars) for _ in range(random.randint(minlen, maxlen))] 23 | ys = [transliterate(x) for x in xs] 24 | 25 | print(" ".join(xs), " ".join(ys), sep = "\t") 26 | -------------------------------------------------------------------------------- /parameters.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | UNIT = "word" # unit of tokenization (char, word, sent) 7 | MIN_LEN = 1 # minimum sequence length for training 8 | MAX_LEN = 50 # maximum sequence length for training and inference 9 | SRC_VOCAB_SIZE = 50000 # source vocabulary size (0: limitless) 10 | TGT_VOCAB_SIZE = 50000 # target vocabulary size (0: limitless) 11 | 12 | RNN_TYPE = "GRU" # GRU, LSTM 13 | NUM_DIRS = 2 # number of directions (1: unidirectional, 2: bidirectional) 14 | NUM_LAYERS = 2 15 | HRE = (UNIT == "sent") # hierarchical recurrent encoding 16 | ENC_EMBED = {"lookup": 300} # encoder embedding (cnn, rnn, lookup, sae) 17 | DEC_EMBED = {"lookup": 300} # decoder embedding (lookup only) 18 | HIDDEN_SIZE = 1000 19 | DROPOUT = 0.5 20 | LEARNING_RATE = 2e-4 21 | 22 | ATTN = True # attention mechanism 23 | COPY = False # copying mechanism 24 | 25 | BEAM_SIZE = 1 26 | BATCH_SIZE = 64 27 | 28 | VERBOSE = 0 # 0: None, 1: attention heatmap, 2: beam search 29 | EVAL_EVERY = 10 30 | SAVE_EVERY = 10 31 | NUM_DIGITS = 4 # number of decimal places to print 32 | 33 | PAD, PAD_IDX = "", 0 # padding 34 | SOS, SOS_IDX = "", 1 # start of sequence 35 | EOS, EOS_IDX = "", 2 # end of sequence 36 | UNK, UNK_IDX = "", 3 # unknown token 37 | 38 | CUDA = torch.cuda.is_available() 39 | torch.manual_seed(0) # for reproducibility 40 | # torch.cuda.set_device(0) 41 | 42 | assert ATTN != COPY 43 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from dataloader import * 3 | from rnn_encoder_decoder import * 4 | from search import * 5 | 6 | def load_model(): 7 | 8 | x_cti = load_tkn_to_idx(sys.argv[2]) 9 | x_wti = load_tkn_to_idx(sys.argv[3]) 10 | y_wti = load_tkn_to_idx(sys.argv[4]) 11 | y_itw = load_idx_to_tkn(sys.argv[4]) 12 | 13 | model = rnn_encoder_decoder(x_cti, x_wti, y_wti) 14 | print(model) 15 | 16 | load_checkpoint(sys.argv[1], model) 17 | 18 | return model, x_cti, x_wti, y_itw 19 | 20 | def run_model(model, data, y_itw): 21 | 22 | with torch.no_grad(): 23 | model.eval() 24 | 25 | for batch in data.batchify(BATCH_SIZE * BEAM_SIZE): 26 | 27 | xc, xw, _, lens = batch.sort() 28 | xc, xw = data.to_tensor(xc, xw, lens, eos = True) 29 | eos = [False for _ in xw] # EOS states 30 | b, t = len(xw), 0 31 | mask, lens = maskset(xw) 32 | 33 | batch.y1 = [[] for _ in xw] 34 | batch.prob = [0 for _ in xw] 35 | batch.attn = [[["", *batch.x1[i], EOS]] for i in batch.idx] 36 | batch.copy = [[["", *batch.x1[i]]] for i in batch.idx] 37 | 38 | model.dec.M, model.dec.H = model.enc(xc, xw, lens) 39 | model.init_state(b) 40 | yi = LongTensor([[SOS_IDX]] * b) 41 | 42 | while t < MAX_LEN and sum(eos) < len(eos): 43 | yo = model.dec(xw, yi, mask) 44 | args = (model.dec, batch, y_itw, eos, lens, yo) 45 | yi = beam_search(*args, t) if BEAM_SIZE > 1 else greedy_search(*args) 46 | t += 1 47 | 48 | batch.unsort() 49 | 50 | if VERBOSE: 51 | for i in range(0, len(batch.y1), BEAM_SIZE): 52 | i //= BEAM_SIZE 53 | print("attn[%d] =" % i) 54 | print(mat2csv(batch.attn[i]), end = "\n\n") 55 | if COPY: 56 | print("copy[%d] =" % i) 57 | print(mat2csv(batch.copy[i][:-1]), end = "\n\n") 58 | 59 | for i, (x0, y0, y1) in enumerate(zip(batch.x0, batch.y0, batch.y1)): 60 | if not i % BEAM_SIZE: # use the best candidate from each beam 61 | y1 = [y_itw[y] for y in y1[:-1]] 62 | yield x0, y0, y1 63 | 64 | def predict(model, x_cti, x_wti, y_itw, filename): 65 | 66 | data = dataloader(batch_first = True) 67 | fo = open(filename) 68 | 69 | for line in fo: 70 | data.append_row() 71 | 72 | x0, y0 = line.strip(), [] 73 | if x0.count("\t") == 1: 74 | x0, y0 = x0.split("\t") 75 | x1 = tokenize(x0, UNIT) 76 | 77 | xc = [[x_cti.get(c, UNK_IDX) for c in w] for w in x1] 78 | xw = [x_wti.get(w, UNK_IDX) for w in x1] 79 | 80 | data.append_item(x0 = x0, x1 = x1, xc = xc, xw = xw, y0 = y0) 81 | 82 | for _ in range(BEAM_SIZE - 1): 83 | data.clone_row() 84 | 85 | fo.close() 86 | 87 | return run_model(model, data, y_itw) 88 | 89 | if __name__ == "__main__": 90 | 91 | if len(sys.argv) != 6: 92 | sys.exit("Usage: %s model vocab.src.char_to_idx vocab.src.word_to_idx vocab.tgt.word_to_idx test_data" % sys.argv[0]) 93 | 94 | for x, y0, y1 in predict(*load_model(), sys.argv[5]): 95 | if y0: 96 | print((x, y0)) 97 | print((x, y1)) 98 | -------------------------------------------------------------------------------- /prepare.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | 3 | def lineiter(fo): 4 | 5 | for line in fo: 6 | x, y = line.split("\t") 7 | x = tokenize(x, UNIT) 8 | y = tokenize(y, UNIT) 9 | if len(x) < MIN_LEN or len(x) > MAX_LEN: 10 | continue 11 | if len(y) < MIN_LEN or len(y) > MAX_LEN: 12 | continue 13 | yield x, y 14 | 15 | def dict_to_tti(tti, vocab_size = 0): 16 | 17 | tokens = [PAD, SOS, EOS, UNK] # predefined tokens 18 | tti = sorted(tti, key = lambda x: -tti[x]) 19 | if vocab_size: 20 | tti = tti[:vocab_size] 21 | return {w: i for i, w in enumerate(tokens + tti)} 22 | 23 | def load_data(): 24 | 25 | data = [] 26 | x_cti = defaultdict(int) 27 | x_wti = defaultdict(int) 28 | y_wti = defaultdict(int) 29 | 30 | fo = open(sys.argv[1]) 31 | for x, y in lineiter(fo): 32 | for w in x: 33 | for c in w: 34 | x_cti[c] += 1 35 | x_wti[w] += 1 36 | for w in y: 37 | y_wti[w] += 1 38 | 39 | x_cti = dict_to_tti(x_cti) 40 | x_wti = dict_to_tti(x_wti, SRC_VOCAB_SIZE) 41 | y_wti = dict_to_tti(y_wti, TGT_VOCAB_SIZE) 42 | 43 | fo.seek(0) 44 | for x, y in lineiter(fo): 45 | x = ["+".join(str(x_cti[c]) for c in w) + ":%d" % x_wti.get(w, UNK_IDX) for w in x] 46 | y = [str(y_wti.get(w, UNK_IDX)) for w in y] 47 | data.append((x, y)) 48 | 49 | fo.close() 50 | data = sorted(data, key = lambda x: -len(x[0])) # sort by source sequence length 51 | 52 | return data, x_cti, x_wti, y_wti 53 | 54 | def save_data(filename, data): 55 | 56 | fo = open(filename, "w") 57 | for seq in data: 58 | if not seq: 59 | print(file = fo) 60 | continue 61 | print(*seq[0], end = "\t", file = fo) 62 | print(*seq[1], file = fo) 63 | fo.close() 64 | 65 | def save_tkn_to_idx(filename, tti): 66 | 67 | fo = open(filename, "w") 68 | for tkn, _ in sorted(tti.items(), key = lambda x: x[1]): 69 | fo.write("%s\n" % tkn) 70 | fo.close() 71 | 72 | if __name__ == "__main__": 73 | 74 | if len(sys.argv) != 2: 75 | sys.exit("Usage: %s training_data" % sys.argv[0]) 76 | 77 | data, x_cti, x_wti, y_wti = load_data() 78 | save_data(sys.argv[1] + ".csv", data) 79 | save_tkn_to_idx(sys.argv[1] + ".src.char_to_idx", x_cti) 80 | save_tkn_to_idx(sys.argv[1] + ".src.word_to_idx", x_wti) 81 | save_tkn_to_idx(sys.argv[1] + ".tgt.word_to_idx", y_wti) 82 | -------------------------------------------------------------------------------- /rnn_decoder.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from embedding import * 3 | 4 | class rnn_decoder(nn.Module): 5 | 6 | def __init__(self, x_wti, y_wti): 7 | 8 | super().__init__() 9 | self.M = None # encoder hidden states 10 | self.H = None # decoder hidden states 11 | self.h = None # decoder output 12 | 13 | # architecture 14 | self.embed = embed(DEC_EMBED, None, y_wti, batch_first = True) 15 | self.rnn = getattr(nn, RNN_TYPE)( 16 | input_size = self.embed.dim + HIDDEN_SIZE * (1 + COPY), 17 | hidden_size = HIDDEN_SIZE // NUM_DIRS, 18 | num_layers = NUM_LAYERS, 19 | bias = True, 20 | batch_first = True, 21 | dropout = DROPOUT, 22 | bidirectional = (NUM_DIRS == 2) 23 | ) 24 | self.attn = attn() 25 | if ATTN: 26 | self.Wc = nn.Linear(HIDDEN_SIZE * 2, HIDDEN_SIZE) 27 | self.Wo = nn.Linear(HIDDEN_SIZE, len(y_wti)) 28 | self.softmax = nn.LogSoftmax(1) 29 | if COPY: 30 | self.Wo = nn.Linear(HIDDEN_SIZE, len(y_wti)) 31 | self.copy = copy(x_wti, y_wti) 32 | 33 | def forward(self, xw, yi, mask): 34 | 35 | h = self.embed(None, None, yi) 36 | 37 | if ATTN: 38 | h = torch.cat([h, self.h], 2) # input feeding 39 | h, self.H = self.rnn(h, self.H) 40 | self.attn(self.M, h, mask) 41 | self.h = self.Wc(torch.cat([self.attn.V, h], 2)).tanh() 42 | h = self.Wo(self.h).squeeze(1) # [B, V] 43 | yo = self.softmax(h) 44 | 45 | if COPY: 46 | _M = self.M[:, :-1] # remove EOS token [B, L' = L - 1] 47 | self.attn(self.M, self.h, mask) # attentive read 48 | self.copy.attn(_M) # selective read 49 | h = torch.cat([h, self.attn.V, self.copy.R], 2) 50 | self.h, self.H = self.rnn(h, self.H) 51 | g = self.Wo(self.h).squeeze(1) # generation scores [B, V] 52 | c = self.copy.score(_M, self.h, mask) # copy scores [B, L'] 53 | yo = self.copy.mix(xw, g, c) # [B, V'] 54 | 55 | return yo 56 | 57 | class attn(nn.Module): # attention mechanism (Luong et al 2015) 58 | 59 | def __init__(self): 60 | 61 | super().__init__() 62 | 63 | # architecture 64 | self.W = None # attention weights 65 | self.V = None # context vector 66 | 67 | def forward(self, hs, ht, mask): 68 | 69 | a = ht.bmm(hs.transpose(1, 2)) # [B, 1, H] @ [B, H, L] = [B, 1, L] 70 | a = a.masked_fill(mask.unsqueeze(1), -10000) 71 | self.W = F.softmax(a, 2) # [B, 1, L] 72 | self.V = self.W.bmm(hs) # [B, 1, L] @ [B, L, H] = [B, 1, H] 73 | 74 | class copy(nn.Module): # copying mechanism (Gu et al 2016) 75 | 76 | def __init__(self, x_wti, y_wti): 77 | 78 | super().__init__() 79 | self.xyi = {i: y_wti[w] for w, i in x_wti.items() if w in y_wti} 80 | self.yxi = {i: x_wti[w] for w, i in y_wti.items() if w in x_wti} 81 | 82 | # architecture 83 | self.R = None # selective read 84 | self.W = nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE) # copy weights 85 | 86 | def attn(self, hs): # selective read 87 | 88 | self.R = self.R.unsqueeze(1).bmm(hs) # [B, 1, L'] @ [B, L', H] = [B, 1, H] 89 | 90 | def score(self, hs, ht, mask): # copy scores 91 | 92 | c = self.W(hs).tanh() # [B, L', H] 93 | c = ht.bmm(c.transpose(1, 2)) # [B, 1, H] @ [B, H, L'] = [B, 1, L'] 94 | c = c.squeeze(1).masked_fill(mask[:, :-1], -10000) # [B, L'] 95 | 96 | self.P = None # generation and copy probabilities 97 | self.R = F.softmax(c, 1) # selective read weights [B, L'] 98 | 99 | return c 100 | 101 | def map(self, xw, vocab_size): # source to target index mapping 102 | 103 | idx = [] 104 | oov = {} 105 | 106 | for i in xw.tolist(): 107 | idx.append([]) 108 | for j in i: 109 | if j in self.xyi: 110 | j = self.xyi[j] 111 | else: 112 | if j not in oov: 113 | oov[j] = vocab_size + len(oov) 114 | j = oov[j] 115 | idx[-1].append(j) 116 | 117 | idx = LongTensor(idx) # [B, L'] 118 | oov = Tensor(len(xw), len(oov)).fill_(1e-6) # [B, OOV] 119 | ohv = zeros(*xw.size(), vocab_size + oov.size(1)) # [B, L', V' = V + OOV] 120 | ohv = ohv.scatter(2, idx.unsqueeze(2), 1) # one hot vector 121 | 122 | return ohv, oov 123 | 124 | def mix(self, xw, g, c): 125 | 126 | z = F.softmax(torch.cat([g, c], 1), 1) # normalization 127 | self.P = g, c = z.split([g.size(1), c.size(1)], 1) 128 | 129 | ohv, oov = self.map(xw[:, :-1], g.size(1)) 130 | g = torch.cat([g, oov], 1) # [B, V'] 131 | c = c.unsqueeze(1).bmm(ohv) # [B, 1, L'] @ [B, L', V'] = [B, 1, V'] 132 | z = (g + c.squeeze(1)).log() # mixed probabilities [B, V'] 133 | 134 | return z 135 | -------------------------------------------------------------------------------- /rnn_encoder.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from embedding import * 3 | 4 | class rnn_encoder(nn.Module): 5 | 6 | def __init__(self, cti, wti): 7 | 8 | super().__init__() 9 | 10 | # architecture 11 | self.embed = embed(ENC_EMBED, cti, wti, batch_first = True, hre = HRE) 12 | self.rnn = getattr(nn, RNN_TYPE)( 13 | input_size = self.embed.dim, 14 | hidden_size = HIDDEN_SIZE // NUM_DIRS, 15 | num_layers = NUM_LAYERS, 16 | bias = True, 17 | batch_first = True, 18 | dropout = DROPOUT, 19 | bidirectional = (NUM_DIRS == 2) 20 | ) 21 | 22 | def init_state(self, b): # initialize states 23 | 24 | n = NUM_LAYERS * NUM_DIRS 25 | h = HIDDEN_SIZE // NUM_DIRS 26 | hs = zeros(n, b, h) # hidden state 27 | if RNN_TYPE == "GRU": 28 | return hs 29 | cs = zeros(n, b, h) # LSTM cell state 30 | return (hs, cs) 31 | 32 | def forward(self, xc, xw, lens): 33 | 34 | b = len(lens) 35 | s = self.init_state(b) 36 | 37 | h = self.embed(b, xc, xw) 38 | h = nn.utils.rnn.pack_padded_sequence(h, lens, batch_first = True) 39 | h, s = self.rnn(h, s) 40 | h, _ = nn.utils.rnn.pad_packed_sequence(h, batch_first = True) 41 | 42 | return h, s 43 | -------------------------------------------------------------------------------- /rnn_encoder_decoder.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from rnn_encoder import * 3 | from rnn_decoder import * 4 | 5 | class rnn_encoder_decoder(nn.Module): 6 | 7 | def __init__(self, x_cti, x_wti, y_wti): 8 | 9 | super().__init__() 10 | 11 | # architecture 12 | self.enc = rnn_encoder(x_cti, x_wti) 13 | self.dec = rnn_decoder(x_wti, y_wti) 14 | if CUDA: self = self.cuda() 15 | 16 | def init_state(self, b): 17 | 18 | self.dec.h = zeros(b, 1, HIDDEN_SIZE) 19 | self.dec.attn.W = zeros(b, 1, self.dec.M.size(1)) 20 | self.dec.attn.V = zeros(b, 1, HIDDEN_SIZE) 21 | 22 | if COPY: 23 | self.dec.copy.R = zeros(b, self.dec.M.size(1) - 1) 24 | 25 | def forward(self, xc, xw, y0): # for training 26 | 27 | self.zero_grad() 28 | b = len(xw) # batch size 29 | mask, lens = maskset(xw) 30 | 31 | self.dec.M, self.dec.H = self.enc(xc, xw, lens) 32 | self.init_state(b) 33 | yi = LongTensor([SOS_IDX] * b) 34 | yo = [] 35 | 36 | for t in range(y0.size(1)): 37 | yo.append(self.dec(xw, yi.unsqueeze(1), mask)) 38 | yi = y0[:, t] # teacher forcing 39 | 40 | yo = torch.stack(yo).transpose(0, 1).flatten(0, 1) 41 | loss = F.nll_loss(yo, y0.view(-1), ignore_index = PAD_IDX) 42 | 43 | return loss 44 | -------------------------------------------------------------------------------- /search.py: -------------------------------------------------------------------------------- 1 | from parameters import * 2 | from utils import * 3 | 4 | def greedy_search(dec, batch, itw, eos, lens, yo): 5 | 6 | p, yi = yo.topk(1) 7 | y = yi.view(-1).tolist() 8 | 9 | for i, _ in filter(lambda x: not x[1], enumerate(eos)): 10 | eos[i] = (y[i] == EOS_IDX) 11 | batch.y1[i].append(y[i]) 12 | batch.prob[i] += p[i] 13 | batch.attn[i].append([itw[y[i]], *dec.attn.W[i][0][:lens[i]]]) 14 | batch.copy[i].append([itw[y[i]], *dec.copy.P[1][i][:lens[i] - 1]] if COPY else []) 15 | 16 | return yi 17 | 18 | def beam_search(dec, batch, itw, eos, lens, yo, t): 19 | 20 | bp, by = yo[::1 if t else BEAM_SIZE].topk(BEAM_SIZE) # [B * BEAM_SIZE, BEAM_SIZE] 21 | 22 | if t: # update probabilities and reshape into [B, BEAM_SIZE * BEAM_SIZE] 23 | bp += Tensor([-10000 if e else p for p, e in zip(batch.prob, eos)]).unsqueeze(1) 24 | bp, by = bp.view(-1, BEAM_SIZE ** 2), by.view(-1, BEAM_SIZE ** 2) 25 | 26 | for i, (bp, by) in enumerate(zip(bp, by.tolist())): # for each sequence 27 | 28 | j, _y1, _prob, _attn, _copy = i * BEAM_SIZE, [], [], [], [] 29 | 30 | if VERBOSE >= 2: 31 | for k in range(0, len(bp), BEAM_SIZE): # for each previous beam 32 | q = j + k // BEAM_SIZE 33 | a = [(batch.prob[q], *(batch.y1[q][-1:] or [SOS_IDX]))] # previous token 34 | b = [(round(p.item(), NUM_DIGITS), y) # current candidates 35 | for p, y in list(zip(bp, by))[k:k + BEAM_SIZE]] 36 | print(f"beam[{t}][{i}][{k // BEAM_SIZE}] = {a} ->", *b) 37 | 38 | for p, k in zip(*bp.topk(BEAM_SIZE)): # append n-best candidates 39 | q = j + k // BEAM_SIZE 40 | _y1.append(batch.y1[q] + [by[k]]) 41 | _prob.append(p.item()) 42 | _attn.append(batch.attn[q] + [[itw[by[k]], *dec.attn.W[q][0][:lens[j]]]]) 43 | _copy.append(batch.copy[q] + [[itw[by[k]], *dec.copy.P[1][q][:lens[j] - 1]]] if COPY else []) 44 | 45 | for k in filter(lambda x: eos[x], range(j, j + BEAM_SIZE)): # append completed sequences 46 | _y1.append(batch.y1[k]) 47 | _prob.append(batch.prob[k]) 48 | _attn.append(batch.attn[k]) 49 | _copy.append(batch.copy[k] if COPY else []) 50 | 51 | topk = sorted(zip(_y1, _prob, _attn, _copy), key = lambda x: -x[1])[:BEAM_SIZE] 52 | 53 | for k, (_y1, _prob, _attn, _copy) in enumerate(topk, j): 54 | eos[k] = (_y1[-1] == EOS_IDX) 55 | batch.y1[k] = _y1 56 | batch.prob[k] = _prob 57 | batch.attn[k] = _attn 58 | batch.copy[k] = _copy 59 | 60 | if VERBOSE >= 2: 61 | print(f"output[{t}][{i}][{k - j}] = ", end = "") 62 | print(([itw[y] for y in _y1], round(_prob, NUM_DIGITS))) 63 | 64 | if VERBOSE >= 2: 65 | print() 66 | 67 | return LongTensor([y[-1] for y in batch.y1]).unsqueeze(1) 68 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from dataloader import * 3 | from rnn_encoder_decoder import * 4 | 5 | def load_data(): 6 | 7 | data = dataloader(batch_first = True) 8 | batch = [] 9 | x_cti = load_tkn_to_idx(sys.argv[2]) # source char_to_idx 10 | x_wti = load_tkn_to_idx(sys.argv[3]) # source word_to_idx 11 | y_wti = load_tkn_to_idx(sys.argv[4]) # target word_to_idx 12 | 13 | print(f"loading {sys.argv[5]}") 14 | 15 | fo = open(sys.argv[5]) 16 | for line in fo: 17 | x, y = line.strip().split("\t") 18 | x = [x.split(":") for x in x.split(" ")] 19 | y = list(map(int, y.split(" "))) 20 | xc, xw = zip(*[(list(map(int, xc.split("+"))), int(xw)) for xc, xw in x]) 21 | data.append_row() 22 | data.append_item(xc = xc, xw = xw, y0 = y) 23 | fo.close() 24 | 25 | for _batch in data.batchify(BATCH_SIZE): 26 | xc, xw, y0, lens = _batch.sort() 27 | xc, xw = data.to_tensor(xc, xw, lens, eos = True) 28 | _, y0 = data.to_tensor(None, y0, eos = True) 29 | batch.append((xc, xw, y0)) 30 | 31 | print("data size: %d" % len(data.y0)) 32 | print("batch size: %d" % (BATCH_SIZE)) 33 | 34 | return batch, x_cti, x_wti, y_wti 35 | 36 | def train(): 37 | 38 | num_epochs = int(sys.argv[-1]) 39 | batch, x_cti, x_wti, y_wti = load_data() 40 | model = rnn_encoder_decoder(x_cti, x_wti, y_wti) 41 | print(model) 42 | 43 | enc_optim = torch.optim.Adam(model.enc.parameters(), lr = LEARNING_RATE) 44 | dec_optim = torch.optim.Adam(model.dec.parameters(), lr = LEARNING_RATE) 45 | epoch = load_checkpoint(sys.argv[1], model) if isfile(sys.argv[1]) else 0 46 | filename = re.sub("\.epoch[0-9]+$", "", sys.argv[1]) 47 | 48 | print("training model") 49 | 50 | for ei in range(epoch + 1, epoch + num_epochs + 1): 51 | 52 | loss_sum = 0 53 | timer = time() 54 | 55 | for xc, xw, y0 in batch: 56 | 57 | loss = model(xc, xw, y0) # forward pass and compute loss 58 | loss.backward() # compute gradients 59 | enc_optim.step() # update encoder parameters 60 | dec_optim.step() # update decoder parameters 61 | loss_sum += loss.item() 62 | 63 | timer = time() - timer 64 | loss_sum /= len(batch) 65 | 66 | if ei % SAVE_EVERY and ei != epoch + num_epochs: 67 | save_checkpoint("", None, ei, loss_sum, timer) 68 | else: 69 | save_checkpoint(filename, model, ei, loss_sum, timer) 70 | 71 | if __name__ == "__main__": 72 | 73 | if len(sys.argv) != 7: 74 | sys.exit("Usage: %s model vocab.src.char_to_idx vocab.src.word_to_idx vocab.tgt.word_to_idx training_data num_epoch" % sys.argv[0]) 75 | 76 | train() 77 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import re 3 | from time import time 4 | from os.path import isfile 5 | from collections import defaultdict 6 | from parameters import * 7 | 8 | Tensor = torch.cuda.FloatTensor if CUDA else torch.FloatTensor 9 | LongTensor = torch.cuda.LongTensor if CUDA else torch.LongTensor 10 | triu = lambda *x: torch.triu(*x).cuda() if CUDA else torch.triu 11 | zeros = lambda *x: torch.zeros(*x).cuda() if CUDA else torch.zeros 12 | 13 | def normalize(x): 14 | 15 | if UNIT in ("word", "sent"): 16 | x = re.sub("(?<=[^ ])(?=[,.?!])", " ", x) 17 | x = re.sub("\s+", " ", x) 18 | x = re.sub("^ | $", "", x) 19 | # x = x.lower() 20 | return x 21 | 22 | def tokenize(x, norm = True): 23 | 24 | if norm: 25 | x = normalize(x) 26 | if UNIT == "char": 27 | return list(x) 28 | if UNIT in ("word", "sent"): 29 | return x.split(" ") 30 | 31 | def load_tkn_to_idx(filename): 32 | 33 | print("loading %s" % filename) 34 | tti = {} 35 | fo = open(filename) 36 | for line in fo: 37 | line = line[:-1] 38 | tti[line] = len(tti) 39 | fo.close() 40 | return tti 41 | 42 | def load_idx_to_tkn(filename): 43 | 44 | print("loading %s" % filename) 45 | itt = [] 46 | fo = open(filename) 47 | for line in fo: 48 | line = line[:-1] 49 | itt.append(line) 50 | fo.close() 51 | return itt 52 | 53 | def load_checkpoint(filename, model = None): 54 | 55 | print("loading %s" % filename) 56 | checkpoint = torch.load(filename) 57 | if model: 58 | model.enc.load_state_dict(checkpoint["enc_state_dict"]) 59 | model.dec.load_state_dict(checkpoint["dec_state_dict"]) 60 | epoch = checkpoint["epoch"] 61 | loss = checkpoint["loss"] 62 | print("epoch = %d, loss = %f" % (checkpoint["epoch"], checkpoint["loss"])) 63 | return epoch 64 | 65 | def save_checkpoint(filename, model, epoch, loss, time): 66 | 67 | print("epoch = %d, loss = %f, time = %f" % (epoch, loss, time)) 68 | if filename and model: 69 | checkpoint = {} 70 | checkpoint["enc_state_dict"] = model.enc.state_dict() 71 | checkpoint["dec_state_dict"] = model.dec.state_dict() 72 | checkpoint["epoch"] = epoch 73 | checkpoint["loss"] = loss 74 | torch.save(checkpoint, filename + ".epoch%d" % epoch) 75 | print("saved %s" % filename) 76 | 77 | def save_loss(filename, epoch, loss_array): 78 | 79 | fo = open(filename + ".epoch%d.loss" % epoch, "w") 80 | fo.write("\n".join(map(str, loss_array)) + "\n") 81 | fo.close() 82 | 83 | def maskset(x): 84 | 85 | mask = x.eq(PAD_IDX) 86 | lens = (x.size(1) - mask.sum(1)).tolist() # x.ne(PAD_IDX).sum(1) 87 | 88 | return mask, lens 89 | 90 | def mat2csv(m, ch = True, rh = True, delim = "\t"): 91 | 92 | csv = [] 93 | if ch: # column header 94 | csv.append(m[0]) # source sequence 95 | for row in m[ch:]: 96 | csv.append([]) 97 | if rh: # row header 98 | csv[-1].append(str(row[0])) # target sequence 99 | csv[-1] += [f"{x:.{NUM_DIGITS}f}" for x in row[rh:]] 100 | 101 | return "\n".join(delim.join(x) for x in csv) 102 | 103 | def f1(p, r): 104 | 105 | return 2 * (p * r) / (p + r) if p + r else 0 106 | --------------------------------------------------------------------------------