├── utils ├── __pycache__ │ ├── exp_utils.cpython-37.pyc │ ├── vocabulary.cpython-37.pyc │ ├── data_parallel.cpython-37.pyc │ ├── log_uniform_sampler.cpython-37.pyc │ └── proj_adaptive_softmax.cpython-37.pyc ├── exp_utils.py ├── adaptive_softmax.py ├── data_parallel.py ├── log_uniform_sampler.py ├── proj_adaptive_softmax.py └── vocabulary.py ├── AMInbest.sh ├── RTnbest.sh ├── SWBDnbest.sh ├── README.md ├── run_AMI.sh ├── run_SWBD.sh ├── data_utils.py ├── train_rnn.py ├── rescore.py └── mem_transformer_rnn.py /utils/__pycache__/exp_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BriansIDP/RTLM/HEAD/utils/__pycache__/exp_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/vocabulary.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BriansIDP/RTLM/HEAD/utils/__pycache__/vocabulary.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/data_parallel.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BriansIDP/RTLM/HEAD/utils/__pycache__/data_parallel.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/log_uniform_sampler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BriansIDP/RTLM/HEAD/utils/__pycache__/log_uniform_sampler.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/proj_adaptive_softmax.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BriansIDP/RTLM/HEAD/utils/__pycache__/proj_adaptive_softmax.cpython-37.pyc -------------------------------------------------------------------------------- /AMInbest.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=${X_SGE_CUDA_DEVICE} 2 | 3 | exp_no=1 4 | dataset=AMI 5 | headnum=8 6 | layernum=8 7 | layer=l0 8 | tag=_32_rnn_direct_${layer} 9 | 10 | expdir=${dataset}_transformer/AMI_${layernum}_${headnum}${tag} 11 | 12 | model=${expdir}/model.pt 13 | echo ${model} 14 | 15 | python rescore.py \ 16 | --data data/AMI \ 17 | --nbest rescore/time_sorted_eval.100bestlist \ 18 | --model ${model} \ 19 | --lm transformer_rnn_xl_${layer} \ 20 | --lmscale 10 \ 21 | --lookback 32 \ 22 | --subbatchsize 100 \ 23 | --cuda \ 24 | --mem_len 32 \ 25 | --logfile LOGs/nbestlog.txt \ 26 | --ppl \ 27 | -------------------------------------------------------------------------------- /RTnbest.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=${X_SGE_CUDA_DEVICE} 2 | 3 | exp_no=1 4 | dataset=SWBD 5 | headnum=8 6 | layernum=24 7 | layer=l0 8 | tag=_64_direct_${layer} 9 | 10 | expdir=${dataset}_transformer/${dataset}_${layernum}_${headnum}${tag} 11 | lm=transformer_rnn_xl_${layer} 12 | 13 | model=${expdir}/model.pt 14 | 15 | python forwardSWBD.py \ 16 | --data data/SWBD \ 17 | --nbest rescore_rt/time_sorted_rt03.nbestlist \ 18 | --model ${model} \ 19 | --lm ${lm} \ 20 | --lmscale 10 \ 21 | --lookback 64 \ 22 | --subbatchsize 100 \ 23 | --cuda \ 24 | --mem_len 64 \ 25 | --logfile LOGs/nbestlog.txt \ 26 | --ppl \ 27 | -------------------------------------------------------------------------------- /SWBDnbest.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=${X_SGE_CUDA_DEVICE} 2 | 3 | exp_no=1 4 | dataset=SWBD 5 | headnum=8 6 | layernum=24 7 | layer=l0 8 | tag=_64_direct_${layer} 9 | 10 | expdir=${dataset}_transformer/${dataset}_${layernum}_${headnum}${tag} 11 | lm=transformer_rnn_xl_${layer} 12 | echo ${lm} 13 | export PYTHONPATH="${expdir}/scripts:$PYTHONPATH" 14 | echo $PYTHONPATH 15 | 16 | model=${expdir}/model.pt 17 | echo ${model} 18 | 19 | python rescore.py \ 20 | --data data/SWBD \ 21 | --nbest rescore_swbd/time_sorted_eval2000.nbestlist \ 22 | --model ${model} \ 23 | --lm ${lm} \ 24 | --lmscale 10 \ 25 | --lookback 64 \ 26 | --subbatchsize 100 \ 27 | --cuda \ 28 | --mem_len 64 \ 29 | --logfile LOGs/nbestlog.txt \ 30 | --ppl \ 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RTLM 2 | ## Introduction 3 | 4 | This repository contains the code for the paper "Transformer Language Models with LSTM-based Cross-utterance Information Representation". The code is mainly adapted from the [Transformer XL PyTorch implementation](https://github.com/kimiyoung/transformer-xl.git). Single GPU version is implemented. 5 | 6 | ## Prerequisite 7 | PyTorch 1.0.0 8 | 9 | ## Training 10 | To train on AMI text data
11 | `bash run_AMI.sh train --work_dir PATH_TO_WORK_DIR` 12 | 13 | To train on SWBD text data
14 | `bash run_SWBD.sh train --work_dir PATH_TO_WORK_DIR` 15 | 16 | ## N-best Rescoring 17 | Rescoring AMI nbest list 18 | `bash AMInbest.sh`
19 | Rescoring SWB nbest list 20 | `bash SWBDnbest.sh`
21 | Rescoring RT03 nbest list 22 | `bash RTnbest.sh`
23 | 24 | Note that the path to the trained LM (--model) and to the nbest list (--nbest) should be modified to your own path. 25 | -------------------------------------------------------------------------------- /run_AMI.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=${X_SGE_CUDA_DEVICE} 2 | 3 | #!/bin/bash 4 | echo 'Run training...' 5 | python train_rnn.py \ 6 | --cuda \ 7 | --data data/AMI/ \ 8 | --dataset AMI \ 9 | --work_dir AMI_transformer \ 10 | --n_layer 8 \ 11 | --d_model 512 \ 12 | --div_val 1 \ 13 | --n_head 8 \ 14 | --d_head 64 \ 15 | --d_inner 512 \ 16 | --dropout 0.3 \ 17 | --dropatt 0.3 \ 18 | --optim adam \ 19 | --warmup_step 5000 \ 20 | --max_step 15000 \ 21 | --lr 0.0002 \ 22 | --batch_size 24 \ 23 | --tgt_len 32 \ 24 | --mem_len 32 \ 25 | --ext_len 0 \ 26 | --future_len 0 \ 27 | --eval_tgt_len 32 \ 28 | --eval-interval 1600 \ 29 | --attn_type 0 \ 30 | --scheduler cosine \ 31 | --rnnenc \ 32 | --rnndim 512 \ 33 | --layerlist '0' \ 34 | --pen_layerlist '0' \ 35 | --merge_type direct \ 36 | --p_scale 0.0000 \ 37 | ${@:2} 38 | -------------------------------------------------------------------------------- /run_SWBD.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=${X_SGE_CUDA_DEVICE} 2 | 3 | #!/bin/bash 4 | echo 'Run training...' 5 | python train_rnn.py \ 6 | --cuda \ 7 | --data data/SWBD/ \ 8 | --work_dir SWBD_transformer \ 9 | --dataset AMI \ 10 | --n_layer 24 \ 11 | --d_model 512 \ 12 | --div_val 4 \ 13 | --n_head 8 \ 14 | --d_head 64 \ 15 | --d_inner 1024 \ 16 | --dropout 0.2 \ 17 | --dropatt 0.1 \ 18 | --optim adam \ 19 | --warmup_step 2000 \ 20 | --max_step 200000 \ 21 | --lr 0.00025 \ 22 | --batch_size 32 \ 23 | --tgt_len 64 \ 24 | --mem_len 64 \ 25 | --ext_len 0 \ 26 | --future_len 0 \ 27 | --eval_tgt_len 64 \ 28 | --eval-interval 1600 \ 29 | --attn_type 0 \ 30 | --scheduler cosine \ 31 | --pre_lnorm \ 32 | --rnnenc \ 33 | --rnndim 512 \ 34 | --layerlist '0' \ 35 | --merge_type project \ 36 | --log-interval 200 \ 37 | ${@:2} 38 | -------------------------------------------------------------------------------- /utils/exp_utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os, shutil 3 | 4 | import numpy as np 5 | 6 | import torch 7 | 8 | 9 | def logging(s, log_path, print_=True, log_=True): 10 | if print_: 11 | print(s) 12 | if log_: 13 | with open(log_path, 'a+') as f_log: 14 | f_log.write(s + '\n') 15 | 16 | def get_logger(log_path, **kwargs): 17 | return functools.partial(logging, log_path=log_path, **kwargs) 18 | 19 | def create_exp_dir(dir_path, scripts_to_save=None, debug=False): 20 | if debug: 21 | print('Debug Mode : no experiment dir created') 22 | return functools.partial(logging, log_path=None, log_=False) 23 | 24 | if not os.path.exists(dir_path): 25 | os.makedirs(dir_path) 26 | 27 | print('Experiment dir : {}'.format(dir_path)) 28 | if scripts_to_save is not None: 29 | script_path = os.path.join(dir_path, 'scripts') 30 | if not os.path.exists(script_path): 31 | os.makedirs(script_path) 32 | for script in scripts_to_save: 33 | dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script)) 34 | shutil.copyfile(script, dst_file) 35 | 36 | return get_logger(log_path=os.path.join(dir_path, 'log.txt')) 37 | 38 | def save_checkpoint(model, optimizer, path, epoch): 39 | torch.save(model, os.path.join(path, 'model_{}.pt'.format(epoch))) 40 | torch.save(optimizer.state_dict(), os.path.join(path, 'optimizer_{}.pt'.format(epoch))) 41 | -------------------------------------------------------------------------------- /utils/adaptive_softmax.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | class AdaptiveLogSoftmax(nn.Module): 10 | def __init__(self, in_features, n_classes, cutoffs, keep_order=False): 11 | super(AdaptiveLogSoftmax, self).__init__() 12 | 13 | cutoffs = list(cutoffs) 14 | 15 | if (cutoffs != sorted(cutoffs)) \ 16 | or (min(cutoffs) <= 0) \ 17 | or (max(cutoffs) >= (n_classes - 1)) \ 18 | or (len(set(cutoffs)) != len(cutoffs)) \ 19 | or any([int(c) != c for c in cutoffs]): 20 | 21 | raise ValueError("cutoffs should be a sequence of unique, positive " 22 | "integers sorted in an increasing order, where " 23 | "each value is between 1 and n_classes-1") 24 | 25 | self.in_features = in_features 26 | self.n_classes = n_classes 27 | self.cutoffs = cutoffs + [n_classes] 28 | 29 | self.shortlist_size = self.cutoffs[0] 30 | self.n_clusters = len(self.cutoffs) - 1 31 | self.head_size = self.shortlist_size + self.n_clusters 32 | 33 | self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.in_features)) 34 | self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters)) 35 | 36 | self.keep_order = keep_order 37 | 38 | 39 | def forward(self, hidden, target, weight, bias, keep_order=False): 40 | if hidden.size(0) != target.size(0): 41 | raise RuntimeError('Input and target should have the same size ' 42 | 'in the batch dimension.') 43 | 44 | head_weight = torch.cat( 45 | [weight[:self.shortlist_size], self.cluster_weight], dim=0) 46 | head_bias = torch.cat( 47 | [bias[:self.shortlist_size], self.cluster_bias], dim=0) 48 | 49 | head_logit = F.linear(hidden, head_weight, bias=head_bias) 50 | head_logprob = F.log_softmax(head_logit, dim=1) 51 | 52 | nll = torch.zeros_like(target, 53 | dtype=hidden.dtype, device=hidden.device) 54 | 55 | offset = 0 56 | cutoff_values = [0] + self.cutoffs 57 | for i in range(len(cutoff_values) - 1): 58 | l_idx, h_idx = cutoff_values[i], cutoff_values[i + 1] 59 | 60 | mask_i = (target >= l_idx) & (target < h_idx) 61 | indices_i = mask_i.nonzero().squeeze() 62 | 63 | if indices_i.numel() == 0: 64 | continue 65 | 66 | target_i = target.index_select(0, indices_i) - l_idx 67 | head_logprob_i = head_logprob.index_select(0, indices_i) 68 | 69 | if i == 0: 70 | logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1) 71 | else: 72 | weight_i = weight[l_idx:h_idx] 73 | bias_i = bias[l_idx:h_idx] 74 | 75 | hidden_i = hidden.index_select(0, indices_i) 76 | 77 | tail_logit_i = F.linear(hidden_i, weight_i, bias=bias_i) 78 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) 79 | 80 | logprob_i = head_logprob_i[:, -i] \ 81 | + tail_logprob_i.gather(1, target_i[:,None]).squeeze(1) 82 | 83 | if (hasattr(self, 'keep_order') and self.keep_order) or keep_order: 84 | nll.index_copy_(0, indices_i, -logprob_i) 85 | else: 86 | nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i) 87 | 88 | offset += logprob_i.size(0) 89 | 90 | return nll 91 | -------------------------------------------------------------------------------- /utils/data_parallel.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.nn.parallel import DataParallel 3 | import torch 4 | from torch.nn.parallel._functions import Scatter 5 | from torch.nn.parallel.parallel_apply import parallel_apply 6 | 7 | def scatter(inputs, target_gpus, chunk_sizes, dim=0): 8 | r""" 9 | Slices tensors into approximately equal chunks and 10 | distributes them across given GPUs. Duplicates 11 | references to objects that are not tensors. 12 | """ 13 | def scatter_map(obj): 14 | if isinstance(obj, torch.Tensor): 15 | try: 16 | return Scatter.apply(target_gpus, chunk_sizes, dim, obj) 17 | except: 18 | print('obj', obj.size()) 19 | print('dim', dim) 20 | print('chunk_sizes', chunk_sizes) 21 | quit() 22 | if isinstance(obj, tuple) and len(obj) > 0: 23 | return list(zip(*map(scatter_map, obj))) 24 | if isinstance(obj, list) and len(obj) > 0: 25 | return list(map(list, zip(*map(scatter_map, obj)))) 26 | if isinstance(obj, dict) and len(obj) > 0: 27 | return list(map(type(obj), zip(*map(scatter_map, obj.items())))) 28 | return [obj for targets in target_gpus] 29 | 30 | # After scatter_map is called, a scatter_map cell will exist. This cell 31 | # has a reference to the actual function scatter_map, which has references 32 | # to a closure that has a reference to the scatter_map cell (because the 33 | # fn is recursive). To avoid this reference cycle, we set the function to 34 | # None, clearing the cell 35 | try: 36 | return scatter_map(inputs) 37 | finally: 38 | scatter_map = None 39 | 40 | def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0): 41 | r"""Scatter with support for kwargs dictionary""" 42 | inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else [] 43 | kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else [] 44 | if len(inputs) < len(kwargs): 45 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) 46 | elif len(kwargs) < len(inputs): 47 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) 48 | inputs = tuple(inputs) 49 | kwargs = tuple(kwargs) 50 | return inputs, kwargs 51 | 52 | class BalancedDataParallel(DataParallel): 53 | def __init__(self, gpu0_bsz, *args, **kwargs): 54 | self.gpu0_bsz = gpu0_bsz 55 | super().__init__(*args, **kwargs) 56 | 57 | def forward(self, *inputs, **kwargs): 58 | if not self.device_ids: 59 | return self.module(*inputs, **kwargs) 60 | if self.gpu0_bsz == 0: 61 | device_ids = self.device_ids[1:] 62 | else: 63 | device_ids = self.device_ids 64 | inputs, kwargs = self.scatter(inputs, kwargs, device_ids) 65 | if len(self.device_ids) == 1: 66 | return self.module(*inputs[0], **kwargs[0]) 67 | replicas = self.replicate(self.module, self.device_ids) 68 | if self.gpu0_bsz == 0: 69 | replicas = replicas[1:] 70 | outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs) 71 | return self.gather(outputs, self.output_device) 72 | 73 | def parallel_apply(self, replicas, device_ids, inputs, kwargs): 74 | return parallel_apply(replicas, inputs, kwargs, device_ids) 75 | 76 | def scatter(self, inputs, kwargs, device_ids): 77 | bsz = inputs[0].size(self.dim) 78 | num_dev = len(self.device_ids) 79 | gpu0_bsz = self.gpu0_bsz 80 | bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1) 81 | if gpu0_bsz < bsz_unit: 82 | chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1) 83 | delta = bsz - sum(chunk_sizes) 84 | for i in range(delta): 85 | chunk_sizes[i + 1] += 1 86 | if gpu0_bsz == 0: 87 | chunk_sizes = chunk_sizes[1:] 88 | else: 89 | return super().scatter(inputs, kwargs, device_ids) 90 | return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim) 91 | 92 | -------------------------------------------------------------------------------- /utils/log_uniform_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | 5 | class LogUniformSampler(object): 6 | def __init__(self, range_max, n_sample): 7 | """ 8 | Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py 9 | `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)` 10 | 11 | expected count can be approximated by 1 - (1 - p)^n 12 | and we use a numerically stable version -expm1(num_tries * log1p(-p)) 13 | 14 | Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run 15 | """ 16 | with torch.no_grad(): 17 | self.range_max = range_max 18 | log_indices = torch.arange(1., range_max+2., 1.).log_() 19 | self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] 20 | # print('P', self.dist.numpy().tolist()[-30:]) 21 | 22 | self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float() 23 | 24 | self.n_sample = n_sample 25 | 26 | def sample(self, labels): 27 | """ 28 | labels: [b1, b2] 29 | Return 30 | true_log_probs: [b1, b2] 31 | samp_log_probs: [n_sample] 32 | neg_samples: [n_sample] 33 | """ 34 | 35 | # neg_samples = torch.empty(0).long() 36 | n_sample = self.n_sample 37 | n_tries = 2 * n_sample 38 | 39 | with torch.no_grad(): 40 | neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique() 41 | device = labels.device 42 | neg_samples = neg_samples.to(device) 43 | true_log_probs = self.log_q[labels].to(device) 44 | samp_log_probs = self.log_q[neg_samples].to(device) 45 | return true_log_probs, samp_log_probs, neg_samples 46 | 47 | def sample_logits(embedding, bias, labels, inputs, sampler): 48 | """ 49 | embedding: an nn.Embedding layer 50 | bias: [n_vocab] 51 | labels: [b1, b2] 52 | inputs: [b1, b2, n_emb] 53 | sampler: you may use a LogUniformSampler 54 | Return 55 | logits: [b1, b2, 1 + n_sample] 56 | """ 57 | true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels) 58 | n_sample = neg_samples.size(0) 59 | b1, b2 = labels.size(0), labels.size(1) 60 | all_ids = torch.cat([labels.view(-1), neg_samples]) 61 | all_w = embedding(all_ids) 62 | true_w = all_w[: -n_sample].view(b1, b2, -1) 63 | sample_w = all_w[- n_sample:].view(n_sample, -1) 64 | 65 | all_b = bias[all_ids] 66 | true_b = all_b[: -n_sample].view(b1, b2) 67 | sample_b = all_b[- n_sample:] 68 | 69 | hit = (labels[:, :, None] == neg_samples).detach() 70 | 71 | true_logits = torch.einsum('ijk,ijk->ij', 72 | [true_w, inputs]) + true_b - true_log_probs 73 | sample_logits = torch.einsum('lk,ijk->ijl', 74 | [sample_w, inputs]) + sample_b - samp_log_probs 75 | sample_logits.masked_fill_(hit, -1e30) 76 | logits = torch.cat([true_logits[:, :, None], sample_logits], -1) 77 | 78 | return logits 79 | 80 | 81 | # class LogUniformSampler(object): 82 | # def __init__(self, range_max, unique=False): 83 | # """ 84 | # Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py 85 | # `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)` 86 | # """ 87 | # self.range_max = range_max 88 | # log_indices = torch.arange(1., range_max+2., 1.).log_() 89 | # self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] 90 | 91 | # self.unique = unique 92 | 93 | # if self.unique: 94 | # self.exclude_mask = torch.ByteTensor(range_max).fill_(0) 95 | 96 | # def sample(self, n_sample, labels): 97 | # pos_sample, new_labels = labels.unique(return_inverse=True) 98 | # n_pos_sample = pos_sample.size(0) 99 | # n_neg_sample = n_sample - n_pos_sample 100 | 101 | # if self.unique: 102 | # self.exclude_mask.index_fill_(0, pos_sample, 1) 103 | # sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0) 104 | # self.exclude_mask.index_fill_(0, pos_sample, 0) 105 | # else: 106 | # sample_dist = self.dist 107 | 108 | # neg_sample = torch.multinomial(sample_dist, n_neg_sample) 109 | 110 | # sample = torch.cat([pos_sample, neg_sample]) 111 | # sample_prob = self.dist[sample] 112 | 113 | # return new_labels, sample, sample_prob 114 | 115 | 116 | if __name__ == '__main__': 117 | S, B = 3, 4 118 | n_vocab = 10000 119 | n_sample = 5 120 | H = 32 121 | 122 | labels = torch.LongTensor(S, B).random_(0, n_vocab) 123 | 124 | # sampler = LogUniformSampler(n_vocab, unique=False) 125 | # new_labels, sample, sample_prob = sampler.sample(n_sample, labels) 126 | 127 | sampler = LogUniformSampler(n_vocab, unique=True) 128 | # true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels) 129 | 130 | # print('true_probs', true_probs.numpy().tolist()) 131 | # print('samp_probs', samp_probs.numpy().tolist()) 132 | # print('neg_samples', neg_samples.numpy().tolist()) 133 | 134 | # print('sum', torch.sum(sampler.dist).item()) 135 | 136 | # assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item() 137 | 138 | embedding = nn.Embedding(n_vocab, H) 139 | bias = torch.zeros(n_vocab) 140 | inputs = torch.Tensor(S, B, H).normal_() 141 | 142 | logits, out_labels = sample_logits(embedding, bias, labels, inputs, sampler, n_sample) 143 | print('logits', logits.detach().numpy().tolist()) 144 | print('logits shape', logits.size()) 145 | print('out_labels', out_labels.detach().numpy().tolist()) 146 | print('out_labels shape', out_labels.size()) 147 | 148 | -------------------------------------------------------------------------------- /utils/proj_adaptive_softmax.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | CUDA_MAJOR = int(torch.version.cuda.split('.')[0]) 10 | CUDA_MINOR = int(torch.version.cuda.split('.')[1]) 11 | 12 | class ProjectedAdaptiveLogSoftmax(nn.Module): 13 | def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, 14 | keep_order=False): 15 | super(ProjectedAdaptiveLogSoftmax, self).__init__() 16 | 17 | self.n_token = n_token 18 | self.d_embed = d_embed 19 | self.d_proj = d_proj 20 | 21 | self.cutoffs = cutoffs + [n_token] 22 | self.cutoff_ends = [0] + self.cutoffs 23 | self.div_val = div_val 24 | 25 | self.shortlist_size = self.cutoffs[0] 26 | self.n_clusters = len(self.cutoffs) - 1 27 | self.head_size = self.shortlist_size + self.n_clusters 28 | 29 | if self.n_clusters > 0: 30 | self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed)) 31 | self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters)) 32 | 33 | self.out_layers = nn.ModuleList() 34 | self.out_projs = nn.ParameterList() 35 | 36 | if div_val == 1: 37 | for i in range(len(self.cutoffs)): 38 | if d_proj != d_embed: 39 | self.out_projs.append( 40 | nn.Parameter(torch.Tensor(d_proj, d_embed)) 41 | ) 42 | else: 43 | self.out_projs.append(None) 44 | 45 | self.out_layers.append(nn.Linear(d_embed, n_token)) 46 | else: 47 | for i in range(len(self.cutoffs)): 48 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1] 49 | d_emb_i = d_embed // (div_val ** i) 50 | 51 | self.out_projs.append( 52 | nn.Parameter(torch.Tensor(d_proj, d_emb_i)) 53 | ) 54 | 55 | self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx)) 56 | 57 | self.keep_order = keep_order 58 | 59 | def _compute_logit(self, hidden, weight, bias, proj): 60 | if proj is None: 61 | logit = F.linear(hidden, weight, bias=bias) 62 | else: 63 | # if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1: 64 | proj_hid = F.linear(hidden, proj.t().contiguous()) 65 | logit = F.linear(proj_hid, weight, bias=bias) 66 | # else: 67 | # logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t())) 68 | # if bias is not None: 69 | # logit = logit + bias 70 | 71 | return logit 72 | 73 | def forward(self, hidden, target, keep_order=False): 74 | ''' 75 | hidden :: [len*bsz x d_proj] 76 | target :: [len*bsz] 77 | ''' 78 | 79 | if hidden.size(0) != target.size(0): 80 | raise RuntimeError('Input and target should have the same size ' 81 | 'in the batch dimension.') 82 | 83 | if self.n_clusters == 0: 84 | logit = self._compute_logit(hidden, self.out_layers[0].weight, 85 | self.out_layers[0].bias, self.out_projs[0]) 86 | nll = -F.log_softmax(logit, dim=-1) \ 87 | .gather(1, target.unsqueeze(1)).squeeze(1) 88 | else: 89 | # construct weights and biases 90 | weights, biases = [], [] 91 | for i in range(len(self.cutoffs)): 92 | if self.div_val == 1: 93 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] 94 | weight_i = self.out_layers[0].weight[l_idx:r_idx] 95 | bias_i = self.out_layers[0].bias[l_idx:r_idx] 96 | else: 97 | weight_i = self.out_layers[i].weight 98 | bias_i = self.out_layers[i].bias 99 | 100 | if i == 0: 101 | weight_i = torch.cat( 102 | [weight_i, self.cluster_weight], dim=0) 103 | bias_i = torch.cat( 104 | [bias_i, self.cluster_bias], dim=0) 105 | 106 | weights.append(weight_i) 107 | biases.append(bias_i) 108 | 109 | head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0] 110 | 111 | head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) 112 | head_logprob = F.log_softmax(head_logit, dim=1) 113 | 114 | nll = torch.zeros_like(target, 115 | dtype=hidden.dtype, device=hidden.device) 116 | 117 | offset = 0 118 | cutoff_values = [0] + self.cutoffs 119 | for i in range(len(cutoff_values) - 1): 120 | l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1] 121 | 122 | mask_i = (target >= l_idx) & (target < r_idx) 123 | indices_i = mask_i.nonzero().squeeze() 124 | 125 | if indices_i.numel() == 0: 126 | continue 127 | 128 | target_i = target.index_select(0, indices_i) - l_idx 129 | head_logprob_i = head_logprob.index_select(0, indices_i) 130 | 131 | if i == 0: 132 | logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1) 133 | else: 134 | weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i] 135 | 136 | hidden_i = hidden.index_select(0, indices_i) 137 | 138 | tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i) 139 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) 140 | 141 | logprob_i = head_logprob_i[:, -i] \ 142 | + tail_logprob_i.gather(1, target_i[:,None]).squeeze(1) 143 | 144 | if (hasattr(self, 'keep_order') and self.keep_order) or keep_order: 145 | nll.index_copy_(0, indices_i, -logprob_i) 146 | else: 147 | nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i) 148 | 149 | offset += logprob_i.size(0) 150 | 151 | return nll 152 | -------------------------------------------------------------------------------- /utils/vocabulary.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import Counter, OrderedDict 3 | 4 | import torch 5 | 6 | class Vocab(object): 7 | def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True, 8 | delimiter=None, vocab_file=None): 9 | self.counter = Counter() 10 | self.special = special 11 | self.min_freq = min_freq 12 | self.max_size = max_size 13 | self.lower_case = lower_case 14 | self.delimiter = delimiter 15 | self.vocab_file = vocab_file 16 | 17 | def tokenize(self, line, add_eos=False, add_double_eos=False): 18 | line = line.strip() 19 | # convert to lower case 20 | if self.lower_case: 21 | line = line.lower() 22 | 23 | # empty delimiter '' will evaluate False 24 | if self.delimiter == '': 25 | symbols = line 26 | else: 27 | symbols = line.split(self.delimiter) 28 | 29 | if add_double_eos: # lm1b 30 | return [''] + symbols + [''] 31 | elif add_eos: 32 | return symbols + [''] 33 | else: 34 | return symbols 35 | 36 | def count_file(self, path, verbose=False, add_eos=False): 37 | if verbose: print('counting file {} ...'.format(path)) 38 | assert os.path.exists(path) 39 | 40 | sents = [] 41 | with open(path, 'r', encoding='utf-8') as f: 42 | for idx, line in enumerate(f): 43 | if verbose and idx > 0 and idx % 500000 == 0: 44 | print(' line {}'.format(idx)) 45 | symbols = self.tokenize(line, add_eos=add_eos) 46 | self.counter.update(symbols) 47 | sents.append(symbols) 48 | 49 | return sents 50 | 51 | def count_sents(self, sents, verbose=False): 52 | """ 53 | sents : a list of sentences, each a list of tokenized symbols 54 | """ 55 | if verbose: print('counting {} sents ...'.format(len(sents))) 56 | for idx, symbols in enumerate(sents): 57 | if verbose and idx > 0 and idx % 500000 == 0: 58 | print(' line {}'.format(idx)) 59 | self.counter.update(symbols) 60 | 61 | def _build_from_file(self, vocab_file): 62 | self.idx2sym = [] 63 | self.sym2idx = OrderedDict() 64 | 65 | with open(vocab_file, 'r', encoding='utf-8') as f: 66 | for line in f: 67 | symb = line.strip().split()[1] 68 | self.add_symbol(symb) 69 | self.unk_idx = self.sym2idx[''] 70 | 71 | def build_vocab(self): 72 | if self.vocab_file: 73 | print('building vocab from {}'.format(self.vocab_file)) 74 | self._build_from_file(self.vocab_file) 75 | print('final vocab size {}'.format(len(self))) 76 | else: 77 | print('building vocab with min_freq={}, max_size={}'.format( 78 | self.min_freq, self.max_size)) 79 | self.idx2sym = [] 80 | self.sym2idx = OrderedDict() 81 | 82 | for sym in self.special: 83 | self.add_special(sym) 84 | 85 | for sym, cnt in self.counter.most_common(self.max_size): 86 | if cnt < self.min_freq: break 87 | self.add_symbol(sym) 88 | 89 | print('final vocab size {} from {} unique tokens'.format( 90 | len(self), len(self.counter))) 91 | 92 | def build_distribution(self): 93 | self.freq_map = [0] * len(self.sym2idx) 94 | for word, idx in self.sym2idx.items(): 95 | self.freq_map[idx] += self.counter[word] 96 | self.freq_map = torch.tensor(self.freq_map) 97 | 98 | def encode_file(self, path, ordered=False, verbose=False, add_eos=True, 99 | add_double_eos=False): 100 | if verbose: print('encoding file {} ...'.format(path)) 101 | assert os.path.exists(path) 102 | encoded = [] 103 | with open(path, 'r', encoding='utf-8') as f: 104 | for idx, line in enumerate(f): 105 | if verbose and idx > 0 and idx % 500000 == 0: 106 | print(' line {}'.format(idx)) 107 | symbols = self.tokenize(line, add_eos=add_eos, 108 | add_double_eos=add_double_eos) 109 | encoded.append(self.convert_to_tensor(symbols)) 110 | 111 | if ordered: 112 | encoded = torch.cat(encoded) 113 | 114 | return encoded 115 | 116 | def encode_sents(self, sents, ordered=False, verbose=False): 117 | if verbose: print('encoding {} sents ...'.format(len(sents))) 118 | encoded = [] 119 | for idx, symbols in enumerate(sents): 120 | if verbose and idx > 0 and idx % 500000 == 0: 121 | print(' line {}'.format(idx)) 122 | encoded.append(self.convert_to_tensor(symbols)) 123 | 124 | if ordered: 125 | encoded = torch.cat(encoded) 126 | 127 | return encoded 128 | 129 | def add_special(self, sym): 130 | if sym not in self.sym2idx: 131 | self.idx2sym.append(sym) 132 | self.sym2idx[sym] = len(self.idx2sym) - 1 133 | setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym]) 134 | 135 | def add_symbol(self, sym): 136 | if sym not in self.sym2idx: 137 | self.idx2sym.append(sym) 138 | self.sym2idx[sym] = len(self.idx2sym) - 1 139 | 140 | def get_sym(self, idx): 141 | assert 0 <= idx < len(self), 'Index {} out of range'.format(idx) 142 | return self.idx2sym[idx] 143 | 144 | def get_idx(self, sym): 145 | if sym in self.sym2idx: 146 | return self.sym2idx[sym] 147 | else: 148 | # print('encounter unk {}'.format(sym)) 149 | assert '' not in sym 150 | assert hasattr(self, 'unk_idx') 151 | return self.sym2idx.get(sym, self.unk_idx) 152 | 153 | def get_symbols(self, indices): 154 | return [self.get_sym(idx) for idx in indices] 155 | 156 | def get_indices(self, symbols): 157 | return [self.get_idx(sym) for sym in symbols] 158 | 159 | def convert_to_tensor(self, symbols): 160 | return torch.LongTensor(self.get_indices(symbols)) 161 | 162 | def convert_to_sent(self, indices, exclude=None): 163 | if exclude is None: 164 | return ' '.join([self.get_sym(idx) for idx in indices]) 165 | else: 166 | return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude]) 167 | 168 | def __len__(self): 169 | return len(self.idx2sym) 170 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import glob 3 | 4 | from collections import Counter, OrderedDict 5 | import numpy as np 6 | import torch 7 | 8 | from utils.vocabulary import Vocab 9 | 10 | class LMOrderedIterator(object): 11 | def __init__(self, data, bsz, bptt, device='cpu', ext_len=None, future_len=None): 12 | """ 13 | data -- LongTensor -- the LongTensor is strictly ordered 14 | """ 15 | self.bsz = bsz 16 | self.bptt = bptt 17 | self.ext_len = ext_len if ext_len is not None else 0 18 | self.future_len = future_len if future_len is not None else 0 19 | 20 | self.device = device 21 | 22 | # Work out how cleanly we can divide the dataset into bsz parts. 23 | self.n_step = data.size(0) // bsz 24 | 25 | # Trim off any extra elements that wouldn't cleanly fit (remainders). 26 | data = data.narrow(0, 0, self.n_step * bsz) 27 | 28 | # Evenly divide the data across the bsz batches. 29 | self.data = data.view(bsz, -1).t().contiguous().to(device) 30 | 31 | # Number of mini-batches 32 | self.n_batch = (self.n_step + self.bptt - 1) // self.bptt 33 | 34 | def get_batch(self, i, bptt=None): 35 | if bptt is None: bptt = self.bptt 36 | seq_len = min(bptt, self.data.size(0) - 1 - i) 37 | future_seqlen = min(self.future_len, self.data.size(0) - 1 - i - seq_len) 38 | 39 | end_idx = i + seq_len + future_seqlen 40 | beg_idx = max(0, i - self.ext_len) 41 | 42 | data = self.data[beg_idx:end_idx] 43 | target = self.data[i+1:i+1+seq_len] 44 | 45 | return data, target, seq_len, future_seqlen 46 | 47 | def get_fixlen_iter(self, start=0): 48 | for i in range(start, self.data.size(0) - 1, self.bptt): 49 | yield self.get_batch(i) 50 | 51 | def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3): 52 | max_len = self.bptt + max_deviation * std 53 | i = start 54 | while True: 55 | bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2. 56 | bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std)))) 57 | data, target, seq_len = self.get_batch(i, bptt) 58 | i += seq_len 59 | yield data, target, seq_len 60 | if i >= self.data.size(0) - 2: 61 | break 62 | 63 | def __iter__(self): 64 | return self.get_fixlen_iter() 65 | 66 | 67 | class LMShuffledIterator(object): 68 | def __init__(self, data, bsz, bptt, device='cpu', ext_len=None, shuffle=False): 69 | """ 70 | data -- list[LongTensor] -- there is no order among the LongTensors 71 | """ 72 | self.data = data 73 | 74 | self.bsz = bsz 75 | self.bptt = bptt 76 | self.ext_len = ext_len if ext_len is not None else 0 77 | 78 | self.device = device 79 | self.shuffle = shuffle 80 | 81 | def get_sent_stream(self): 82 | # index iterator 83 | epoch_indices = np.random.permutation(len(self.data)) if self.shuffle \ 84 | else np.array(range(len(self.data))) 85 | 86 | # sentence iterator 87 | for idx in epoch_indices: 88 | yield self.data[idx] 89 | 90 | def stream_iterator(self, sent_stream): 91 | # streams for each data in the batch 92 | streams = [None] * self.bsz 93 | 94 | data = torch.LongTensor(self.bptt, self.bsz) 95 | target = torch.LongTensor(self.bptt, self.bsz) 96 | 97 | n_retain = 0 98 | 99 | while True: 100 | # data : [n_retain+bptt x bsz] 101 | # target : [bptt x bsz] 102 | data[n_retain:].fill_(-1) 103 | target.fill_(-1) 104 | 105 | valid_batch = True 106 | 107 | for i in range(self.bsz): 108 | n_filled = 0 109 | try: 110 | while n_filled < self.bptt: 111 | if streams[i] is None or len(streams[i]) <= 1: 112 | streams[i] = next(sent_stream) 113 | # number of new tokens to fill in 114 | n_new = min(len(streams[i]) - 1, self.bptt - n_filled) 115 | # first n_retain tokens are retained from last batch 116 | data[n_retain+n_filled:n_retain+n_filled+n_new, i] = \ 117 | streams[i][:n_new] 118 | target[n_filled:n_filled+n_new, i] = \ 119 | streams[i][1:n_new+1] 120 | streams[i] = streams[i][n_new:] 121 | n_filled += n_new 122 | except StopIteration: 123 | valid_batch = False 124 | break 125 | 126 | if not valid_batch: 127 | return 128 | 129 | data = data.to(self.device) 130 | target = target.to(self.device) 131 | 132 | yield data, target, self.bptt 133 | 134 | n_retain = min(data.size(0), self.ext_len) 135 | if n_retain > 0: 136 | data[:n_retain] = data[-n_retain:] 137 | data.resize_(n_retain + self.bptt, data.size(1)) 138 | 139 | def __iter__(self): 140 | # sent_stream is an iterator 141 | sent_stream = self.get_sent_stream() 142 | 143 | for batch in self.stream_iterator(sent_stream): 144 | yield batch 145 | 146 | 147 | class LMMultiFileIterator(LMShuffledIterator): 148 | def __init__(self, paths, vocab, bsz, bptt, device='cpu', ext_len=None, 149 | shuffle=False): 150 | 151 | self.paths = paths 152 | self.vocab = vocab 153 | 154 | self.bsz = bsz 155 | self.bptt = bptt 156 | self.ext_len = ext_len if ext_len is not None else 0 157 | 158 | self.device = device 159 | self.shuffle = shuffle 160 | 161 | def get_sent_stream(self, path): 162 | sents = self.vocab.encode_file(path, add_double_eos=True) 163 | if self.shuffle: 164 | np.random.shuffle(sents) 165 | sent_stream = iter(sents) 166 | 167 | return sent_stream 168 | 169 | def __iter__(self): 170 | if self.shuffle: 171 | np.random.shuffle(self.paths) 172 | 173 | for path in self.paths: 174 | # sent_stream is an iterator 175 | sent_stream = self.get_sent_stream(path) 176 | for batch in self.stream_iterator(sent_stream): 177 | yield batch 178 | 179 | 180 | class Corpus(object): 181 | def __init__(self, path, dataset, *args, **kwargs): 182 | self.dataset = dataset 183 | self.vocab = Vocab(*args, **kwargs) 184 | 185 | if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8', 'AMI']: 186 | self.vocab.count_file(os.path.join(path, 'train.txt')) 187 | self.vocab.count_file(os.path.join(path, 'valid.txt')) 188 | self.vocab.count_file(os.path.join(path, 'test.txt')) 189 | elif self.dataset == 'wt103': 190 | self.vocab.count_file(os.path.join(path, 'train.txt')) 191 | elif self.dataset == 'lm1b': 192 | train_path_pattern = os.path.join( 193 | path, '1-billion-word-language-modeling-benchmark-r13output', 194 | 'training-monolingual.tokenized.shuffled', 'news.en-*') 195 | train_paths = glob.glob(train_path_pattern) 196 | # the vocab will load from file when build_vocab() is called 197 | 198 | self.vocab.build_vocab() 199 | # build unigram distribution of the vocab 200 | self.vocab.build_distribution() 201 | 202 | if self.dataset in ['ptb', 'wt2', 'wt103', 'AMI']: 203 | self.train = self.vocab.encode_file( 204 | os.path.join(path, 'train.txt'), ordered=True) 205 | self.valid = self.vocab.encode_file( 206 | os.path.join(path, 'valid.txt'), ordered=True) 207 | self.test = self.vocab.encode_file( 208 | os.path.join(path, 'test.txt'), ordered=True) 209 | elif self.dataset in ['enwik8', 'text8']: 210 | self.train = self.vocab.encode_file( 211 | os.path.join(path, 'train.txt'), ordered=True, add_eos=False) 212 | self.valid = self.vocab.encode_file( 213 | os.path.join(path, 'valid.txt'), ordered=True, add_eos=False) 214 | self.test = self.vocab.encode_file( 215 | os.path.join(path, 'test.txt'), ordered=True, add_eos=False) 216 | elif self.dataset == 'lm1b': 217 | self.train = train_paths 218 | self.valid = self.vocab.encode_file( 219 | os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True) 220 | self.test = self.vocab.encode_file( 221 | os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True) 222 | 223 | def get_iterator(self, split, *args, **kwargs): 224 | if split == 'train': 225 | if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8', 'AMI']: 226 | data_iter = LMOrderedIterator(self.train, *args, **kwargs) 227 | elif self.dataset == 'lm1b': 228 | kwargs['shuffle'] = True 229 | data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs) 230 | elif split in ['valid', 'test']: 231 | data = self.valid if split == 'valid' else self.test 232 | if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8', 'AMI']: 233 | data_iter = LMOrderedIterator(data, *args, **kwargs) 234 | elif self.dataset == 'lm1b': 235 | data_iter = LMShuffledIterator(data, *args, **kwargs) 236 | 237 | return data_iter 238 | 239 | 240 | def get_lm_corpus(datadir, dataset): 241 | fn = os.path.join(datadir, 'cache.pt') 242 | if os.path.exists(fn): 243 | print('Loading cached dataset...') 244 | corpus = torch.load(fn) 245 | else: 246 | print('Producing dataset {}...'.format(dataset)) 247 | kwargs = {} 248 | if dataset in ['wt103', 'wt2', 'AMI']: 249 | kwargs['special'] = [''] 250 | kwargs['lower_case'] = False 251 | kwargs['vocab_file'] = os.path.join(datadir, 'dictionary.txt') 252 | elif dataset == 'ptb': 253 | kwargs['special'] = [''] 254 | kwargs['lower_case'] = True 255 | elif dataset == 'lm1b': 256 | kwargs['special'] = [] 257 | kwargs['lower_case'] = False 258 | kwargs['vocab_file'] = os.path.join(datadir, '1b_word_vocab.txt') 259 | elif dataset in ['enwik8', 'text8']: 260 | pass 261 | 262 | corpus = Corpus(datadir, dataset, **kwargs) 263 | # torch.save(corpus, fn) 264 | 265 | return corpus 266 | 267 | if __name__ == '__main__': 268 | import argparse 269 | parser = argparse.ArgumentParser(description='unit test') 270 | parser.add_argument('--datadir', type=str, default='../data/text8', 271 | help='location of the data corpus') 272 | parser.add_argument('--dataset', type=str, default='text8', 273 | choices=['ptb', 'wt2', 'wt103', 'lm1b', 'enwik8', 'text8'], 274 | help='dataset name') 275 | args = parser.parse_args() 276 | 277 | corpus = get_lm_corpus(args.datadir, args.dataset) 278 | print('Vocab size : {}'.format(len(corpus.vocab.idx2sym))) 279 | -------------------------------------------------------------------------------- /train_rnn.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import argparse 3 | import time 4 | import math 5 | import os, sys 6 | import itertools 7 | 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | 14 | from data_utils import get_lm_corpus 15 | from mem_transformer_rnn import MemTransformerLM 16 | from utils.exp_utils import create_exp_dir 17 | from utils.data_parallel import BalancedDataParallel 18 | 19 | parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model') 20 | parser.add_argument('--data', type=str, default='../data/wikitext-103', 21 | help='location of the data corpus') 22 | parser.add_argument('--dataset', type=str, default='wt103', 23 | choices=['wt103', 'lm1b', 'enwik8', 'text8', 'AMI'], 24 | help='dataset name') 25 | parser.add_argument('--n_layer', type=int, default=12, 26 | help='number of total layers') 27 | parser.add_argument('--n_head', type=int, default=10, 28 | help='number of heads') 29 | parser.add_argument('--d_head', type=int, default=50, 30 | help='head dimension') 31 | parser.add_argument('--d_embed', type=int, default=-1, 32 | help='embedding dimension') 33 | parser.add_argument('--d_model', type=int, default=500, 34 | help='model dimension') 35 | parser.add_argument('--d_inner', type=int, default=1000, 36 | help='inner dimension in FF') 37 | parser.add_argument('--dropout', type=float, default=0.0, 38 | help='global dropout rate') 39 | parser.add_argument('--dropatt', type=float, default=0.0, 40 | help='attention probability dropout rate') 41 | parser.add_argument('--init', default='normal', type=str, 42 | help='parameter initializer to use.') 43 | parser.add_argument('--emb_init', default='normal', type=str, 44 | help='parameter initializer to use.') 45 | parser.add_argument('--init_range', type=float, default=0.1, 46 | help='parameters initialized by U(-init_range, init_range)') 47 | parser.add_argument('--emb_init_range', type=float, default=0.01, 48 | help='parameters initialized by U(-init_range, init_range)') 49 | parser.add_argument('--init_std', type=float, default=0.02, 50 | help='parameters initialized by N(0, init_std)') 51 | parser.add_argument('--proj_init_std', type=float, default=0.01, 52 | help='parameters initialized by N(0, init_std)') 53 | parser.add_argument('--optim', default='adam', type=str, 54 | choices=['adam', 'sgd', 'adagrad'], 55 | help='optimizer to use.') 56 | parser.add_argument('--lr', type=float, default=0.00025, 57 | help='initial learning rate (0.00025|5 for adam|sgd)') 58 | parser.add_argument('--mom', type=float, default=0.0, 59 | help='momentum for sgd') 60 | parser.add_argument('--scheduler', default='cosine', type=str, 61 | choices=['cosine', 'inv_sqrt', 'dev_perf', 'constant'], 62 | help='lr scheduler to use.') 63 | parser.add_argument('--warmup_step', type=int, default=0, 64 | help='upper epoch limit') 65 | parser.add_argument('--decay_rate', type=float, default=0.5, 66 | help='decay factor when ReduceLROnPlateau is used') 67 | parser.add_argument('--lr_min', type=float, default=0.0, 68 | help='minimum learning rate during annealing') 69 | parser.add_argument('--clip', type=float, default=0.25, 70 | help='gradient clipping') 71 | parser.add_argument('--clip_nonemb', action='store_true', 72 | help='only clip the gradient of non-embedding params') 73 | parser.add_argument('--max_step', type=int, default=100000, 74 | help='upper epoch limit') 75 | parser.add_argument('--batch_size', type=int, default=60, 76 | help='batch size') 77 | parser.add_argument('--batch_chunk', type=int, default=1, 78 | help='split batch into chunks to save memory') 79 | parser.add_argument('--tgt_len', type=int, default=70, 80 | help='number of tokens to predict') 81 | parser.add_argument('--eval_tgt_len', type=int, default=50, 82 | help='number of tokens to predict for evaluation') 83 | parser.add_argument('--ext_len', type=int, default=0, 84 | help='length of the extended context') 85 | parser.add_argument('--future_len', type=int, default=0, 86 | help='length of the future context') 87 | parser.add_argument('--mem_len', type=int, default=0, 88 | help='length of the retained previous heads') 89 | parser.add_argument('--not_tied', action='store_true', 90 | help='do not tie the word embedding and softmax weights') 91 | parser.add_argument('--seed', type=int, default=1111, 92 | help='random seed') 93 | parser.add_argument('--cuda', action='store_true', 94 | help='use CUDA') 95 | parser.add_argument('--adaptive', action='store_true', 96 | help='use adaptive softmax') 97 | parser.add_argument('--div_val', type=int, default=1, 98 | help='divident value for adapative input and softmax') 99 | parser.add_argument('--pre_lnorm', action='store_true', 100 | help='apply LayerNorm to the input instead of the output') 101 | parser.add_argument('--varlen', action='store_true', 102 | help='use variable length') 103 | parser.add_argument('--multi_gpu', action='store_true', 104 | help='use multiple GPU') 105 | parser.add_argument('--log-interval', type=int, default=200, 106 | help='report interval') 107 | parser.add_argument('--eval-interval', type=int, default=4000, 108 | help='evaluation interval') 109 | parser.add_argument('--work_dir', default='LM-TFM', type=str, 110 | help='experiment directory.') 111 | parser.add_argument('--restart', action='store_true', 112 | help='restart training from the saved checkpoint') 113 | parser.add_argument('--restart_dir', type=str, default='', 114 | help='restart dir') 115 | parser.add_argument('--debug', action='store_true', 116 | help='run in debug mode (do not create exp dir)') 117 | parser.add_argument('--same_length', action='store_true', 118 | help='use the same attn length for all tokens') 119 | parser.add_argument('--attn_type', type=int, default=0, 120 | help='attention type. 0 for ours, 1 for Shaw et al,' 121 | '2 for Vaswani et al, 3 for Al Rfou et al.') 122 | parser.add_argument('--clamp_len', type=int, default=-1, 123 | help='use the same pos embeddings after clamp_len') 124 | parser.add_argument('--eta_min', type=float, default=0.0, 125 | help='min learning rate for cosine scheduler') 126 | parser.add_argument('--gpu0_bsz', type=int, default=-1, 127 | help='batch size on gpu 0') 128 | parser.add_argument('--max_eval_steps', type=int, default=-1, 129 | help='max eval steps') 130 | parser.add_argument('--sample_softmax', type=int, default=-1, 131 | help='number of samples in sampled softmax') 132 | parser.add_argument('--patience', type=int, default=0, 133 | help='patience') 134 | parser.add_argument('--finetune_v2', action='store_true', 135 | help='finetune v2') 136 | parser.add_argument('--finetune_v3', action='store_true', 137 | help='finetune v3') 138 | parser.add_argument('--fp16', action='store_true', 139 | help='Run in pseudo-fp16 mode (fp16 storage fp32 math).') 140 | parser.add_argument('--static-loss-scale', type=float, default=1, 141 | help='Static loss scale, positive power of 2 values can ' 142 | 'improve fp16 convergence.') 143 | parser.add_argument('--dynamic-loss-scale', action='store_true', 144 | help='Use dynamic loss scaling. If supplied, this argument' 145 | ' supersedes --static-loss-scale.') 146 | parser.add_argument('--rnnenc', action='store_true', 147 | help='use rnn encoder') 148 | parser.add_argument('--rnnmodel', type=str, default="", 149 | help='load pretrained rnn model') 150 | parser.add_argument('--rnndim', type=int, default=500, 151 | help='dimension of rnn hidden state') 152 | parser.add_argument('--layerlist', type=str, default='0', 153 | help='layers to insert rnn') 154 | parser.add_argument('--evalmode', action='store_true', 155 | help='Do evaluate forward only') 156 | parser.add_argument('--pen_layerlist', type=str, default='0', 157 | help='layers to apply attention penalty') 158 | parser.add_argument('--p_scale', type=float, default=0.0, 159 | help='penalty scaling factor') 160 | parser.add_argument('--merge_type', type=str, default="direct", 161 | help='Type of merging rnn hidden states') 162 | 163 | args = parser.parse_args() 164 | args.tied = not args.not_tied 165 | 166 | if args.d_embed < 0: 167 | args.d_embed = args.d_model 168 | 169 | assert args.ext_len >= 0, 'extended context length must be non-negative' 170 | assert args.batch_size % args.batch_chunk == 0 171 | 172 | if not args.evalmode: 173 | args.work_dir = os.path.join(args.work_dir, time.strftime('%Y%m%d-%H%M%S')) 174 | logging = create_exp_dir(args.work_dir, 175 | scripts_to_save=['train_rnn.py', 'mem_transformer_rnn.py'], debug=args.debug) 176 | 177 | # Set the random seed manually for reproducibility. 178 | np.random.seed(args.seed) 179 | torch.manual_seed(args.seed) 180 | if torch.cuda.is_available(): 181 | if not args.cuda: 182 | print('WARNING: You have a CUDA device, so you should probably run with --cuda') 183 | else: 184 | torch.cuda.manual_seed_all(args.seed) 185 | 186 | # Validate `--fp16` option 187 | if args.fp16: 188 | if not args.cuda: 189 | print('WARNING: --fp16 requires --cuda, ignoring --fp16 option') 190 | args.fp16 = False 191 | else: 192 | try: 193 | from apex.fp16_utils import FP16_Optimizer 194 | except: 195 | print('WARNING: apex not installed, ignoring --fp16 option') 196 | args.fp16 = False 197 | 198 | device = torch.device('cuda' if args.cuda else 'cpu') 199 | 200 | ############################################################################### 201 | # Load data 202 | ############################################################################### 203 | corpus = get_lm_corpus(args.data, args.dataset) 204 | ntokens = len(corpus.vocab) 205 | args.n_token = ntokens 206 | 207 | eval_batch_size = 10 208 | tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len, 209 | device=device, ext_len=args.ext_len, future_len=args.future_len) 210 | va_iter = corpus.get_iterator('valid', eval_batch_size, args.eval_tgt_len, 211 | device=device, ext_len=args.ext_len, future_len=args.future_len) 212 | te_iter = corpus.get_iterator('test', eval_batch_size, args.eval_tgt_len, 213 | device=device, ext_len=args.ext_len, future_len=args.future_len) 214 | 215 | # adaptive softmax / embedding 216 | cutoffs, tie_projs = [], [False] 217 | if args.adaptive: 218 | assert args.dataset in ['wt103', 'lm1b'] 219 | if args.dataset == 'wt103': 220 | cutoffs = [20000, 40000, 200000] 221 | tie_projs += [True] * len(cutoffs) 222 | elif args.dataset == 'lm1b': 223 | cutoffs = [60000, 100000, 640000] 224 | tie_projs += [False] * len(cutoffs) 225 | 226 | ############################################################################### 227 | # Build the model 228 | ############################################################################### 229 | def init_weight(weight): 230 | if args.init == 'uniform': 231 | nn.init.uniform_(weight, -args.init_range, args.init_range) 232 | elif args.init == 'normal': 233 | nn.init.normal_(weight, 0.0, args.init_std) 234 | 235 | def init_bias(bias): 236 | nn.init.constant_(bias, 0.0) 237 | 238 | def weights_init(m): 239 | classname = m.__class__.__name__ 240 | if classname.find('Linear') != -1: 241 | if hasattr(m, 'weight') and m.weight is not None: 242 | init_weight(m.weight) 243 | if hasattr(m, 'bias') and m.bias is not None: 244 | init_bias(m.bias) 245 | elif classname.find('AdaptiveEmbedding') != -1: 246 | if hasattr(m, 'emb_projs'): 247 | for i in range(len(m.emb_projs)): 248 | if m.emb_projs[i] is not None: 249 | nn.init.normal_(m.emb_projs[i], 0.0, args.proj_init_std) 250 | elif classname.find('Embedding') != -1: 251 | if hasattr(m, 'weight'): 252 | init_weight(m.weight) 253 | elif classname.find('ProjectedAdaptiveLogSoftmax') != -1: 254 | if hasattr(m, 'cluster_weight') and m.cluster_weight is not None: 255 | init_weight(m.cluster_weight) 256 | if hasattr(m, 'cluster_bias') and m.cluster_bias is not None: 257 | init_bias(m.cluster_bias) 258 | if hasattr(m, 'out_projs'): 259 | for i in range(len(m.out_projs)): 260 | if m.out_projs[i] is not None: 261 | nn.init.normal_(m.out_projs[i], 0.0, args.proj_init_std) 262 | elif classname.find('LayerNorm') != -1: 263 | if hasattr(m, 'weight'): 264 | nn.init.normal_(m.weight, 1.0, args.init_std) 265 | if hasattr(m, 'bias') and m.bias is not None: 266 | init_bias(m.bias) 267 | elif classname.find('TransformerLM') != -1: 268 | if hasattr(m, 'r_emb'): 269 | init_weight(m.r_emb) 270 | if hasattr(m, 'r_w_bias'): 271 | init_weight(m.r_w_bias) 272 | if hasattr(m, 'r_r_bias'): 273 | init_weight(m.r_r_bias) 274 | if hasattr(m, 'r_bias'): 275 | init_bias(m.r_bias) 276 | 277 | def repackage_hidden(h): 278 | """Wraps hidden states in new Tensors, to detach them from their history.""" 279 | if isinstance(h, torch.Tensor): 280 | return h.detach() 281 | else: 282 | return tuple(repackage_hidden(v) for v in h) 283 | 284 | def update_dropout(m): 285 | classname = m.__class__.__name__ 286 | if classname.find('Dropout') != -1: 287 | if hasattr(m, 'p'): 288 | m.p = args.dropout 289 | 290 | def update_dropatt(m): 291 | if hasattr(m, 'dropatt'): 292 | m.dropatt.p = args.dropatt 293 | 294 | if args.restart: 295 | with open(os.path.join(args.restart_dir, 'model.pt'), 'rb') as f: 296 | model = torch.load(f) 297 | if not args.fp16: 298 | model = model.float() 299 | model.apply(update_dropout) 300 | model.apply(update_dropatt) 301 | else: 302 | model = MemTransformerLM(ntokens, args.n_layer, args.n_head, args.d_model, 303 | args.d_head, args.d_inner, args.dropout, args.dropatt, 304 | tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val, 305 | tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len, 306 | ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs, 307 | same_length=args.same_length, attn_type=args.attn_type, 308 | clamp_len=args.clamp_len, sample_softmax=args.sample_softmax, 309 | rnnenc=args.rnnenc, rnndim=args.rnndim, 310 | layer_list=args.layerlist, future_len=args.future_len, 311 | attn_layerlist=args.pen_layerlist, merge_type=args.merge_type) 312 | model.apply(weights_init) 313 | model.word_emb.apply(weights_init) # ensure embedding init is not overridden by out_layer in case of weight sharing 314 | args.n_all_param = sum([p.nelement() for p in model.parameters()]) 315 | args.n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()]) 316 | 317 | # use rnn for connection 318 | if args.rnnenc and args.rnnmodel != "": 319 | rnnmodel_dict = torch.load(args.rnnmodel).rnn.state_dict() 320 | for i, layer in enumerate(model.rnn_list): 321 | model_dict = layer.state_dict() 322 | update_dict = {k: v for k, v in rnnmodel_dict.items() if k in model_dict} 323 | model_dict.update(update_dict) 324 | model.rnn_list[i].load_state_dict(model_dict) 325 | 326 | if args.fp16: 327 | model = model.half() 328 | 329 | if args.multi_gpu: 330 | model = model.to(device) 331 | if args.gpu0_bsz >= 0: 332 | para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk, 333 | model, dim=1).to(device) 334 | else: 335 | para_model = nn.DataParallel(model, dim=1).to(device) 336 | else: 337 | para_model = model.to(device) 338 | 339 | #### optimizer 340 | if args.optim.lower() == 'sgd': 341 | if args.sample_softmax > 0: 342 | dense_params, sparse_params = [], [] 343 | for param in model.parameters(): 344 | if param.size() == model.word_emb.weight.size(): 345 | sparse_params.append(param) 346 | else: 347 | dense_params.append(param) 348 | optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2) 349 | optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom) 350 | else: 351 | optimizer = optim.SGD(model.parameters(), lr=args.lr, 352 | momentum=args.mom) 353 | elif args.optim.lower() == 'adam': 354 | if args.sample_softmax > 0: 355 | dense_params, sparse_params = [], [] 356 | for param in model.parameters(): 357 | if param.size() == model.word_emb.weight.size(): 358 | sparse_params.append(param) 359 | else: 360 | dense_params.append(param) 361 | optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr) 362 | optimizer = optim.Adam(dense_params, lr=args.lr) 363 | else: 364 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 365 | elif args.optim.lower() == 'adagrad': 366 | optimizer = optim.Adagrad(model.parameters(), lr=args.lr) 367 | 368 | #### scheduler 369 | if args.scheduler == 'cosine': 370 | # here we do not set eta_min to lr_min to be backward compatible 371 | # because in previous versions eta_min is default to 0 372 | # rather than the default value of lr_min 1e-6 373 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 374 | args.max_step, eta_min=args.eta_min) # should use eta_min arg 375 | if args.sample_softmax > 0: 376 | scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(optimizer_sparse, 377 | args.max_step, eta_min=args.eta_min) # should use eta_min arg 378 | elif args.scheduler == 'inv_sqrt': 379 | # originally used for Transformer (in Attention is all you need) 380 | def lr_lambda(step): 381 | # return a multiplier instead of a learning rate 382 | if step == 0 and args.warmup_step == 0: 383 | return 1. 384 | else: 385 | return 1. / (step ** 0.5) if step > args.warmup_step \ 386 | else step / (args.warmup_step ** 1.5) 387 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) 388 | elif args.scheduler == 'dev_perf': 389 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 390 | factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min) 391 | if args.sample_softmax > 0: 392 | scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(optimizer_sparse, 393 | factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min) 394 | elif args.scheduler == 'constant': 395 | pass 396 | 397 | if args.cuda and args.fp16: 398 | # If args.dynamic_loss_scale is False, static_loss_scale will be used. 399 | # If args.dynamic_loss_scale is True, it will take precedence over static_loss_scale. 400 | optimizer = FP16_Optimizer(optimizer, 401 | static_loss_scale = args.static_loss_scale, 402 | dynamic_loss_scale = args.dynamic_loss_scale, 403 | dynamic_loss_args = {'init_scale': 2 ** 16}) 404 | 405 | if args.restart: 406 | if os.path.exists(os.path.join(args.restart_dir, 'optimizer.pt')): 407 | with open(os.path.join(args.restart_dir, 'optimizer.pt'), 'rb') as f: 408 | opt_state_dict = torch.load(f) 409 | optimizer.load_state_dict(opt_state_dict) 410 | else: 411 | print('Optimizer was not saved. Start from scratch.') 412 | 413 | logging('=' * 100) 414 | for k, v in args.__dict__.items(): 415 | logging(' - {} : {}'.format(k, v)) 416 | logging('=' * 100) 417 | logging('#params = {}'.format(args.n_all_param)) 418 | logging('#non emb params = {}'.format(args.n_nonemb_param)) 419 | 420 | ############################################################################### 421 | # Training code 422 | ############################################################################### 423 | 424 | def evaluate(eval_iter): 425 | # Turn on evaluation mode which disables dropout. 426 | model.eval() 427 | 428 | # If the model does not use memory at all, make the ext_len longer. 429 | # Otherwise, make the mem_len longer and keep the ext_len the same. 430 | if args.mem_len == 0: 431 | model.reset_length(args.eval_tgt_len, 432 | args.ext_len+args.tgt_len-args.eval_tgt_len, args.mem_len) 433 | else: 434 | model.reset_length(args.eval_tgt_len, 435 | args.ext_len, args.mem_len+args.tgt_len-args.eval_tgt_len) 436 | 437 | # Evaluation 438 | total_len, total_loss = 0, 0. 439 | # add RNN connection 440 | rnn_hidden = None 441 | if args.rnnenc: 442 | rnn_hidden = para_model.init_hidden(eval_batch_size) 443 | 444 | with torch.no_grad(): 445 | mems = tuple() 446 | for i, (data, target, seq_len, future_seqlen) in enumerate(eval_iter): 447 | if args.max_eval_steps > 0 and i >= args.max_eval_steps: 448 | break 449 | ret = model(data, target, *mems, rnn_hidden=rnn_hidden, future_seqlen=future_seqlen) 450 | loss, mems, penalty, rnn_hidden = ret[0], ret[1:-2], ret[-2], ret[-1] 451 | loss = loss.mean() 452 | total_loss += seq_len * loss.float().item() 453 | total_len += seq_len 454 | 455 | # Switch back to the training mode 456 | model.reset_length(args.tgt_len, args.ext_len, args.mem_len) 457 | model.train() 458 | 459 | return total_loss / total_len 460 | 461 | 462 | def train(): 463 | # Turn on training mode which enables dropout. 464 | global train_step, train_loss, best_val_loss, eval_start_time, log_start_time, train_p 465 | model.train() 466 | if args.batch_chunk > 1: 467 | mems = [tuple() for _ in range(args.batch_chunk)] 468 | else: 469 | mems = tuple() 470 | train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter 471 | # add RNN connection 472 | rnn_hidden = None 473 | if args.rnnenc: 474 | rnn_hidden = model.init_hidden(args.batch_size) 475 | prev_data = None 476 | for batch, (data, target, seq_len, future_seqlen) in enumerate(train_iter): 477 | model.zero_grad() 478 | # use RNN state carry on 479 | if rnn_hidden is not None: 480 | rnn_hidden = [repackage_hidden(hid) for hid in rnn_hidden] 481 | if args.batch_chunk > 1: 482 | data_chunks = torch.chunk(data, args.batch_chunk, 1) 483 | target_chunks = torch.chunk(target, args.batch_chunk, 1) 484 | for i in range(args.batch_chunk): 485 | data_i = data_chunks[i].contiguous() 486 | target_i = target_chunks[i].contiguous() 487 | ret = para_model(data_i, target_i, *mems[i], hidden) 488 | loss, mems[i], penalty, rnn_hidden = ret[0], ret[1:-2], ret[-2], ret[-1] 489 | loss = loss.float().mean().type_as(loss) / args.batch_chunk 490 | if args.fp16: 491 | optimizer.backward(loss) 492 | else: 493 | loss.backward() 494 | train_loss += loss.float().item() 495 | else: 496 | ret = para_model(data, target, *mems, rnn_hidden=rnn_hidden, future_seqlen=future_seqlen) 497 | prev_data = data 498 | loss, mems, penalty, rnn_hidden = ret[0], ret[1:-2], ret[-2], ret[-1] 499 | loss = loss.float().mean().type_as(loss) 500 | # Add penalty 501 | p_loss = loss + penalty * args.p_scale 502 | if args.fp16: 503 | optimizer.backward(p_loss) 504 | else: 505 | p_loss.backward() 506 | train_loss += loss.float().item() 507 | train_p += penalty.float().item() 508 | 509 | if args.fp16: 510 | optimizer.clip_master_grads(args.clip) 511 | else: 512 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 513 | 514 | optimizer.step() 515 | if args.sample_softmax > 0: 516 | optimizer_sparse.step() 517 | 518 | # step-wise learning rate annealing 519 | train_step += 1 520 | if args.scheduler in ['cosine', 'constant', 'dev_perf']: 521 | # linear warmup stage 522 | if train_step < args.warmup_step: 523 | curr_lr = args.lr * train_step / args.warmup_step 524 | optimizer.param_groups[0]['lr'] = curr_lr 525 | if args.sample_softmax > 0: 526 | optimizer_sparse.param_groups[0]['lr'] = curr_lr * 2 527 | else: 528 | if args.scheduler == 'cosine': 529 | scheduler.step(train_step) 530 | if args.sample_softmax > 0: 531 | scheduler_sparse.step(train_step) 532 | elif args.scheduler == 'inv_sqrt': 533 | scheduler.step(train_step) 534 | 535 | if train_step % args.log_interval == 0: 536 | cur_loss = train_loss / args.log_interval 537 | cur_p = train_p / args.log_interval 538 | elapsed = time.time() - log_start_time 539 | log_str = '| epoch {:3d} step {:>8d} | {:>6d} batches | lr {:.3g} ' \ 540 | '| ms/batch {:5.2f} | loss {:5.2f}'.format( 541 | epoch, train_step, batch+1, optimizer.param_groups[0]['lr'], 542 | elapsed * 1000 / args.log_interval, cur_loss) 543 | if args.dataset in ['enwik8', 'text8']: 544 | log_str += ' | bpc {:9.5f}'.format(cur_loss / math.log(2)) 545 | else: 546 | log_str += ' | ppl {:9.3f}'.format(math.exp(cur_loss)) 547 | logging(log_str) 548 | train_loss = 0 549 | train_p = 0 550 | log_start_time = time.time() 551 | 552 | if train_step % args.eval_interval == 0: 553 | val_loss = evaluate(va_iter) 554 | logging('-' * 100) 555 | log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \ 556 | '| valid loss {:5.2f}'.format( 557 | train_step // args.eval_interval, train_step, 558 | (time.time() - eval_start_time), val_loss) 559 | if args.dataset in ['enwik8', 'text8']: 560 | log_str += ' | bpc {:9.5f}'.format(val_loss / math.log(2)) 561 | else: 562 | log_str += ' | valid ppl {:9.3f}'.format(math.exp(val_loss)) 563 | logging(log_str) 564 | logging('-' * 100) 565 | # Save the model if the validation loss is the best we've seen so far. 566 | if not best_val_loss or val_loss < best_val_loss: 567 | if not args.debug: 568 | with open(os.path.join(args.work_dir, 'model.pt'), 'wb') as f: 569 | torch.save(model, f) 570 | with open(os.path.join(args.work_dir, 'optimizer.pt'), 'wb') as f: 571 | torch.save(optimizer.state_dict(), f) 572 | best_val_loss = val_loss 573 | 574 | # dev-performance based learning rate annealing 575 | if args.scheduler == 'dev_perf': 576 | scheduler.step(val_loss) 577 | if args.sample_softmax > 0: 578 | scheduler_sparse.step(val_loss) 579 | 580 | eval_start_time = time.time() 581 | 582 | if train_step == args.max_step: 583 | break 584 | 585 | # Loop over epochs. 586 | train_step = 0 587 | train_loss = 0 588 | train_p = 0 589 | best_val_loss = None 590 | 591 | log_start_time = time.time() 592 | eval_start_time = time.time() 593 | 594 | # At any point you can hit Ctrl + C to break out of training early. 595 | if not args.evalmode: 596 | try: 597 | for epoch in itertools.count(start=1): 598 | train() 599 | if train_step == args.max_step: 600 | logging('-' * 100) 601 | logging('End of training') 602 | break 603 | except KeyboardInterrupt: 604 | logging('-' * 100) 605 | logging('Exiting from training early') 606 | 607 | # Load the best saved model. 608 | with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f: 609 | model = torch.load(f) 610 | para_model = model.to(device) 611 | 612 | # Run on test data. 613 | test_loss = evaluate(te_iter) 614 | logging('=' * 100) 615 | if args.dataset in ['enwik8', 'text8']: 616 | logging('| End of training | test loss {:5.2f} | test bpc {:9.5f}'.format( 617 | test_loss, test_loss / math.log(2))) 618 | else: 619 | logging('| End of training | test loss {:5.2f} | test ppl {:9.3f}'.format( 620 | test_loss, math.exp(test_loss))) 621 | logging('=' * 100) 622 | -------------------------------------------------------------------------------- /rescore.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import argparse 3 | import sys, os 4 | import torch 5 | import math 6 | import time 7 | from operator import itemgetter 8 | import numpy as np 9 | import torch.nn.functional as F 10 | 11 | # import data 12 | 13 | parser = argparse.ArgumentParser(description='PyTorch Level-2 RNN/LSTM Language Model') 14 | parser.add_argument('--data', type=str, default='./data/AMI', 15 | help='location of the data corpus') 16 | parser.add_argument('--model', type=str, default='model.pt', 17 | help='location of the 1st level model') 18 | parser.add_argument('--cuda', action='store_true', 19 | help='use CUDA') 20 | parser.add_argument('--lookback', type=int, default=0, 21 | help='Number of backward utterance embeddings to be incorporated') 22 | parser.add_argument('--uttlookforward', type=int, default=1, 23 | help='Number of forward utterance embeddings to be incorporated') 24 | parser.add_argument('--excludeself', type=int, default=1, 25 | help='current utterance embeddings to be incorporated') 26 | parser.add_argument('--saveprefix', type=str, default='tensors/AMI', 27 | help='Specify which data utterance embeddings saved') 28 | parser.add_argument('--nbest', type=str, default='dev.nbest.info.txt', 29 | help='Specify which nbest file to be used') 30 | parser.add_argument('--lmscale', type=float, default=6, 31 | help='how much importance to attach to rnn score') 32 | parser.add_argument('--lm', type=str, default='original', 33 | help='Specify which language model to be used: rnn, ngram or original') 34 | parser.add_argument('--ngramlist', type=str, default='', 35 | help='Specify which ngram stream file to be used') 36 | parser.add_argument('--saveemb', action='store_true', 37 | help='save utterance embeddings') 38 | parser.add_argument('--save1best', action='store_true', 39 | help='save 1best list') 40 | parser.add_argument('--context', type=str, default='0', 41 | help='Specify which utterance embeddings to be used') 42 | parser.add_argument('--use_context', action='store_true', 43 | help='use future context') 44 | parser.add_argument('--contextfile', type=str, default='rescore/time_sorted_dev.nbestlist.context', 45 | help='1best file for context') 46 | parser.add_argument('--logfile', type=str, default='LOGs/log.txt', 47 | help='the logfile for this script') 48 | parser.add_argument('--interp', action='store_true', 49 | help='Linear interpolation of LMs') 50 | parser.add_argument('--factor', type=float, default=0.8, 51 | help='ngram interpolation weight factor') 52 | parser.add_argument('--gscale', type=float, default=12.0, 53 | help='ngram grammar scaling factor') 54 | parser.add_argument('--maxlen', type=int, default=0, 55 | help='how many future words to look at') 56 | parser.add_argument('--use_true', action='store_true', 57 | help='Use true dev file for study') 58 | parser.add_argument('--true_file', type=str, default='data/dev.txt', 59 | help='Specify which true context file to be used') 60 | parser.add_argument('--futurescale', type=float, default=1.0, 61 | help='how much importance to attach to future word scores') 62 | parser.add_argument('--map', type=str, default='rescore/eval.map', 63 | help='mapping file for utterance names') 64 | parser.add_argument('--subbatchsize', type=int, default=20, 65 | help='Sub batch size for batched forwarding') 66 | parser.add_argument('--mem_len', type=int, default=0, 67 | help='Sub batch size for batched forwarding') 68 | parser.add_argument('--ppl', action='store_true', 69 | help='Calculate and report ppl') 70 | parser.add_argument('--pplword', action='store_true', 71 | help='Calculate and report average ppl for each word') 72 | parser.add_argument('--extra-model', type=str, default='', 73 | help='Extra LM to be interpolated') 74 | parser.add_argument('--extra-modeltype', type=str, default='RNN', 75 | help='The type of the extra LM to be interpolated') 76 | 77 | args = parser.parse_args() 78 | 79 | def logging(s, print_=True, log_=True): 80 | if print_: 81 | print(s) 82 | if log_: 83 | with open(args.logfile, 'a+') as f_log: 84 | f_log.write(s + '\n') 85 | 86 | # Read in dictionary 87 | logging("Reading dictionary...") 88 | dictionary = {} 89 | with open(os.path.join(args.data, 'dictionary.txt')) as vocabin: 90 | lines = vocabin.readlines() 91 | for line in lines: 92 | ind, word = line.strip().split(' ') 93 | if word not in dictionary: 94 | dictionary[word] = ind 95 | else: 96 | logging("Error! Repeated words in the dictionary!") 97 | 98 | ntokens = len(dictionary) 99 | eosidx = int(dictionary['']) 100 | 101 | if args.pplword: 102 | wordppl = {} 103 | for word, ind in dictionary.items(): 104 | wordppl[word] = [0.0, 0] 105 | 106 | device = torch.device("cuda" if args.cuda else "cpu") 107 | cpu = torch.device("cpu") 108 | 109 | # Read in trained 1st level model 110 | logging("Reading model...") 111 | with open(args.model, 'rb') as f: 112 | model = torch.load(f) 113 | # after load the rnn params are not a continuous chunk of memory 114 | # this makes them a continuous chunk, and will speed up forward pass 115 | if args.cuda: 116 | model.cuda() 117 | 118 | logging("Reading extra model...") 119 | extramodel = None 120 | if args.extra_model != '': 121 | with open(args.extra_model, 'rb') as f: 122 | extramodel = torch.load(f) 123 | # after load the rnn params are not a continuous chunk of memory 124 | # this makes them a continuous chunk, and will speed up forward pass 125 | if args.cuda: 126 | extramodel.cuda() 127 | criterion = torch.nn.CrossEntropyLoss(reduction='none') 128 | 129 | def repackage_hidden(h): 130 | """Wraps hidden states in new Tensors, to detach them from their history.""" 131 | if isinstance(h, torch.Tensor): 132 | return h.detach() 133 | else: 134 | return tuple(repackage_hidden(v) for v in h) 135 | 136 | def forward_extra(extra_model, inputs, targets, hidden): 137 | prob_list = [] 138 | hidden_list = [] 139 | for i, input_data in enumerate(inputs): 140 | target = targets[i] 141 | output, new_hidden = extra_model(input_data.view(-1, 1).to(device), hidden) 142 | probs = criterion(output.squeeze(1), target.to(device)) 143 | probs = torch.exp(-probs) 144 | prob_list.append(probs) 145 | hidden_list.append(new_hidden) 146 | return prob_list, hidden_list 147 | 148 | # Batched forward lookback RNN 149 | def forward_each_utt_batched_lookback_rnn(model, 150 | lines, 151 | utt_name, 152 | prev_utts, 153 | mems, 154 | ppl=False, 155 | hidden=None, 156 | extra_model=None, 157 | extra_hidden=None): 158 | # Process each line 159 | inputs = [] 160 | targets = [] 161 | ac_scores = [] 162 | lm_scores = [] 163 | maxlen = 0 164 | # new_mems = mems 165 | utterances = [] 166 | utterances_ind = [] 167 | extra_inputs = [] 168 | extra_targets = [] 169 | target_index_list = [] 170 | 171 | for line in lines: 172 | linevec = line.strip().split() 173 | ac_score = float(linevec[0]) 174 | utterance = linevec[4:-1] 175 | currentline = [] 176 | for i, word in enumerate(utterance): 177 | if word in dictionary: 178 | currentline.append(int(dictionary[word])) 179 | else: 180 | currentline.append(int(dictionary[''])) 181 | utterances.append(utterance) 182 | utterances_ind.append(currentline) 183 | ac_scores.append(ac_score) 184 | if len(currentline) > maxlen: 185 | maxlen = len(currentline) 186 | mask = [] 187 | ac_score_tensor = torch.tensor(ac_scores).to(device) 188 | # Pad inputs and targets, prev_append in [0, len(prev_utts)] 189 | prev_append = max(min(args.lookback - maxlen, len(prev_utts)), 1) 190 | for i, symbols in enumerate(utterances_ind): 191 | full_sequence = prev_utts[-prev_append:] + symbols + [eosidx] * (maxlen - len(symbols) + 1) 192 | inputs.append(full_sequence[:-1]) 193 | targets.append(full_sequence[1:]) 194 | # get interpolated model inputs and targets 195 | extra_inputs.append(torch.LongTensor([eosidx] + symbols)) 196 | extra_targets.append(torch.LongTensor(symbols+[eosidx])) 197 | mask.append([0.0] * (prev_append - 1) + [1.0] * (len(symbols) + 1) + [0.0] * (maxlen - len(symbols))) 198 | # arrange inputs and targets into tensors 199 | input_tensor = torch.LongTensor(inputs).to(device).t().contiguous() 200 | target_tensor = torch.LongTensor(targets).to(device).t().contiguous() 201 | mask_tensor = torch.tensor(mask).to(device).t().contiguous() 202 | bsize = input_tensor.size(1) 203 | seq_len = input_tensor.size(0) 204 | 205 | # forward prop interpolate model 206 | if args.extra_model != '' and args.extra_modeltype == 'RNN': 207 | interp_prob_list, extra_hidden_list = forward_extra(extra_model, extra_inputs, extra_targets, extra_hidden) 208 | 209 | # Forward prop transformer 210 | logProblist = [] 211 | mem_list = [] 212 | ppl_list = [] 213 | hidden_list = [] 214 | # initialise RNN hidden state 215 | if hidden is None and getattr(model, "rnnenc", False): 216 | hidden = model.init_hidden(1) 217 | pos_hidden = [(hid[0][-1:], hid[1][-1:]) for hid in hidden] 218 | elif prev_append == 1 and getattr(model, "rnnenc", False): 219 | pos_hidden = [(hid[0][-1:], hid[1][-1:]) for hid in hidden] 220 | elif getattr(model, "rnnenc", False): 221 | pos_hidden = [(hid[0][-prev_append:-prev_append+1], hid[1][-prev_append:-prev_append+1]) for hid in hidden] 222 | # import pdb; pdb.set_trace() 223 | # transformer XL 224 | tiled_mems = tuple() 225 | if len(mems) > 0 and prev_append < len(prev_utts): 226 | # determine how much memory to keep: prev_append + tiled_mems[0].size(0) = mems[0].size(0) 227 | tiled_mems = [mem[-prev_append-args.mem_len+1:-prev_append+1].repeat(1, bsize, 1) for mem in mems] 228 | # Start forwarding 229 | for i in range(0, bsize, args.subbatchsize): 230 | # mems for transformer XL 231 | if len(tiled_mems) > 0: 232 | this_mem = [mem[:, i:i+args.subbatchsize, :].contiguous() for mem in tiled_mems] 233 | else: 234 | this_mem = tuple() 235 | 236 | bsz = min(args.subbatchsize, bsize - i) 237 | # expand rnn hidden state 238 | rnn_hidden = None 239 | if hidden is not None: 240 | rnn_hidden = [(hid[0].repeat(1, bsz, 1), hid[1].repeat(1, bsz, 1)) for hid in pos_hidden] 241 | 242 | ret = model(input_tensor[:, i:i+args.subbatchsize].contiguous(), 243 | target_tensor[:, i:i+args.subbatchsize].contiguous(), 244 | *this_mem, rnn_hidden=rnn_hidden, stepwise=True) 245 | loss, this_mem, penalty, rnn_hidden = ret[0], ret[1:-2], ret[-2], ret[-1] 246 | 247 | if args.mem_len > 0 and len(this_mem) > 0: 248 | mem_list.append(torch.stack(this_mem)) 249 | loss = loss * mask_tensor[:, i:i+args.subbatchsize] 250 | logProblist.append(loss) 251 | hidden_list.append(rnn_hidden) 252 | if args.pplword: 253 | ppl_list.append(loss) 254 | # outputlist.append(output[:,-1,:]) 255 | lmscores = torch.cat(logProblist, 1) 256 | if args.extra_model == '': 257 | lmscores = torch.sum(lmscores, dim=0) 258 | else: 259 | interpolated_score = [] 260 | for i, probs in enumerate(interp_prob_list): 261 | tranformer_score = lmscores[:,i].tolist() 262 | tranformer_score = torch.tensor([np.exp(-score) for score in tranformer_score if score > 0]).to(device) 263 | assert len(probs) == len(tranformer_score) 264 | lmscore = -torch.log(args.factor * tranformer_score + (1 - args.factor) * probs) 265 | interpolated_score.append(torch.sum(lmscore)) 266 | lmscores = torch.stack(interpolated_score) 267 | # lmscores = torch.sum(logProb.view(seq_len, bsize)*mask_tensor, 0) 268 | total_scores = - lmscores * args.lmscale + ac_score_tensor 269 | # Get output in some format 270 | outputlines = [] 271 | for i, utt in enumerate(utterances): 272 | out = ' '.join([utt_name+'-'+str(i+1), '{:5.2f}'.format(lmscores[i])]) 273 | outputlines.append(out+'\n') 274 | max_ind = torch.argmax(total_scores) 275 | # RNN hidden state selection 276 | if len(hidden_list) > 0: 277 | all_hidden = [] 278 | for hid in zip(*hidden_list): 279 | hid_l = list(zip(*hid)) 280 | all_hidden.append((torch.cat(hid_l[0], dim=1), torch.cat(hid_l[1], dim=1))) 281 | best_hid = [(hid[0][:, max_ind:max_ind+1, :], hid[1][:, max_ind:max_ind+1, :]) for hid in all_hidden] 282 | else: 283 | best_hid = None 284 | 285 | best_utt = utterances[max_ind] 286 | best_utt_len = len(utterances[max_ind]) 287 | prev_utts += (utterances_ind[max_ind] + [eosidx]) 288 | hidden = [(torch.cat([hidden[i][0][-args.lookback-2:], hid[0][:best_utt_len+1]], dim=0), 289 | torch.cat([hidden[i][1][-args.lookback-2:], hid[1][:best_utt_len+1]], dim=0)) for i, hid in enumerate(best_hid)] 290 | if len(mem_list) > 0: 291 | mem_list = torch.cat(mem_list, dim=2) 292 | start_pos = max(mem_list.size(1) - maxlen - 1, 0) 293 | mem_list = mem_list[:, start_pos:start_pos+len(best_utt)+1, max_ind:max_ind+1, :] 294 | if len(mems) > 0: 295 | mem_list = [torch.cat([mems[i], mem_list[i]])[-(args.mem_len+args.lookback):] for i in range(mem_list.size(0))] 296 | else: 297 | mem_list = [mem_list[i] for i in range(mem_list.size(0))] 298 | # extra hidden states for interpolation 299 | extrahidden = extra_hidden_list[max_ind] if args.extra_model != '' else None 300 | # calculate perplexity 301 | best_ppl = lmscores[max_ind] if ppl else None 302 | # calculate per word perplexity 303 | if args.pplword: 304 | ppl_list = torch.cat(ppl_list, dim=1)[:, max_ind] 305 | for i, word in enumerate(best_utt+['']): 306 | if word in wordppl: 307 | wordppl[word][0] += ppl_list[i+prev_append-1] 308 | wordppl[word][1] += 1 309 | else: 310 | wordppl['OOV'][0] += ppl_list[i+prev_append-1] 311 | wordppl['OOV'][1] += 1 312 | return best_utt, outputlines, prev_utts[-args.lookback:], mem_list, best_ppl, hidden, extrahidden 313 | 314 | # Batched forward lookback 315 | def forward_each_utt_batched_lookback(model, 316 | lines, 317 | utt_name, 318 | prev_utts, 319 | mems, 320 | ppl=False, 321 | hidden=None, 322 | extra_model=None, 323 | extra_hidden=None): 324 | # Process each line 325 | inputs = [] 326 | targets = [] 327 | ac_scores = [] 328 | lm_scores = [] 329 | maxlen = 0 330 | new_mems = mems 331 | utterances = [] 332 | utterances_ind = [] 333 | target_index_list = [] 334 | extra_inputs = [] 335 | extra_targets = [] 336 | 337 | for line in lines: 338 | linevec = line.strip().split() 339 | ac_score = float(linevec[0]) 340 | utterance = linevec[4:-1] 341 | currentline = [] 342 | for i, word in enumerate(utterance): 343 | if word in dictionary: 344 | currentline.append(int(dictionary[word])) 345 | else: 346 | currentline.append(int(dictionary[''])) 347 | utterances.append(utterance) 348 | utterances_ind.append(currentline) 349 | ac_scores.append(ac_score) 350 | if len(currentline) > maxlen: 351 | maxlen = len(currentline) 352 | mask = [] 353 | ac_score_tensor = torch.tensor(ac_scores).to(device) 354 | # Pad inputs and targets, prev_append in [0, len(prev_utts)] 355 | prev_append = max(min(args.lookback - maxlen, len(prev_utts)), 1) 356 | for i, symbols in enumerate(utterances_ind): 357 | full_sequence = prev_utts[-prev_append:] + symbols + [eosidx] * (maxlen - len(symbols) + 1) 358 | inputs.append(full_sequence[:-1]) 359 | targets.append(full_sequence[1:]) 360 | # get interpolated model inputs and targets 361 | extra_inputs.append(torch.LongTensor([eosidx] + symbols)) 362 | extra_targets.append(torch.LongTensor(symbols+[eosidx])) 363 | 364 | mask.append([0.0] * (prev_append - 1) + [1.0] * (len(symbols) + 1) + [0.0] * (maxlen - len(symbols))) 365 | # arrange inputs and targets into tensors 366 | input_tensor = torch.LongTensor(inputs).to(device).t().contiguous() 367 | target_tensor = torch.LongTensor(targets).to(device).t().contiguous() 368 | mask_tensor = torch.tensor(mask).to(device).t().contiguous() 369 | bsize = input_tensor.size(1) 370 | seq_len = input_tensor.size(0) 371 | 372 | # forward prop interpolate model 373 | if args.extra_model != '' and args.extra_modeltype == 'RNN': 374 | interp_prob_list, extra_hidden_list = forward_extra(extra_model, extra_inputs, extra_targets, extra_hidden) 375 | 376 | # Forward prop transformer 377 | logProblist = [] 378 | mem_list = [] 379 | ppl_list = [] 380 | # initialise RNN hidden stte 381 | if hidden is None and getattr(model, "rnnenc", False): 382 | hidden = model.init_hidden(1) 383 | # transformer XL 384 | tiled_mems = tuple() 385 | if len(mems) > 0 and prev_append < len(prev_utts): 386 | # determine how much memory to keep: prev_append + tiled_mems[0].size(0) = mems[0].size(0) 387 | tiled_mems = [mem[-prev_append-args.mem_len+1:-prev_append+1].repeat(1, bsize, 1) for mem in mems] 388 | for i in range(0, bsize, args.subbatchsize): 389 | # mems for transformer XL 390 | if len(tiled_mems) > 0: 391 | this_mem = [mem[:, i:i+args.subbatchsize, :].contiguous() for mem in tiled_mems] 392 | else: 393 | this_mem = tuple() 394 | 395 | bsz = min(args.subbatchsize, bsize - i) 396 | # expand rnn hidden state 397 | rnn_hidden = None 398 | if hidden is not None: 399 | rnn_hidden = [(hid[0].repeat(1, bsz, 1), hid[1].repeat(1, bsz, 1)) for hid in hidden] 400 | 401 | ret = model(input_tensor[:, i:i+args.subbatchsize].contiguous(), 402 | target_tensor[:, i:i+args.subbatchsize].contiguous(), 403 | *this_mem, rnn_hidden=rnn_hidden) 404 | loss, this_mem, penalty, hidden = ret[0], ret[1:-2], ret[-2], ret[-1] 405 | if args.mem_len > 0 and len(this_mem) > 0: 406 | mem_list.append(torch.stack(this_mem)) 407 | loss = loss * mask_tensor[:, i:i+args.subbatchsize] 408 | logProblist.append(loss) 409 | if args.pplword: 410 | ppl_list.append(loss) 411 | # outputlist.append(output[:,-1,:]) 412 | lmscores = torch.cat(logProblist, 1) 413 | if args.extra_model == '': 414 | lmscores = torch.sum(lmscores, dim=0) 415 | else: 416 | interpolated_score = [] 417 | for i, probs in enumerate(interp_prob_list): 418 | tranformer_score = lmscores[:,i].tolist() 419 | tranformer_score = torch.tensor([np.exp(-score) for score in tranformer_score if score > 0]).to(device) 420 | assert len(probs) == len(tranformer_score) 421 | lmscore = -torch.log(args.factor * tranformer_score + (1 - args.factor) * probs) 422 | interpolated_score.append(torch.sum(lmscore)) 423 | lmscores = torch.stack(interpolated_score) 424 | # lmscores = torch.sum(logProb.view(seq_len, bsize)*mask_tensor, 0) 425 | total_scores = - lmscores * args.lmscale + ac_score_tensor 426 | # Get output in some format 427 | outputlines = [] 428 | for i, utt in enumerate(utterances): 429 | out = ' '.join([utt_name+'-'+str(i+1), '{:5.2f}'.format(lmscores[i])]) 430 | outputlines.append(out+'\n') 431 | max_ind = torch.argmax(total_scores) 432 | best_utt = utterances[max_ind] 433 | prev_utts += (utterances_ind[max_ind] + [eosidx]) 434 | if len(mem_list) > 0: 435 | new_mems = torch.cat(mem_list, dim=2) 436 | start_pos = max(new_mems.size(1) - maxlen - 1, 0) 437 | new_mems = new_mems[:, start_pos:start_pos+len(best_utt)+1, max_ind:max_ind+1, :] 438 | if len(mems) > 0: 439 | new_mems = [torch.cat([mems[i], new_mems[i]])[-(args.mem_len+args.lookback):] for i in range(new_mems.size(0))] 440 | else: 441 | new_mems = [new_mems[i] for i in range(new_mems.size(0))] 442 | # extra hidden states for interpolation 443 | extrahidden = extra_hidden_list[max_ind] if args.extra_model != '' else None 444 | # calculate perplexity 445 | best_ppl = lmscores[max_ind] if ppl else None 446 | # calculate per word perplexity 447 | if args.pplword: 448 | ppl_list = torch.cat(ppl_list, dim=1)[:, max_ind] 449 | for i, word in enumerate(best_utt+['']): 450 | if word in wordppl: 451 | wordppl[word][0] += ppl_list[i+prev_append-1] 452 | wordppl[word][1] += 1 453 | else: 454 | wordppl['OOV'][0] += ppl_list[i+prev_append-1] 455 | wordppl['OOV'][1] += 1 456 | return best_utt, outputlines, prev_utts[-args.lookback:], new_mems, best_ppl, extrahidden 457 | 458 | # Batched forward 459 | def forward_each_utt_batched(model, lines, utt_name, prev_utts, mems, ppl=False, hidden=None): 460 | # Process each line 461 | inputs = [] 462 | targets = [] 463 | ac_scores = [] 464 | lm_scores = [] 465 | maxlen = 0 466 | new_mems = mems 467 | utterances = [] 468 | utterances_ind = [] 469 | 470 | for line in lines: 471 | linevec = line.strip().split() 472 | ac_score = float(linevec[0]) 473 | utterance = linevec[4:-1] 474 | currentline = [] 475 | for i, word in enumerate(utterance): 476 | if word in dictionary: 477 | currentline.append(int(dictionary[word])) 478 | else: 479 | currentline.append(int(dictionary[''])) 480 | currentline = [eosidx] + currentline 481 | currenttarget = currentline[1:] 482 | currenttarget.append(eosidx) 483 | inputs.append(currentline) 484 | targets.append(currenttarget) 485 | utterances.append(utterance) 486 | utterances_ind.append(currenttarget) 487 | ac_scores.append(ac_score) 488 | if len(currentline) > maxlen: 489 | maxlen = len(currentline) 490 | mask = [] 491 | ac_score_tensor = torch.tensor(ac_scores).to(device) 492 | for i, symbols in enumerate(inputs): 493 | inputs[i] = symbols + [eosidx] * (maxlen - len(symbols)) 494 | targets[i] = targets[i] + [eosidx] * (maxlen - len(symbols)) 495 | mask.append([1.0] * len(symbols) + [0.0] * (maxlen - len(symbols))) 496 | 497 | input_tensor = torch.LongTensor(inputs).to(device).t().contiguous() 498 | target_tensor = torch.LongTensor(targets).to(device).t().contiguous() 499 | mask_tensor = torch.tensor(mask).to(device).t().contiguous() 500 | bsize = input_tensor.size(1) 501 | seq_len = input_tensor.size(0) 502 | 503 | # Forward prop transformer 504 | logProblist = [] 505 | mem_list = [] 506 | ppl_list = [] 507 | hidden_list = [] 508 | if hidden is None and getattr(model, "rnnenc", False): 509 | hidden = model.init_hidden(1) 510 | # transformer XL 511 | prev_mem_len = 0 512 | if len(mems) > 0: 513 | prev_mem_len = mems[0].size(0) 514 | tiled_mems = [mem.repeat(1, bsize, 1) for mem in mems] 515 | for i in range(0, bsize, args.subbatchsize): 516 | # mems for transformer XL 517 | if len(mems) > 0: 518 | this_mem = [mem[:, i:i+args.subbatchsize, :].contiguous() for mem in tiled_mems] 519 | else: 520 | this_mem = mems 521 | bsz = min(args.subbatchsize, bsize - i) 522 | # expand rnn hidden state 523 | rnn_hidden = None 524 | if hidden is not None: 525 | rnn_hidden = [(hid[0].repeat(1, bsz, 1), hid[1].repeat(1, bsz, 1)) for hid in hidden] 526 | # forward pass 527 | ret = model(input_tensor[:, i:i+args.subbatchsize].contiguous(), 528 | target_tensor[:, i:i+args.subbatchsize].contiguous(), 529 | *this_mem, rnn_hidden=rnn_hidden) 530 | loss, this_mems, rnn_hidden = ret[0], ret[1:-1], ret[-1] 531 | if args.mem_len > 0: 532 | mem_list.append(torch.stack(this_mems)) 533 | if hidden is not None: 534 | hidden_list.append(rnn_hidden) 535 | loss = loss * mask_tensor[:, i:i+args.subbatchsize] 536 | logProblist.append(torch.sum(loss, dim=0)) 537 | # outputlist.append(output[:,-1,:]) 538 | lmscores = torch.cat(logProblist, 0) 539 | # lmscores = torch.sum(logProb.view(seq_len, bsize)*mask_tensor, 0) 540 | total_scores = - lmscores * args.lmscale + ac_score_tensor 541 | # Get output in some format 542 | outputlines = [] 543 | for i, utt in enumerate(utterances): 544 | out = ' '.join([utt_name+'-'+str(i+1), '{:5.2f}'.format(lmscores[i])]) 545 | outputlines.append(out+'\n') 546 | max_ind = torch.argmax(total_scores) 547 | # choose best rnn hidden state 548 | if len(hidden_list) > 0: 549 | all_hidden = [] 550 | for hid in zip(*hidden_list): 551 | hid_l = list(zip(*hid)) 552 | all_hidden.append((torch.cat(hid_l[0], dim=1), torch.cat(hid_l[1], dim=1))) 553 | best_hid = [(hid[0][:, max_ind:max_ind+1, :], hid[1][:, max_ind:max_ind+1, :]) for hid in all_hidden] 554 | else: 555 | best_hid = None 556 | 557 | best_utt = utterances[max_ind] 558 | prev_utts += utterances_ind[max_ind] 559 | if len(mem_list) > 0: 560 | new_mems = torch.cat(mem_list, dim=2) 561 | start_pos = max(new_mems.size(1) - seq_len, 0) 562 | new_mems = new_mems[:, start_pos:start_pos+len(best_utt)+1, max_ind:max_ind+1, :] 563 | if len(mems) > 0: 564 | new_mems = [torch.cat([mems[i], new_mems[i]])[-args.mem_len:] for i in range(new_mems.size(0))] 565 | else: 566 | new_mems = [new_mems[i] for i in range(new_mems.size(0))] 567 | # calculate perplexity 568 | best_ppl = lmscores[max_ind] if ppl else None 569 | return best_utt, outputlines, prev_utts, new_mems, best_ppl, best_hid 570 | 571 | def forward_nbest_utterance(model, nbestfile, extramodel=None): 572 | """ The main body of the rescore function. """ 573 | model.eval() 574 | # decide if we calculate the average of the loss 575 | if args.interp: 576 | forwardCrit = torch.nn.CrossEntropyLoss(reduction='none') 577 | else: 578 | forwardCrit = torch.nn.CrossEntropyLoss() 579 | 580 | extrahidden = None 581 | if args.extra_model != '' and extramodel is not None: 582 | extramodel.eval() 583 | extrahidden = extramodel.init_hidden(1) 584 | # initialising variables needed 585 | ngram_cursor = 0 586 | lmscored_lines = [] 587 | best_utt_list = [] 588 | emb_list = [] 589 | utt_idx = 0 590 | prev_utts = [eosidx] # * args.lookback 591 | start = time.time() 592 | best_hid = None 593 | best_ppl = None 594 | mems = tuple() 595 | 596 | total_ppl = torch.zeros(1) 597 | total_len = 0 598 | 599 | # Ngram used for lattice rescoring 600 | with open(nbestfile) as filein: 601 | with torch.no_grad(): 602 | for utterancefile in filein: 603 | # Iterating over the nbest list 604 | labname = utterancefile.strip().split('/')[-1] 605 | # Read in ngram LM files for interpolation 606 | if args.interp: 607 | ngram_probfile_name = ngram_listfile.readline() 608 | ngram_probfile = open(ngram_probfile_name.strip()) 609 | ngram_prob_lines = ngram_probfile.readlines() 610 | future_context = future_context_dict[utt_idx] if args.use_context else [] 611 | 612 | # Start processing each nbestlist 613 | with open(utterancefile.strip()) as uttfile: 614 | uttlines = uttfile.readlines() 615 | 616 | uttscore = [] 617 | # Do re-ranking batch by batch 618 | if not args.interp: 619 | if args.lookback > 0 and getattr(model, "rnnenc", False): 620 | bestutt, to_write, prev_utts, mems, best_ppl, best_hid, extrahidden = forward_each_utt_batched_lookback_rnn( 621 | model, uttlines, labname, prev_utts, mems, args.ppl, best_hid, 622 | extra_model=extramodel, extra_hidden=extrahidden) 623 | elif args.lookback > 0: 624 | bestutt, to_write, prev_utts, mems, best_ppl, extrahidden = forward_each_utt_batched_lookback( 625 | model, uttlines, labname, prev_utts, mems, args.ppl, 626 | extra_model=extramodel, extra_hidden=extrahidden) 627 | else: 628 | bestutt, to_write, prev_utts, mems, best_ppl, best_hid = forward_each_utt_batched( 629 | model, uttlines, labname, prev_utts, mems, args.ppl, best_hid) 630 | lmscored_lines += to_write 631 | utt_idx += 1 632 | best_utt_list.append((labname, bestutt)) 633 | if args.ppl and best_ppl is not None: 634 | total_ppl += torch.sum(best_ppl) 635 | total_len += len(bestutt) + 1 636 | # Log every completion of n utterances 637 | if utt_idx % 100 == 0: 638 | logging("current ppl is {:5.2f}".format(torch.exp(total_ppl / total_len).item())) 639 | logging('rescored {} utterances, time overlapped {:6.2f}'.format(str(utt_idx), time.time()-start)) 640 | # Write out renewed lmscore file 641 | with open(nbestfile+'.renew.'+args.lm, 'w') as fout: 642 | fout.writelines(lmscored_lines) 643 | # Save 1-best for later use for the context 644 | if args.save1best: 645 | # Write out for second level forwarding 646 | with open(nbestfile+'.context', 'w') as fout: 647 | for i, eachutt in enumerate(best_utt_list): 648 | linetowrite = ' ' + ' '.join(eachutt[1]) + ' \n' 649 | fout.write(linetowrite) 650 | if args.ppl: 651 | print(total_len) 652 | print(torch.exp(total_ppl / total_len)) 653 | 654 | logging('getting utterances') 655 | forward_nbest_utterance(model, args.nbest, extramodel) 656 | if args.pplword: 657 | with open("word_ppl_{}".format(args.lm), 'w') as fin: 658 | for word, group in wordppl.items(): 659 | total_ppl, total_count = group 660 | if total_count > 0: 661 | fin.write('{}\t\t{}\t\t{:5.3f}\n'.format(word, total_count, float(total_ppl/total_count))) 662 | -------------------------------------------------------------------------------- /mem_transformer_rnn.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import math 3 | import functools 4 | 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | sys.path.append('utils') 12 | from proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax 13 | from log_uniform_sampler import LogUniformSampler, sample_logits 14 | 15 | class PositionalEmbedding(nn.Module): 16 | def __init__(self, demb): 17 | super(PositionalEmbedding, self).__init__() 18 | 19 | self.demb = demb 20 | 21 | inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) 22 | self.register_buffer('inv_freq', inv_freq) 23 | 24 | def forward(self, pos_seq, bsz=None): 25 | sinusoid_inp = torch.ger(pos_seq, self.inv_freq) 26 | pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) 27 | 28 | if bsz is not None: 29 | return pos_emb[:,None,:].expand(-1, bsz, -1) 30 | else: 31 | return pos_emb[:,None,:] 32 | 33 | 34 | class PositionwiseFF(nn.Module): 35 | def __init__(self, d_model, d_inner, dropout, pre_lnorm=False): 36 | super(PositionwiseFF, self).__init__() 37 | 38 | self.d_model = d_model 39 | self.d_inner = d_inner 40 | self.dropout = dropout 41 | 42 | self.CoreNet = nn.Sequential( 43 | nn.Linear(d_model, d_inner), nn.ReLU(inplace=True), 44 | nn.Dropout(dropout), 45 | nn.Linear(d_inner, d_model), 46 | nn.Dropout(dropout), 47 | ) 48 | 49 | self.layer_norm = nn.LayerNorm(d_model) 50 | 51 | self.pre_lnorm = pre_lnorm 52 | 53 | def forward(self, inp): 54 | if self.pre_lnorm: 55 | ##### layer normalization + positionwise feed-forward 56 | core_out = self.CoreNet(self.layer_norm(inp)) 57 | 58 | ##### residual connection 59 | output = core_out + inp 60 | else: 61 | ##### positionwise feed-forward 62 | core_out = self.CoreNet(inp) 63 | 64 | ##### residual connection + layer normalization 65 | output = self.layer_norm(inp + core_out) 66 | 67 | return output 68 | 69 | class MultiHeadAttn(nn.Module): 70 | def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, 71 | pre_lnorm=False, penalty=False): 72 | super(MultiHeadAttn, self).__init__() 73 | 74 | self.n_head = n_head 75 | self.d_model = d_model 76 | self.d_head = d_head 77 | self.dropout = dropout 78 | 79 | self.q_net = nn.Linear(d_model, n_head * d_head, bias=False) 80 | self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False) 81 | 82 | self.drop = nn.Dropout(dropout) 83 | self.dropatt = nn.Dropout(dropatt) 84 | self.o_net = nn.Linear(n_head * d_head, d_model, bias=False) 85 | 86 | self.layer_norm = nn.LayerNorm(d_model) 87 | 88 | self.scale = 1 / (d_head ** 0.5) 89 | 90 | self.pre_lnorm = pre_lnorm 91 | self.penalty = penalty 92 | 93 | def calc_penalty(self, attn_prob): 94 | if self.penalty: 95 | ATA = torch.einsum('ijbm,ijbn->ibmn', attn_prob, attn_prob) 96 | seqlen = attn_prob.size(0) 97 | lambda_array = torch.diag(torch.tensor([1.0]*(self.n_head))) 98 | pen = ((ATA - torch.eye(ATA.size(2), device=ATA.device)) ** 2).sum() 99 | else: 100 | pen = attn_prob.new_zeros(1) 101 | return pen 102 | 103 | def forward(self, h, attn_mask=None, mems=None): 104 | ##### multihead attention 105 | # [hlen x bsz x n_head x d_head] 106 | 107 | if mems is not None: 108 | c = torch.cat([mems, h], 0) 109 | else: 110 | c = h 111 | 112 | if self.pre_lnorm: 113 | ##### layer normalization 114 | c = self.layer_norm(c) 115 | 116 | head_q = self.q_net(h) 117 | head_k, head_v = torch.chunk(self.kv_net(c), 2, -1) 118 | 119 | head_q = head_q.view(h.size(0), h.size(1), self.n_head, self.d_head) 120 | head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head) 121 | head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head) 122 | 123 | # [qlen x klen x bsz x n_head] 124 | attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k)) 125 | attn_score.mul_(self.scale) 126 | if attn_mask is not None and attn_mask.any().item(): 127 | if attn_mask.dim() == 2: 128 | attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf')) 129 | elif attn_mask.dim() == 3: 130 | attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf')) 131 | 132 | # [qlen x klen x bsz x n_head] 133 | attn_prob = F.softmax(attn_score, dim=1) 134 | attn_penalty = self.calc_penalty(attn_prob) 135 | attn_prob = self.dropatt(attn_prob) 136 | 137 | # [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head] 138 | attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v)) 139 | attn_vec = attn_vec.contiguous().view( 140 | attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) 141 | 142 | ##### linear projection 143 | attn_out = self.o_net(attn_vec) 144 | attn_out = self.drop(attn_out) 145 | 146 | if self.pre_lnorm: 147 | ##### residual connection 148 | output = h + attn_out 149 | else: 150 | ##### residual connection + layer normalization 151 | output = self.layer_norm(h + attn_out) 152 | 153 | return output, attn_penalty 154 | 155 | class RelMultiHeadAttn(nn.Module): 156 | def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, 157 | tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False, penalty=False): 158 | super(RelMultiHeadAttn, self).__init__() 159 | 160 | self.n_head = n_head 161 | self.d_model = d_model 162 | self.d_head = d_head 163 | self.dropout = dropout 164 | 165 | self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False) 166 | 167 | self.drop = nn.Dropout(dropout) 168 | self.dropatt = nn.Dropout(dropatt) 169 | self.o_net = nn.Linear(n_head * d_head, d_model, bias=False) 170 | 171 | self.layer_norm = nn.LayerNorm(d_model) 172 | 173 | self.scale = 1 / (d_head ** 0.5) 174 | 175 | self.pre_lnorm = pre_lnorm 176 | self.penalty = penalty 177 | 178 | def _parallelogram_mask(self, h, w, left=False): 179 | mask = torch.ones((h, w)).byte() 180 | m = min(h, w) 181 | mask[:m,:m] = torch.triu(mask[:m,:m]) 182 | mask[-m:,-m:] = torch.tril(mask[-m:,-m:]) 183 | 184 | if left: 185 | return mask 186 | else: 187 | return mask.flip(0) 188 | 189 | def _shift(self, x, qlen, klen, mask, left=False): 190 | if qlen > 1: 191 | zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)), 192 | device=x.device, dtype=x.dtype) 193 | else: 194 | zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype) 195 | 196 | if left: 197 | mask = mask.flip(1) 198 | x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1) 199 | else: 200 | x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1) 201 | 202 | x = x_padded.masked_select(mask[:,:,None,None]) \ 203 | .view(qlen, klen, x.size(2), x.size(3)) 204 | 205 | return x 206 | 207 | def _calc_penalty(self, attn_prob): 208 | if self.penalty: 209 | ATA = torch.einsum('ijbm,ijbn->ibmn', attn_prob, attn_prob) 210 | seqlen = attn_prob.size(0) 211 | lambda_array = torch.diag(torch.tensor([0.5]*(self.n_head))) 212 | pen = ((ATA - torch.eye(ATA.size(2), device=ATA.device)) ** 2).sum() 213 | else: 214 | pen = attn_prob.new_zeros(1) 215 | return pen 216 | 217 | def _rel_shift(self, x, zero_triu=False): 218 | zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]), 219 | device=x.device, dtype=x.dtype) 220 | # import pdb; pdb.set_trace() 221 | x_padded = torch.cat([zero_pad, x], dim=1) 222 | 223 | x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:]) 224 | 225 | x = x_padded[1:].view_as(x) 226 | 227 | if zero_triu: 228 | ones = torch.ones((x.size(0), x.size(1))) 229 | x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None] 230 | 231 | return x 232 | 233 | def forward(self, w, r, attn_mask=None, mems=None): 234 | raise NotImplementedError 235 | 236 | class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn): 237 | def __init__(self, *args, **kwargs): 238 | super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs) 239 | 240 | self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False) 241 | 242 | def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None): 243 | qlen, rlen, bsz = w.size(0), r.size(0), w.size(1) 244 | 245 | if mems is not None: 246 | cat = torch.cat([mems, w], 0) 247 | if self.pre_lnorm: 248 | w_heads = self.qkv_net(self.layer_norm(cat)) 249 | else: 250 | w_heads = self.qkv_net(cat) 251 | r_head_k = self.r_net(r) 252 | 253 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) 254 | w_head_q = w_head_q[-qlen:] 255 | else: 256 | if self.pre_lnorm: 257 | w_heads = self.qkv_net(self.layer_norm(w)) 258 | else: 259 | w_heads = self.qkv_net(w) 260 | r_head_k = self.r_net(r) 261 | 262 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) 263 | 264 | klen = w_head_k.size(0) 265 | 266 | # import pdb; pdb.set_trace() 267 | w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head 268 | w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head 269 | w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head 270 | 271 | r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head 272 | 273 | #### compute attention score 274 | rw_head_q = w_head_q + r_w_bias # qlen x bsz x n_head x d_head 275 | AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head 276 | 277 | rr_head_q = w_head_q + r_r_bias 278 | BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k)) # qlen x klen x bsz x n_head 279 | BD = self._rel_shift(BD) 280 | 281 | # [qlen x klen x bsz x n_head] 282 | attn_score = AC + BD 283 | attn_score.mul_(self.scale) 284 | 285 | #### compute attention probability 286 | if attn_mask is not None and attn_mask.any().item(): 287 | if attn_mask.dim() == 2: 288 | attn_score = attn_score.float().masked_fill( 289 | attn_mask[None,:,:,None], -float('inf')).type_as(attn_score) 290 | elif attn_mask.dim() == 3: 291 | attn_score = attn_score.float().masked_fill( 292 | attn_mask[:,:,:,None], -float('inf')).type_as(attn_score) 293 | 294 | # [qlen x klen x bsz x n_head] 295 | attn_prob = F.softmax(attn_score, dim=1) 296 | attn_prob = self.dropatt(attn_prob) 297 | attn_penalty = self._calc_penalty(attn_prob) 298 | 299 | #### compute attention vector 300 | attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v)) 301 | 302 | # [qlen x bsz x n_head x d_head] 303 | attn_vec = attn_vec.contiguous().view( 304 | attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) 305 | 306 | ##### linear projection 307 | attn_out = self.o_net(attn_vec) 308 | attn_out = self.drop(attn_out) 309 | 310 | if self.pre_lnorm: 311 | ##### residual connection 312 | output = w + attn_out 313 | else: 314 | ##### residual connection + layer normalization 315 | output = self.layer_norm(w + attn_out) 316 | 317 | return output, attn_penalty 318 | 319 | class RelLearnableMultiHeadAttn(RelMultiHeadAttn): 320 | def __init__(self, *args, **kwargs): 321 | super(RelLearnableMultiHeadAttn, self).__init__(*args, **kwargs) 322 | 323 | def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None): 324 | # r_emb: [klen, n_head, d_head], used for term B 325 | # r_w_bias: [n_head, d_head], used for term C 326 | # r_bias: [klen, n_head], used for term D 327 | 328 | qlen, bsz = w.size(0), w.size(1) 329 | 330 | if mems is not None: 331 | cat = torch.cat([mems, w], 0) 332 | if self.pre_lnorm: 333 | w_heads = self.qkv_net(self.layer_norm(cat)) 334 | else: 335 | w_heads = self.qkv_net(cat) 336 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) 337 | 338 | w_head_q = w_head_q[-qlen:] 339 | else: 340 | if self.pre_lnorm: 341 | w_heads = self.qkv_net(self.layer_norm(w)) 342 | else: 343 | w_heads = self.qkv_net(w) 344 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) 345 | 346 | klen = w_head_k.size(0) 347 | 348 | w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) 349 | w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) 350 | w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) 351 | 352 | if klen > r_emb.size(0): 353 | r_emb_pad = r_emb[0:1].expand(klen-r_emb.size(0), -1, -1) 354 | r_emb = torch.cat([r_emb_pad, r_emb], 0) 355 | r_bias_pad = r_bias[0:1].expand(klen-r_bias.size(0), -1) 356 | r_bias = torch.cat([r_bias_pad, r_bias], 0) 357 | else: 358 | r_emb = r_emb[-klen:] 359 | r_bias = r_bias[-klen:] 360 | 361 | #### compute attention score 362 | rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head 363 | 364 | AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head 365 | B_ = torch.einsum('ibnd,jnd->ijbn', (w_head_q, r_emb)) # qlen x klen x bsz x n_head 366 | D_ = r_bias[None, :, None] # 1 x klen x 1 x n_head 367 | BD = self._rel_shift(B_ + D_) 368 | 369 | # [qlen x klen x bsz x n_head] 370 | attn_score = AC + BD 371 | attn_score.mul_(self.scale) 372 | 373 | #### compute attention probability 374 | if attn_mask is not None and attn_mask.any().item(): 375 | if attn_mask.dim() == 2: 376 | attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf')) 377 | elif attn_mask.dim() == 3: 378 | attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf')) 379 | 380 | # [qlen x klen x bsz x n_head] 381 | attn_prob = F.softmax(attn_score, dim=1) 382 | attn_prob = self.dropatt(attn_prob) 383 | 384 | #### compute attention vector 385 | attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v)) 386 | 387 | # [qlen x bsz x n_head x d_head] 388 | attn_vec = attn_vec.contiguous().view( 389 | attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) 390 | 391 | ##### linear projection 392 | attn_out = self.o_net(attn_vec) 393 | attn_out = self.drop(attn_out) 394 | 395 | if self.pre_lnorm: 396 | ##### residual connection 397 | output = w + attn_out 398 | else: 399 | ##### residual connection + layer normalization 400 | output = self.layer_norm(w + attn_out) 401 | 402 | return output 403 | 404 | class DecoderLayer(nn.Module): 405 | def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs): 406 | super(DecoderLayer, self).__init__() 407 | 408 | self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) 409 | self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, 410 | pre_lnorm=kwargs.get('pre_lnorm')) 411 | 412 | def forward(self, dec_inp, dec_attn_mask=None, mems=None): 413 | 414 | output, attn_pen = self.dec_attn(dec_inp, attn_mask=dec_attn_mask, 415 | mems=mems) 416 | output = self.pos_ff(output) 417 | 418 | return output, attn_pen 419 | 420 | class RelLearnableDecoderLayer(nn.Module): 421 | def __init__(self, n_head, d_model, d_head, d_inner, dropout, 422 | **kwargs): 423 | super(RelLearnableDecoderLayer, self).__init__() 424 | 425 | self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout, 426 | **kwargs) 427 | self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, 428 | pre_lnorm=kwargs.get('pre_lnorm')) 429 | 430 | def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None): 431 | 432 | output = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias, 433 | attn_mask=dec_attn_mask, 434 | mems=mems) 435 | output = self.pos_ff(output) 436 | 437 | return output 438 | 439 | class RelPartialLearnableDecoderLayer(nn.Module): 440 | def __init__(self, n_head, d_model, d_head, d_inner, dropout, 441 | **kwargs): 442 | super(RelPartialLearnableDecoderLayer, self).__init__() 443 | 444 | self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model, 445 | d_head, dropout, **kwargs) 446 | self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, 447 | pre_lnorm=kwargs.get('pre_lnorm')) 448 | 449 | def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None): 450 | 451 | output, attn_p = self.dec_attn(dec_inp, r, r_w_bias, r_r_bias, 452 | attn_mask=dec_attn_mask, 453 | mems=mems) 454 | output = self.pos_ff(output) 455 | 456 | return output, attn_p 457 | 458 | 459 | class AdaptiveEmbedding(nn.Module): 460 | def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, 461 | sample_softmax=False): 462 | super(AdaptiveEmbedding, self).__init__() 463 | 464 | self.n_token = n_token 465 | self.d_embed = d_embed 466 | 467 | self.cutoffs = cutoffs + [n_token] 468 | self.div_val = div_val 469 | self.d_proj = d_proj 470 | 471 | self.emb_scale = d_proj ** 0.5 472 | 473 | self.cutoff_ends = [0] + self.cutoffs 474 | 475 | self.emb_layers = nn.ModuleList() 476 | self.emb_projs = nn.ParameterList() 477 | if div_val == 1: 478 | self.emb_layers.append( 479 | nn.Embedding(n_token, d_embed, sparse=sample_softmax>0) 480 | ) 481 | if d_proj != d_embed: 482 | self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_embed))) 483 | else: 484 | for i in range(len(self.cutoffs)): 485 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1] 486 | d_emb_i = d_embed // (div_val ** i) 487 | self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i)) 488 | self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_emb_i))) 489 | 490 | def forward(self, inp): 491 | if self.div_val == 1: 492 | embed = self.emb_layers[0](inp) 493 | if self.d_proj != self.d_embed: 494 | embed = F.linear(embed, self.emb_projs[0]) 495 | else: 496 | param = next(self.parameters()) 497 | inp_flat = inp.view(-1) 498 | emb_flat = torch.zeros([inp_flat.size(0), self.d_proj], 499 | dtype=param.dtype, device=param.device) 500 | for i in range(len(self.cutoffs)): 501 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] 502 | 503 | mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx) 504 | indices_i = mask_i.nonzero().squeeze() 505 | 506 | if indices_i.numel() == 0: 507 | continue 508 | 509 | inp_i = inp_flat.index_select(0, indices_i) - l_idx 510 | emb_i = self.emb_layers[i](inp_i) 511 | emb_i = F.linear(emb_i, self.emb_projs[i]) 512 | 513 | emb_flat.index_copy_(0, indices_i, emb_i) 514 | 515 | embed = emb_flat.view(*inp.size(), self.d_proj) 516 | 517 | embed.mul_(self.emb_scale) 518 | 519 | return embed 520 | 521 | 522 | class MemTransformerLM(nn.Module): 523 | def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner, 524 | dropout, dropatt, tie_weight=True, d_embed=None, 525 | div_val=1, tie_projs=[False], pre_lnorm=False, 526 | tgt_len=None, ext_len=None, mem_len=None, 527 | cutoffs=[], adapt_inp=False, 528 | same_length=False, attn_type=0, clamp_len=-1, 529 | sample_softmax=-1, rnnenc=False, rnndim=0, 530 | layer_list='', future_len=0, attn_layerlist='', merge_type='direct'): 531 | super(MemTransformerLM, self).__init__() 532 | self.n_token = n_token 533 | 534 | d_embed = d_model if d_embed is None else d_embed 535 | self.d_embed = d_embed 536 | self.d_model = d_model 537 | self.n_head = n_head 538 | self.d_head = d_head 539 | 540 | self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs, 541 | div_val=div_val) 542 | 543 | self.drop = nn.Dropout(dropout) 544 | 545 | self.n_layer = n_layer 546 | 547 | self.tgt_len = tgt_len 548 | self.mem_len = mem_len 549 | self.ext_len = ext_len 550 | self.future_len = future_len 551 | self.max_klen = tgt_len + ext_len + mem_len + future_len 552 | 553 | self.attn_type = attn_type 554 | 555 | # RNN hidden state carry on 556 | self.layer_list = [int(i) for i in layer_list.split()] 557 | self.rnnlayer_list = self.layer_list 558 | print("rnn layer list: {}".format(self.rnnlayer_list)) 559 | if rnnenc and rnndim != 0: 560 | if merge_type in ['gating', 'project']: 561 | self.rnnproj = nn.Linear(rnndim + d_model, d_model) 562 | self.rnn_list = nn.ModuleList([nn.LSTM(d_model, rnndim, 1) for i in range(len(self.rnnlayer_list))]) 563 | # attn penalisation 564 | self.attn_pen_layers = [int(i) for i in attn_layerlist.split()] 565 | self.merge_type = merge_type 566 | 567 | self.layers = nn.ModuleList() 568 | if attn_type == 0: # the default attention 569 | for i in range(n_layer): 570 | # dropatt = dropatt * 2 if (i == 0 and rnnenc) else dropatt 571 | use_penalty = i in self.attn_pen_layers 572 | self.layers.append( 573 | RelPartialLearnableDecoderLayer( 574 | n_head, d_model, d_head, d_inner, dropout, 575 | tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, 576 | dropatt=dropatt, pre_lnorm=pre_lnorm, penalty=use_penalty) 577 | ) 578 | elif attn_type == 1: # learnable embeddings 579 | for i in range(n_layer): 580 | self.layers.append( 581 | RelLearnableDecoderLayer( 582 | n_head, d_model, d_head, d_inner, dropout, 583 | tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, 584 | dropatt=dropatt, pre_lnorm=pre_lnorm) 585 | ) 586 | elif attn_type in [2, 3]: # absolute embeddings 587 | for i in range(n_layer): 588 | use_penalty = i in self.attn_pen_layers 589 | # dropatt = dropatt * 0 if (i == 0 and rnnenc) else dropatt 590 | self.layers.append( 591 | DecoderLayer( 592 | n_head, d_model, d_head, d_inner, dropout, 593 | dropatt=dropatt, pre_lnorm=pre_lnorm, penalty=use_penalty) 594 | ) 595 | 596 | self.sample_softmax = sample_softmax 597 | # use sampled softmax 598 | if sample_softmax > 0: 599 | self.out_layer = nn.Linear(d_model, n_token) 600 | if tie_weight: 601 | self.out_layer.weight = self.word_emb.weight 602 | self.tie_weight = tie_weight 603 | self.sampler = LogUniformSampler(n_token, sample_softmax) 604 | 605 | # use adaptive softmax (including standard softmax) 606 | else: 607 | self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model, 608 | cutoffs, div_val=div_val) 609 | 610 | if tie_weight: 611 | for i in range(len(self.crit.out_layers)): 612 | self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight 613 | 614 | if tie_projs: 615 | for i, tie_proj in enumerate(tie_projs): 616 | if tie_proj and div_val == 1 and d_model != d_embed: 617 | self.crit.out_projs[i] = self.word_emb.emb_projs[0] 618 | elif tie_proj and div_val != 1: 619 | self.crit.out_projs[i] = self.word_emb.emb_projs[i] 620 | 621 | self.rnnenc = rnnenc 622 | self.rnndim = rnndim 623 | self.same_length = same_length 624 | self.clamp_len = clamp_len 625 | 626 | self._create_params() 627 | 628 | def backward_compatible(self): 629 | self.sample_softmax = -1 630 | 631 | def _create_params(self): 632 | if self.attn_type == 0: # default attention 633 | self.pos_emb = PositionalEmbedding(self.d_model) 634 | self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) 635 | self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) 636 | elif self.attn_type == 1: # learnable 637 | self.r_emb = nn.Parameter(torch.Tensor( 638 | self.n_layer, self.max_klen, self.n_head, self.d_head)) 639 | self.r_w_bias = nn.Parameter(torch.Tensor( 640 | self.n_layer, self.n_head, self.d_head)) 641 | self.r_bias = nn.Parameter(torch.Tensor( 642 | self.n_layer, self.max_klen, self.n_head)) 643 | elif self.attn_type == 2: # absolute standard 644 | self.pos_emb = PositionalEmbedding(self.d_model) 645 | elif self.attn_type == 3: # absolute deeper SA 646 | self.r_emb = nn.Parameter(torch.Tensor( 647 | self.n_layer, self.max_klen, self.n_head, self.d_head)) 648 | # for RNN projection 649 | if self.rnnenc and self.rnndim != self.d_model: 650 | self.rnnproj.bias.data.zero_() 651 | self.rnnproj.weight.data.uniform_(-0.1, 0.1) 652 | 653 | def reset_length(self, tgt_len, ext_len, mem_len): 654 | self.tgt_len = tgt_len 655 | self.mem_len = mem_len 656 | self.ext_len = ext_len 657 | 658 | def init_mems(self): 659 | if self.mem_len > 0: 660 | mems = [] 661 | param = next(self.parameters()) 662 | for i in range(self.n_layer+1): 663 | empty = torch.empty(0, dtype=param.dtype, device=param.device) 664 | mems.append(empty) 665 | 666 | return mems 667 | else: 668 | return None 669 | 670 | def _update_mems(self, hids, mems, qlen, mlen): 671 | # does not deal with None 672 | if mems is None: return None 673 | 674 | # mems is not None 675 | assert len(hids) == len(mems), 'len(hids) != len(mems)' 676 | 677 | # There are `mlen + qlen` steps that can be cached into mems 678 | # For the next step, the last `ext_len` of the `qlen` tokens 679 | # will be used as the extended context. Hence, we only cache 680 | # the tokens from `mlen + qlen - self.ext_len - self.mem_len` 681 | # to `mlen + qlen - self.ext_len`. 682 | with torch.no_grad(): 683 | new_mems = [] 684 | end_idx = mlen + max(0, qlen - 0 - self.ext_len) 685 | beg_idx = max(0, end_idx - self.mem_len) 686 | for i in range(len(hids)): 687 | 688 | cat = torch.cat([mems[i], hids[i]], dim=0) 689 | new_mems.append(cat[beg_idx:end_idx].detach()) 690 | 691 | return new_mems 692 | 693 | def init_hidden(self, bsz): 694 | hidden_list = [] 695 | for i in self.rnnlayer_list: 696 | weight = next(self.parameters()) 697 | hidden_list.append((weight.new_zeros(1, bsz, getattr(self, 'rnndim', self.d_model)), 698 | weight.new_zeros(1, bsz, getattr(self, 'rnndim', self.d_model)))) 699 | return hidden_list 700 | 701 | def init_hidden_singlelayer(self, bsz): 702 | weight = next(self.parameters()) 703 | hidden = (weight.new_zeros(1, bsz, self.rnndim), weight.new_zeros(1, bsz, self.rnndim)) 704 | return hidden 705 | 706 | def forward_rnn(self, i, core_out, rnn_hidden, stepwise=False, future_seqlen=0, tgtlen=0): 707 | # gs534 - rnn in the middle of a transformer layer 708 | index = self.rnnlayer_list.index(i) 709 | if not stepwise: 710 | if self.ext_len > 0: 711 | rnn_out_ext = None 712 | if tgtlen != core_out.size(0) - future_seqlen: 713 | rnn_out_ext, rnn_hidden[index] = self.rnn_list[index]( 714 | core_out[:core_out.size(0)-future_seqlen-tgtlen], rnn_hidden[index]) 715 | rnn_out, _ = self.rnn_list[index]( 716 | core_out[core_out.size(0)-future_seqlen-tgtlen:core_out.size(0)-future_seqlen], rnn_hidden[index]) 717 | if rnn_out_ext is not None: 718 | rnn_out = torch.cat([rnn_out_ext, rnn_out], dim=0) 719 | else: 720 | rnn_out, rnn_hidden[index] = self.rnn_list[index]( 721 | core_out[:core_out.size(0)-future_seqlen], rnn_hidden[index]) 722 | if future_seqlen > 1: 723 | new_hiddens = self.init_hidden_singlelayer(core_out.size(1)) 724 | rnnout_future, _ = self.rnn_list[index](core_out[-future_seqlen+1:], new_hiddens) 725 | rnn_out = torch.cat([rnn_out, core_out[-future_seqlen:-future_seqlen+1], rnnout_future], dim=0) 726 | else: 727 | step_hidden = rnn_hidden[index] 728 | rnn_hidden_step = [] 729 | core_out_list = [] 730 | for k in range(core_out.size(0) - future_seqlen): 731 | rnn_output, step_hidden = self.rnn_list[index](core_out[k:k+1], step_hidden) 732 | core_out_list.append(rnn_output) 733 | rnn_hidden_step.append(step_hidden) 734 | rnn_hidden_step = list(zip(*rnn_hidden_step)) 735 | rnn_hidden[index] = (torch.cat(rnn_hidden_step[0], dim=0), torch.cat(rnn_hidden_step[1], dim=0)) 736 | rnn_out = torch.cat(core_out_list, dim=0) 737 | if future_seqlen > 0: 738 | new_hiddens = self.init_hidden_singlelayer(core_out.size(1)) 739 | rnnout_future, _ = self.rnn_list[index](core_out[-future_seqlen+1:], new_hiddens) 740 | # rnnout_future = core_out[-future_seqlen:] 741 | rnn_out = torch.cat([rnn_out, core_out[-future_seqlen:-future_seqlen+1], rnnout_future], dim=0) 742 | assert (self.rnndim == self.d_model or self.merge_type == 'project') 743 | return rnn_out, rnn_hidden 744 | 745 | def get_future_mask(self, qlen, flen, word_emb, mlen=0): 746 | klen = qlen + mlen 747 | dec_attn_mask = torch.triu( 748 | word_emb.new_ones(qlen-flen, klen-flen), diagonal=1+mlen) 749 | # various paddings 750 | ones_bottom_left = word_emb.new_ones(flen, klen-flen) 751 | ones_single_column = word_emb.new_ones(qlen, 1 if flen > 0 else 0) 752 | ones_top_right = word_emb.new_ones(qlen-flen, flen-1 if flen > 0 else 0) 753 | zeros_bottom_right = word_emb.new_zeros(flen, flen-1 if flen > 0 else 0) 754 | zeros_top_right = word_emb.new_zeros(qlen-flen, flen-1 if flen > 0 else 0) 755 | # construct half masks 756 | attn_mask_right = torch.cat([ones_top_right, zeros_bottom_right], dim=0) 757 | attn_mask_right_future = torch.cat([zeros_top_right, zeros_bottom_right], dim=0) 758 | attn_mask_left = torch.cat([dec_attn_mask, ones_bottom_left], dim=0) 759 | # construct attn masks 760 | dec_attn_mask = torch.cat([attn_mask_left, ones_single_column, attn_mask_right], dim=-1) 761 | dec_attn_mask_future = torch.cat( 762 | [attn_mask_left, ones_single_column, attn_mask_right_future], dim=-1) 763 | return dec_attn_mask[:,:,None], dec_attn_mask_future[:,:,None] 764 | 765 | def _forward(self, dec_inp, mems=None, rnn_hidden=None, stepwise=False, future_seqlen=0, tgt_len=0): 766 | qlen, bsz = dec_inp.size() 767 | attn_pen_list = [] 768 | 769 | word_emb = self.word_emb(dec_inp) 770 | 771 | mlen = mems[0].size(0) if mems is not None else 0 772 | klen = mlen + qlen 773 | if self.same_length: 774 | all_ones = word_emb.new_ones(qlen, klen) 775 | mask_len = klen - self.mem_len 776 | if mask_len > 0: 777 | mask_shift_len = qlen - mask_len 778 | else: 779 | mask_shift_len = qlen 780 | dec_attn_mask = (torch.triu(all_ones, 1+mlen) 781 | + torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1 782 | elif self.future_len == 0: 783 | dec_attn_mask = torch.triu( 784 | word_emb.new_ones(qlen, klen), diagonal=1+mlen)[:,:,None] 785 | dec_attn_mask = dec_attn_mask.byte() 786 | else: 787 | dec_attn_mask_normal, dec_attn_mask_future = self.get_future_mask( 788 | qlen, future_seqlen, word_emb, mlen) 789 | dec_attn_mask_normal = dec_attn_mask_normal.byte() 790 | dec_attn_mask_future = dec_attn_mask_future.byte() 791 | 792 | hids = [] 793 | if self.attn_type == 0: # default 794 | pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device, 795 | dtype=word_emb.dtype) 796 | if self.clamp_len > 0: 797 | pos_seq.clamp_(max=self.clamp_len) 798 | pos_emb = self.pos_emb(pos_seq) 799 | 800 | core_out = self.drop(word_emb) 801 | pos_emb = self.drop(pos_emb) 802 | if self.rnnenc and 0 in self.rnnlayer_list: 803 | rnn_out, rnn_hidden = self.forward_rnn(0, core_out, rnn_hidden, stepwise, future_seqlen) 804 | if self.merge_type == 'project': 805 | core_out = torch.relu(self.rnnproj(torch.cat([rnn_out, core_out], dim=-1))) 806 | # core_out = (self.rnnproj(torch.cat([rnn_out, core_out], dim=-1))) 807 | elif self.merge_type == 'gating': 808 | core_gating = torch.sigmoid(self.rnnproj(torch.cat([rnn_out, core_out], dim=-1))) 809 | core_out = rnn_out * core_gating + core_out * (1 - core_gating) 810 | else: 811 | core_out = rnn_out 812 | hids.append(core_out) 813 | 814 | for i, layer in enumerate(self.layers): 815 | if self.future_len != 0: 816 | dec_attn_mask = dec_attn_mask_future if i in self.layer_list else dec_attn_mask_normal 817 | mems_i = None if mems is None else mems[i] 818 | core_out, attn_pen = layer(core_out, pos_emb, self.r_w_bias, 819 | self.r_r_bias, dec_attn_mask=dec_attn_mask, 820 | mems=mems_i) 821 | # gs534 - rnn in the middle of a transformer layer 822 | if self.rnnenc and i+1 in self.rnnlayer_list: 823 | rnn_out, rnn_hidden = self.forward_rnn(i+1, core_out, rnn_hidden, stepwise, future_seqlen) 824 | if self.merge_type == 'project': 825 | core_out = torch.relu(self.rnnproj(torch.cat([rnn_out, core_out], dim=-1))) 826 | # core_out = (self.rnnproj(torch.cat([rnn_out, core_out], dim=-1))) 827 | elif self.merge_type == 'gating': 828 | core_gating = torch.sigmoid(self.rnnproj(torch.cat([rnn_out, core_out], dim=-1))) 829 | core_out = rnn_out * core_gating + core_out * (1 - core_gating) 830 | else: 831 | core_out = rnn_out 832 | attn_pen_list.append(attn_pen) 833 | hids.append(core_out) 834 | elif self.attn_type == 1: # learnable 835 | core_out = self.drop(word_emb) 836 | hids.append(core_out) 837 | for i, layer in enumerate(self.layers): 838 | if self.clamp_len > 0: 839 | r_emb = self.r_emb[i][-self.clamp_len :] 840 | r_bias = self.r_bias[i][-self.clamp_len :] 841 | else: 842 | r_emb, r_bias = self.r_emb[i], self.r_bias[i] 843 | 844 | mems_i = None if mems is None else mems[i] 845 | core_out = layer(core_out, r_emb, self.r_w_bias[i], 846 | r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i) 847 | hids.append(core_out) 848 | elif self.attn_type == 2: # absolute 849 | pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device, 850 | dtype=word_emb.dtype) 851 | if self.clamp_len > 0: 852 | pos_seq.clamp_(max=self.clamp_len) 853 | pos_emb = self.pos_emb(pos_seq) 854 | 855 | word_emb = self.drop(word_emb) 856 | # rnn_out, rnn_hidden = self.forward_rnn(word_emb, rnn_hidden, stepwise, future_seqlen, tgt_len) 857 | core_out = word_emb 858 | 859 | hids.append(core_out) 860 | for i, layer in enumerate(self.layers): 861 | if self.future_len != 0: 862 | dec_attn_mask = dec_attn_mask_future if i in self.layer_list else dec_attn_mask_normal 863 | mems_i = None if mems is None else mems[i] 864 | if mems_i is not None and mlen > 0: 865 | mems_i += pos_emb[:mlen] 866 | # gs534 - rnn in the middle of a transformer layer 867 | if self.rnnenc and i in self.rnnlayer_list: 868 | core_out = self.drop(self.rnnproj(torch.cat([rnn_out, core_out], dim=-1))) 869 | core_out, attn_pen = layer(core_out, dec_attn_mask=dec_attn_mask, 870 | mems=mems_i) 871 | attn_pen_list.append(attn_pen) 872 | hids.append(core_out) 873 | elif self.attn_type == 3: 874 | core_out = self.drop(word_emb) 875 | 876 | hids.append(core_out) 877 | for i, layer in enumerate(self.layers): 878 | mems_i = None if mems is None else mems[i] 879 | if mems_i is not None and mlen > 0: 880 | cur_emb = self.r_emb[i][:-qlen] 881 | cur_size = cur_emb.size(0) 882 | if cur_size < mlen: 883 | cur_emb_pad = cur_emb[0:1].expand(mlen-cur_size, -1, -1) 884 | cur_emb = torch.cat([cur_emb_pad, cur_emb], 0) 885 | else: 886 | cur_emb = cur_emb[-mlen:] 887 | mems_i += cur_emb.view(mlen, 1, -1) 888 | core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1) 889 | 890 | core_out = layer(core_out, dec_attn_mask=dec_attn_mask, 891 | mems=mems_i) 892 | hids.append(core_out) 893 | 894 | core_out = self.drop(core_out) 895 | new_mems = self._update_mems(hids, mems, mlen, qlen) 896 | attn_pen = sum(attn_pen_list)[0] if attn_pen_list != [] else core_out.new_zeros(1) 897 | 898 | return core_out, new_mems, rnn_hidden, attn_pen 899 | 900 | def forward(self, data, target, *mems, rnn_hidden=None, stepwise=False, future_seqlen=0): 901 | # nn.DataParallel does not allow size(0) tensors to be broadcasted. 902 | # So, have to initialize size(0) mems inside the model forward. 903 | # Moreover, have to return new_mems to allow nn.DataParallel to piece 904 | # them together. 905 | if not mems: mems = self.init_mems() 906 | 907 | tgt_len = target.size(0) 908 | hidden, new_mems, rnn_hidden, attn_pen = self._forward(data, 909 | mems=mems, 910 | rnn_hidden=rnn_hidden, 911 | stepwise=stepwise, 912 | future_seqlen=future_seqlen, 913 | tgt_len=tgt_len) 914 | 915 | if future_seqlen > 0: 916 | pred_hid = hidden[-tgt_len-future_seqlen:-future_seqlen] 917 | else: 918 | pred_hid = hidden[-tgt_len:] 919 | if self.sample_softmax > 0 and self.training: 920 | assert self.tie_weight 921 | logit = sample_logits(self.word_emb, 922 | self.out_layer.bias, target, pred_hid, self.sampler) 923 | loss = -F.log_softmax(logit, -1)[:, :, 0] 924 | else: 925 | loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1)) 926 | loss = loss.view(tgt_len, -1) 927 | 928 | if new_mems is None: 929 | return [loss] + [attn_pen] + [rnn_hidden] 930 | else: 931 | return [loss] + new_mems + [attn_pen] + [rnn_hidden] 932 | 933 | if __name__ == '__main__': 934 | import argparse 935 | 936 | parser = argparse.ArgumentParser(description='unit test') 937 | 938 | parser.add_argument('--n_layer', type=int, default=4, help='') 939 | parser.add_argument('--n_rel_layer', type=int, default=4, help='') 940 | parser.add_argument('--n_head', type=int, default=2, help='') 941 | parser.add_argument('--d_head', type=int, default=2, help='') 942 | parser.add_argument('--d_model', type=int, default=200, help='') 943 | parser.add_argument('--d_embed', type=int, default=200, help='') 944 | parser.add_argument('--d_inner', type=int, default=200, help='') 945 | parser.add_argument('--dropout', type=float, default=0.0, help='') 946 | parser.add_argument('--cuda', action='store_true', help='') 947 | parser.add_argument('--seed', type=int, default=1111, help='') 948 | parser.add_argument('--multi_gpu', action='store_true', help='') 949 | 950 | args = parser.parse_args() 951 | 952 | device = torch.device("cuda" if args.cuda else "cpu") 953 | 954 | B = 4 955 | tgt_len, mem_len, ext_len = 36, 36, 0 956 | data_len = tgt_len * 20 957 | args.n_token = 10000 958 | 959 | import data_utils 960 | 961 | data = torch.LongTensor(data_len*B).random_(0, args.n_token).to(device) 962 | diter = data_utils.LMOrderedIterator(data, B, tgt_len, device=device, ext_len=ext_len) 963 | 964 | cutoffs = [args.n_token // 2] 965 | tie_projs = [False] + [True] * len(cutoffs) 966 | 967 | for div_val in [1, 2]: 968 | for d_embed in [200, 100]: 969 | model = MemTransformerLM(args.n_token, args.n_layer, args.n_head, 970 | args.d_model, args.d_head, args.d_inner, args.dropout, 971 | dropatt=args.dropout, tie_weight=True, 972 | d_embed=d_embed, div_val=div_val, 973 | tie_projs=tie_projs, pre_lnorm=True, 974 | tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, 975 | cutoffs=cutoffs, attn_type=0).to(device) 976 | 977 | print(sum(p.numel() for p in model.parameters())) 978 | 979 | mems = tuple() 980 | for idx, (inp, tgt, seqlen) in enumerate(diter): 981 | print('batch {}'.format(idx)) 982 | out = model(inp, tgt, *mems) 983 | mems = out[1:] 984 | --------------------------------------------------------------------------------