├── utils
├── __pycache__
│ ├── exp_utils.cpython-37.pyc
│ ├── vocabulary.cpython-37.pyc
│ ├── data_parallel.cpython-37.pyc
│ ├── log_uniform_sampler.cpython-37.pyc
│ └── proj_adaptive_softmax.cpython-37.pyc
├── exp_utils.py
├── adaptive_softmax.py
├── data_parallel.py
├── log_uniform_sampler.py
├── proj_adaptive_softmax.py
└── vocabulary.py
├── AMInbest.sh
├── RTnbest.sh
├── SWBDnbest.sh
├── README.md
├── run_AMI.sh
├── run_SWBD.sh
├── data_utils.py
├── train_rnn.py
├── rescore.py
└── mem_transformer_rnn.py
/utils/__pycache__/exp_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BriansIDP/RTLM/HEAD/utils/__pycache__/exp_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/vocabulary.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BriansIDP/RTLM/HEAD/utils/__pycache__/vocabulary.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/data_parallel.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BriansIDP/RTLM/HEAD/utils/__pycache__/data_parallel.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/log_uniform_sampler.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BriansIDP/RTLM/HEAD/utils/__pycache__/log_uniform_sampler.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/proj_adaptive_softmax.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BriansIDP/RTLM/HEAD/utils/__pycache__/proj_adaptive_softmax.cpython-37.pyc
--------------------------------------------------------------------------------
/AMInbest.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=${X_SGE_CUDA_DEVICE}
2 |
3 | exp_no=1
4 | dataset=AMI
5 | headnum=8
6 | layernum=8
7 | layer=l0
8 | tag=_32_rnn_direct_${layer}
9 |
10 | expdir=${dataset}_transformer/AMI_${layernum}_${headnum}${tag}
11 |
12 | model=${expdir}/model.pt
13 | echo ${model}
14 |
15 | python rescore.py \
16 | --data data/AMI \
17 | --nbest rescore/time_sorted_eval.100bestlist \
18 | --model ${model} \
19 | --lm transformer_rnn_xl_${layer} \
20 | --lmscale 10 \
21 | --lookback 32 \
22 | --subbatchsize 100 \
23 | --cuda \
24 | --mem_len 32 \
25 | --logfile LOGs/nbestlog.txt \
26 | --ppl \
27 |
--------------------------------------------------------------------------------
/RTnbest.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=${X_SGE_CUDA_DEVICE}
2 |
3 | exp_no=1
4 | dataset=SWBD
5 | headnum=8
6 | layernum=24
7 | layer=l0
8 | tag=_64_direct_${layer}
9 |
10 | expdir=${dataset}_transformer/${dataset}_${layernum}_${headnum}${tag}
11 | lm=transformer_rnn_xl_${layer}
12 |
13 | model=${expdir}/model.pt
14 |
15 | python forwardSWBD.py \
16 | --data data/SWBD \
17 | --nbest rescore_rt/time_sorted_rt03.nbestlist \
18 | --model ${model} \
19 | --lm ${lm} \
20 | --lmscale 10 \
21 | --lookback 64 \
22 | --subbatchsize 100 \
23 | --cuda \
24 | --mem_len 64 \
25 | --logfile LOGs/nbestlog.txt \
26 | --ppl \
27 |
--------------------------------------------------------------------------------
/SWBDnbest.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=${X_SGE_CUDA_DEVICE}
2 |
3 | exp_no=1
4 | dataset=SWBD
5 | headnum=8
6 | layernum=24
7 | layer=l0
8 | tag=_64_direct_${layer}
9 |
10 | expdir=${dataset}_transformer/${dataset}_${layernum}_${headnum}${tag}
11 | lm=transformer_rnn_xl_${layer}
12 | echo ${lm}
13 | export PYTHONPATH="${expdir}/scripts:$PYTHONPATH"
14 | echo $PYTHONPATH
15 |
16 | model=${expdir}/model.pt
17 | echo ${model}
18 |
19 | python rescore.py \
20 | --data data/SWBD \
21 | --nbest rescore_swbd/time_sorted_eval2000.nbestlist \
22 | --model ${model} \
23 | --lm ${lm} \
24 | --lmscale 10 \
25 | --lookback 64 \
26 | --subbatchsize 100 \
27 | --cuda \
28 | --mem_len 64 \
29 | --logfile LOGs/nbestlog.txt \
30 | --ppl \
31 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RTLM
2 | ## Introduction
3 |
4 | This repository contains the code for the paper "Transformer Language Models with LSTM-based Cross-utterance Information Representation". The code is mainly adapted from the [Transformer XL PyTorch implementation](https://github.com/kimiyoung/transformer-xl.git). Single GPU version is implemented.
5 |
6 | ## Prerequisite
7 | PyTorch 1.0.0
8 |
9 | ## Training
10 | To train on AMI text data
11 | `bash run_AMI.sh train --work_dir PATH_TO_WORK_DIR`
12 |
13 | To train on SWBD text data
14 | `bash run_SWBD.sh train --work_dir PATH_TO_WORK_DIR`
15 |
16 | ## N-best Rescoring
17 | Rescoring AMI nbest list
18 | `bash AMInbest.sh`
19 | Rescoring SWB nbest list
20 | `bash SWBDnbest.sh`
21 | Rescoring RT03 nbest list
22 | `bash RTnbest.sh`
23 |
24 | Note that the path to the trained LM (--model) and to the nbest list (--nbest) should be modified to your own path.
25 |
--------------------------------------------------------------------------------
/run_AMI.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=${X_SGE_CUDA_DEVICE}
2 |
3 | #!/bin/bash
4 | echo 'Run training...'
5 | python train_rnn.py \
6 | --cuda \
7 | --data data/AMI/ \
8 | --dataset AMI \
9 | --work_dir AMI_transformer \
10 | --n_layer 8 \
11 | --d_model 512 \
12 | --div_val 1 \
13 | --n_head 8 \
14 | --d_head 64 \
15 | --d_inner 512 \
16 | --dropout 0.3 \
17 | --dropatt 0.3 \
18 | --optim adam \
19 | --warmup_step 5000 \
20 | --max_step 15000 \
21 | --lr 0.0002 \
22 | --batch_size 24 \
23 | --tgt_len 32 \
24 | --mem_len 32 \
25 | --ext_len 0 \
26 | --future_len 0 \
27 | --eval_tgt_len 32 \
28 | --eval-interval 1600 \
29 | --attn_type 0 \
30 | --scheduler cosine \
31 | --rnnenc \
32 | --rnndim 512 \
33 | --layerlist '0' \
34 | --pen_layerlist '0' \
35 | --merge_type direct \
36 | --p_scale 0.0000 \
37 | ${@:2}
38 |
--------------------------------------------------------------------------------
/run_SWBD.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=${X_SGE_CUDA_DEVICE}
2 |
3 | #!/bin/bash
4 | echo 'Run training...'
5 | python train_rnn.py \
6 | --cuda \
7 | --data data/SWBD/ \
8 | --work_dir SWBD_transformer \
9 | --dataset AMI \
10 | --n_layer 24 \
11 | --d_model 512 \
12 | --div_val 4 \
13 | --n_head 8 \
14 | --d_head 64 \
15 | --d_inner 1024 \
16 | --dropout 0.2 \
17 | --dropatt 0.1 \
18 | --optim adam \
19 | --warmup_step 2000 \
20 | --max_step 200000 \
21 | --lr 0.00025 \
22 | --batch_size 32 \
23 | --tgt_len 64 \
24 | --mem_len 64 \
25 | --ext_len 0 \
26 | --future_len 0 \
27 | --eval_tgt_len 64 \
28 | --eval-interval 1600 \
29 | --attn_type 0 \
30 | --scheduler cosine \
31 | --pre_lnorm \
32 | --rnnenc \
33 | --rnndim 512 \
34 | --layerlist '0' \
35 | --merge_type project \
36 | --log-interval 200 \
37 | ${@:2}
38 |
--------------------------------------------------------------------------------
/utils/exp_utils.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import os, shutil
3 |
4 | import numpy as np
5 |
6 | import torch
7 |
8 |
9 | def logging(s, log_path, print_=True, log_=True):
10 | if print_:
11 | print(s)
12 | if log_:
13 | with open(log_path, 'a+') as f_log:
14 | f_log.write(s + '\n')
15 |
16 | def get_logger(log_path, **kwargs):
17 | return functools.partial(logging, log_path=log_path, **kwargs)
18 |
19 | def create_exp_dir(dir_path, scripts_to_save=None, debug=False):
20 | if debug:
21 | print('Debug Mode : no experiment dir created')
22 | return functools.partial(logging, log_path=None, log_=False)
23 |
24 | if not os.path.exists(dir_path):
25 | os.makedirs(dir_path)
26 |
27 | print('Experiment dir : {}'.format(dir_path))
28 | if scripts_to_save is not None:
29 | script_path = os.path.join(dir_path, 'scripts')
30 | if not os.path.exists(script_path):
31 | os.makedirs(script_path)
32 | for script in scripts_to_save:
33 | dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script))
34 | shutil.copyfile(script, dst_file)
35 |
36 | return get_logger(log_path=os.path.join(dir_path, 'log.txt'))
37 |
38 | def save_checkpoint(model, optimizer, path, epoch):
39 | torch.save(model, os.path.join(path, 'model_{}.pt'.format(epoch)))
40 | torch.save(optimizer.state_dict(), os.path.join(path, 'optimizer_{}.pt'.format(epoch)))
41 |
--------------------------------------------------------------------------------
/utils/adaptive_softmax.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 |
3 | import numpy as np
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | class AdaptiveLogSoftmax(nn.Module):
10 | def __init__(self, in_features, n_classes, cutoffs, keep_order=False):
11 | super(AdaptiveLogSoftmax, self).__init__()
12 |
13 | cutoffs = list(cutoffs)
14 |
15 | if (cutoffs != sorted(cutoffs)) \
16 | or (min(cutoffs) <= 0) \
17 | or (max(cutoffs) >= (n_classes - 1)) \
18 | or (len(set(cutoffs)) != len(cutoffs)) \
19 | or any([int(c) != c for c in cutoffs]):
20 |
21 | raise ValueError("cutoffs should be a sequence of unique, positive "
22 | "integers sorted in an increasing order, where "
23 | "each value is between 1 and n_classes-1")
24 |
25 | self.in_features = in_features
26 | self.n_classes = n_classes
27 | self.cutoffs = cutoffs + [n_classes]
28 |
29 | self.shortlist_size = self.cutoffs[0]
30 | self.n_clusters = len(self.cutoffs) - 1
31 | self.head_size = self.shortlist_size + self.n_clusters
32 |
33 | self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.in_features))
34 | self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))
35 |
36 | self.keep_order = keep_order
37 |
38 |
39 | def forward(self, hidden, target, weight, bias, keep_order=False):
40 | if hidden.size(0) != target.size(0):
41 | raise RuntimeError('Input and target should have the same size '
42 | 'in the batch dimension.')
43 |
44 | head_weight = torch.cat(
45 | [weight[:self.shortlist_size], self.cluster_weight], dim=0)
46 | head_bias = torch.cat(
47 | [bias[:self.shortlist_size], self.cluster_bias], dim=0)
48 |
49 | head_logit = F.linear(hidden, head_weight, bias=head_bias)
50 | head_logprob = F.log_softmax(head_logit, dim=1)
51 |
52 | nll = torch.zeros_like(target,
53 | dtype=hidden.dtype, device=hidden.device)
54 |
55 | offset = 0
56 | cutoff_values = [0] + self.cutoffs
57 | for i in range(len(cutoff_values) - 1):
58 | l_idx, h_idx = cutoff_values[i], cutoff_values[i + 1]
59 |
60 | mask_i = (target >= l_idx) & (target < h_idx)
61 | indices_i = mask_i.nonzero().squeeze()
62 |
63 | if indices_i.numel() == 0:
64 | continue
65 |
66 | target_i = target.index_select(0, indices_i) - l_idx
67 | head_logprob_i = head_logprob.index_select(0, indices_i)
68 |
69 | if i == 0:
70 | logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1)
71 | else:
72 | weight_i = weight[l_idx:h_idx]
73 | bias_i = bias[l_idx:h_idx]
74 |
75 | hidden_i = hidden.index_select(0, indices_i)
76 |
77 | tail_logit_i = F.linear(hidden_i, weight_i, bias=bias_i)
78 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
79 |
80 | logprob_i = head_logprob_i[:, -i] \
81 | + tail_logprob_i.gather(1, target_i[:,None]).squeeze(1)
82 |
83 | if (hasattr(self, 'keep_order') and self.keep_order) or keep_order:
84 | nll.index_copy_(0, indices_i, -logprob_i)
85 | else:
86 | nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i)
87 |
88 | offset += logprob_i.size(0)
89 |
90 | return nll
91 |
--------------------------------------------------------------------------------
/utils/data_parallel.py:
--------------------------------------------------------------------------------
1 |
2 | from torch.nn.parallel import DataParallel
3 | import torch
4 | from torch.nn.parallel._functions import Scatter
5 | from torch.nn.parallel.parallel_apply import parallel_apply
6 |
7 | def scatter(inputs, target_gpus, chunk_sizes, dim=0):
8 | r"""
9 | Slices tensors into approximately equal chunks and
10 | distributes them across given GPUs. Duplicates
11 | references to objects that are not tensors.
12 | """
13 | def scatter_map(obj):
14 | if isinstance(obj, torch.Tensor):
15 | try:
16 | return Scatter.apply(target_gpus, chunk_sizes, dim, obj)
17 | except:
18 | print('obj', obj.size())
19 | print('dim', dim)
20 | print('chunk_sizes', chunk_sizes)
21 | quit()
22 | if isinstance(obj, tuple) and len(obj) > 0:
23 | return list(zip(*map(scatter_map, obj)))
24 | if isinstance(obj, list) and len(obj) > 0:
25 | return list(map(list, zip(*map(scatter_map, obj))))
26 | if isinstance(obj, dict) and len(obj) > 0:
27 | return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
28 | return [obj for targets in target_gpus]
29 |
30 | # After scatter_map is called, a scatter_map cell will exist. This cell
31 | # has a reference to the actual function scatter_map, which has references
32 | # to a closure that has a reference to the scatter_map cell (because the
33 | # fn is recursive). To avoid this reference cycle, we set the function to
34 | # None, clearing the cell
35 | try:
36 | return scatter_map(inputs)
37 | finally:
38 | scatter_map = None
39 |
40 | def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0):
41 | r"""Scatter with support for kwargs dictionary"""
42 | inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else []
43 | kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else []
44 | if len(inputs) < len(kwargs):
45 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
46 | elif len(kwargs) < len(inputs):
47 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
48 | inputs = tuple(inputs)
49 | kwargs = tuple(kwargs)
50 | return inputs, kwargs
51 |
52 | class BalancedDataParallel(DataParallel):
53 | def __init__(self, gpu0_bsz, *args, **kwargs):
54 | self.gpu0_bsz = gpu0_bsz
55 | super().__init__(*args, **kwargs)
56 |
57 | def forward(self, *inputs, **kwargs):
58 | if not self.device_ids:
59 | return self.module(*inputs, **kwargs)
60 | if self.gpu0_bsz == 0:
61 | device_ids = self.device_ids[1:]
62 | else:
63 | device_ids = self.device_ids
64 | inputs, kwargs = self.scatter(inputs, kwargs, device_ids)
65 | if len(self.device_ids) == 1:
66 | return self.module(*inputs[0], **kwargs[0])
67 | replicas = self.replicate(self.module, self.device_ids)
68 | if self.gpu0_bsz == 0:
69 | replicas = replicas[1:]
70 | outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs)
71 | return self.gather(outputs, self.output_device)
72 |
73 | def parallel_apply(self, replicas, device_ids, inputs, kwargs):
74 | return parallel_apply(replicas, inputs, kwargs, device_ids)
75 |
76 | def scatter(self, inputs, kwargs, device_ids):
77 | bsz = inputs[0].size(self.dim)
78 | num_dev = len(self.device_ids)
79 | gpu0_bsz = self.gpu0_bsz
80 | bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)
81 | if gpu0_bsz < bsz_unit:
82 | chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1)
83 | delta = bsz - sum(chunk_sizes)
84 | for i in range(delta):
85 | chunk_sizes[i + 1] += 1
86 | if gpu0_bsz == 0:
87 | chunk_sizes = chunk_sizes[1:]
88 | else:
89 | return super().scatter(inputs, kwargs, device_ids)
90 | return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)
91 |
92 |
--------------------------------------------------------------------------------
/utils/log_uniform_sampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import numpy as np
4 |
5 | class LogUniformSampler(object):
6 | def __init__(self, range_max, n_sample):
7 | """
8 | Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
9 | `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
10 |
11 | expected count can be approximated by 1 - (1 - p)^n
12 | and we use a numerically stable version -expm1(num_tries * log1p(-p))
13 |
14 | Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run
15 | """
16 | with torch.no_grad():
17 | self.range_max = range_max
18 | log_indices = torch.arange(1., range_max+2., 1.).log_()
19 | self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
20 | # print('P', self.dist.numpy().tolist()[-30:])
21 |
22 | self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float()
23 |
24 | self.n_sample = n_sample
25 |
26 | def sample(self, labels):
27 | """
28 | labels: [b1, b2]
29 | Return
30 | true_log_probs: [b1, b2]
31 | samp_log_probs: [n_sample]
32 | neg_samples: [n_sample]
33 | """
34 |
35 | # neg_samples = torch.empty(0).long()
36 | n_sample = self.n_sample
37 | n_tries = 2 * n_sample
38 |
39 | with torch.no_grad():
40 | neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique()
41 | device = labels.device
42 | neg_samples = neg_samples.to(device)
43 | true_log_probs = self.log_q[labels].to(device)
44 | samp_log_probs = self.log_q[neg_samples].to(device)
45 | return true_log_probs, samp_log_probs, neg_samples
46 |
47 | def sample_logits(embedding, bias, labels, inputs, sampler):
48 | """
49 | embedding: an nn.Embedding layer
50 | bias: [n_vocab]
51 | labels: [b1, b2]
52 | inputs: [b1, b2, n_emb]
53 | sampler: you may use a LogUniformSampler
54 | Return
55 | logits: [b1, b2, 1 + n_sample]
56 | """
57 | true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels)
58 | n_sample = neg_samples.size(0)
59 | b1, b2 = labels.size(0), labels.size(1)
60 | all_ids = torch.cat([labels.view(-1), neg_samples])
61 | all_w = embedding(all_ids)
62 | true_w = all_w[: -n_sample].view(b1, b2, -1)
63 | sample_w = all_w[- n_sample:].view(n_sample, -1)
64 |
65 | all_b = bias[all_ids]
66 | true_b = all_b[: -n_sample].view(b1, b2)
67 | sample_b = all_b[- n_sample:]
68 |
69 | hit = (labels[:, :, None] == neg_samples).detach()
70 |
71 | true_logits = torch.einsum('ijk,ijk->ij',
72 | [true_w, inputs]) + true_b - true_log_probs
73 | sample_logits = torch.einsum('lk,ijk->ijl',
74 | [sample_w, inputs]) + sample_b - samp_log_probs
75 | sample_logits.masked_fill_(hit, -1e30)
76 | logits = torch.cat([true_logits[:, :, None], sample_logits], -1)
77 |
78 | return logits
79 |
80 |
81 | # class LogUniformSampler(object):
82 | # def __init__(self, range_max, unique=False):
83 | # """
84 | # Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
85 | # `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
86 | # """
87 | # self.range_max = range_max
88 | # log_indices = torch.arange(1., range_max+2., 1.).log_()
89 | # self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
90 |
91 | # self.unique = unique
92 |
93 | # if self.unique:
94 | # self.exclude_mask = torch.ByteTensor(range_max).fill_(0)
95 |
96 | # def sample(self, n_sample, labels):
97 | # pos_sample, new_labels = labels.unique(return_inverse=True)
98 | # n_pos_sample = pos_sample.size(0)
99 | # n_neg_sample = n_sample - n_pos_sample
100 |
101 | # if self.unique:
102 | # self.exclude_mask.index_fill_(0, pos_sample, 1)
103 | # sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0)
104 | # self.exclude_mask.index_fill_(0, pos_sample, 0)
105 | # else:
106 | # sample_dist = self.dist
107 |
108 | # neg_sample = torch.multinomial(sample_dist, n_neg_sample)
109 |
110 | # sample = torch.cat([pos_sample, neg_sample])
111 | # sample_prob = self.dist[sample]
112 |
113 | # return new_labels, sample, sample_prob
114 |
115 |
116 | if __name__ == '__main__':
117 | S, B = 3, 4
118 | n_vocab = 10000
119 | n_sample = 5
120 | H = 32
121 |
122 | labels = torch.LongTensor(S, B).random_(0, n_vocab)
123 |
124 | # sampler = LogUniformSampler(n_vocab, unique=False)
125 | # new_labels, sample, sample_prob = sampler.sample(n_sample, labels)
126 |
127 | sampler = LogUniformSampler(n_vocab, unique=True)
128 | # true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels)
129 |
130 | # print('true_probs', true_probs.numpy().tolist())
131 | # print('samp_probs', samp_probs.numpy().tolist())
132 | # print('neg_samples', neg_samples.numpy().tolist())
133 |
134 | # print('sum', torch.sum(sampler.dist).item())
135 |
136 | # assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item()
137 |
138 | embedding = nn.Embedding(n_vocab, H)
139 | bias = torch.zeros(n_vocab)
140 | inputs = torch.Tensor(S, B, H).normal_()
141 |
142 | logits, out_labels = sample_logits(embedding, bias, labels, inputs, sampler, n_sample)
143 | print('logits', logits.detach().numpy().tolist())
144 | print('logits shape', logits.size())
145 | print('out_labels', out_labels.detach().numpy().tolist())
146 | print('out_labels shape', out_labels.size())
147 |
148 |
--------------------------------------------------------------------------------
/utils/proj_adaptive_softmax.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 |
3 | import numpy as np
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | CUDA_MAJOR = int(torch.version.cuda.split('.')[0])
10 | CUDA_MINOR = int(torch.version.cuda.split('.')[1])
11 |
12 | class ProjectedAdaptiveLogSoftmax(nn.Module):
13 | def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
14 | keep_order=False):
15 | super(ProjectedAdaptiveLogSoftmax, self).__init__()
16 |
17 | self.n_token = n_token
18 | self.d_embed = d_embed
19 | self.d_proj = d_proj
20 |
21 | self.cutoffs = cutoffs + [n_token]
22 | self.cutoff_ends = [0] + self.cutoffs
23 | self.div_val = div_val
24 |
25 | self.shortlist_size = self.cutoffs[0]
26 | self.n_clusters = len(self.cutoffs) - 1
27 | self.head_size = self.shortlist_size + self.n_clusters
28 |
29 | if self.n_clusters > 0:
30 | self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed))
31 | self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))
32 |
33 | self.out_layers = nn.ModuleList()
34 | self.out_projs = nn.ParameterList()
35 |
36 | if div_val == 1:
37 | for i in range(len(self.cutoffs)):
38 | if d_proj != d_embed:
39 | self.out_projs.append(
40 | nn.Parameter(torch.Tensor(d_proj, d_embed))
41 | )
42 | else:
43 | self.out_projs.append(None)
44 |
45 | self.out_layers.append(nn.Linear(d_embed, n_token))
46 | else:
47 | for i in range(len(self.cutoffs)):
48 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
49 | d_emb_i = d_embed // (div_val ** i)
50 |
51 | self.out_projs.append(
52 | nn.Parameter(torch.Tensor(d_proj, d_emb_i))
53 | )
54 |
55 | self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx))
56 |
57 | self.keep_order = keep_order
58 |
59 | def _compute_logit(self, hidden, weight, bias, proj):
60 | if proj is None:
61 | logit = F.linear(hidden, weight, bias=bias)
62 | else:
63 | # if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1:
64 | proj_hid = F.linear(hidden, proj.t().contiguous())
65 | logit = F.linear(proj_hid, weight, bias=bias)
66 | # else:
67 | # logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t()))
68 | # if bias is not None:
69 | # logit = logit + bias
70 |
71 | return logit
72 |
73 | def forward(self, hidden, target, keep_order=False):
74 | '''
75 | hidden :: [len*bsz x d_proj]
76 | target :: [len*bsz]
77 | '''
78 |
79 | if hidden.size(0) != target.size(0):
80 | raise RuntimeError('Input and target should have the same size '
81 | 'in the batch dimension.')
82 |
83 | if self.n_clusters == 0:
84 | logit = self._compute_logit(hidden, self.out_layers[0].weight,
85 | self.out_layers[0].bias, self.out_projs[0])
86 | nll = -F.log_softmax(logit, dim=-1) \
87 | .gather(1, target.unsqueeze(1)).squeeze(1)
88 | else:
89 | # construct weights and biases
90 | weights, biases = [], []
91 | for i in range(len(self.cutoffs)):
92 | if self.div_val == 1:
93 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
94 | weight_i = self.out_layers[0].weight[l_idx:r_idx]
95 | bias_i = self.out_layers[0].bias[l_idx:r_idx]
96 | else:
97 | weight_i = self.out_layers[i].weight
98 | bias_i = self.out_layers[i].bias
99 |
100 | if i == 0:
101 | weight_i = torch.cat(
102 | [weight_i, self.cluster_weight], dim=0)
103 | bias_i = torch.cat(
104 | [bias_i, self.cluster_bias], dim=0)
105 |
106 | weights.append(weight_i)
107 | biases.append(bias_i)
108 |
109 | head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0]
110 |
111 | head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
112 | head_logprob = F.log_softmax(head_logit, dim=1)
113 |
114 | nll = torch.zeros_like(target,
115 | dtype=hidden.dtype, device=hidden.device)
116 |
117 | offset = 0
118 | cutoff_values = [0] + self.cutoffs
119 | for i in range(len(cutoff_values) - 1):
120 | l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1]
121 |
122 | mask_i = (target >= l_idx) & (target < r_idx)
123 | indices_i = mask_i.nonzero().squeeze()
124 |
125 | if indices_i.numel() == 0:
126 | continue
127 |
128 | target_i = target.index_select(0, indices_i) - l_idx
129 | head_logprob_i = head_logprob.index_select(0, indices_i)
130 |
131 | if i == 0:
132 | logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1)
133 | else:
134 | weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]
135 |
136 | hidden_i = hidden.index_select(0, indices_i)
137 |
138 | tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i)
139 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
140 |
141 | logprob_i = head_logprob_i[:, -i] \
142 | + tail_logprob_i.gather(1, target_i[:,None]).squeeze(1)
143 |
144 | if (hasattr(self, 'keep_order') and self.keep_order) or keep_order:
145 | nll.index_copy_(0, indices_i, -logprob_i)
146 | else:
147 | nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i)
148 |
149 | offset += logprob_i.size(0)
150 |
151 | return nll
152 |
--------------------------------------------------------------------------------
/utils/vocabulary.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import Counter, OrderedDict
3 |
4 | import torch
5 |
6 | class Vocab(object):
7 | def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True,
8 | delimiter=None, vocab_file=None):
9 | self.counter = Counter()
10 | self.special = special
11 | self.min_freq = min_freq
12 | self.max_size = max_size
13 | self.lower_case = lower_case
14 | self.delimiter = delimiter
15 | self.vocab_file = vocab_file
16 |
17 | def tokenize(self, line, add_eos=False, add_double_eos=False):
18 | line = line.strip()
19 | # convert to lower case
20 | if self.lower_case:
21 | line = line.lower()
22 |
23 | # empty delimiter '' will evaluate False
24 | if self.delimiter == '':
25 | symbols = line
26 | else:
27 | symbols = line.split(self.delimiter)
28 |
29 | if add_double_eos: # lm1b
30 | return [''] + symbols + ['']
31 | elif add_eos:
32 | return symbols + ['']
33 | else:
34 | return symbols
35 |
36 | def count_file(self, path, verbose=False, add_eos=False):
37 | if verbose: print('counting file {} ...'.format(path))
38 | assert os.path.exists(path)
39 |
40 | sents = []
41 | with open(path, 'r', encoding='utf-8') as f:
42 | for idx, line in enumerate(f):
43 | if verbose and idx > 0 and idx % 500000 == 0:
44 | print(' line {}'.format(idx))
45 | symbols = self.tokenize(line, add_eos=add_eos)
46 | self.counter.update(symbols)
47 | sents.append(symbols)
48 |
49 | return sents
50 |
51 | def count_sents(self, sents, verbose=False):
52 | """
53 | sents : a list of sentences, each a list of tokenized symbols
54 | """
55 | if verbose: print('counting {} sents ...'.format(len(sents)))
56 | for idx, symbols in enumerate(sents):
57 | if verbose and idx > 0 and idx % 500000 == 0:
58 | print(' line {}'.format(idx))
59 | self.counter.update(symbols)
60 |
61 | def _build_from_file(self, vocab_file):
62 | self.idx2sym = []
63 | self.sym2idx = OrderedDict()
64 |
65 | with open(vocab_file, 'r', encoding='utf-8') as f:
66 | for line in f:
67 | symb = line.strip().split()[1]
68 | self.add_symbol(symb)
69 | self.unk_idx = self.sym2idx['']
70 |
71 | def build_vocab(self):
72 | if self.vocab_file:
73 | print('building vocab from {}'.format(self.vocab_file))
74 | self._build_from_file(self.vocab_file)
75 | print('final vocab size {}'.format(len(self)))
76 | else:
77 | print('building vocab with min_freq={}, max_size={}'.format(
78 | self.min_freq, self.max_size))
79 | self.idx2sym = []
80 | self.sym2idx = OrderedDict()
81 |
82 | for sym in self.special:
83 | self.add_special(sym)
84 |
85 | for sym, cnt in self.counter.most_common(self.max_size):
86 | if cnt < self.min_freq: break
87 | self.add_symbol(sym)
88 |
89 | print('final vocab size {} from {} unique tokens'.format(
90 | len(self), len(self.counter)))
91 |
92 | def build_distribution(self):
93 | self.freq_map = [0] * len(self.sym2idx)
94 | for word, idx in self.sym2idx.items():
95 | self.freq_map[idx] += self.counter[word]
96 | self.freq_map = torch.tensor(self.freq_map)
97 |
98 | def encode_file(self, path, ordered=False, verbose=False, add_eos=True,
99 | add_double_eos=False):
100 | if verbose: print('encoding file {} ...'.format(path))
101 | assert os.path.exists(path)
102 | encoded = []
103 | with open(path, 'r', encoding='utf-8') as f:
104 | for idx, line in enumerate(f):
105 | if verbose and idx > 0 and idx % 500000 == 0:
106 | print(' line {}'.format(idx))
107 | symbols = self.tokenize(line, add_eos=add_eos,
108 | add_double_eos=add_double_eos)
109 | encoded.append(self.convert_to_tensor(symbols))
110 |
111 | if ordered:
112 | encoded = torch.cat(encoded)
113 |
114 | return encoded
115 |
116 | def encode_sents(self, sents, ordered=False, verbose=False):
117 | if verbose: print('encoding {} sents ...'.format(len(sents)))
118 | encoded = []
119 | for idx, symbols in enumerate(sents):
120 | if verbose and idx > 0 and idx % 500000 == 0:
121 | print(' line {}'.format(idx))
122 | encoded.append(self.convert_to_tensor(symbols))
123 |
124 | if ordered:
125 | encoded = torch.cat(encoded)
126 |
127 | return encoded
128 |
129 | def add_special(self, sym):
130 | if sym not in self.sym2idx:
131 | self.idx2sym.append(sym)
132 | self.sym2idx[sym] = len(self.idx2sym) - 1
133 | setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym])
134 |
135 | def add_symbol(self, sym):
136 | if sym not in self.sym2idx:
137 | self.idx2sym.append(sym)
138 | self.sym2idx[sym] = len(self.idx2sym) - 1
139 |
140 | def get_sym(self, idx):
141 | assert 0 <= idx < len(self), 'Index {} out of range'.format(idx)
142 | return self.idx2sym[idx]
143 |
144 | def get_idx(self, sym):
145 | if sym in self.sym2idx:
146 | return self.sym2idx[sym]
147 | else:
148 | # print('encounter unk {}'.format(sym))
149 | assert '' not in sym
150 | assert hasattr(self, 'unk_idx')
151 | return self.sym2idx.get(sym, self.unk_idx)
152 |
153 | def get_symbols(self, indices):
154 | return [self.get_sym(idx) for idx in indices]
155 |
156 | def get_indices(self, symbols):
157 | return [self.get_idx(sym) for sym in symbols]
158 |
159 | def convert_to_tensor(self, symbols):
160 | return torch.LongTensor(self.get_indices(symbols))
161 |
162 | def convert_to_sent(self, indices, exclude=None):
163 | if exclude is None:
164 | return ' '.join([self.get_sym(idx) for idx in indices])
165 | else:
166 | return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude])
167 |
168 | def __len__(self):
169 | return len(self.idx2sym)
170 |
--------------------------------------------------------------------------------
/data_utils.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import glob
3 |
4 | from collections import Counter, OrderedDict
5 | import numpy as np
6 | import torch
7 |
8 | from utils.vocabulary import Vocab
9 |
10 | class LMOrderedIterator(object):
11 | def __init__(self, data, bsz, bptt, device='cpu', ext_len=None, future_len=None):
12 | """
13 | data -- LongTensor -- the LongTensor is strictly ordered
14 | """
15 | self.bsz = bsz
16 | self.bptt = bptt
17 | self.ext_len = ext_len if ext_len is not None else 0
18 | self.future_len = future_len if future_len is not None else 0
19 |
20 | self.device = device
21 |
22 | # Work out how cleanly we can divide the dataset into bsz parts.
23 | self.n_step = data.size(0) // bsz
24 |
25 | # Trim off any extra elements that wouldn't cleanly fit (remainders).
26 | data = data.narrow(0, 0, self.n_step * bsz)
27 |
28 | # Evenly divide the data across the bsz batches.
29 | self.data = data.view(bsz, -1).t().contiguous().to(device)
30 |
31 | # Number of mini-batches
32 | self.n_batch = (self.n_step + self.bptt - 1) // self.bptt
33 |
34 | def get_batch(self, i, bptt=None):
35 | if bptt is None: bptt = self.bptt
36 | seq_len = min(bptt, self.data.size(0) - 1 - i)
37 | future_seqlen = min(self.future_len, self.data.size(0) - 1 - i - seq_len)
38 |
39 | end_idx = i + seq_len + future_seqlen
40 | beg_idx = max(0, i - self.ext_len)
41 |
42 | data = self.data[beg_idx:end_idx]
43 | target = self.data[i+1:i+1+seq_len]
44 |
45 | return data, target, seq_len, future_seqlen
46 |
47 | def get_fixlen_iter(self, start=0):
48 | for i in range(start, self.data.size(0) - 1, self.bptt):
49 | yield self.get_batch(i)
50 |
51 | def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3):
52 | max_len = self.bptt + max_deviation * std
53 | i = start
54 | while True:
55 | bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2.
56 | bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std))))
57 | data, target, seq_len = self.get_batch(i, bptt)
58 | i += seq_len
59 | yield data, target, seq_len
60 | if i >= self.data.size(0) - 2:
61 | break
62 |
63 | def __iter__(self):
64 | return self.get_fixlen_iter()
65 |
66 |
67 | class LMShuffledIterator(object):
68 | def __init__(self, data, bsz, bptt, device='cpu', ext_len=None, shuffle=False):
69 | """
70 | data -- list[LongTensor] -- there is no order among the LongTensors
71 | """
72 | self.data = data
73 |
74 | self.bsz = bsz
75 | self.bptt = bptt
76 | self.ext_len = ext_len if ext_len is not None else 0
77 |
78 | self.device = device
79 | self.shuffle = shuffle
80 |
81 | def get_sent_stream(self):
82 | # index iterator
83 | epoch_indices = np.random.permutation(len(self.data)) if self.shuffle \
84 | else np.array(range(len(self.data)))
85 |
86 | # sentence iterator
87 | for idx in epoch_indices:
88 | yield self.data[idx]
89 |
90 | def stream_iterator(self, sent_stream):
91 | # streams for each data in the batch
92 | streams = [None] * self.bsz
93 |
94 | data = torch.LongTensor(self.bptt, self.bsz)
95 | target = torch.LongTensor(self.bptt, self.bsz)
96 |
97 | n_retain = 0
98 |
99 | while True:
100 | # data : [n_retain+bptt x bsz]
101 | # target : [bptt x bsz]
102 | data[n_retain:].fill_(-1)
103 | target.fill_(-1)
104 |
105 | valid_batch = True
106 |
107 | for i in range(self.bsz):
108 | n_filled = 0
109 | try:
110 | while n_filled < self.bptt:
111 | if streams[i] is None or len(streams[i]) <= 1:
112 | streams[i] = next(sent_stream)
113 | # number of new tokens to fill in
114 | n_new = min(len(streams[i]) - 1, self.bptt - n_filled)
115 | # first n_retain tokens are retained from last batch
116 | data[n_retain+n_filled:n_retain+n_filled+n_new, i] = \
117 | streams[i][:n_new]
118 | target[n_filled:n_filled+n_new, i] = \
119 | streams[i][1:n_new+1]
120 | streams[i] = streams[i][n_new:]
121 | n_filled += n_new
122 | except StopIteration:
123 | valid_batch = False
124 | break
125 |
126 | if not valid_batch:
127 | return
128 |
129 | data = data.to(self.device)
130 | target = target.to(self.device)
131 |
132 | yield data, target, self.bptt
133 |
134 | n_retain = min(data.size(0), self.ext_len)
135 | if n_retain > 0:
136 | data[:n_retain] = data[-n_retain:]
137 | data.resize_(n_retain + self.bptt, data.size(1))
138 |
139 | def __iter__(self):
140 | # sent_stream is an iterator
141 | sent_stream = self.get_sent_stream()
142 |
143 | for batch in self.stream_iterator(sent_stream):
144 | yield batch
145 |
146 |
147 | class LMMultiFileIterator(LMShuffledIterator):
148 | def __init__(self, paths, vocab, bsz, bptt, device='cpu', ext_len=None,
149 | shuffle=False):
150 |
151 | self.paths = paths
152 | self.vocab = vocab
153 |
154 | self.bsz = bsz
155 | self.bptt = bptt
156 | self.ext_len = ext_len if ext_len is not None else 0
157 |
158 | self.device = device
159 | self.shuffle = shuffle
160 |
161 | def get_sent_stream(self, path):
162 | sents = self.vocab.encode_file(path, add_double_eos=True)
163 | if self.shuffle:
164 | np.random.shuffle(sents)
165 | sent_stream = iter(sents)
166 |
167 | return sent_stream
168 |
169 | def __iter__(self):
170 | if self.shuffle:
171 | np.random.shuffle(self.paths)
172 |
173 | for path in self.paths:
174 | # sent_stream is an iterator
175 | sent_stream = self.get_sent_stream(path)
176 | for batch in self.stream_iterator(sent_stream):
177 | yield batch
178 |
179 |
180 | class Corpus(object):
181 | def __init__(self, path, dataset, *args, **kwargs):
182 | self.dataset = dataset
183 | self.vocab = Vocab(*args, **kwargs)
184 |
185 | if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8', 'AMI']:
186 | self.vocab.count_file(os.path.join(path, 'train.txt'))
187 | self.vocab.count_file(os.path.join(path, 'valid.txt'))
188 | self.vocab.count_file(os.path.join(path, 'test.txt'))
189 | elif self.dataset == 'wt103':
190 | self.vocab.count_file(os.path.join(path, 'train.txt'))
191 | elif self.dataset == 'lm1b':
192 | train_path_pattern = os.path.join(
193 | path, '1-billion-word-language-modeling-benchmark-r13output',
194 | 'training-monolingual.tokenized.shuffled', 'news.en-*')
195 | train_paths = glob.glob(train_path_pattern)
196 | # the vocab will load from file when build_vocab() is called
197 |
198 | self.vocab.build_vocab()
199 | # build unigram distribution of the vocab
200 | self.vocab.build_distribution()
201 |
202 | if self.dataset in ['ptb', 'wt2', 'wt103', 'AMI']:
203 | self.train = self.vocab.encode_file(
204 | os.path.join(path, 'train.txt'), ordered=True)
205 | self.valid = self.vocab.encode_file(
206 | os.path.join(path, 'valid.txt'), ordered=True)
207 | self.test = self.vocab.encode_file(
208 | os.path.join(path, 'test.txt'), ordered=True)
209 | elif self.dataset in ['enwik8', 'text8']:
210 | self.train = self.vocab.encode_file(
211 | os.path.join(path, 'train.txt'), ordered=True, add_eos=False)
212 | self.valid = self.vocab.encode_file(
213 | os.path.join(path, 'valid.txt'), ordered=True, add_eos=False)
214 | self.test = self.vocab.encode_file(
215 | os.path.join(path, 'test.txt'), ordered=True, add_eos=False)
216 | elif self.dataset == 'lm1b':
217 | self.train = train_paths
218 | self.valid = self.vocab.encode_file(
219 | os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True)
220 | self.test = self.vocab.encode_file(
221 | os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True)
222 |
223 | def get_iterator(self, split, *args, **kwargs):
224 | if split == 'train':
225 | if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8', 'AMI']:
226 | data_iter = LMOrderedIterator(self.train, *args, **kwargs)
227 | elif self.dataset == 'lm1b':
228 | kwargs['shuffle'] = True
229 | data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs)
230 | elif split in ['valid', 'test']:
231 | data = self.valid if split == 'valid' else self.test
232 | if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8', 'AMI']:
233 | data_iter = LMOrderedIterator(data, *args, **kwargs)
234 | elif self.dataset == 'lm1b':
235 | data_iter = LMShuffledIterator(data, *args, **kwargs)
236 |
237 | return data_iter
238 |
239 |
240 | def get_lm_corpus(datadir, dataset):
241 | fn = os.path.join(datadir, 'cache.pt')
242 | if os.path.exists(fn):
243 | print('Loading cached dataset...')
244 | corpus = torch.load(fn)
245 | else:
246 | print('Producing dataset {}...'.format(dataset))
247 | kwargs = {}
248 | if dataset in ['wt103', 'wt2', 'AMI']:
249 | kwargs['special'] = ['']
250 | kwargs['lower_case'] = False
251 | kwargs['vocab_file'] = os.path.join(datadir, 'dictionary.txt')
252 | elif dataset == 'ptb':
253 | kwargs['special'] = ['']
254 | kwargs['lower_case'] = True
255 | elif dataset == 'lm1b':
256 | kwargs['special'] = []
257 | kwargs['lower_case'] = False
258 | kwargs['vocab_file'] = os.path.join(datadir, '1b_word_vocab.txt')
259 | elif dataset in ['enwik8', 'text8']:
260 | pass
261 |
262 | corpus = Corpus(datadir, dataset, **kwargs)
263 | # torch.save(corpus, fn)
264 |
265 | return corpus
266 |
267 | if __name__ == '__main__':
268 | import argparse
269 | parser = argparse.ArgumentParser(description='unit test')
270 | parser.add_argument('--datadir', type=str, default='../data/text8',
271 | help='location of the data corpus')
272 | parser.add_argument('--dataset', type=str, default='text8',
273 | choices=['ptb', 'wt2', 'wt103', 'lm1b', 'enwik8', 'text8'],
274 | help='dataset name')
275 | args = parser.parse_args()
276 |
277 | corpus = get_lm_corpus(args.datadir, args.dataset)
278 | print('Vocab size : {}'.format(len(corpus.vocab.idx2sym)))
279 |
--------------------------------------------------------------------------------
/train_rnn.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 | import argparse
3 | import time
4 | import math
5 | import os, sys
6 | import itertools
7 |
8 | import numpy as np
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.optim as optim
13 |
14 | from data_utils import get_lm_corpus
15 | from mem_transformer_rnn import MemTransformerLM
16 | from utils.exp_utils import create_exp_dir
17 | from utils.data_parallel import BalancedDataParallel
18 |
19 | parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
20 | parser.add_argument('--data', type=str, default='../data/wikitext-103',
21 | help='location of the data corpus')
22 | parser.add_argument('--dataset', type=str, default='wt103',
23 | choices=['wt103', 'lm1b', 'enwik8', 'text8', 'AMI'],
24 | help='dataset name')
25 | parser.add_argument('--n_layer', type=int, default=12,
26 | help='number of total layers')
27 | parser.add_argument('--n_head', type=int, default=10,
28 | help='number of heads')
29 | parser.add_argument('--d_head', type=int, default=50,
30 | help='head dimension')
31 | parser.add_argument('--d_embed', type=int, default=-1,
32 | help='embedding dimension')
33 | parser.add_argument('--d_model', type=int, default=500,
34 | help='model dimension')
35 | parser.add_argument('--d_inner', type=int, default=1000,
36 | help='inner dimension in FF')
37 | parser.add_argument('--dropout', type=float, default=0.0,
38 | help='global dropout rate')
39 | parser.add_argument('--dropatt', type=float, default=0.0,
40 | help='attention probability dropout rate')
41 | parser.add_argument('--init', default='normal', type=str,
42 | help='parameter initializer to use.')
43 | parser.add_argument('--emb_init', default='normal', type=str,
44 | help='parameter initializer to use.')
45 | parser.add_argument('--init_range', type=float, default=0.1,
46 | help='parameters initialized by U(-init_range, init_range)')
47 | parser.add_argument('--emb_init_range', type=float, default=0.01,
48 | help='parameters initialized by U(-init_range, init_range)')
49 | parser.add_argument('--init_std', type=float, default=0.02,
50 | help='parameters initialized by N(0, init_std)')
51 | parser.add_argument('--proj_init_std', type=float, default=0.01,
52 | help='parameters initialized by N(0, init_std)')
53 | parser.add_argument('--optim', default='adam', type=str,
54 | choices=['adam', 'sgd', 'adagrad'],
55 | help='optimizer to use.')
56 | parser.add_argument('--lr', type=float, default=0.00025,
57 | help='initial learning rate (0.00025|5 for adam|sgd)')
58 | parser.add_argument('--mom', type=float, default=0.0,
59 | help='momentum for sgd')
60 | parser.add_argument('--scheduler', default='cosine', type=str,
61 | choices=['cosine', 'inv_sqrt', 'dev_perf', 'constant'],
62 | help='lr scheduler to use.')
63 | parser.add_argument('--warmup_step', type=int, default=0,
64 | help='upper epoch limit')
65 | parser.add_argument('--decay_rate', type=float, default=0.5,
66 | help='decay factor when ReduceLROnPlateau is used')
67 | parser.add_argument('--lr_min', type=float, default=0.0,
68 | help='minimum learning rate during annealing')
69 | parser.add_argument('--clip', type=float, default=0.25,
70 | help='gradient clipping')
71 | parser.add_argument('--clip_nonemb', action='store_true',
72 | help='only clip the gradient of non-embedding params')
73 | parser.add_argument('--max_step', type=int, default=100000,
74 | help='upper epoch limit')
75 | parser.add_argument('--batch_size', type=int, default=60,
76 | help='batch size')
77 | parser.add_argument('--batch_chunk', type=int, default=1,
78 | help='split batch into chunks to save memory')
79 | parser.add_argument('--tgt_len', type=int, default=70,
80 | help='number of tokens to predict')
81 | parser.add_argument('--eval_tgt_len', type=int, default=50,
82 | help='number of tokens to predict for evaluation')
83 | parser.add_argument('--ext_len', type=int, default=0,
84 | help='length of the extended context')
85 | parser.add_argument('--future_len', type=int, default=0,
86 | help='length of the future context')
87 | parser.add_argument('--mem_len', type=int, default=0,
88 | help='length of the retained previous heads')
89 | parser.add_argument('--not_tied', action='store_true',
90 | help='do not tie the word embedding and softmax weights')
91 | parser.add_argument('--seed', type=int, default=1111,
92 | help='random seed')
93 | parser.add_argument('--cuda', action='store_true',
94 | help='use CUDA')
95 | parser.add_argument('--adaptive', action='store_true',
96 | help='use adaptive softmax')
97 | parser.add_argument('--div_val', type=int, default=1,
98 | help='divident value for adapative input and softmax')
99 | parser.add_argument('--pre_lnorm', action='store_true',
100 | help='apply LayerNorm to the input instead of the output')
101 | parser.add_argument('--varlen', action='store_true',
102 | help='use variable length')
103 | parser.add_argument('--multi_gpu', action='store_true',
104 | help='use multiple GPU')
105 | parser.add_argument('--log-interval', type=int, default=200,
106 | help='report interval')
107 | parser.add_argument('--eval-interval', type=int, default=4000,
108 | help='evaluation interval')
109 | parser.add_argument('--work_dir', default='LM-TFM', type=str,
110 | help='experiment directory.')
111 | parser.add_argument('--restart', action='store_true',
112 | help='restart training from the saved checkpoint')
113 | parser.add_argument('--restart_dir', type=str, default='',
114 | help='restart dir')
115 | parser.add_argument('--debug', action='store_true',
116 | help='run in debug mode (do not create exp dir)')
117 | parser.add_argument('--same_length', action='store_true',
118 | help='use the same attn length for all tokens')
119 | parser.add_argument('--attn_type', type=int, default=0,
120 | help='attention type. 0 for ours, 1 for Shaw et al,'
121 | '2 for Vaswani et al, 3 for Al Rfou et al.')
122 | parser.add_argument('--clamp_len', type=int, default=-1,
123 | help='use the same pos embeddings after clamp_len')
124 | parser.add_argument('--eta_min', type=float, default=0.0,
125 | help='min learning rate for cosine scheduler')
126 | parser.add_argument('--gpu0_bsz', type=int, default=-1,
127 | help='batch size on gpu 0')
128 | parser.add_argument('--max_eval_steps', type=int, default=-1,
129 | help='max eval steps')
130 | parser.add_argument('--sample_softmax', type=int, default=-1,
131 | help='number of samples in sampled softmax')
132 | parser.add_argument('--patience', type=int, default=0,
133 | help='patience')
134 | parser.add_argument('--finetune_v2', action='store_true',
135 | help='finetune v2')
136 | parser.add_argument('--finetune_v3', action='store_true',
137 | help='finetune v3')
138 | parser.add_argument('--fp16', action='store_true',
139 | help='Run in pseudo-fp16 mode (fp16 storage fp32 math).')
140 | parser.add_argument('--static-loss-scale', type=float, default=1,
141 | help='Static loss scale, positive power of 2 values can '
142 | 'improve fp16 convergence.')
143 | parser.add_argument('--dynamic-loss-scale', action='store_true',
144 | help='Use dynamic loss scaling. If supplied, this argument'
145 | ' supersedes --static-loss-scale.')
146 | parser.add_argument('--rnnenc', action='store_true',
147 | help='use rnn encoder')
148 | parser.add_argument('--rnnmodel', type=str, default="",
149 | help='load pretrained rnn model')
150 | parser.add_argument('--rnndim', type=int, default=500,
151 | help='dimension of rnn hidden state')
152 | parser.add_argument('--layerlist', type=str, default='0',
153 | help='layers to insert rnn')
154 | parser.add_argument('--evalmode', action='store_true',
155 | help='Do evaluate forward only')
156 | parser.add_argument('--pen_layerlist', type=str, default='0',
157 | help='layers to apply attention penalty')
158 | parser.add_argument('--p_scale', type=float, default=0.0,
159 | help='penalty scaling factor')
160 | parser.add_argument('--merge_type', type=str, default="direct",
161 | help='Type of merging rnn hidden states')
162 |
163 | args = parser.parse_args()
164 | args.tied = not args.not_tied
165 |
166 | if args.d_embed < 0:
167 | args.d_embed = args.d_model
168 |
169 | assert args.ext_len >= 0, 'extended context length must be non-negative'
170 | assert args.batch_size % args.batch_chunk == 0
171 |
172 | if not args.evalmode:
173 | args.work_dir = os.path.join(args.work_dir, time.strftime('%Y%m%d-%H%M%S'))
174 | logging = create_exp_dir(args.work_dir,
175 | scripts_to_save=['train_rnn.py', 'mem_transformer_rnn.py'], debug=args.debug)
176 |
177 | # Set the random seed manually for reproducibility.
178 | np.random.seed(args.seed)
179 | torch.manual_seed(args.seed)
180 | if torch.cuda.is_available():
181 | if not args.cuda:
182 | print('WARNING: You have a CUDA device, so you should probably run with --cuda')
183 | else:
184 | torch.cuda.manual_seed_all(args.seed)
185 |
186 | # Validate `--fp16` option
187 | if args.fp16:
188 | if not args.cuda:
189 | print('WARNING: --fp16 requires --cuda, ignoring --fp16 option')
190 | args.fp16 = False
191 | else:
192 | try:
193 | from apex.fp16_utils import FP16_Optimizer
194 | except:
195 | print('WARNING: apex not installed, ignoring --fp16 option')
196 | args.fp16 = False
197 |
198 | device = torch.device('cuda' if args.cuda else 'cpu')
199 |
200 | ###############################################################################
201 | # Load data
202 | ###############################################################################
203 | corpus = get_lm_corpus(args.data, args.dataset)
204 | ntokens = len(corpus.vocab)
205 | args.n_token = ntokens
206 |
207 | eval_batch_size = 10
208 | tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len,
209 | device=device, ext_len=args.ext_len, future_len=args.future_len)
210 | va_iter = corpus.get_iterator('valid', eval_batch_size, args.eval_tgt_len,
211 | device=device, ext_len=args.ext_len, future_len=args.future_len)
212 | te_iter = corpus.get_iterator('test', eval_batch_size, args.eval_tgt_len,
213 | device=device, ext_len=args.ext_len, future_len=args.future_len)
214 |
215 | # adaptive softmax / embedding
216 | cutoffs, tie_projs = [], [False]
217 | if args.adaptive:
218 | assert args.dataset in ['wt103', 'lm1b']
219 | if args.dataset == 'wt103':
220 | cutoffs = [20000, 40000, 200000]
221 | tie_projs += [True] * len(cutoffs)
222 | elif args.dataset == 'lm1b':
223 | cutoffs = [60000, 100000, 640000]
224 | tie_projs += [False] * len(cutoffs)
225 |
226 | ###############################################################################
227 | # Build the model
228 | ###############################################################################
229 | def init_weight(weight):
230 | if args.init == 'uniform':
231 | nn.init.uniform_(weight, -args.init_range, args.init_range)
232 | elif args.init == 'normal':
233 | nn.init.normal_(weight, 0.0, args.init_std)
234 |
235 | def init_bias(bias):
236 | nn.init.constant_(bias, 0.0)
237 |
238 | def weights_init(m):
239 | classname = m.__class__.__name__
240 | if classname.find('Linear') != -1:
241 | if hasattr(m, 'weight') and m.weight is not None:
242 | init_weight(m.weight)
243 | if hasattr(m, 'bias') and m.bias is not None:
244 | init_bias(m.bias)
245 | elif classname.find('AdaptiveEmbedding') != -1:
246 | if hasattr(m, 'emb_projs'):
247 | for i in range(len(m.emb_projs)):
248 | if m.emb_projs[i] is not None:
249 | nn.init.normal_(m.emb_projs[i], 0.0, args.proj_init_std)
250 | elif classname.find('Embedding') != -1:
251 | if hasattr(m, 'weight'):
252 | init_weight(m.weight)
253 | elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
254 | if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
255 | init_weight(m.cluster_weight)
256 | if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
257 | init_bias(m.cluster_bias)
258 | if hasattr(m, 'out_projs'):
259 | for i in range(len(m.out_projs)):
260 | if m.out_projs[i] is not None:
261 | nn.init.normal_(m.out_projs[i], 0.0, args.proj_init_std)
262 | elif classname.find('LayerNorm') != -1:
263 | if hasattr(m, 'weight'):
264 | nn.init.normal_(m.weight, 1.0, args.init_std)
265 | if hasattr(m, 'bias') and m.bias is not None:
266 | init_bias(m.bias)
267 | elif classname.find('TransformerLM') != -1:
268 | if hasattr(m, 'r_emb'):
269 | init_weight(m.r_emb)
270 | if hasattr(m, 'r_w_bias'):
271 | init_weight(m.r_w_bias)
272 | if hasattr(m, 'r_r_bias'):
273 | init_weight(m.r_r_bias)
274 | if hasattr(m, 'r_bias'):
275 | init_bias(m.r_bias)
276 |
277 | def repackage_hidden(h):
278 | """Wraps hidden states in new Tensors, to detach them from their history."""
279 | if isinstance(h, torch.Tensor):
280 | return h.detach()
281 | else:
282 | return tuple(repackage_hidden(v) for v in h)
283 |
284 | def update_dropout(m):
285 | classname = m.__class__.__name__
286 | if classname.find('Dropout') != -1:
287 | if hasattr(m, 'p'):
288 | m.p = args.dropout
289 |
290 | def update_dropatt(m):
291 | if hasattr(m, 'dropatt'):
292 | m.dropatt.p = args.dropatt
293 |
294 | if args.restart:
295 | with open(os.path.join(args.restart_dir, 'model.pt'), 'rb') as f:
296 | model = torch.load(f)
297 | if not args.fp16:
298 | model = model.float()
299 | model.apply(update_dropout)
300 | model.apply(update_dropatt)
301 | else:
302 | model = MemTransformerLM(ntokens, args.n_layer, args.n_head, args.d_model,
303 | args.d_head, args.d_inner, args.dropout, args.dropatt,
304 | tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val,
305 | tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len,
306 | ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs,
307 | same_length=args.same_length, attn_type=args.attn_type,
308 | clamp_len=args.clamp_len, sample_softmax=args.sample_softmax,
309 | rnnenc=args.rnnenc, rnndim=args.rnndim,
310 | layer_list=args.layerlist, future_len=args.future_len,
311 | attn_layerlist=args.pen_layerlist, merge_type=args.merge_type)
312 | model.apply(weights_init)
313 | model.word_emb.apply(weights_init) # ensure embedding init is not overridden by out_layer in case of weight sharing
314 | args.n_all_param = sum([p.nelement() for p in model.parameters()])
315 | args.n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])
316 |
317 | # use rnn for connection
318 | if args.rnnenc and args.rnnmodel != "":
319 | rnnmodel_dict = torch.load(args.rnnmodel).rnn.state_dict()
320 | for i, layer in enumerate(model.rnn_list):
321 | model_dict = layer.state_dict()
322 | update_dict = {k: v for k, v in rnnmodel_dict.items() if k in model_dict}
323 | model_dict.update(update_dict)
324 | model.rnn_list[i].load_state_dict(model_dict)
325 |
326 | if args.fp16:
327 | model = model.half()
328 |
329 | if args.multi_gpu:
330 | model = model.to(device)
331 | if args.gpu0_bsz >= 0:
332 | para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk,
333 | model, dim=1).to(device)
334 | else:
335 | para_model = nn.DataParallel(model, dim=1).to(device)
336 | else:
337 | para_model = model.to(device)
338 |
339 | #### optimizer
340 | if args.optim.lower() == 'sgd':
341 | if args.sample_softmax > 0:
342 | dense_params, sparse_params = [], []
343 | for param in model.parameters():
344 | if param.size() == model.word_emb.weight.size():
345 | sparse_params.append(param)
346 | else:
347 | dense_params.append(param)
348 | optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2)
349 | optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom)
350 | else:
351 | optimizer = optim.SGD(model.parameters(), lr=args.lr,
352 | momentum=args.mom)
353 | elif args.optim.lower() == 'adam':
354 | if args.sample_softmax > 0:
355 | dense_params, sparse_params = [], []
356 | for param in model.parameters():
357 | if param.size() == model.word_emb.weight.size():
358 | sparse_params.append(param)
359 | else:
360 | dense_params.append(param)
361 | optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr)
362 | optimizer = optim.Adam(dense_params, lr=args.lr)
363 | else:
364 | optimizer = optim.Adam(model.parameters(), lr=args.lr)
365 | elif args.optim.lower() == 'adagrad':
366 | optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
367 |
368 | #### scheduler
369 | if args.scheduler == 'cosine':
370 | # here we do not set eta_min to lr_min to be backward compatible
371 | # because in previous versions eta_min is default to 0
372 | # rather than the default value of lr_min 1e-6
373 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
374 | args.max_step, eta_min=args.eta_min) # should use eta_min arg
375 | if args.sample_softmax > 0:
376 | scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(optimizer_sparse,
377 | args.max_step, eta_min=args.eta_min) # should use eta_min arg
378 | elif args.scheduler == 'inv_sqrt':
379 | # originally used for Transformer (in Attention is all you need)
380 | def lr_lambda(step):
381 | # return a multiplier instead of a learning rate
382 | if step == 0 and args.warmup_step == 0:
383 | return 1.
384 | else:
385 | return 1. / (step ** 0.5) if step > args.warmup_step \
386 | else step / (args.warmup_step ** 1.5)
387 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
388 | elif args.scheduler == 'dev_perf':
389 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
390 | factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min)
391 | if args.sample_softmax > 0:
392 | scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(optimizer_sparse,
393 | factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min)
394 | elif args.scheduler == 'constant':
395 | pass
396 |
397 | if args.cuda and args.fp16:
398 | # If args.dynamic_loss_scale is False, static_loss_scale will be used.
399 | # If args.dynamic_loss_scale is True, it will take precedence over static_loss_scale.
400 | optimizer = FP16_Optimizer(optimizer,
401 | static_loss_scale = args.static_loss_scale,
402 | dynamic_loss_scale = args.dynamic_loss_scale,
403 | dynamic_loss_args = {'init_scale': 2 ** 16})
404 |
405 | if args.restart:
406 | if os.path.exists(os.path.join(args.restart_dir, 'optimizer.pt')):
407 | with open(os.path.join(args.restart_dir, 'optimizer.pt'), 'rb') as f:
408 | opt_state_dict = torch.load(f)
409 | optimizer.load_state_dict(opt_state_dict)
410 | else:
411 | print('Optimizer was not saved. Start from scratch.')
412 |
413 | logging('=' * 100)
414 | for k, v in args.__dict__.items():
415 | logging(' - {} : {}'.format(k, v))
416 | logging('=' * 100)
417 | logging('#params = {}'.format(args.n_all_param))
418 | logging('#non emb params = {}'.format(args.n_nonemb_param))
419 |
420 | ###############################################################################
421 | # Training code
422 | ###############################################################################
423 |
424 | def evaluate(eval_iter):
425 | # Turn on evaluation mode which disables dropout.
426 | model.eval()
427 |
428 | # If the model does not use memory at all, make the ext_len longer.
429 | # Otherwise, make the mem_len longer and keep the ext_len the same.
430 | if args.mem_len == 0:
431 | model.reset_length(args.eval_tgt_len,
432 | args.ext_len+args.tgt_len-args.eval_tgt_len, args.mem_len)
433 | else:
434 | model.reset_length(args.eval_tgt_len,
435 | args.ext_len, args.mem_len+args.tgt_len-args.eval_tgt_len)
436 |
437 | # Evaluation
438 | total_len, total_loss = 0, 0.
439 | # add RNN connection
440 | rnn_hidden = None
441 | if args.rnnenc:
442 | rnn_hidden = para_model.init_hidden(eval_batch_size)
443 |
444 | with torch.no_grad():
445 | mems = tuple()
446 | for i, (data, target, seq_len, future_seqlen) in enumerate(eval_iter):
447 | if args.max_eval_steps > 0 and i >= args.max_eval_steps:
448 | break
449 | ret = model(data, target, *mems, rnn_hidden=rnn_hidden, future_seqlen=future_seqlen)
450 | loss, mems, penalty, rnn_hidden = ret[0], ret[1:-2], ret[-2], ret[-1]
451 | loss = loss.mean()
452 | total_loss += seq_len * loss.float().item()
453 | total_len += seq_len
454 |
455 | # Switch back to the training mode
456 | model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
457 | model.train()
458 |
459 | return total_loss / total_len
460 |
461 |
462 | def train():
463 | # Turn on training mode which enables dropout.
464 | global train_step, train_loss, best_val_loss, eval_start_time, log_start_time, train_p
465 | model.train()
466 | if args.batch_chunk > 1:
467 | mems = [tuple() for _ in range(args.batch_chunk)]
468 | else:
469 | mems = tuple()
470 | train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter
471 | # add RNN connection
472 | rnn_hidden = None
473 | if args.rnnenc:
474 | rnn_hidden = model.init_hidden(args.batch_size)
475 | prev_data = None
476 | for batch, (data, target, seq_len, future_seqlen) in enumerate(train_iter):
477 | model.zero_grad()
478 | # use RNN state carry on
479 | if rnn_hidden is not None:
480 | rnn_hidden = [repackage_hidden(hid) for hid in rnn_hidden]
481 | if args.batch_chunk > 1:
482 | data_chunks = torch.chunk(data, args.batch_chunk, 1)
483 | target_chunks = torch.chunk(target, args.batch_chunk, 1)
484 | for i in range(args.batch_chunk):
485 | data_i = data_chunks[i].contiguous()
486 | target_i = target_chunks[i].contiguous()
487 | ret = para_model(data_i, target_i, *mems[i], hidden)
488 | loss, mems[i], penalty, rnn_hidden = ret[0], ret[1:-2], ret[-2], ret[-1]
489 | loss = loss.float().mean().type_as(loss) / args.batch_chunk
490 | if args.fp16:
491 | optimizer.backward(loss)
492 | else:
493 | loss.backward()
494 | train_loss += loss.float().item()
495 | else:
496 | ret = para_model(data, target, *mems, rnn_hidden=rnn_hidden, future_seqlen=future_seqlen)
497 | prev_data = data
498 | loss, mems, penalty, rnn_hidden = ret[0], ret[1:-2], ret[-2], ret[-1]
499 | loss = loss.float().mean().type_as(loss)
500 | # Add penalty
501 | p_loss = loss + penalty * args.p_scale
502 | if args.fp16:
503 | optimizer.backward(p_loss)
504 | else:
505 | p_loss.backward()
506 | train_loss += loss.float().item()
507 | train_p += penalty.float().item()
508 |
509 | if args.fp16:
510 | optimizer.clip_master_grads(args.clip)
511 | else:
512 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
513 |
514 | optimizer.step()
515 | if args.sample_softmax > 0:
516 | optimizer_sparse.step()
517 |
518 | # step-wise learning rate annealing
519 | train_step += 1
520 | if args.scheduler in ['cosine', 'constant', 'dev_perf']:
521 | # linear warmup stage
522 | if train_step < args.warmup_step:
523 | curr_lr = args.lr * train_step / args.warmup_step
524 | optimizer.param_groups[0]['lr'] = curr_lr
525 | if args.sample_softmax > 0:
526 | optimizer_sparse.param_groups[0]['lr'] = curr_lr * 2
527 | else:
528 | if args.scheduler == 'cosine':
529 | scheduler.step(train_step)
530 | if args.sample_softmax > 0:
531 | scheduler_sparse.step(train_step)
532 | elif args.scheduler == 'inv_sqrt':
533 | scheduler.step(train_step)
534 |
535 | if train_step % args.log_interval == 0:
536 | cur_loss = train_loss / args.log_interval
537 | cur_p = train_p / args.log_interval
538 | elapsed = time.time() - log_start_time
539 | log_str = '| epoch {:3d} step {:>8d} | {:>6d} batches | lr {:.3g} ' \
540 | '| ms/batch {:5.2f} | loss {:5.2f}'.format(
541 | epoch, train_step, batch+1, optimizer.param_groups[0]['lr'],
542 | elapsed * 1000 / args.log_interval, cur_loss)
543 | if args.dataset in ['enwik8', 'text8']:
544 | log_str += ' | bpc {:9.5f}'.format(cur_loss / math.log(2))
545 | else:
546 | log_str += ' | ppl {:9.3f}'.format(math.exp(cur_loss))
547 | logging(log_str)
548 | train_loss = 0
549 | train_p = 0
550 | log_start_time = time.time()
551 |
552 | if train_step % args.eval_interval == 0:
553 | val_loss = evaluate(va_iter)
554 | logging('-' * 100)
555 | log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \
556 | '| valid loss {:5.2f}'.format(
557 | train_step // args.eval_interval, train_step,
558 | (time.time() - eval_start_time), val_loss)
559 | if args.dataset in ['enwik8', 'text8']:
560 | log_str += ' | bpc {:9.5f}'.format(val_loss / math.log(2))
561 | else:
562 | log_str += ' | valid ppl {:9.3f}'.format(math.exp(val_loss))
563 | logging(log_str)
564 | logging('-' * 100)
565 | # Save the model if the validation loss is the best we've seen so far.
566 | if not best_val_loss or val_loss < best_val_loss:
567 | if not args.debug:
568 | with open(os.path.join(args.work_dir, 'model.pt'), 'wb') as f:
569 | torch.save(model, f)
570 | with open(os.path.join(args.work_dir, 'optimizer.pt'), 'wb') as f:
571 | torch.save(optimizer.state_dict(), f)
572 | best_val_loss = val_loss
573 |
574 | # dev-performance based learning rate annealing
575 | if args.scheduler == 'dev_perf':
576 | scheduler.step(val_loss)
577 | if args.sample_softmax > 0:
578 | scheduler_sparse.step(val_loss)
579 |
580 | eval_start_time = time.time()
581 |
582 | if train_step == args.max_step:
583 | break
584 |
585 | # Loop over epochs.
586 | train_step = 0
587 | train_loss = 0
588 | train_p = 0
589 | best_val_loss = None
590 |
591 | log_start_time = time.time()
592 | eval_start_time = time.time()
593 |
594 | # At any point you can hit Ctrl + C to break out of training early.
595 | if not args.evalmode:
596 | try:
597 | for epoch in itertools.count(start=1):
598 | train()
599 | if train_step == args.max_step:
600 | logging('-' * 100)
601 | logging('End of training')
602 | break
603 | except KeyboardInterrupt:
604 | logging('-' * 100)
605 | logging('Exiting from training early')
606 |
607 | # Load the best saved model.
608 | with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f:
609 | model = torch.load(f)
610 | para_model = model.to(device)
611 |
612 | # Run on test data.
613 | test_loss = evaluate(te_iter)
614 | logging('=' * 100)
615 | if args.dataset in ['enwik8', 'text8']:
616 | logging('| End of training | test loss {:5.2f} | test bpc {:9.5f}'.format(
617 | test_loss, test_loss / math.log(2)))
618 | else:
619 | logging('| End of training | test loss {:5.2f} | test ppl {:9.3f}'.format(
620 | test_loss, math.exp(test_loss)))
621 | logging('=' * 100)
622 |
--------------------------------------------------------------------------------
/rescore.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 | import argparse
3 | import sys, os
4 | import torch
5 | import math
6 | import time
7 | from operator import itemgetter
8 | import numpy as np
9 | import torch.nn.functional as F
10 |
11 | # import data
12 |
13 | parser = argparse.ArgumentParser(description='PyTorch Level-2 RNN/LSTM Language Model')
14 | parser.add_argument('--data', type=str, default='./data/AMI',
15 | help='location of the data corpus')
16 | parser.add_argument('--model', type=str, default='model.pt',
17 | help='location of the 1st level model')
18 | parser.add_argument('--cuda', action='store_true',
19 | help='use CUDA')
20 | parser.add_argument('--lookback', type=int, default=0,
21 | help='Number of backward utterance embeddings to be incorporated')
22 | parser.add_argument('--uttlookforward', type=int, default=1,
23 | help='Number of forward utterance embeddings to be incorporated')
24 | parser.add_argument('--excludeself', type=int, default=1,
25 | help='current utterance embeddings to be incorporated')
26 | parser.add_argument('--saveprefix', type=str, default='tensors/AMI',
27 | help='Specify which data utterance embeddings saved')
28 | parser.add_argument('--nbest', type=str, default='dev.nbest.info.txt',
29 | help='Specify which nbest file to be used')
30 | parser.add_argument('--lmscale', type=float, default=6,
31 | help='how much importance to attach to rnn score')
32 | parser.add_argument('--lm', type=str, default='original',
33 | help='Specify which language model to be used: rnn, ngram or original')
34 | parser.add_argument('--ngramlist', type=str, default='',
35 | help='Specify which ngram stream file to be used')
36 | parser.add_argument('--saveemb', action='store_true',
37 | help='save utterance embeddings')
38 | parser.add_argument('--save1best', action='store_true',
39 | help='save 1best list')
40 | parser.add_argument('--context', type=str, default='0',
41 | help='Specify which utterance embeddings to be used')
42 | parser.add_argument('--use_context', action='store_true',
43 | help='use future context')
44 | parser.add_argument('--contextfile', type=str, default='rescore/time_sorted_dev.nbestlist.context',
45 | help='1best file for context')
46 | parser.add_argument('--logfile', type=str, default='LOGs/log.txt',
47 | help='the logfile for this script')
48 | parser.add_argument('--interp', action='store_true',
49 | help='Linear interpolation of LMs')
50 | parser.add_argument('--factor', type=float, default=0.8,
51 | help='ngram interpolation weight factor')
52 | parser.add_argument('--gscale', type=float, default=12.0,
53 | help='ngram grammar scaling factor')
54 | parser.add_argument('--maxlen', type=int, default=0,
55 | help='how many future words to look at')
56 | parser.add_argument('--use_true', action='store_true',
57 | help='Use true dev file for study')
58 | parser.add_argument('--true_file', type=str, default='data/dev.txt',
59 | help='Specify which true context file to be used')
60 | parser.add_argument('--futurescale', type=float, default=1.0,
61 | help='how much importance to attach to future word scores')
62 | parser.add_argument('--map', type=str, default='rescore/eval.map',
63 | help='mapping file for utterance names')
64 | parser.add_argument('--subbatchsize', type=int, default=20,
65 | help='Sub batch size for batched forwarding')
66 | parser.add_argument('--mem_len', type=int, default=0,
67 | help='Sub batch size for batched forwarding')
68 | parser.add_argument('--ppl', action='store_true',
69 | help='Calculate and report ppl')
70 | parser.add_argument('--pplword', action='store_true',
71 | help='Calculate and report average ppl for each word')
72 | parser.add_argument('--extra-model', type=str, default='',
73 | help='Extra LM to be interpolated')
74 | parser.add_argument('--extra-modeltype', type=str, default='RNN',
75 | help='The type of the extra LM to be interpolated')
76 |
77 | args = parser.parse_args()
78 |
79 | def logging(s, print_=True, log_=True):
80 | if print_:
81 | print(s)
82 | if log_:
83 | with open(args.logfile, 'a+') as f_log:
84 | f_log.write(s + '\n')
85 |
86 | # Read in dictionary
87 | logging("Reading dictionary...")
88 | dictionary = {}
89 | with open(os.path.join(args.data, 'dictionary.txt')) as vocabin:
90 | lines = vocabin.readlines()
91 | for line in lines:
92 | ind, word = line.strip().split(' ')
93 | if word not in dictionary:
94 | dictionary[word] = ind
95 | else:
96 | logging("Error! Repeated words in the dictionary!")
97 |
98 | ntokens = len(dictionary)
99 | eosidx = int(dictionary[''])
100 |
101 | if args.pplword:
102 | wordppl = {}
103 | for word, ind in dictionary.items():
104 | wordppl[word] = [0.0, 0]
105 |
106 | device = torch.device("cuda" if args.cuda else "cpu")
107 | cpu = torch.device("cpu")
108 |
109 | # Read in trained 1st level model
110 | logging("Reading model...")
111 | with open(args.model, 'rb') as f:
112 | model = torch.load(f)
113 | # after load the rnn params are not a continuous chunk of memory
114 | # this makes them a continuous chunk, and will speed up forward pass
115 | if args.cuda:
116 | model.cuda()
117 |
118 | logging("Reading extra model...")
119 | extramodel = None
120 | if args.extra_model != '':
121 | with open(args.extra_model, 'rb') as f:
122 | extramodel = torch.load(f)
123 | # after load the rnn params are not a continuous chunk of memory
124 | # this makes them a continuous chunk, and will speed up forward pass
125 | if args.cuda:
126 | extramodel.cuda()
127 | criterion = torch.nn.CrossEntropyLoss(reduction='none')
128 |
129 | def repackage_hidden(h):
130 | """Wraps hidden states in new Tensors, to detach them from their history."""
131 | if isinstance(h, torch.Tensor):
132 | return h.detach()
133 | else:
134 | return tuple(repackage_hidden(v) for v in h)
135 |
136 | def forward_extra(extra_model, inputs, targets, hidden):
137 | prob_list = []
138 | hidden_list = []
139 | for i, input_data in enumerate(inputs):
140 | target = targets[i]
141 | output, new_hidden = extra_model(input_data.view(-1, 1).to(device), hidden)
142 | probs = criterion(output.squeeze(1), target.to(device))
143 | probs = torch.exp(-probs)
144 | prob_list.append(probs)
145 | hidden_list.append(new_hidden)
146 | return prob_list, hidden_list
147 |
148 | # Batched forward lookback RNN
149 | def forward_each_utt_batched_lookback_rnn(model,
150 | lines,
151 | utt_name,
152 | prev_utts,
153 | mems,
154 | ppl=False,
155 | hidden=None,
156 | extra_model=None,
157 | extra_hidden=None):
158 | # Process each line
159 | inputs = []
160 | targets = []
161 | ac_scores = []
162 | lm_scores = []
163 | maxlen = 0
164 | # new_mems = mems
165 | utterances = []
166 | utterances_ind = []
167 | extra_inputs = []
168 | extra_targets = []
169 | target_index_list = []
170 |
171 | for line in lines:
172 | linevec = line.strip().split()
173 | ac_score = float(linevec[0])
174 | utterance = linevec[4:-1]
175 | currentline = []
176 | for i, word in enumerate(utterance):
177 | if word in dictionary:
178 | currentline.append(int(dictionary[word]))
179 | else:
180 | currentline.append(int(dictionary['']))
181 | utterances.append(utterance)
182 | utterances_ind.append(currentline)
183 | ac_scores.append(ac_score)
184 | if len(currentline) > maxlen:
185 | maxlen = len(currentline)
186 | mask = []
187 | ac_score_tensor = torch.tensor(ac_scores).to(device)
188 | # Pad inputs and targets, prev_append in [0, len(prev_utts)]
189 | prev_append = max(min(args.lookback - maxlen, len(prev_utts)), 1)
190 | for i, symbols in enumerate(utterances_ind):
191 | full_sequence = prev_utts[-prev_append:] + symbols + [eosidx] * (maxlen - len(symbols) + 1)
192 | inputs.append(full_sequence[:-1])
193 | targets.append(full_sequence[1:])
194 | # get interpolated model inputs and targets
195 | extra_inputs.append(torch.LongTensor([eosidx] + symbols))
196 | extra_targets.append(torch.LongTensor(symbols+[eosidx]))
197 | mask.append([0.0] * (prev_append - 1) + [1.0] * (len(symbols) + 1) + [0.0] * (maxlen - len(symbols)))
198 | # arrange inputs and targets into tensors
199 | input_tensor = torch.LongTensor(inputs).to(device).t().contiguous()
200 | target_tensor = torch.LongTensor(targets).to(device).t().contiguous()
201 | mask_tensor = torch.tensor(mask).to(device).t().contiguous()
202 | bsize = input_tensor.size(1)
203 | seq_len = input_tensor.size(0)
204 |
205 | # forward prop interpolate model
206 | if args.extra_model != '' and args.extra_modeltype == 'RNN':
207 | interp_prob_list, extra_hidden_list = forward_extra(extra_model, extra_inputs, extra_targets, extra_hidden)
208 |
209 | # Forward prop transformer
210 | logProblist = []
211 | mem_list = []
212 | ppl_list = []
213 | hidden_list = []
214 | # initialise RNN hidden state
215 | if hidden is None and getattr(model, "rnnenc", False):
216 | hidden = model.init_hidden(1)
217 | pos_hidden = [(hid[0][-1:], hid[1][-1:]) for hid in hidden]
218 | elif prev_append == 1 and getattr(model, "rnnenc", False):
219 | pos_hidden = [(hid[0][-1:], hid[1][-1:]) for hid in hidden]
220 | elif getattr(model, "rnnenc", False):
221 | pos_hidden = [(hid[0][-prev_append:-prev_append+1], hid[1][-prev_append:-prev_append+1]) for hid in hidden]
222 | # import pdb; pdb.set_trace()
223 | # transformer XL
224 | tiled_mems = tuple()
225 | if len(mems) > 0 and prev_append < len(prev_utts):
226 | # determine how much memory to keep: prev_append + tiled_mems[0].size(0) = mems[0].size(0)
227 | tiled_mems = [mem[-prev_append-args.mem_len+1:-prev_append+1].repeat(1, bsize, 1) for mem in mems]
228 | # Start forwarding
229 | for i in range(0, bsize, args.subbatchsize):
230 | # mems for transformer XL
231 | if len(tiled_mems) > 0:
232 | this_mem = [mem[:, i:i+args.subbatchsize, :].contiguous() for mem in tiled_mems]
233 | else:
234 | this_mem = tuple()
235 |
236 | bsz = min(args.subbatchsize, bsize - i)
237 | # expand rnn hidden state
238 | rnn_hidden = None
239 | if hidden is not None:
240 | rnn_hidden = [(hid[0].repeat(1, bsz, 1), hid[1].repeat(1, bsz, 1)) for hid in pos_hidden]
241 |
242 | ret = model(input_tensor[:, i:i+args.subbatchsize].contiguous(),
243 | target_tensor[:, i:i+args.subbatchsize].contiguous(),
244 | *this_mem, rnn_hidden=rnn_hidden, stepwise=True)
245 | loss, this_mem, penalty, rnn_hidden = ret[0], ret[1:-2], ret[-2], ret[-1]
246 |
247 | if args.mem_len > 0 and len(this_mem) > 0:
248 | mem_list.append(torch.stack(this_mem))
249 | loss = loss * mask_tensor[:, i:i+args.subbatchsize]
250 | logProblist.append(loss)
251 | hidden_list.append(rnn_hidden)
252 | if args.pplword:
253 | ppl_list.append(loss)
254 | # outputlist.append(output[:,-1,:])
255 | lmscores = torch.cat(logProblist, 1)
256 | if args.extra_model == '':
257 | lmscores = torch.sum(lmscores, dim=0)
258 | else:
259 | interpolated_score = []
260 | for i, probs in enumerate(interp_prob_list):
261 | tranformer_score = lmscores[:,i].tolist()
262 | tranformer_score = torch.tensor([np.exp(-score) for score in tranformer_score if score > 0]).to(device)
263 | assert len(probs) == len(tranformer_score)
264 | lmscore = -torch.log(args.factor * tranformer_score + (1 - args.factor) * probs)
265 | interpolated_score.append(torch.sum(lmscore))
266 | lmscores = torch.stack(interpolated_score)
267 | # lmscores = torch.sum(logProb.view(seq_len, bsize)*mask_tensor, 0)
268 | total_scores = - lmscores * args.lmscale + ac_score_tensor
269 | # Get output in some format
270 | outputlines = []
271 | for i, utt in enumerate(utterances):
272 | out = ' '.join([utt_name+'-'+str(i+1), '{:5.2f}'.format(lmscores[i])])
273 | outputlines.append(out+'\n')
274 | max_ind = torch.argmax(total_scores)
275 | # RNN hidden state selection
276 | if len(hidden_list) > 0:
277 | all_hidden = []
278 | for hid in zip(*hidden_list):
279 | hid_l = list(zip(*hid))
280 | all_hidden.append((torch.cat(hid_l[0], dim=1), torch.cat(hid_l[1], dim=1)))
281 | best_hid = [(hid[0][:, max_ind:max_ind+1, :], hid[1][:, max_ind:max_ind+1, :]) for hid in all_hidden]
282 | else:
283 | best_hid = None
284 |
285 | best_utt = utterances[max_ind]
286 | best_utt_len = len(utterances[max_ind])
287 | prev_utts += (utterances_ind[max_ind] + [eosidx])
288 | hidden = [(torch.cat([hidden[i][0][-args.lookback-2:], hid[0][:best_utt_len+1]], dim=0),
289 | torch.cat([hidden[i][1][-args.lookback-2:], hid[1][:best_utt_len+1]], dim=0)) for i, hid in enumerate(best_hid)]
290 | if len(mem_list) > 0:
291 | mem_list = torch.cat(mem_list, dim=2)
292 | start_pos = max(mem_list.size(1) - maxlen - 1, 0)
293 | mem_list = mem_list[:, start_pos:start_pos+len(best_utt)+1, max_ind:max_ind+1, :]
294 | if len(mems) > 0:
295 | mem_list = [torch.cat([mems[i], mem_list[i]])[-(args.mem_len+args.lookback):] for i in range(mem_list.size(0))]
296 | else:
297 | mem_list = [mem_list[i] for i in range(mem_list.size(0))]
298 | # extra hidden states for interpolation
299 | extrahidden = extra_hidden_list[max_ind] if args.extra_model != '' else None
300 | # calculate perplexity
301 | best_ppl = lmscores[max_ind] if ppl else None
302 | # calculate per word perplexity
303 | if args.pplword:
304 | ppl_list = torch.cat(ppl_list, dim=1)[:, max_ind]
305 | for i, word in enumerate(best_utt+['']):
306 | if word in wordppl:
307 | wordppl[word][0] += ppl_list[i+prev_append-1]
308 | wordppl[word][1] += 1
309 | else:
310 | wordppl['OOV'][0] += ppl_list[i+prev_append-1]
311 | wordppl['OOV'][1] += 1
312 | return best_utt, outputlines, prev_utts[-args.lookback:], mem_list, best_ppl, hidden, extrahidden
313 |
314 | # Batched forward lookback
315 | def forward_each_utt_batched_lookback(model,
316 | lines,
317 | utt_name,
318 | prev_utts,
319 | mems,
320 | ppl=False,
321 | hidden=None,
322 | extra_model=None,
323 | extra_hidden=None):
324 | # Process each line
325 | inputs = []
326 | targets = []
327 | ac_scores = []
328 | lm_scores = []
329 | maxlen = 0
330 | new_mems = mems
331 | utterances = []
332 | utterances_ind = []
333 | target_index_list = []
334 | extra_inputs = []
335 | extra_targets = []
336 |
337 | for line in lines:
338 | linevec = line.strip().split()
339 | ac_score = float(linevec[0])
340 | utterance = linevec[4:-1]
341 | currentline = []
342 | for i, word in enumerate(utterance):
343 | if word in dictionary:
344 | currentline.append(int(dictionary[word]))
345 | else:
346 | currentline.append(int(dictionary['']))
347 | utterances.append(utterance)
348 | utterances_ind.append(currentline)
349 | ac_scores.append(ac_score)
350 | if len(currentline) > maxlen:
351 | maxlen = len(currentline)
352 | mask = []
353 | ac_score_tensor = torch.tensor(ac_scores).to(device)
354 | # Pad inputs and targets, prev_append in [0, len(prev_utts)]
355 | prev_append = max(min(args.lookback - maxlen, len(prev_utts)), 1)
356 | for i, symbols in enumerate(utterances_ind):
357 | full_sequence = prev_utts[-prev_append:] + symbols + [eosidx] * (maxlen - len(symbols) + 1)
358 | inputs.append(full_sequence[:-1])
359 | targets.append(full_sequence[1:])
360 | # get interpolated model inputs and targets
361 | extra_inputs.append(torch.LongTensor([eosidx] + symbols))
362 | extra_targets.append(torch.LongTensor(symbols+[eosidx]))
363 |
364 | mask.append([0.0] * (prev_append - 1) + [1.0] * (len(symbols) + 1) + [0.0] * (maxlen - len(symbols)))
365 | # arrange inputs and targets into tensors
366 | input_tensor = torch.LongTensor(inputs).to(device).t().contiguous()
367 | target_tensor = torch.LongTensor(targets).to(device).t().contiguous()
368 | mask_tensor = torch.tensor(mask).to(device).t().contiguous()
369 | bsize = input_tensor.size(1)
370 | seq_len = input_tensor.size(0)
371 |
372 | # forward prop interpolate model
373 | if args.extra_model != '' and args.extra_modeltype == 'RNN':
374 | interp_prob_list, extra_hidden_list = forward_extra(extra_model, extra_inputs, extra_targets, extra_hidden)
375 |
376 | # Forward prop transformer
377 | logProblist = []
378 | mem_list = []
379 | ppl_list = []
380 | # initialise RNN hidden stte
381 | if hidden is None and getattr(model, "rnnenc", False):
382 | hidden = model.init_hidden(1)
383 | # transformer XL
384 | tiled_mems = tuple()
385 | if len(mems) > 0 and prev_append < len(prev_utts):
386 | # determine how much memory to keep: prev_append + tiled_mems[0].size(0) = mems[0].size(0)
387 | tiled_mems = [mem[-prev_append-args.mem_len+1:-prev_append+1].repeat(1, bsize, 1) for mem in mems]
388 | for i in range(0, bsize, args.subbatchsize):
389 | # mems for transformer XL
390 | if len(tiled_mems) > 0:
391 | this_mem = [mem[:, i:i+args.subbatchsize, :].contiguous() for mem in tiled_mems]
392 | else:
393 | this_mem = tuple()
394 |
395 | bsz = min(args.subbatchsize, bsize - i)
396 | # expand rnn hidden state
397 | rnn_hidden = None
398 | if hidden is not None:
399 | rnn_hidden = [(hid[0].repeat(1, bsz, 1), hid[1].repeat(1, bsz, 1)) for hid in hidden]
400 |
401 | ret = model(input_tensor[:, i:i+args.subbatchsize].contiguous(),
402 | target_tensor[:, i:i+args.subbatchsize].contiguous(),
403 | *this_mem, rnn_hidden=rnn_hidden)
404 | loss, this_mem, penalty, hidden = ret[0], ret[1:-2], ret[-2], ret[-1]
405 | if args.mem_len > 0 and len(this_mem) > 0:
406 | mem_list.append(torch.stack(this_mem))
407 | loss = loss * mask_tensor[:, i:i+args.subbatchsize]
408 | logProblist.append(loss)
409 | if args.pplword:
410 | ppl_list.append(loss)
411 | # outputlist.append(output[:,-1,:])
412 | lmscores = torch.cat(logProblist, 1)
413 | if args.extra_model == '':
414 | lmscores = torch.sum(lmscores, dim=0)
415 | else:
416 | interpolated_score = []
417 | for i, probs in enumerate(interp_prob_list):
418 | tranformer_score = lmscores[:,i].tolist()
419 | tranformer_score = torch.tensor([np.exp(-score) for score in tranformer_score if score > 0]).to(device)
420 | assert len(probs) == len(tranformer_score)
421 | lmscore = -torch.log(args.factor * tranformer_score + (1 - args.factor) * probs)
422 | interpolated_score.append(torch.sum(lmscore))
423 | lmscores = torch.stack(interpolated_score)
424 | # lmscores = torch.sum(logProb.view(seq_len, bsize)*mask_tensor, 0)
425 | total_scores = - lmscores * args.lmscale + ac_score_tensor
426 | # Get output in some format
427 | outputlines = []
428 | for i, utt in enumerate(utterances):
429 | out = ' '.join([utt_name+'-'+str(i+1), '{:5.2f}'.format(lmscores[i])])
430 | outputlines.append(out+'\n')
431 | max_ind = torch.argmax(total_scores)
432 | best_utt = utterances[max_ind]
433 | prev_utts += (utterances_ind[max_ind] + [eosidx])
434 | if len(mem_list) > 0:
435 | new_mems = torch.cat(mem_list, dim=2)
436 | start_pos = max(new_mems.size(1) - maxlen - 1, 0)
437 | new_mems = new_mems[:, start_pos:start_pos+len(best_utt)+1, max_ind:max_ind+1, :]
438 | if len(mems) > 0:
439 | new_mems = [torch.cat([mems[i], new_mems[i]])[-(args.mem_len+args.lookback):] for i in range(new_mems.size(0))]
440 | else:
441 | new_mems = [new_mems[i] for i in range(new_mems.size(0))]
442 | # extra hidden states for interpolation
443 | extrahidden = extra_hidden_list[max_ind] if args.extra_model != '' else None
444 | # calculate perplexity
445 | best_ppl = lmscores[max_ind] if ppl else None
446 | # calculate per word perplexity
447 | if args.pplword:
448 | ppl_list = torch.cat(ppl_list, dim=1)[:, max_ind]
449 | for i, word in enumerate(best_utt+['']):
450 | if word in wordppl:
451 | wordppl[word][0] += ppl_list[i+prev_append-1]
452 | wordppl[word][1] += 1
453 | else:
454 | wordppl['OOV'][0] += ppl_list[i+prev_append-1]
455 | wordppl['OOV'][1] += 1
456 | return best_utt, outputlines, prev_utts[-args.lookback:], new_mems, best_ppl, extrahidden
457 |
458 | # Batched forward
459 | def forward_each_utt_batched(model, lines, utt_name, prev_utts, mems, ppl=False, hidden=None):
460 | # Process each line
461 | inputs = []
462 | targets = []
463 | ac_scores = []
464 | lm_scores = []
465 | maxlen = 0
466 | new_mems = mems
467 | utterances = []
468 | utterances_ind = []
469 |
470 | for line in lines:
471 | linevec = line.strip().split()
472 | ac_score = float(linevec[0])
473 | utterance = linevec[4:-1]
474 | currentline = []
475 | for i, word in enumerate(utterance):
476 | if word in dictionary:
477 | currentline.append(int(dictionary[word]))
478 | else:
479 | currentline.append(int(dictionary['']))
480 | currentline = [eosidx] + currentline
481 | currenttarget = currentline[1:]
482 | currenttarget.append(eosidx)
483 | inputs.append(currentline)
484 | targets.append(currenttarget)
485 | utterances.append(utterance)
486 | utterances_ind.append(currenttarget)
487 | ac_scores.append(ac_score)
488 | if len(currentline) > maxlen:
489 | maxlen = len(currentline)
490 | mask = []
491 | ac_score_tensor = torch.tensor(ac_scores).to(device)
492 | for i, symbols in enumerate(inputs):
493 | inputs[i] = symbols + [eosidx] * (maxlen - len(symbols))
494 | targets[i] = targets[i] + [eosidx] * (maxlen - len(symbols))
495 | mask.append([1.0] * len(symbols) + [0.0] * (maxlen - len(symbols)))
496 |
497 | input_tensor = torch.LongTensor(inputs).to(device).t().contiguous()
498 | target_tensor = torch.LongTensor(targets).to(device).t().contiguous()
499 | mask_tensor = torch.tensor(mask).to(device).t().contiguous()
500 | bsize = input_tensor.size(1)
501 | seq_len = input_tensor.size(0)
502 |
503 | # Forward prop transformer
504 | logProblist = []
505 | mem_list = []
506 | ppl_list = []
507 | hidden_list = []
508 | if hidden is None and getattr(model, "rnnenc", False):
509 | hidden = model.init_hidden(1)
510 | # transformer XL
511 | prev_mem_len = 0
512 | if len(mems) > 0:
513 | prev_mem_len = mems[0].size(0)
514 | tiled_mems = [mem.repeat(1, bsize, 1) for mem in mems]
515 | for i in range(0, bsize, args.subbatchsize):
516 | # mems for transformer XL
517 | if len(mems) > 0:
518 | this_mem = [mem[:, i:i+args.subbatchsize, :].contiguous() for mem in tiled_mems]
519 | else:
520 | this_mem = mems
521 | bsz = min(args.subbatchsize, bsize - i)
522 | # expand rnn hidden state
523 | rnn_hidden = None
524 | if hidden is not None:
525 | rnn_hidden = [(hid[0].repeat(1, bsz, 1), hid[1].repeat(1, bsz, 1)) for hid in hidden]
526 | # forward pass
527 | ret = model(input_tensor[:, i:i+args.subbatchsize].contiguous(),
528 | target_tensor[:, i:i+args.subbatchsize].contiguous(),
529 | *this_mem, rnn_hidden=rnn_hidden)
530 | loss, this_mems, rnn_hidden = ret[0], ret[1:-1], ret[-1]
531 | if args.mem_len > 0:
532 | mem_list.append(torch.stack(this_mems))
533 | if hidden is not None:
534 | hidden_list.append(rnn_hidden)
535 | loss = loss * mask_tensor[:, i:i+args.subbatchsize]
536 | logProblist.append(torch.sum(loss, dim=0))
537 | # outputlist.append(output[:,-1,:])
538 | lmscores = torch.cat(logProblist, 0)
539 | # lmscores = torch.sum(logProb.view(seq_len, bsize)*mask_tensor, 0)
540 | total_scores = - lmscores * args.lmscale + ac_score_tensor
541 | # Get output in some format
542 | outputlines = []
543 | for i, utt in enumerate(utterances):
544 | out = ' '.join([utt_name+'-'+str(i+1), '{:5.2f}'.format(lmscores[i])])
545 | outputlines.append(out+'\n')
546 | max_ind = torch.argmax(total_scores)
547 | # choose best rnn hidden state
548 | if len(hidden_list) > 0:
549 | all_hidden = []
550 | for hid in zip(*hidden_list):
551 | hid_l = list(zip(*hid))
552 | all_hidden.append((torch.cat(hid_l[0], dim=1), torch.cat(hid_l[1], dim=1)))
553 | best_hid = [(hid[0][:, max_ind:max_ind+1, :], hid[1][:, max_ind:max_ind+1, :]) for hid in all_hidden]
554 | else:
555 | best_hid = None
556 |
557 | best_utt = utterances[max_ind]
558 | prev_utts += utterances_ind[max_ind]
559 | if len(mem_list) > 0:
560 | new_mems = torch.cat(mem_list, dim=2)
561 | start_pos = max(new_mems.size(1) - seq_len, 0)
562 | new_mems = new_mems[:, start_pos:start_pos+len(best_utt)+1, max_ind:max_ind+1, :]
563 | if len(mems) > 0:
564 | new_mems = [torch.cat([mems[i], new_mems[i]])[-args.mem_len:] for i in range(new_mems.size(0))]
565 | else:
566 | new_mems = [new_mems[i] for i in range(new_mems.size(0))]
567 | # calculate perplexity
568 | best_ppl = lmscores[max_ind] if ppl else None
569 | return best_utt, outputlines, prev_utts, new_mems, best_ppl, best_hid
570 |
571 | def forward_nbest_utterance(model, nbestfile, extramodel=None):
572 | """ The main body of the rescore function. """
573 | model.eval()
574 | # decide if we calculate the average of the loss
575 | if args.interp:
576 | forwardCrit = torch.nn.CrossEntropyLoss(reduction='none')
577 | else:
578 | forwardCrit = torch.nn.CrossEntropyLoss()
579 |
580 | extrahidden = None
581 | if args.extra_model != '' and extramodel is not None:
582 | extramodel.eval()
583 | extrahidden = extramodel.init_hidden(1)
584 | # initialising variables needed
585 | ngram_cursor = 0
586 | lmscored_lines = []
587 | best_utt_list = []
588 | emb_list = []
589 | utt_idx = 0
590 | prev_utts = [eosidx] # * args.lookback
591 | start = time.time()
592 | best_hid = None
593 | best_ppl = None
594 | mems = tuple()
595 |
596 | total_ppl = torch.zeros(1)
597 | total_len = 0
598 |
599 | # Ngram used for lattice rescoring
600 | with open(nbestfile) as filein:
601 | with torch.no_grad():
602 | for utterancefile in filein:
603 | # Iterating over the nbest list
604 | labname = utterancefile.strip().split('/')[-1]
605 | # Read in ngram LM files for interpolation
606 | if args.interp:
607 | ngram_probfile_name = ngram_listfile.readline()
608 | ngram_probfile = open(ngram_probfile_name.strip())
609 | ngram_prob_lines = ngram_probfile.readlines()
610 | future_context = future_context_dict[utt_idx] if args.use_context else []
611 |
612 | # Start processing each nbestlist
613 | with open(utterancefile.strip()) as uttfile:
614 | uttlines = uttfile.readlines()
615 |
616 | uttscore = []
617 | # Do re-ranking batch by batch
618 | if not args.interp:
619 | if args.lookback > 0 and getattr(model, "rnnenc", False):
620 | bestutt, to_write, prev_utts, mems, best_ppl, best_hid, extrahidden = forward_each_utt_batched_lookback_rnn(
621 | model, uttlines, labname, prev_utts, mems, args.ppl, best_hid,
622 | extra_model=extramodel, extra_hidden=extrahidden)
623 | elif args.lookback > 0:
624 | bestutt, to_write, prev_utts, mems, best_ppl, extrahidden = forward_each_utt_batched_lookback(
625 | model, uttlines, labname, prev_utts, mems, args.ppl,
626 | extra_model=extramodel, extra_hidden=extrahidden)
627 | else:
628 | bestutt, to_write, prev_utts, mems, best_ppl, best_hid = forward_each_utt_batched(
629 | model, uttlines, labname, prev_utts, mems, args.ppl, best_hid)
630 | lmscored_lines += to_write
631 | utt_idx += 1
632 | best_utt_list.append((labname, bestutt))
633 | if args.ppl and best_ppl is not None:
634 | total_ppl += torch.sum(best_ppl)
635 | total_len += len(bestutt) + 1
636 | # Log every completion of n utterances
637 | if utt_idx % 100 == 0:
638 | logging("current ppl is {:5.2f}".format(torch.exp(total_ppl / total_len).item()))
639 | logging('rescored {} utterances, time overlapped {:6.2f}'.format(str(utt_idx), time.time()-start))
640 | # Write out renewed lmscore file
641 | with open(nbestfile+'.renew.'+args.lm, 'w') as fout:
642 | fout.writelines(lmscored_lines)
643 | # Save 1-best for later use for the context
644 | if args.save1best:
645 | # Write out for second level forwarding
646 | with open(nbestfile+'.context', 'w') as fout:
647 | for i, eachutt in enumerate(best_utt_list):
648 | linetowrite = ' ' + ' '.join(eachutt[1]) + ' \n'
649 | fout.write(linetowrite)
650 | if args.ppl:
651 | print(total_len)
652 | print(torch.exp(total_ppl / total_len))
653 |
654 | logging('getting utterances')
655 | forward_nbest_utterance(model, args.nbest, extramodel)
656 | if args.pplword:
657 | with open("word_ppl_{}".format(args.lm), 'w') as fin:
658 | for word, group in wordppl.items():
659 | total_ppl, total_count = group
660 | if total_count > 0:
661 | fin.write('{}\t\t{}\t\t{:5.3f}\n'.format(word, total_count, float(total_ppl/total_count)))
662 |
--------------------------------------------------------------------------------
/mem_transformer_rnn.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import math
3 | import functools
4 |
5 | import numpy as np
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | sys.path.append('utils')
12 | from proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax
13 | from log_uniform_sampler import LogUniformSampler, sample_logits
14 |
15 | class PositionalEmbedding(nn.Module):
16 | def __init__(self, demb):
17 | super(PositionalEmbedding, self).__init__()
18 |
19 | self.demb = demb
20 |
21 | inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
22 | self.register_buffer('inv_freq', inv_freq)
23 |
24 | def forward(self, pos_seq, bsz=None):
25 | sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
26 | pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
27 |
28 | if bsz is not None:
29 | return pos_emb[:,None,:].expand(-1, bsz, -1)
30 | else:
31 | return pos_emb[:,None,:]
32 |
33 |
34 | class PositionwiseFF(nn.Module):
35 | def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
36 | super(PositionwiseFF, self).__init__()
37 |
38 | self.d_model = d_model
39 | self.d_inner = d_inner
40 | self.dropout = dropout
41 |
42 | self.CoreNet = nn.Sequential(
43 | nn.Linear(d_model, d_inner), nn.ReLU(inplace=True),
44 | nn.Dropout(dropout),
45 | nn.Linear(d_inner, d_model),
46 | nn.Dropout(dropout),
47 | )
48 |
49 | self.layer_norm = nn.LayerNorm(d_model)
50 |
51 | self.pre_lnorm = pre_lnorm
52 |
53 | def forward(self, inp):
54 | if self.pre_lnorm:
55 | ##### layer normalization + positionwise feed-forward
56 | core_out = self.CoreNet(self.layer_norm(inp))
57 |
58 | ##### residual connection
59 | output = core_out + inp
60 | else:
61 | ##### positionwise feed-forward
62 | core_out = self.CoreNet(inp)
63 |
64 | ##### residual connection + layer normalization
65 | output = self.layer_norm(inp + core_out)
66 |
67 | return output
68 |
69 | class MultiHeadAttn(nn.Module):
70 | def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
71 | pre_lnorm=False, penalty=False):
72 | super(MultiHeadAttn, self).__init__()
73 |
74 | self.n_head = n_head
75 | self.d_model = d_model
76 | self.d_head = d_head
77 | self.dropout = dropout
78 |
79 | self.q_net = nn.Linear(d_model, n_head * d_head, bias=False)
80 | self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False)
81 |
82 | self.drop = nn.Dropout(dropout)
83 | self.dropatt = nn.Dropout(dropatt)
84 | self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
85 |
86 | self.layer_norm = nn.LayerNorm(d_model)
87 |
88 | self.scale = 1 / (d_head ** 0.5)
89 |
90 | self.pre_lnorm = pre_lnorm
91 | self.penalty = penalty
92 |
93 | def calc_penalty(self, attn_prob):
94 | if self.penalty:
95 | ATA = torch.einsum('ijbm,ijbn->ibmn', attn_prob, attn_prob)
96 | seqlen = attn_prob.size(0)
97 | lambda_array = torch.diag(torch.tensor([1.0]*(self.n_head)))
98 | pen = ((ATA - torch.eye(ATA.size(2), device=ATA.device)) ** 2).sum()
99 | else:
100 | pen = attn_prob.new_zeros(1)
101 | return pen
102 |
103 | def forward(self, h, attn_mask=None, mems=None):
104 | ##### multihead attention
105 | # [hlen x bsz x n_head x d_head]
106 |
107 | if mems is not None:
108 | c = torch.cat([mems, h], 0)
109 | else:
110 | c = h
111 |
112 | if self.pre_lnorm:
113 | ##### layer normalization
114 | c = self.layer_norm(c)
115 |
116 | head_q = self.q_net(h)
117 | head_k, head_v = torch.chunk(self.kv_net(c), 2, -1)
118 |
119 | head_q = head_q.view(h.size(0), h.size(1), self.n_head, self.d_head)
120 | head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head)
121 | head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head)
122 |
123 | # [qlen x klen x bsz x n_head]
124 | attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k))
125 | attn_score.mul_(self.scale)
126 | if attn_mask is not None and attn_mask.any().item():
127 | if attn_mask.dim() == 2:
128 | attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
129 | elif attn_mask.dim() == 3:
130 | attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf'))
131 |
132 | # [qlen x klen x bsz x n_head]
133 | attn_prob = F.softmax(attn_score, dim=1)
134 | attn_penalty = self.calc_penalty(attn_prob)
135 | attn_prob = self.dropatt(attn_prob)
136 |
137 | # [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
138 | attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v))
139 | attn_vec = attn_vec.contiguous().view(
140 | attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
141 |
142 | ##### linear projection
143 | attn_out = self.o_net(attn_vec)
144 | attn_out = self.drop(attn_out)
145 |
146 | if self.pre_lnorm:
147 | ##### residual connection
148 | output = h + attn_out
149 | else:
150 | ##### residual connection + layer normalization
151 | output = self.layer_norm(h + attn_out)
152 |
153 | return output, attn_penalty
154 |
155 | class RelMultiHeadAttn(nn.Module):
156 | def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
157 | tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False, penalty=False):
158 | super(RelMultiHeadAttn, self).__init__()
159 |
160 | self.n_head = n_head
161 | self.d_model = d_model
162 | self.d_head = d_head
163 | self.dropout = dropout
164 |
165 | self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False)
166 |
167 | self.drop = nn.Dropout(dropout)
168 | self.dropatt = nn.Dropout(dropatt)
169 | self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
170 |
171 | self.layer_norm = nn.LayerNorm(d_model)
172 |
173 | self.scale = 1 / (d_head ** 0.5)
174 |
175 | self.pre_lnorm = pre_lnorm
176 | self.penalty = penalty
177 |
178 | def _parallelogram_mask(self, h, w, left=False):
179 | mask = torch.ones((h, w)).byte()
180 | m = min(h, w)
181 | mask[:m,:m] = torch.triu(mask[:m,:m])
182 | mask[-m:,-m:] = torch.tril(mask[-m:,-m:])
183 |
184 | if left:
185 | return mask
186 | else:
187 | return mask.flip(0)
188 |
189 | def _shift(self, x, qlen, klen, mask, left=False):
190 | if qlen > 1:
191 | zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)),
192 | device=x.device, dtype=x.dtype)
193 | else:
194 | zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype)
195 |
196 | if left:
197 | mask = mask.flip(1)
198 | x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1)
199 | else:
200 | x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1)
201 |
202 | x = x_padded.masked_select(mask[:,:,None,None]) \
203 | .view(qlen, klen, x.size(2), x.size(3))
204 |
205 | return x
206 |
207 | def _calc_penalty(self, attn_prob):
208 | if self.penalty:
209 | ATA = torch.einsum('ijbm,ijbn->ibmn', attn_prob, attn_prob)
210 | seqlen = attn_prob.size(0)
211 | lambda_array = torch.diag(torch.tensor([0.5]*(self.n_head)))
212 | pen = ((ATA - torch.eye(ATA.size(2), device=ATA.device)) ** 2).sum()
213 | else:
214 | pen = attn_prob.new_zeros(1)
215 | return pen
216 |
217 | def _rel_shift(self, x, zero_triu=False):
218 | zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
219 | device=x.device, dtype=x.dtype)
220 | # import pdb; pdb.set_trace()
221 | x_padded = torch.cat([zero_pad, x], dim=1)
222 |
223 | x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])
224 |
225 | x = x_padded[1:].view_as(x)
226 |
227 | if zero_triu:
228 | ones = torch.ones((x.size(0), x.size(1)))
229 | x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None]
230 |
231 | return x
232 |
233 | def forward(self, w, r, attn_mask=None, mems=None):
234 | raise NotImplementedError
235 |
236 | class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
237 | def __init__(self, *args, **kwargs):
238 | super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
239 |
240 | self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)
241 |
242 | def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None):
243 | qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)
244 |
245 | if mems is not None:
246 | cat = torch.cat([mems, w], 0)
247 | if self.pre_lnorm:
248 | w_heads = self.qkv_net(self.layer_norm(cat))
249 | else:
250 | w_heads = self.qkv_net(cat)
251 | r_head_k = self.r_net(r)
252 |
253 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
254 | w_head_q = w_head_q[-qlen:]
255 | else:
256 | if self.pre_lnorm:
257 | w_heads = self.qkv_net(self.layer_norm(w))
258 | else:
259 | w_heads = self.qkv_net(w)
260 | r_head_k = self.r_net(r)
261 |
262 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
263 |
264 | klen = w_head_k.size(0)
265 |
266 | # import pdb; pdb.set_trace()
267 | w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
268 | w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
269 | w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
270 |
271 | r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head
272 |
273 | #### compute attention score
274 | rw_head_q = w_head_q + r_w_bias # qlen x bsz x n_head x d_head
275 | AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
276 |
277 | rr_head_q = w_head_q + r_r_bias
278 | BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k)) # qlen x klen x bsz x n_head
279 | BD = self._rel_shift(BD)
280 |
281 | # [qlen x klen x bsz x n_head]
282 | attn_score = AC + BD
283 | attn_score.mul_(self.scale)
284 |
285 | #### compute attention probability
286 | if attn_mask is not None and attn_mask.any().item():
287 | if attn_mask.dim() == 2:
288 | attn_score = attn_score.float().masked_fill(
289 | attn_mask[None,:,:,None], -float('inf')).type_as(attn_score)
290 | elif attn_mask.dim() == 3:
291 | attn_score = attn_score.float().masked_fill(
292 | attn_mask[:,:,:,None], -float('inf')).type_as(attn_score)
293 |
294 | # [qlen x klen x bsz x n_head]
295 | attn_prob = F.softmax(attn_score, dim=1)
296 | attn_prob = self.dropatt(attn_prob)
297 | attn_penalty = self._calc_penalty(attn_prob)
298 |
299 | #### compute attention vector
300 | attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))
301 |
302 | # [qlen x bsz x n_head x d_head]
303 | attn_vec = attn_vec.contiguous().view(
304 | attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
305 |
306 | ##### linear projection
307 | attn_out = self.o_net(attn_vec)
308 | attn_out = self.drop(attn_out)
309 |
310 | if self.pre_lnorm:
311 | ##### residual connection
312 | output = w + attn_out
313 | else:
314 | ##### residual connection + layer normalization
315 | output = self.layer_norm(w + attn_out)
316 |
317 | return output, attn_penalty
318 |
319 | class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
320 | def __init__(self, *args, **kwargs):
321 | super(RelLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
322 |
323 | def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None):
324 | # r_emb: [klen, n_head, d_head], used for term B
325 | # r_w_bias: [n_head, d_head], used for term C
326 | # r_bias: [klen, n_head], used for term D
327 |
328 | qlen, bsz = w.size(0), w.size(1)
329 |
330 | if mems is not None:
331 | cat = torch.cat([mems, w], 0)
332 | if self.pre_lnorm:
333 | w_heads = self.qkv_net(self.layer_norm(cat))
334 | else:
335 | w_heads = self.qkv_net(cat)
336 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
337 |
338 | w_head_q = w_head_q[-qlen:]
339 | else:
340 | if self.pre_lnorm:
341 | w_heads = self.qkv_net(self.layer_norm(w))
342 | else:
343 | w_heads = self.qkv_net(w)
344 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
345 |
346 | klen = w_head_k.size(0)
347 |
348 | w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)
349 | w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)
350 | w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)
351 |
352 | if klen > r_emb.size(0):
353 | r_emb_pad = r_emb[0:1].expand(klen-r_emb.size(0), -1, -1)
354 | r_emb = torch.cat([r_emb_pad, r_emb], 0)
355 | r_bias_pad = r_bias[0:1].expand(klen-r_bias.size(0), -1)
356 | r_bias = torch.cat([r_bias_pad, r_bias], 0)
357 | else:
358 | r_emb = r_emb[-klen:]
359 | r_bias = r_bias[-klen:]
360 |
361 | #### compute attention score
362 | rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head
363 |
364 | AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
365 | B_ = torch.einsum('ibnd,jnd->ijbn', (w_head_q, r_emb)) # qlen x klen x bsz x n_head
366 | D_ = r_bias[None, :, None] # 1 x klen x 1 x n_head
367 | BD = self._rel_shift(B_ + D_)
368 |
369 | # [qlen x klen x bsz x n_head]
370 | attn_score = AC + BD
371 | attn_score.mul_(self.scale)
372 |
373 | #### compute attention probability
374 | if attn_mask is not None and attn_mask.any().item():
375 | if attn_mask.dim() == 2:
376 | attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
377 | elif attn_mask.dim() == 3:
378 | attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf'))
379 |
380 | # [qlen x klen x bsz x n_head]
381 | attn_prob = F.softmax(attn_score, dim=1)
382 | attn_prob = self.dropatt(attn_prob)
383 |
384 | #### compute attention vector
385 | attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))
386 |
387 | # [qlen x bsz x n_head x d_head]
388 | attn_vec = attn_vec.contiguous().view(
389 | attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
390 |
391 | ##### linear projection
392 | attn_out = self.o_net(attn_vec)
393 | attn_out = self.drop(attn_out)
394 |
395 | if self.pre_lnorm:
396 | ##### residual connection
397 | output = w + attn_out
398 | else:
399 | ##### residual connection + layer normalization
400 | output = self.layer_norm(w + attn_out)
401 |
402 | return output
403 |
404 | class DecoderLayer(nn.Module):
405 | def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs):
406 | super(DecoderLayer, self).__init__()
407 |
408 | self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
409 | self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
410 | pre_lnorm=kwargs.get('pre_lnorm'))
411 |
412 | def forward(self, dec_inp, dec_attn_mask=None, mems=None):
413 |
414 | output, attn_pen = self.dec_attn(dec_inp, attn_mask=dec_attn_mask,
415 | mems=mems)
416 | output = self.pos_ff(output)
417 |
418 | return output, attn_pen
419 |
420 | class RelLearnableDecoderLayer(nn.Module):
421 | def __init__(self, n_head, d_model, d_head, d_inner, dropout,
422 | **kwargs):
423 | super(RelLearnableDecoderLayer, self).__init__()
424 |
425 | self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
426 | **kwargs)
427 | self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
428 | pre_lnorm=kwargs.get('pre_lnorm'))
429 |
430 | def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
431 |
432 | output = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias,
433 | attn_mask=dec_attn_mask,
434 | mems=mems)
435 | output = self.pos_ff(output)
436 |
437 | return output
438 |
439 | class RelPartialLearnableDecoderLayer(nn.Module):
440 | def __init__(self, n_head, d_model, d_head, d_inner, dropout,
441 | **kwargs):
442 | super(RelPartialLearnableDecoderLayer, self).__init__()
443 |
444 | self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
445 | d_head, dropout, **kwargs)
446 | self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
447 | pre_lnorm=kwargs.get('pre_lnorm'))
448 |
449 | def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):
450 |
451 | output, attn_p = self.dec_attn(dec_inp, r, r_w_bias, r_r_bias,
452 | attn_mask=dec_attn_mask,
453 | mems=mems)
454 | output = self.pos_ff(output)
455 |
456 | return output, attn_p
457 |
458 |
459 | class AdaptiveEmbedding(nn.Module):
460 | def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
461 | sample_softmax=False):
462 | super(AdaptiveEmbedding, self).__init__()
463 |
464 | self.n_token = n_token
465 | self.d_embed = d_embed
466 |
467 | self.cutoffs = cutoffs + [n_token]
468 | self.div_val = div_val
469 | self.d_proj = d_proj
470 |
471 | self.emb_scale = d_proj ** 0.5
472 |
473 | self.cutoff_ends = [0] + self.cutoffs
474 |
475 | self.emb_layers = nn.ModuleList()
476 | self.emb_projs = nn.ParameterList()
477 | if div_val == 1:
478 | self.emb_layers.append(
479 | nn.Embedding(n_token, d_embed, sparse=sample_softmax>0)
480 | )
481 | if d_proj != d_embed:
482 | self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_embed)))
483 | else:
484 | for i in range(len(self.cutoffs)):
485 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
486 | d_emb_i = d_embed // (div_val ** i)
487 | self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i))
488 | self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_emb_i)))
489 |
490 | def forward(self, inp):
491 | if self.div_val == 1:
492 | embed = self.emb_layers[0](inp)
493 | if self.d_proj != self.d_embed:
494 | embed = F.linear(embed, self.emb_projs[0])
495 | else:
496 | param = next(self.parameters())
497 | inp_flat = inp.view(-1)
498 | emb_flat = torch.zeros([inp_flat.size(0), self.d_proj],
499 | dtype=param.dtype, device=param.device)
500 | for i in range(len(self.cutoffs)):
501 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
502 |
503 | mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
504 | indices_i = mask_i.nonzero().squeeze()
505 |
506 | if indices_i.numel() == 0:
507 | continue
508 |
509 | inp_i = inp_flat.index_select(0, indices_i) - l_idx
510 | emb_i = self.emb_layers[i](inp_i)
511 | emb_i = F.linear(emb_i, self.emb_projs[i])
512 |
513 | emb_flat.index_copy_(0, indices_i, emb_i)
514 |
515 | embed = emb_flat.view(*inp.size(), self.d_proj)
516 |
517 | embed.mul_(self.emb_scale)
518 |
519 | return embed
520 |
521 |
522 | class MemTransformerLM(nn.Module):
523 | def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner,
524 | dropout, dropatt, tie_weight=True, d_embed=None,
525 | div_val=1, tie_projs=[False], pre_lnorm=False,
526 | tgt_len=None, ext_len=None, mem_len=None,
527 | cutoffs=[], adapt_inp=False,
528 | same_length=False, attn_type=0, clamp_len=-1,
529 | sample_softmax=-1, rnnenc=False, rnndim=0,
530 | layer_list='', future_len=0, attn_layerlist='', merge_type='direct'):
531 | super(MemTransformerLM, self).__init__()
532 | self.n_token = n_token
533 |
534 | d_embed = d_model if d_embed is None else d_embed
535 | self.d_embed = d_embed
536 | self.d_model = d_model
537 | self.n_head = n_head
538 | self.d_head = d_head
539 |
540 | self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs,
541 | div_val=div_val)
542 |
543 | self.drop = nn.Dropout(dropout)
544 |
545 | self.n_layer = n_layer
546 |
547 | self.tgt_len = tgt_len
548 | self.mem_len = mem_len
549 | self.ext_len = ext_len
550 | self.future_len = future_len
551 | self.max_klen = tgt_len + ext_len + mem_len + future_len
552 |
553 | self.attn_type = attn_type
554 |
555 | # RNN hidden state carry on
556 | self.layer_list = [int(i) for i in layer_list.split()]
557 | self.rnnlayer_list = self.layer_list
558 | print("rnn layer list: {}".format(self.rnnlayer_list))
559 | if rnnenc and rnndim != 0:
560 | if merge_type in ['gating', 'project']:
561 | self.rnnproj = nn.Linear(rnndim + d_model, d_model)
562 | self.rnn_list = nn.ModuleList([nn.LSTM(d_model, rnndim, 1) for i in range(len(self.rnnlayer_list))])
563 | # attn penalisation
564 | self.attn_pen_layers = [int(i) for i in attn_layerlist.split()]
565 | self.merge_type = merge_type
566 |
567 | self.layers = nn.ModuleList()
568 | if attn_type == 0: # the default attention
569 | for i in range(n_layer):
570 | # dropatt = dropatt * 2 if (i == 0 and rnnenc) else dropatt
571 | use_penalty = i in self.attn_pen_layers
572 | self.layers.append(
573 | RelPartialLearnableDecoderLayer(
574 | n_head, d_model, d_head, d_inner, dropout,
575 | tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
576 | dropatt=dropatt, pre_lnorm=pre_lnorm, penalty=use_penalty)
577 | )
578 | elif attn_type == 1: # learnable embeddings
579 | for i in range(n_layer):
580 | self.layers.append(
581 | RelLearnableDecoderLayer(
582 | n_head, d_model, d_head, d_inner, dropout,
583 | tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
584 | dropatt=dropatt, pre_lnorm=pre_lnorm)
585 | )
586 | elif attn_type in [2, 3]: # absolute embeddings
587 | for i in range(n_layer):
588 | use_penalty = i in self.attn_pen_layers
589 | # dropatt = dropatt * 0 if (i == 0 and rnnenc) else dropatt
590 | self.layers.append(
591 | DecoderLayer(
592 | n_head, d_model, d_head, d_inner, dropout,
593 | dropatt=dropatt, pre_lnorm=pre_lnorm, penalty=use_penalty)
594 | )
595 |
596 | self.sample_softmax = sample_softmax
597 | # use sampled softmax
598 | if sample_softmax > 0:
599 | self.out_layer = nn.Linear(d_model, n_token)
600 | if tie_weight:
601 | self.out_layer.weight = self.word_emb.weight
602 | self.tie_weight = tie_weight
603 | self.sampler = LogUniformSampler(n_token, sample_softmax)
604 |
605 | # use adaptive softmax (including standard softmax)
606 | else:
607 | self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model,
608 | cutoffs, div_val=div_val)
609 |
610 | if tie_weight:
611 | for i in range(len(self.crit.out_layers)):
612 | self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight
613 |
614 | if tie_projs:
615 | for i, tie_proj in enumerate(tie_projs):
616 | if tie_proj and div_val == 1 and d_model != d_embed:
617 | self.crit.out_projs[i] = self.word_emb.emb_projs[0]
618 | elif tie_proj and div_val != 1:
619 | self.crit.out_projs[i] = self.word_emb.emb_projs[i]
620 |
621 | self.rnnenc = rnnenc
622 | self.rnndim = rnndim
623 | self.same_length = same_length
624 | self.clamp_len = clamp_len
625 |
626 | self._create_params()
627 |
628 | def backward_compatible(self):
629 | self.sample_softmax = -1
630 |
631 | def _create_params(self):
632 | if self.attn_type == 0: # default attention
633 | self.pos_emb = PositionalEmbedding(self.d_model)
634 | self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
635 | self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
636 | elif self.attn_type == 1: # learnable
637 | self.r_emb = nn.Parameter(torch.Tensor(
638 | self.n_layer, self.max_klen, self.n_head, self.d_head))
639 | self.r_w_bias = nn.Parameter(torch.Tensor(
640 | self.n_layer, self.n_head, self.d_head))
641 | self.r_bias = nn.Parameter(torch.Tensor(
642 | self.n_layer, self.max_klen, self.n_head))
643 | elif self.attn_type == 2: # absolute standard
644 | self.pos_emb = PositionalEmbedding(self.d_model)
645 | elif self.attn_type == 3: # absolute deeper SA
646 | self.r_emb = nn.Parameter(torch.Tensor(
647 | self.n_layer, self.max_klen, self.n_head, self.d_head))
648 | # for RNN projection
649 | if self.rnnenc and self.rnndim != self.d_model:
650 | self.rnnproj.bias.data.zero_()
651 | self.rnnproj.weight.data.uniform_(-0.1, 0.1)
652 |
653 | def reset_length(self, tgt_len, ext_len, mem_len):
654 | self.tgt_len = tgt_len
655 | self.mem_len = mem_len
656 | self.ext_len = ext_len
657 |
658 | def init_mems(self):
659 | if self.mem_len > 0:
660 | mems = []
661 | param = next(self.parameters())
662 | for i in range(self.n_layer+1):
663 | empty = torch.empty(0, dtype=param.dtype, device=param.device)
664 | mems.append(empty)
665 |
666 | return mems
667 | else:
668 | return None
669 |
670 | def _update_mems(self, hids, mems, qlen, mlen):
671 | # does not deal with None
672 | if mems is None: return None
673 |
674 | # mems is not None
675 | assert len(hids) == len(mems), 'len(hids) != len(mems)'
676 |
677 | # There are `mlen + qlen` steps that can be cached into mems
678 | # For the next step, the last `ext_len` of the `qlen` tokens
679 | # will be used as the extended context. Hence, we only cache
680 | # the tokens from `mlen + qlen - self.ext_len - self.mem_len`
681 | # to `mlen + qlen - self.ext_len`.
682 | with torch.no_grad():
683 | new_mems = []
684 | end_idx = mlen + max(0, qlen - 0 - self.ext_len)
685 | beg_idx = max(0, end_idx - self.mem_len)
686 | for i in range(len(hids)):
687 |
688 | cat = torch.cat([mems[i], hids[i]], dim=0)
689 | new_mems.append(cat[beg_idx:end_idx].detach())
690 |
691 | return new_mems
692 |
693 | def init_hidden(self, bsz):
694 | hidden_list = []
695 | for i in self.rnnlayer_list:
696 | weight = next(self.parameters())
697 | hidden_list.append((weight.new_zeros(1, bsz, getattr(self, 'rnndim', self.d_model)),
698 | weight.new_zeros(1, bsz, getattr(self, 'rnndim', self.d_model))))
699 | return hidden_list
700 |
701 | def init_hidden_singlelayer(self, bsz):
702 | weight = next(self.parameters())
703 | hidden = (weight.new_zeros(1, bsz, self.rnndim), weight.new_zeros(1, bsz, self.rnndim))
704 | return hidden
705 |
706 | def forward_rnn(self, i, core_out, rnn_hidden, stepwise=False, future_seqlen=0, tgtlen=0):
707 | # gs534 - rnn in the middle of a transformer layer
708 | index = self.rnnlayer_list.index(i)
709 | if not stepwise:
710 | if self.ext_len > 0:
711 | rnn_out_ext = None
712 | if tgtlen != core_out.size(0) - future_seqlen:
713 | rnn_out_ext, rnn_hidden[index] = self.rnn_list[index](
714 | core_out[:core_out.size(0)-future_seqlen-tgtlen], rnn_hidden[index])
715 | rnn_out, _ = self.rnn_list[index](
716 | core_out[core_out.size(0)-future_seqlen-tgtlen:core_out.size(0)-future_seqlen], rnn_hidden[index])
717 | if rnn_out_ext is not None:
718 | rnn_out = torch.cat([rnn_out_ext, rnn_out], dim=0)
719 | else:
720 | rnn_out, rnn_hidden[index] = self.rnn_list[index](
721 | core_out[:core_out.size(0)-future_seqlen], rnn_hidden[index])
722 | if future_seqlen > 1:
723 | new_hiddens = self.init_hidden_singlelayer(core_out.size(1))
724 | rnnout_future, _ = self.rnn_list[index](core_out[-future_seqlen+1:], new_hiddens)
725 | rnn_out = torch.cat([rnn_out, core_out[-future_seqlen:-future_seqlen+1], rnnout_future], dim=0)
726 | else:
727 | step_hidden = rnn_hidden[index]
728 | rnn_hidden_step = []
729 | core_out_list = []
730 | for k in range(core_out.size(0) - future_seqlen):
731 | rnn_output, step_hidden = self.rnn_list[index](core_out[k:k+1], step_hidden)
732 | core_out_list.append(rnn_output)
733 | rnn_hidden_step.append(step_hidden)
734 | rnn_hidden_step = list(zip(*rnn_hidden_step))
735 | rnn_hidden[index] = (torch.cat(rnn_hidden_step[0], dim=0), torch.cat(rnn_hidden_step[1], dim=0))
736 | rnn_out = torch.cat(core_out_list, dim=0)
737 | if future_seqlen > 0:
738 | new_hiddens = self.init_hidden_singlelayer(core_out.size(1))
739 | rnnout_future, _ = self.rnn_list[index](core_out[-future_seqlen+1:], new_hiddens)
740 | # rnnout_future = core_out[-future_seqlen:]
741 | rnn_out = torch.cat([rnn_out, core_out[-future_seqlen:-future_seqlen+1], rnnout_future], dim=0)
742 | assert (self.rnndim == self.d_model or self.merge_type == 'project')
743 | return rnn_out, rnn_hidden
744 |
745 | def get_future_mask(self, qlen, flen, word_emb, mlen=0):
746 | klen = qlen + mlen
747 | dec_attn_mask = torch.triu(
748 | word_emb.new_ones(qlen-flen, klen-flen), diagonal=1+mlen)
749 | # various paddings
750 | ones_bottom_left = word_emb.new_ones(flen, klen-flen)
751 | ones_single_column = word_emb.new_ones(qlen, 1 if flen > 0 else 0)
752 | ones_top_right = word_emb.new_ones(qlen-flen, flen-1 if flen > 0 else 0)
753 | zeros_bottom_right = word_emb.new_zeros(flen, flen-1 if flen > 0 else 0)
754 | zeros_top_right = word_emb.new_zeros(qlen-flen, flen-1 if flen > 0 else 0)
755 | # construct half masks
756 | attn_mask_right = torch.cat([ones_top_right, zeros_bottom_right], dim=0)
757 | attn_mask_right_future = torch.cat([zeros_top_right, zeros_bottom_right], dim=0)
758 | attn_mask_left = torch.cat([dec_attn_mask, ones_bottom_left], dim=0)
759 | # construct attn masks
760 | dec_attn_mask = torch.cat([attn_mask_left, ones_single_column, attn_mask_right], dim=-1)
761 | dec_attn_mask_future = torch.cat(
762 | [attn_mask_left, ones_single_column, attn_mask_right_future], dim=-1)
763 | return dec_attn_mask[:,:,None], dec_attn_mask_future[:,:,None]
764 |
765 | def _forward(self, dec_inp, mems=None, rnn_hidden=None, stepwise=False, future_seqlen=0, tgt_len=0):
766 | qlen, bsz = dec_inp.size()
767 | attn_pen_list = []
768 |
769 | word_emb = self.word_emb(dec_inp)
770 |
771 | mlen = mems[0].size(0) if mems is not None else 0
772 | klen = mlen + qlen
773 | if self.same_length:
774 | all_ones = word_emb.new_ones(qlen, klen)
775 | mask_len = klen - self.mem_len
776 | if mask_len > 0:
777 | mask_shift_len = qlen - mask_len
778 | else:
779 | mask_shift_len = qlen
780 | dec_attn_mask = (torch.triu(all_ones, 1+mlen)
781 | + torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1
782 | elif self.future_len == 0:
783 | dec_attn_mask = torch.triu(
784 | word_emb.new_ones(qlen, klen), diagonal=1+mlen)[:,:,None]
785 | dec_attn_mask = dec_attn_mask.byte()
786 | else:
787 | dec_attn_mask_normal, dec_attn_mask_future = self.get_future_mask(
788 | qlen, future_seqlen, word_emb, mlen)
789 | dec_attn_mask_normal = dec_attn_mask_normal.byte()
790 | dec_attn_mask_future = dec_attn_mask_future.byte()
791 |
792 | hids = []
793 | if self.attn_type == 0: # default
794 | pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
795 | dtype=word_emb.dtype)
796 | if self.clamp_len > 0:
797 | pos_seq.clamp_(max=self.clamp_len)
798 | pos_emb = self.pos_emb(pos_seq)
799 |
800 | core_out = self.drop(word_emb)
801 | pos_emb = self.drop(pos_emb)
802 | if self.rnnenc and 0 in self.rnnlayer_list:
803 | rnn_out, rnn_hidden = self.forward_rnn(0, core_out, rnn_hidden, stepwise, future_seqlen)
804 | if self.merge_type == 'project':
805 | core_out = torch.relu(self.rnnproj(torch.cat([rnn_out, core_out], dim=-1)))
806 | # core_out = (self.rnnproj(torch.cat([rnn_out, core_out], dim=-1)))
807 | elif self.merge_type == 'gating':
808 | core_gating = torch.sigmoid(self.rnnproj(torch.cat([rnn_out, core_out], dim=-1)))
809 | core_out = rnn_out * core_gating + core_out * (1 - core_gating)
810 | else:
811 | core_out = rnn_out
812 | hids.append(core_out)
813 |
814 | for i, layer in enumerate(self.layers):
815 | if self.future_len != 0:
816 | dec_attn_mask = dec_attn_mask_future if i in self.layer_list else dec_attn_mask_normal
817 | mems_i = None if mems is None else mems[i]
818 | core_out, attn_pen = layer(core_out, pos_emb, self.r_w_bias,
819 | self.r_r_bias, dec_attn_mask=dec_attn_mask,
820 | mems=mems_i)
821 | # gs534 - rnn in the middle of a transformer layer
822 | if self.rnnenc and i+1 in self.rnnlayer_list:
823 | rnn_out, rnn_hidden = self.forward_rnn(i+1, core_out, rnn_hidden, stepwise, future_seqlen)
824 | if self.merge_type == 'project':
825 | core_out = torch.relu(self.rnnproj(torch.cat([rnn_out, core_out], dim=-1)))
826 | # core_out = (self.rnnproj(torch.cat([rnn_out, core_out], dim=-1)))
827 | elif self.merge_type == 'gating':
828 | core_gating = torch.sigmoid(self.rnnproj(torch.cat([rnn_out, core_out], dim=-1)))
829 | core_out = rnn_out * core_gating + core_out * (1 - core_gating)
830 | else:
831 | core_out = rnn_out
832 | attn_pen_list.append(attn_pen)
833 | hids.append(core_out)
834 | elif self.attn_type == 1: # learnable
835 | core_out = self.drop(word_emb)
836 | hids.append(core_out)
837 | for i, layer in enumerate(self.layers):
838 | if self.clamp_len > 0:
839 | r_emb = self.r_emb[i][-self.clamp_len :]
840 | r_bias = self.r_bias[i][-self.clamp_len :]
841 | else:
842 | r_emb, r_bias = self.r_emb[i], self.r_bias[i]
843 |
844 | mems_i = None if mems is None else mems[i]
845 | core_out = layer(core_out, r_emb, self.r_w_bias[i],
846 | r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
847 | hids.append(core_out)
848 | elif self.attn_type == 2: # absolute
849 | pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,
850 | dtype=word_emb.dtype)
851 | if self.clamp_len > 0:
852 | pos_seq.clamp_(max=self.clamp_len)
853 | pos_emb = self.pos_emb(pos_seq)
854 |
855 | word_emb = self.drop(word_emb)
856 | # rnn_out, rnn_hidden = self.forward_rnn(word_emb, rnn_hidden, stepwise, future_seqlen, tgt_len)
857 | core_out = word_emb
858 |
859 | hids.append(core_out)
860 | for i, layer in enumerate(self.layers):
861 | if self.future_len != 0:
862 | dec_attn_mask = dec_attn_mask_future if i in self.layer_list else dec_attn_mask_normal
863 | mems_i = None if mems is None else mems[i]
864 | if mems_i is not None and mlen > 0:
865 | mems_i += pos_emb[:mlen]
866 | # gs534 - rnn in the middle of a transformer layer
867 | if self.rnnenc and i in self.rnnlayer_list:
868 | core_out = self.drop(self.rnnproj(torch.cat([rnn_out, core_out], dim=-1)))
869 | core_out, attn_pen = layer(core_out, dec_attn_mask=dec_attn_mask,
870 | mems=mems_i)
871 | attn_pen_list.append(attn_pen)
872 | hids.append(core_out)
873 | elif self.attn_type == 3:
874 | core_out = self.drop(word_emb)
875 |
876 | hids.append(core_out)
877 | for i, layer in enumerate(self.layers):
878 | mems_i = None if mems is None else mems[i]
879 | if mems_i is not None and mlen > 0:
880 | cur_emb = self.r_emb[i][:-qlen]
881 | cur_size = cur_emb.size(0)
882 | if cur_size < mlen:
883 | cur_emb_pad = cur_emb[0:1].expand(mlen-cur_size, -1, -1)
884 | cur_emb = torch.cat([cur_emb_pad, cur_emb], 0)
885 | else:
886 | cur_emb = cur_emb[-mlen:]
887 | mems_i += cur_emb.view(mlen, 1, -1)
888 | core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1)
889 |
890 | core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
891 | mems=mems_i)
892 | hids.append(core_out)
893 |
894 | core_out = self.drop(core_out)
895 | new_mems = self._update_mems(hids, mems, mlen, qlen)
896 | attn_pen = sum(attn_pen_list)[0] if attn_pen_list != [] else core_out.new_zeros(1)
897 |
898 | return core_out, new_mems, rnn_hidden, attn_pen
899 |
900 | def forward(self, data, target, *mems, rnn_hidden=None, stepwise=False, future_seqlen=0):
901 | # nn.DataParallel does not allow size(0) tensors to be broadcasted.
902 | # So, have to initialize size(0) mems inside the model forward.
903 | # Moreover, have to return new_mems to allow nn.DataParallel to piece
904 | # them together.
905 | if not mems: mems = self.init_mems()
906 |
907 | tgt_len = target.size(0)
908 | hidden, new_mems, rnn_hidden, attn_pen = self._forward(data,
909 | mems=mems,
910 | rnn_hidden=rnn_hidden,
911 | stepwise=stepwise,
912 | future_seqlen=future_seqlen,
913 | tgt_len=tgt_len)
914 |
915 | if future_seqlen > 0:
916 | pred_hid = hidden[-tgt_len-future_seqlen:-future_seqlen]
917 | else:
918 | pred_hid = hidden[-tgt_len:]
919 | if self.sample_softmax > 0 and self.training:
920 | assert self.tie_weight
921 | logit = sample_logits(self.word_emb,
922 | self.out_layer.bias, target, pred_hid, self.sampler)
923 | loss = -F.log_softmax(logit, -1)[:, :, 0]
924 | else:
925 | loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1))
926 | loss = loss.view(tgt_len, -1)
927 |
928 | if new_mems is None:
929 | return [loss] + [attn_pen] + [rnn_hidden]
930 | else:
931 | return [loss] + new_mems + [attn_pen] + [rnn_hidden]
932 |
933 | if __name__ == '__main__':
934 | import argparse
935 |
936 | parser = argparse.ArgumentParser(description='unit test')
937 |
938 | parser.add_argument('--n_layer', type=int, default=4, help='')
939 | parser.add_argument('--n_rel_layer', type=int, default=4, help='')
940 | parser.add_argument('--n_head', type=int, default=2, help='')
941 | parser.add_argument('--d_head', type=int, default=2, help='')
942 | parser.add_argument('--d_model', type=int, default=200, help='')
943 | parser.add_argument('--d_embed', type=int, default=200, help='')
944 | parser.add_argument('--d_inner', type=int, default=200, help='')
945 | parser.add_argument('--dropout', type=float, default=0.0, help='')
946 | parser.add_argument('--cuda', action='store_true', help='')
947 | parser.add_argument('--seed', type=int, default=1111, help='')
948 | parser.add_argument('--multi_gpu', action='store_true', help='')
949 |
950 | args = parser.parse_args()
951 |
952 | device = torch.device("cuda" if args.cuda else "cpu")
953 |
954 | B = 4
955 | tgt_len, mem_len, ext_len = 36, 36, 0
956 | data_len = tgt_len * 20
957 | args.n_token = 10000
958 |
959 | import data_utils
960 |
961 | data = torch.LongTensor(data_len*B).random_(0, args.n_token).to(device)
962 | diter = data_utils.LMOrderedIterator(data, B, tgt_len, device=device, ext_len=ext_len)
963 |
964 | cutoffs = [args.n_token // 2]
965 | tie_projs = [False] + [True] * len(cutoffs)
966 |
967 | for div_val in [1, 2]:
968 | for d_embed in [200, 100]:
969 | model = MemTransformerLM(args.n_token, args.n_layer, args.n_head,
970 | args.d_model, args.d_head, args.d_inner, args.dropout,
971 | dropatt=args.dropout, tie_weight=True,
972 | d_embed=d_embed, div_val=div_val,
973 | tie_projs=tie_projs, pre_lnorm=True,
974 | tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
975 | cutoffs=cutoffs, attn_type=0).to(device)
976 |
977 | print(sum(p.numel() for p in model.parameters()))
978 |
979 | mems = tuple()
980 | for idx, (inp, tgt, seqlen) in enumerate(diter):
981 | print('batch {}'.format(idx))
982 | out = model(inp, tgt, *mems)
983 | mems = out[1:]
984 |
--------------------------------------------------------------------------------