├── .gitignore ├── README.md ├── file.gif ├── main.py └── models ├── UTransformer.py └── common_layer.py /.gitignore: -------------------------------------------------------------------------------- 1 | .data 2 | *.pyc 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Universal-Transformer-Pytorch 2 | Simple and self-contained implementation of the [Universal Transformer](https://arxiv.org/abs/1807.03819) (Dehghani, 2018) in Pytorch. Please open issues if you find bugs, and send pull request if you want to contribuite. 3 | 4 | ![](file.gif) 5 | GIF taken from: [https://twitter.com/OriolVinyalsML/status/1017523208059260929](https://twitter.com/OriolVinyalsML/status/1017523208059260929) 6 | 7 | ## Universal Transformer 8 | The basic Transformer model has been taken from [https://github.com/kolloldas/torchnlp](https://github.com/kolloldas/torchnlp). For now it has been implemented: 9 | 10 | - Universal Transformer Encoder Decoder, with position and time embeddings. 11 | - [Adaptive Computation Time](https://arxiv.org/abs/1603.08983) (Graves, 2016) as describe in Universal Transformer paper. 12 | - Universal Transformer for bAbI data. 13 | 14 | ## Dependendency 15 | ``` 16 | python3 17 | pytorch 0.4 18 | torchtext 19 | argparse 20 | ``` 21 | ## How to run 22 | To run standard Universal Transformer on bAbI run: 23 | ``` 24 | python main.py --task 1 25 | ``` 26 | To run Adaptive Computation Time: 27 | ``` 28 | python main.py --task 1 --act 29 | ``` 30 | 31 | ## Results 32 | 10k over 10 run, get the maximum. 33 | 34 | In task 16 17 18 19 I notice that are very hard to converge also in training set. 35 | The problem seams to be the lr rate scheduling. Moreover, on 1K setting the results 36 | are very bad yet, maybe I have to tune some hyper-parameters. 37 | 38 | |Task | Uni-Trs| + ACT | Original | 39 | | --- |--- |--- |--- | 40 | | 1 | 0.0 | 0.0 | 0.0 | 41 | | 2 | 0.0 | 0.2 | 0.0 | 42 | | 3 | 0.8 | 2.4 | 0.4 | 43 | | 4 | 0.0 | 0.0 | 0.0 | 44 | | 5 | 0.4 | 0.1 | 0.0 | 45 | | 6 | 0.0 | 0.0 | 0.0 | 46 | | 7 | 0.4 | 0.0 | 0.0 | 47 | | 8 | 0.2 | 0.1 | 0.0 | 48 | | 9 | 0.0 | 0.0 | 0.0 | 49 | | 10 | 0.0 | 0.0 | 0.0 | 50 | | 11 | 0.0 | 0.0 | 0.0 | 51 | | 12 | 0.0 | 0.0 | 0.0 | 52 | | 13 | 0.0 | 0.0 | 0.0 | 53 | | 14 | 0.0 | 0.0 | 0.0 | 54 | | 15 | 0.0 | 0.0 | 0.0 | 55 | | 16 | 50.5 | 50.6 | 0.4 | 56 | | 17 | 13.7 | 14.1 | 0.6 | 57 | | 18 | 4 | 6.9 | 0.0 | 58 | | 19 | 79.2 | 65.2 | 2.8 | 59 | | 20 | 0.0 | 0.0 | 0.0 | 60 | |--- | --- | --- | --- | 61 | | avg | 7.46 | 6.98 | 0.21 | 62 | | fail | 3 | 3 | 0 | 63 | 64 | ## TODO 65 | - Visualize ACT on different tasks 66 | 67 | 87 | 107 | -------------------------------------------------------------------------------- /file.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreamad8/Universal-Transformer-Pytorch/e6b06375269e805a23acbb07ef1aa4d6402bce52/file.gif -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from torchtext import datasets 3 | from torchtext.datasets.babi import BABI20Field 4 | from models.UTransformer import BabiUTransformer 5 | from models.common_layer import NoamOpt 6 | import torch.nn as nn 7 | import torch 8 | import numpy as np 9 | from copy import deepcopy 10 | 11 | def parse_config(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--cuda", action="store_true") 14 | parser.add_argument("--save_path", type=str, default="save/") 15 | parser.add_argument("--task", type=int, default=1) 16 | parser.add_argument("--run_avg", type=int, default=10) 17 | parser.add_argument("--heads", type=int, default=2) 18 | parser.add_argument("--depth", type=int, default=128) 19 | parser.add_argument("--filter", type=int, default=128) 20 | parser.add_argument("--max_hops", type=int, default=6) 21 | parser.add_argument("--batch_size", type=int, default=100) 22 | parser.add_argument("--emb", type=int, default=128) 23 | parser.add_argument("--lr", type=float, default=0.001) 24 | parser.add_argument("--act", action="store_true") 25 | parser.add_argument("--act_loss_weight", type=float, default=0.001) 26 | parser.add_argument("--noam", action="store_true") 27 | parser.add_argument("--verbose", action="store_true") 28 | return parser.parse_args() 29 | 30 | 31 | def get_babi_vocab(task): 32 | text = BABI20Field(70) 33 | train, val, test = datasets.BABI20.splits(text, root='.data', task=task, joint=False, 34 | tenK=True, only_supporting=False) 35 | text.build_vocab(train) 36 | vocab_len = len(text.vocab.freqs) 37 | # print("VOCAB LEN:",vocab_len ) 38 | return vocab_len + 1 39 | 40 | def evaluate(model, criterion, loader): 41 | model.eval() 42 | acc = [] 43 | loss = [] 44 | for b in loader: 45 | story, query, answer = b.story,b.query,b.answer.squeeze() 46 | if(config.cuda): story, query, answer = story.cuda(), query.cuda(), answer.cuda() 47 | pred_prob = model(story, query) 48 | loss.append(criterion(pred_prob[0], answer).item()) 49 | pred = pred_prob[1].data.max(1)[1] # max func return (max, argmax) 50 | acc.append( pred.eq(answer.data).cpu().numpy() ) 51 | 52 | acc = np.concatenate(acc) 53 | acc = np.mean(acc) 54 | loss = np.mean(loss) 55 | return acc,loss 56 | 57 | def main(config): 58 | vocab_len = get_babi_vocab(config.task) 59 | train_iter, val_iter, test_iter = datasets.BABI20.iters(batch_size=config.batch_size, 60 | root='.data', 61 | memory_size=70, 62 | task=config.task, 63 | joint=False, 64 | tenK=False, 65 | only_supporting=False, 66 | sort=False, 67 | shuffle=True) 68 | model = BabiUTransformer(num_vocab=vocab_len, 69 | embedding_size=config.emb, 70 | hidden_size=config.emb, 71 | num_layers=config.max_hops, 72 | num_heads=config.heads, 73 | total_key_depth=config.depth, 74 | total_value_depth=config.depth, 75 | filter_size=config.filter, 76 | act=config.act) 77 | if(config.verbose): 78 | print(model) 79 | print("ACT",config.act) 80 | if(config.cuda): model.cuda() 81 | 82 | criterion = nn.CrossEntropyLoss() 83 | if(config.noam): 84 | opt = NoamOpt(config.emb, 1, 4000, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 85 | else: 86 | opt = torch.optim.Adam(model.parameters(),lr=config.lr) 87 | 88 | if(config.verbose): 89 | acc_val, loss_val = evaluate(model, criterion, val_iter) 90 | print("RAND_VAL ACC:{:.4f}\t RAND_VAL LOSS:{:.4f}".format(acc_val, loss_val)) 91 | correct = [] 92 | loss_nb = [] 93 | cnt_batch = 0 94 | avg_best = 0 95 | cnt = 0 96 | model.train() 97 | for b in train_iter: 98 | story, query, answer = b.story,b.query,b.answer.squeeze() 99 | if(config.cuda): story, query, answer = story.cuda(), query.cuda(), answer.cuda() 100 | if(config.noam): 101 | opt.optimizer.zero_grad() 102 | else: 103 | opt.zero_grad() 104 | pred_prob = model(story, query) 105 | loss = criterion(pred_prob[0], answer) 106 | if(config.act): 107 | R_t = pred_prob[2][0] 108 | N_t = pred_prob[2][1] 109 | p_t = R_t + N_t 110 | avg_p_t = torch.sum(torch.sum(p_t,dim=1)/p_t.size(1))/p_t.size(0) 111 | loss += config.act_loss_weight * avg_p_t.item() 112 | 113 | loss.backward() 114 | opt.step() 115 | 116 | ## LOG 117 | loss_nb.append(loss.item()) 118 | pred = pred_prob[1].data.max(1)[1] # max func return (max, argmax) 119 | correct.append(np.mean(pred.eq(answer.data).cpu().numpy())) 120 | cnt_batch += 1 121 | if(cnt_batch % 10 == 0): 122 | acc = np.mean(correct) 123 | loss_nb = np.mean(loss_nb) 124 | if(config.verbose): 125 | print("TRN ACC:{:.4f}\tTRN LOSS:{:.4f}".format(acc, loss_nb)) 126 | 127 | acc_val, loss_val = evaluate(model, criterion, val_iter) 128 | if(config.verbose): 129 | print("VAL ACC:{:.4f}\tVAL LOSS:{:.4f}".format(acc_val, loss_val)) 130 | 131 | if(acc_val > avg_best): 132 | avg_best = acc_val 133 | weights_best = deepcopy(model.state_dict()) 134 | cnt = 0 135 | else: 136 | cnt += 1 137 | if(cnt == 45): break 138 | if(avg_best == 1.0): break 139 | 140 | correct = [] 141 | loss_nb = [] 142 | cnt_batch = 0 143 | 144 | 145 | model.load_state_dict({ name: weights_best[name] for name in weights_best }) 146 | acc_test, loss_test = evaluate(model, criterion, test_iter) 147 | if(config.verbose): 148 | print("TST ACC:{:.4f}\tTST LOSS:{:.4f}".format(acc_val, loss_val)) 149 | return acc_test 150 | 151 | if __name__ == "__main__": 152 | config = parse_config() 153 | for t in range(1,21): 154 | config.task = t 155 | acc = [] 156 | for i in range(config.run_avg): 157 | acc.append(main(config)) 158 | print("Noam",config.noam,"ACT",config.act,"Task:",config.task,"Max:",max(acc),"Mean:",np.mean(acc),"Std:",np.std(acc)) 159 | 160 | -------------------------------------------------------------------------------- /models/UTransformer.py: -------------------------------------------------------------------------------- 1 | ### TAKEN FROM https://github.com/kolloldas/torchnlp 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | import torch.nn.init as I 8 | import numpy as np 9 | import math 10 | from models.common_layer import EncoderLayer ,DecoderLayer ,MultiHeadAttention ,Conv ,PositionwiseFeedForward ,LayerNorm ,_gen_bias_mask ,_gen_timing_signal 11 | 12 | class BabiUTransformer(nn.Module): 13 | """ 14 | A Transformer Module For BabI data. 15 | Inputs should be in the shape story: [batch_size, memory_size, story_len ] 16 | query: [batch_size, 1, story_len] 17 | Outputs will have the shape [batch_size, ] 18 | """ 19 | def __init__(self, num_vocab, embedding_size, hidden_size, num_layers, num_heads, total_key_depth, total_value_depth, 20 | filter_size, max_length=71, input_dropout=0.0, layer_dropout=0.0, 21 | attention_dropout=0.0, relu_dropout=0.0, use_mask=False, act=False ): 22 | super(BabiUTransformer, self).__init__() 23 | self.embedding_dim = embedding_size 24 | self.emb = nn.Embedding(num_vocab, embedding_size, padding_idx=0) 25 | self.transformer_enc = Encoder(embedding_size, hidden_size, num_layers, num_heads, total_key_depth, total_value_depth, 26 | filter_size, max_length=71, input_dropout=input_dropout, layer_dropout=layer_dropout, 27 | attention_dropout=attention_dropout, relu_dropout=relu_dropout, use_mask=False, act=act) 28 | 29 | self.W = nn.Linear(self.embedding_dim,num_vocab) 30 | 31 | 32 | # Share the weight matrix between target word embedding & the final logit dense layer 33 | self.W.weight = self.emb.weight 34 | 35 | self.softmax = nn.Softmax(dim=1) 36 | ## POSITIONAL MASK 37 | self.mask = nn.Parameter(I.constant_(torch.empty(11, self.embedding_dim), 1)) 38 | 39 | def forward(self,story, query): 40 | 41 | story_size = story.size() 42 | ## STORY ENCODER + MUlt Mask 43 | embed = self.emb(story.view(story.size(0), -1)) 44 | embed = embed.view(story_size+(embed.size(-1),)) 45 | embed_story = torch.sum(embed*self.mask[:story.size(2),:].unsqueeze(0), 2) 46 | 47 | ## QUERY ENCODER + MUlt Mask 48 | query_embed = self.emb(query) 49 | embed_query = torch.sum(query_embed.unsqueeze(1)*self.mask[:query.size(1),:], 2) 50 | 51 | ## CONCAT STORY AND QUERY 52 | embed = torch.cat([embed_story, embed_query],dim=1) 53 | 54 | ## APPLY TRANSFORMER 55 | logit, act = self.transformer_enc(embed) 56 | 57 | a_hat = self.W(torch.sum(logit,dim=1)/logit.size(1)) ## reduce mean 58 | 59 | return a_hat, self.softmax(a_hat), act 60 | 61 | 62 | 63 | class Encoder(nn.Module): 64 | """ 65 | A Transformer Encoder module. 66 | Inputs should be in the shape [batch_size, length, hidden_size] 67 | Outputs will have the shape [batch_size, length, hidden_size] 68 | Refer Fig.1 in https://arxiv.org/pdf/1706.03762.pdf 69 | """ 70 | def __init__(self, embedding_size, hidden_size, num_layers, num_heads, total_key_depth, total_value_depth, 71 | filter_size, max_length=100, input_dropout=0.0, layer_dropout=0.0, 72 | attention_dropout=0.0, relu_dropout=0.0, use_mask=False, act=False): 73 | """ 74 | Parameters: 75 | embedding_size: Size of embeddings 76 | hidden_size: Hidden size 77 | num_layers: Total layers in the Encoder 78 | num_heads: Number of attention heads 79 | total_key_depth: Size of last dimension of keys. Must be divisible by num_head 80 | total_value_depth: Size of last dimension of values. Must be divisible by num_head 81 | output_depth: Size last dimension of the final output 82 | filter_size: Hidden size of the middle layer in FFN 83 | max_length: Max sequence length (required for timing signal) 84 | input_dropout: Dropout just after embedding 85 | layer_dropout: Dropout for each layer 86 | attention_dropout: Dropout probability after attention (Should be non-zero only during training) 87 | relu_dropout: Dropout probability after relu in FFN (Should be non-zero only during training) 88 | use_mask: Set to True to turn on future value masking 89 | """ 90 | 91 | super(Encoder, self).__init__() 92 | 93 | self.timing_signal = _gen_timing_signal(max_length, hidden_size) 94 | ## for t 95 | self.position_signal = _gen_timing_signal(num_layers, hidden_size) 96 | 97 | self.num_layers = num_layers 98 | self.act = act 99 | params =(hidden_size, 100 | total_key_depth or hidden_size, 101 | total_value_depth or hidden_size, 102 | filter_size, 103 | num_heads, 104 | _gen_bias_mask(max_length) if use_mask else None, 105 | layer_dropout, 106 | attention_dropout, 107 | relu_dropout) 108 | 109 | self.proj_flag = False 110 | if(embedding_size == hidden_size): 111 | self.embedding_proj = nn.Linear(embedding_size, hidden_size, bias=False) 112 | self.proj_flag = True 113 | 114 | self.enc = EncoderLayer(*params) 115 | 116 | self.layer_norm = LayerNorm(hidden_size) 117 | self.input_dropout = nn.Dropout(input_dropout) 118 | if(self.act): 119 | self.act_fn = ACT_basic(hidden_size) 120 | 121 | def forward(self, inputs): 122 | 123 | #Add input dropout 124 | x = self.input_dropout(inputs) 125 | 126 | if(self.proj_flag): 127 | # Project to hidden size 128 | x = self.embedding_proj(x) 129 | 130 | if(self.act): 131 | x, (remainders,n_updates) = self.act_fn(x, inputs, self.enc, self.timing_signal, self.position_signal, self.num_layers) 132 | return x, (remainders,n_updates) 133 | else: 134 | for l in range(self.num_layers): 135 | x += self.timing_signal[:, :inputs.shape[1], :].type_as(inputs.data) 136 | x += self.position_signal[:, l, :].unsqueeze(1).repeat(1,inputs.shape[1],1).type_as(inputs.data) 137 | x = self.enc(x) 138 | return x, None 139 | 140 | def get_attn_key_pad_mask(seq_k, seq_q): 141 | ''' For masking out the padding part of key sequence. ''' 142 | # Expand to fit the shape of key query attention matrix. 143 | len_q = seq_q.size(1) 144 | PAD = 0 145 | padding_mask = seq_k.eq(PAD) 146 | padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk 147 | 148 | return padding_mask 149 | 150 | class Decoder(nn.Module): 151 | """ 152 | A Transformer Decoder module. 153 | Inputs should be in the shape [batch_size, length, hidden_size] 154 | Outputs will have the shape [batch_size, length, hidden_size] 155 | Refer Fig.1 in https://arxiv.org/pdf/1706.03762.pdf 156 | """ 157 | def __init__(self, embedding_size, hidden_size, num_layers, num_heads, total_key_depth, total_value_depth, 158 | filter_size, max_length=100, input_dropout=0.0, layer_dropout=0.0, 159 | attention_dropout=0.0, relu_dropout=0.0, act=False): 160 | """ 161 | Parameters: 162 | embedding_size: Size of embeddings 163 | hidden_size: Hidden size 164 | num_layers: Total layers in the Encoder 165 | num_heads: Number of attention heads 166 | total_key_depth: Size of last dimension of keys. Must be divisible by num_head 167 | total_value_depth: Size of last dimension of values. Must be divisible by num_head 168 | output_depth: Size last dimension of the final output 169 | filter_size: Hidden size of the middle layer in FFN 170 | max_length: Max sequence length (required for timing signal) 171 | input_dropout: Dropout just after embedding 172 | layer_dropout: Dropout for each layer 173 | attention_dropout: Dropout probability after attention (Should be non-zero only during training) 174 | relu_dropout: Dropout probability after relu in FFN (Should be non-zero only during training) 175 | """ 176 | 177 | super(Decoder, self).__init__() 178 | 179 | self.timing_signal = _gen_timing_signal(max_length, hidden_size) 180 | self.position_signal = _gen_timing_signal(num_layers, hidden_size) 181 | self.num_layers = num_layers 182 | self.act = act 183 | params =(hidden_size, 184 | total_key_depth or hidden_size, 185 | total_value_depth or hidden_size, 186 | filter_size, 187 | num_heads, 188 | _gen_bias_mask(max_length), # mandatory 189 | layer_dropout, 190 | attention_dropout, 191 | relu_dropout) 192 | 193 | self.proj_flag = False 194 | if(embedding_size == hidden_size): 195 | self.embedding_proj = nn.Linear(embedding_size, hidden_size, bias=False) 196 | self.proj_flag = True 197 | self.dec = DecoderLayer(*params) 198 | 199 | self.layer_norm = LayerNorm(hidden_size) 200 | self.input_dropout = nn.Dropout(input_dropout) 201 | if(self.act): 202 | self.act_fn = ACT_basic(hidden_size) 203 | 204 | def forward(self, inputs, encoder_output): 205 | #Add input dropout 206 | x = self.input_dropout(inputs) 207 | 208 | if(self.proj_flag): 209 | # Project to hidden size 210 | x = self.embedding_proj(x) 211 | 212 | if(self.act): 213 | x, (remainders,n_updates) = self.act_fn(x, inputs, self.dec, self.timing_signal, self.position_signal, self.num_layers, encoder_output) 214 | return x, (remainders,n_updates) 215 | else: 216 | for l in range(self.num_layers): 217 | x += self.timing_signal[:, :inputs.shape[1], :].type_as(inputs.data) 218 | x += self.position_signal[:, l, :].unsqueeze(1).repeat(1,inputs.shape[1],1).type_as(inputs.data) 219 | x, _ = self.dec((x, encoder_output)) 220 | return x 221 | 222 | 223 | 224 | ### CONVERTED FROM https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/research/universal_transformer_util.py#L1062 225 | class ACT_basic(nn.Module): 226 | def __init__(self,hidden_size): 227 | super(ACT_basic, self).__init__() 228 | self.sigma = nn.Sigmoid() 229 | self.p = nn.Linear(hidden_size,1) 230 | self.p.bias.data.fill_(1) 231 | self.threshold = 1 - 0.1 232 | 233 | def forward(self, state, inputs, fn, time_enc, pos_enc, max_hop, encoder_output=None): 234 | # init_hdd 235 | ## [B, S] 236 | halting_probability = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda() 237 | ## [B, S 238 | remainders = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda() 239 | ## [B, S] 240 | n_updates = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda() 241 | ## [B, S, HDD] 242 | previous_state = torch.zeros_like(inputs).cuda() 243 | step = 0 244 | # for l in range(self.num_layers): 245 | while( ((halting_probability self.threshold).float() * still_running 256 | 257 | # Mask of inputs which haven't halted, and didn't halt this step 258 | still_running = (halting_probability + p * still_running <= self.threshold).float() * still_running 259 | 260 | # Add the halting probability for this step to the halting 261 | # probabilities for those input which haven't halted yet 262 | halting_probability = halting_probability + p * still_running 263 | 264 | # Compute remainders for the inputs which halted at this step 265 | remainders = remainders + new_halted * (1 - halting_probability) 266 | 267 | # Add the remainders to those inputs which halted at this step 268 | halting_probability = halting_probability + new_halted * remainders 269 | 270 | # Increment n_updates for all inputs which are still running 271 | n_updates = n_updates + still_running + new_halted 272 | 273 | # Compute the weight to be applied to the new state and output 274 | # 0 when the input has already halted 275 | # p when the input hasn't halted yet 276 | # the remainders when it halted this step 277 | update_weights = p * still_running + new_halted * remainders 278 | 279 | if(encoder_output): 280 | state, _ = fn((state,encoder_output)) 281 | else: 282 | # apply transformation on the state 283 | state = fn(state) 284 | 285 | # update running part in the weighted state and keep the rest 286 | previous_state = ((state * update_weights.unsqueeze(-1)) + (previous_state * (1 - update_weights.unsqueeze(-1)))) 287 | ## previous_state is actually the new_state at end of hte loop 288 | ## to save a line I assigned to previous_state so in the next 289 | ## iteration is correct. Notice that indeed we return previous_state 290 | step+=1 291 | return previous_state, (remainders,n_updates) 292 | -------------------------------------------------------------------------------- /models/common_layer.py: -------------------------------------------------------------------------------- 1 | ### MOSTO OF IT TAKEN FROM https://github.com/kolloldas/torchnlp 2 | ## MINOR CHANGES 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | import torch.nn.init as I 8 | import numpy as np 9 | import math 10 | 11 | 12 | class EncoderLayer(nn.Module): 13 | """ 14 | Represents one Encoder layer of the Transformer Encoder 15 | Refer Fig. 1 in https://arxiv.org/pdf/1706.03762.pdf 16 | NOTE: The layer normalization step has been moved to the input as per latest version of T2T 17 | """ 18 | def __init__(self, hidden_size, total_key_depth, total_value_depth, filter_size, num_heads, 19 | bias_mask=None, layer_dropout=0.0, attention_dropout=0.0, relu_dropout=0.0): 20 | """ 21 | Parameters: 22 | hidden_size: Hidden size 23 | total_key_depth: Size of last dimension of keys. Must be divisible by num_head 24 | total_value_depth: Size of last dimension of values. Must be divisible by num_head 25 | output_depth: Size last dimension of the final output 26 | filter_size: Hidden size of the middle layer in FFN 27 | num_heads: Number of attention heads 28 | bias_mask: Masking tensor to prevent connections to future elements 29 | layer_dropout: Dropout for this layer 30 | attention_dropout: Dropout probability after attention (Should be non-zero only during training) 31 | relu_dropout: Dropout probability after relu in FFN (Should be non-zero only during training) 32 | """ 33 | 34 | super(EncoderLayer, self).__init__() 35 | 36 | self.multi_head_attention = MultiHeadAttention(hidden_size, total_key_depth, total_value_depth, 37 | hidden_size, num_heads, bias_mask, attention_dropout) 38 | 39 | self.positionwise_feed_forward = PositionwiseFeedForward(hidden_size, filter_size, hidden_size, 40 | layer_config='cc', padding = 'both', 41 | dropout=relu_dropout) 42 | self.dropout = nn.Dropout(layer_dropout) 43 | self.layer_norm_mha = LayerNorm(hidden_size) 44 | self.layer_norm_ffn = LayerNorm(hidden_size) 45 | 46 | def forward(self, inputs): 47 | x = inputs 48 | 49 | # Layer Normalization 50 | x_norm = self.layer_norm_mha(x) 51 | 52 | # Multi-head attention 53 | y = self.multi_head_attention(x_norm, x_norm, x_norm) 54 | 55 | # Dropout and residual 56 | x = self.dropout(x + y) 57 | 58 | # Layer Normalization 59 | x_norm = self.layer_norm_ffn(x) 60 | 61 | # Positionwise Feedforward 62 | y = self.positionwise_feed_forward(x_norm) 63 | 64 | # Dropout and residual 65 | y = self.dropout(x + y) 66 | 67 | return y 68 | 69 | class DecoderLayer(nn.Module): 70 | """ 71 | Represents one Decoder layer of the Transformer Decoder 72 | Refer Fig. 1 in https://arxiv.org/pdf/1706.03762.pdf 73 | NOTE: The layer normalization step has been moved to the input as per latest version of T2T 74 | """ 75 | def __init__(self, hidden_size, total_key_depth, total_value_depth, filter_size, num_heads, 76 | bias_mask, layer_dropout=0.0, attention_dropout=0.0, relu_dropout=0.0): 77 | """ 78 | Parameters: 79 | hidden_size: Hidden size 80 | total_key_depth: Size of last dimension of keys. Must be divisible by num_head 81 | total_value_depth: Size of last dimension of values. Must be divisible by num_head 82 | output_depth: Size last dimension of the final output 83 | filter_size: Hidden size of the middle layer in FFN 84 | num_heads: Number of attention heads 85 | bias_mask: Masking tensor to prevent connections to future elements 86 | layer_dropout: Dropout for this layer 87 | attention_dropout: Dropout probability after attention (Should be non-zero only during training) 88 | relu_dropout: Dropout probability after relu in FFN (Should be non-zero only during training) 89 | """ 90 | 91 | super(DecoderLayer, self).__init__() 92 | 93 | self.multi_head_attention_dec = MultiHeadAttention(hidden_size, total_key_depth, total_value_depth, 94 | hidden_size, num_heads, bias_mask, attention_dropout) 95 | 96 | self.multi_head_attention_enc_dec = MultiHeadAttention(hidden_size, total_key_depth, total_value_depth, 97 | hidden_size, num_heads, None, attention_dropout) 98 | 99 | self.positionwise_feed_forward = PositionwiseFeedForward(hidden_size, filter_size, hidden_size, 100 | layer_config='cc', padding = 'left', 101 | dropout=relu_dropout) 102 | self.dropout = nn.Dropout(layer_dropout) 103 | self.layer_norm_mha_dec = LayerNorm(hidden_size) 104 | self.layer_norm_mha_enc = LayerNorm(hidden_size) 105 | self.layer_norm_ffn = LayerNorm(hidden_size) 106 | 107 | 108 | def forward(self, inputs): 109 | """ 110 | NOTE: Inputs is a tuple consisting of decoder inputs and encoder output 111 | """ 112 | x, encoder_outputs = inputs 113 | 114 | # Layer Normalization before decoder self attention 115 | x_norm = self.layer_norm_mha_dec(x) 116 | 117 | # Masked Multi-head attention 118 | y = self.multi_head_attention_dec(x_norm, x_norm, x_norm) 119 | 120 | # Dropout and residual after self-attention 121 | x = self.dropout(x + y) 122 | 123 | # Layer Normalization before encoder-decoder attention 124 | x_norm = self.layer_norm_mha_enc(x) 125 | 126 | # Multi-head encoder-decoder attention 127 | y = self.multi_head_attention_enc_dec(x_norm, encoder_outputs, encoder_outputs) 128 | 129 | # Dropout and residual after encoder-decoder attention 130 | x = self.dropout(x + y) 131 | 132 | # Layer Normalization 133 | x_norm = self.layer_norm_ffn(x) 134 | 135 | # Positionwise Feedforward 136 | y = self.positionwise_feed_forward(x_norm) 137 | 138 | # Dropout and residual after positionwise feed forward layer 139 | y = self.dropout(x + y) 140 | 141 | # Return encoder outputs as well to work with nn.Sequential 142 | return y, encoder_outputs 143 | 144 | 145 | 146 | class MultiHeadAttention(nn.Module): 147 | """ 148 | Multi-head attention as per https://arxiv.org/pdf/1706.03762.pdf 149 | Refer Figure 2 150 | """ 151 | def __init__(self, input_depth, total_key_depth, total_value_depth, output_depth, 152 | num_heads, bias_mask=None, dropout=0.0): 153 | """ 154 | Parameters: 155 | input_depth: Size of last dimension of input 156 | total_key_depth: Size of last dimension of keys. Must be divisible by num_head 157 | total_value_depth: Size of last dimension of values. Must be divisible by num_head 158 | output_depth: Size last dimension of the final output 159 | num_heads: Number of attention heads 160 | bias_mask: Masking tensor to prevent connections to future elements 161 | dropout: Dropout probability (Should be non-zero only during training) 162 | """ 163 | super(MultiHeadAttention, self).__init__() 164 | # Checks borrowed from 165 | # https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/layers/common_attention.py 166 | if total_key_depth % num_heads != 0: 167 | raise ValueError("Key depth (%d) must be divisible by the number of " 168 | "attention heads (%d)." % (total_key_depth, num_heads)) 169 | if total_value_depth % num_heads != 0: 170 | raise ValueError("Value depth (%d) must be divisible by the number of " 171 | "attention heads (%d)." % (total_value_depth, num_heads)) 172 | 173 | self.num_heads = num_heads 174 | self.query_scale = (total_key_depth//num_heads)**-0.5 175 | self.bias_mask = bias_mask 176 | 177 | # Key and query depth will be same 178 | self.query_linear = nn.Linear(input_depth, total_key_depth, bias=False) 179 | self.key_linear = nn.Linear(input_depth, total_key_depth, bias=False) 180 | self.value_linear = nn.Linear(input_depth, total_value_depth, bias=False) 181 | self.output_linear = nn.Linear(total_value_depth, output_depth, bias=False) 182 | 183 | self.dropout = nn.Dropout(dropout) 184 | 185 | def _split_heads(self, x): 186 | """ 187 | Split x such to add an extra num_heads dimension 188 | Input: 189 | x: a Tensor with shape [batch_size, seq_length, depth] 190 | Returns: 191 | A Tensor with shape [batch_size, num_heads, seq_length, depth/num_heads] 192 | """ 193 | if len(x.shape) != 3: 194 | raise ValueError("x must have rank 3") 195 | shape = x.shape 196 | return x.view(shape[0], shape[1], self.num_heads, shape[2]//self.num_heads).permute(0, 2, 1, 3) 197 | 198 | def _merge_heads(self, x): 199 | """ 200 | Merge the extra num_heads into the last dimension 201 | Input: 202 | x: a Tensor with shape [batch_size, num_heads, seq_length, depth/num_heads] 203 | Returns: 204 | A Tensor with shape [batch_size, seq_length, depth] 205 | """ 206 | if len(x.shape) != 4: 207 | raise ValueError("x must have rank 4") 208 | shape = x.shape 209 | return x.permute(0, 2, 1, 3).contiguous().view(shape[0], shape[2], shape[3]*self.num_heads) 210 | 211 | def forward(self, queries, keys, values, src_mask=None): 212 | 213 | # Do a linear for each component 214 | queries = self.query_linear(queries) 215 | keys = self.key_linear(keys) 216 | values = self.value_linear(values) 217 | 218 | # Split into multiple heads 219 | queries = self._split_heads(queries) 220 | keys = self._split_heads(keys) 221 | values = self._split_heads(values) 222 | 223 | # Scale queries 224 | queries *= self.query_scale 225 | 226 | # Combine queries and keys 227 | logits = torch.matmul(queries, keys.permute(0, 1, 3, 2)) 228 | 229 | 230 | if src_mask is not None: 231 | logits = logits.masked_fill(src_mask, -np.inf) 232 | 233 | # Add bias to mask future values 234 | if self.bias_mask is not None: 235 | logits += self.bias_mask[:, :, :logits.shape[-2], :logits.shape[-1]].type_as(logits.data) 236 | 237 | # Convert to probabilites 238 | weights = nn.functional.softmax(logits, dim=-1) 239 | 240 | # Dropout 241 | weights = self.dropout(weights) 242 | 243 | # Combine with values to get context 244 | contexts = torch.matmul(weights, values) 245 | 246 | # Merge heads 247 | contexts = self._merge_heads(contexts) 248 | #contexts = torch.tanh(contexts) 249 | 250 | # Linear to get output 251 | outputs = self.output_linear(contexts) 252 | 253 | return outputs 254 | 255 | class Conv(nn.Module): 256 | """ 257 | Convenience class that does padding and convolution for inputs in the format 258 | [batch_size, sequence length, hidden size] 259 | """ 260 | def __init__(self, input_size, output_size, kernel_size, pad_type): 261 | """ 262 | Parameters: 263 | input_size: Input feature size 264 | output_size: Output feature size 265 | kernel_size: Kernel width 266 | pad_type: left -> pad on the left side (to mask future data), 267 | both -> pad on both sides 268 | """ 269 | super(Conv, self).__init__() 270 | padding = (kernel_size - 1, 0) if pad_type == 'left' else (kernel_size//2, (kernel_size - 1)//2) 271 | self.pad = nn.ConstantPad1d(padding, 0) 272 | self.conv = nn.Conv1d(input_size, output_size, kernel_size=kernel_size, padding=0) 273 | 274 | def forward(self, inputs): 275 | inputs = self.pad(inputs.permute(0, 2, 1)) 276 | outputs = self.conv(inputs).permute(0, 2, 1) 277 | 278 | return outputs 279 | 280 | 281 | class PositionwiseFeedForward(nn.Module): 282 | """ 283 | Does a Linear + RELU + Linear on each of the timesteps 284 | """ 285 | def __init__(self, input_depth, filter_size, output_depth, layer_config='ll', padding='left', dropout=0.0): 286 | """ 287 | Parameters: 288 | input_depth: Size of last dimension of input 289 | filter_size: Hidden size of the middle layer 290 | output_depth: Size last dimension of the final output 291 | layer_config: ll -> linear + ReLU + linear 292 | cc -> conv + ReLU + conv etc. 293 | padding: left -> pad on the left side (to mask future data), 294 | both -> pad on both sides 295 | dropout: Dropout probability (Should be non-zero only during training) 296 | """ 297 | super(PositionwiseFeedForward, self).__init__() 298 | 299 | layers = [] 300 | sizes = ([(input_depth, filter_size)] + 301 | [(filter_size, filter_size)]*(len(layer_config)-2) + 302 | [(filter_size, output_depth)]) 303 | 304 | for lc, s in zip(list(layer_config), sizes): 305 | if lc == 'l': 306 | layers.append(nn.Linear(*s)) 307 | elif lc == 'c': 308 | layers.append(Conv(*s, kernel_size=3, pad_type=padding)) 309 | else: 310 | raise ValueError("Unknown layer type {}".format(lc)) 311 | 312 | self.layers = nn.ModuleList(layers) 313 | self.relu = nn.ReLU() 314 | self.dropout = nn.Dropout(dropout) 315 | 316 | def forward(self, inputs): 317 | x = inputs 318 | for i, layer in enumerate(self.layers): 319 | x = layer(x) 320 | if i < len(self.layers): 321 | x = self.relu(x) 322 | x = self.dropout(x) 323 | 324 | return x 325 | 326 | 327 | class LayerNorm(nn.Module): 328 | # Borrowed from jekbradbury 329 | # https://github.com/pytorch/pytorch/issues/1959 330 | def __init__(self, features, eps=1e-6): 331 | super(LayerNorm, self).__init__() 332 | self.gamma = nn.Parameter(torch.ones(features)) 333 | self.beta = nn.Parameter(torch.zeros(features)) 334 | self.eps = eps 335 | 336 | def forward(self, x): 337 | mean = x.mean(-1, keepdim=True) 338 | std = x.std(-1, keepdim=True) 339 | return self.gamma * (x - mean) / (std + self.eps) + self.beta 340 | 341 | 342 | def _gen_bias_mask(max_length): 343 | """ 344 | Generates bias values (-Inf) to mask future timesteps during attention 345 | """ 346 | np_mask = np.triu(np.full([max_length, max_length], -np.inf), 1) 347 | torch_mask = torch.from_numpy(np_mask).type(torch.FloatTensor) 348 | 349 | return torch_mask.unsqueeze(0).unsqueeze(1) 350 | 351 | def _gen_timing_signal(length, channels, min_timescale=1.0, max_timescale=1.0e4): 352 | """ 353 | Generates a [1, length, channels] timing signal consisting of sinusoids 354 | Adapted from: 355 | https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/layers/common_attention.py 356 | """ 357 | position = np.arange(length) 358 | num_timescales = channels // 2 359 | log_timescale_increment = ( math.log(float(max_timescale) / float(min_timescale)) / (float(num_timescales) - 1)) 360 | inv_timescales = min_timescale * np.exp(np.arange(num_timescales).astype(np.float) * -log_timescale_increment) 361 | scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales, 0) 362 | 363 | signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1) 364 | signal = np.pad(signal, [[0, 0], [0, channels % 2]], 365 | 'constant', constant_values=[0.0, 0.0]) 366 | signal = signal.reshape([1, length, channels]) 367 | 368 | return torch.from_numpy(signal).type(torch.FloatTensor) 369 | 370 | 371 | def position_encoding(sentence_size, embedding_dim): 372 | encoding = np.ones((embedding_dim, sentence_size), dtype=np.float32) 373 | ls = sentence_size + 1 374 | le = embedding_dim + 1 375 | for i in range(1, le): 376 | for j in range(1, ls): 377 | encoding[i-1, j-1] = (i - (embedding_dim+1)/2) * (j - (sentence_size+1)/2) 378 | encoding = 1 + 4 * encoding / embedding_dim / sentence_size 379 | # Make position encoding of time words identity to avoid modifying them 380 | # encoding[:, -1] = 1.0 381 | return np.transpose(encoding) 382 | 383 | 384 | class LabelSmoothing(nn.Module): 385 | "Implement label smoothing." 386 | def __init__(self, size, padding_idx, smoothing=0.0): 387 | super(LabelSmoothing, self).__init__() 388 | self.criterion = nn.KLDivLoss(reduction='sum') 389 | self.padding_idx = padding_idx 390 | self.confidence = 1.0 - smoothing 391 | self.smoothing = smoothing 392 | self.size = size 393 | self.true_dist = None 394 | 395 | def forward(self, x, target): 396 | assert x.size(1) == self.size 397 | true_dist = x.data.clone() 398 | true_dist.fill_(self.smoothing / (self.size - 2)) 399 | true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) 400 | true_dist[:, self.padding_idx] = 0 401 | mask = torch.nonzero(target.data == self.padding_idx) 402 | if mask.dim() > 0: 403 | true_dist.index_fill_(0, mask.squeeze(), 0.0) 404 | self.true_dist = true_dist 405 | return self.criterion(x, true_dist) 406 | 407 | 408 | class NoamOpt: 409 | "Optim wrapper that implements rate." 410 | def __init__(self, model_size, factor, warmup, optimizer): 411 | self.optimizer = optimizer 412 | self._step = 0 413 | self.warmup = warmup 414 | self.factor = factor 415 | self.model_size = model_size 416 | self._rate = 0 417 | 418 | def step(self): 419 | "Update parameters and rate" 420 | self._step += 1 421 | rate = self.rate() 422 | for p in self.optimizer.param_groups: 423 | p['lr'] = rate 424 | self._rate = rate 425 | self.optimizer.step() 426 | 427 | def rate(self, step = None): 428 | "Implement `lrate` above" 429 | if step is None: 430 | step = self._step 431 | return self.factor * \ 432 | (self.model_size ** (-0.5) * 433 | min(step ** (-0.5), step * self.warmup ** (-1.5))) --------------------------------------------------------------------------------