├── .gitignore ├── LICENSE ├── README.md ├── dataset ├── common.py ├── lm.py ├── problem.py └── translation.py ├── decoder.py ├── model ├── fast_transformer.py └── transformer.py ├── requirements.txt ├── train.py └── utils ├── optimizer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.pyc 3 | __pycache__ 4 | .eggs 5 | .mypy_cache 6 | .pytest_cache 7 | *.egg-info 8 | build 9 | dist 10 | *.so 11 | 12 | .data 13 | data 14 | output 15 | *_data* 16 | *output* 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Choongwoo Han 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Transformer 3 | 4 | This is a pytorch implementation of the 5 | [Transformer](https://arxiv.org/abs/1706.03762) model like 6 | [tensorflow/tensor2tensor](https://github.com/tensorflow/tensor2tensor). 7 | 8 | ## Prerequisite 9 | 10 | I tested it with PyTorch 1.0.0 and Python 3.6.8. 11 | 12 | It's using [SpaCy](https://spacy.io/usage/) to tokenize languages for wmt32k 13 | dataset. So, if you want to run `wmt32k` problem which is a de/en translation 14 | dataset, you should download language models first with the following command. 15 | 16 | ``` 17 | $ pip install spacy 18 | $ python -m spacy download en 19 | $ python -m spacy download de 20 | ``` 21 | 22 | ## Usage 23 | 24 | 1. Train a model. 25 | ``` 26 | $ python train.py --problem wmt32k --output_dir ./output --data_dir ./wmt32k_data 27 | or 28 | $ python train.py --problem lm1b --output_dir ./output --data_dir ./lm1b_data 29 | ``` 30 | 31 | If you want to try `fast_transformer`, give a `model` argument after installing 32 | [tcop-pytorch](https://github.com/tunz/tcop-pytorch). 33 | ``` 34 | $ python train.py --problem lm1b --output_dir ./output --data_dir ./lm1b_data --model fast_transformer 35 | ``` 36 | 37 | 38 | 2. You can translate a single sentence with the trained model. 39 | ``` 40 | $ python decoder.py --translate --data_dir ./wmt32k_data --model_dir ./output/last/models 41 | ``` 42 | -------------------------------------------------------------------------------- /dataset/common.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import torch 4 | from torchtext.data import Iterator 5 | from tqdm import tqdm 6 | 7 | 8 | class BucketByLengthIterator(Iterator): 9 | def __init__(self, *args, max_length=None, example_length_fn=None, 10 | data_paths=None, **kwargs): 11 | batch_size = kwargs['batch_size'] 12 | 13 | self.boundaries = self._bucket_boundaries(max_length) 14 | self.batch_sizes = self._batch_sizes(batch_size) 15 | self.example_length_fn = example_length_fn 16 | self.data_paths = data_paths 17 | self.data_path_idx = 0 18 | self.buckets = [[] for _ in range(len(self.boundaries)+1)] 19 | 20 | super(BucketByLengthIterator, self).__init__(*args, **kwargs) 21 | 22 | def create_batches(self): 23 | self.batches = self._bucket_by_seq_length(self.data()) 24 | 25 | def reload_examples(self): 26 | self.data_path_idx = (self.data_path_idx + 1) % len(self.data_paths) 27 | data_path = self.data_paths[self.data_path_idx] 28 | 29 | examples = torch.load(data_path) 30 | self.dataset.examples = examples 31 | 32 | def _bucket_by_seq_length(self, data): 33 | for ex in data: 34 | length = self.example_length_fn(ex) 35 | 36 | idx = None 37 | for i, boundary in enumerate(self.boundaries): 38 | if length <= boundary: 39 | idx = i 40 | break 41 | assert idx is not None 42 | 43 | self.buckets[idx].append(ex) 44 | if len(self.buckets[idx]) >= self.batch_sizes[idx]: 45 | yield self.buckets[idx] 46 | self.buckets[idx] = [] 47 | 48 | def _bucket_boundaries(self, max_length, min_length=8, 49 | length_bucket_step=1.1): 50 | x = min_length 51 | boundaries = [] 52 | while x < max_length: 53 | boundaries.append(x) 54 | x = max(x + 1, int(x * length_bucket_step)) 55 | return boundaries + [max_length] 56 | 57 | def _batch_sizes(self, batch_size): 58 | batch_sizes = [ 59 | max(1, batch_size // length) for length in self.boundaries 60 | ] 61 | max_batch_size = max(batch_sizes) 62 | highly_composite_numbers = [ 63 | 1, 2, 4, 6, 12, 24, 36, 48, 60, 120, 180, 240, 360, 720, 840, 1260, 64 | 1680, 2520, 5040, 7560, 10080, 15120, 20160, 25200, 27720, 45360, 65 | 50400, 55440, 83160, 110880, 166320, 221760, 277200, 332640, 66 | 498960, 554400, 665280, 720720, 1081080, 1441440, 2162160, 2882880, 67 | 3603600, 4324320, 6486480, 7207200, 8648640, 10810800, 14414400, 68 | 17297280, 21621600, 32432400, 36756720, 43243200, 61261200, 69 | 73513440, 110270160 70 | ] 71 | window_size = max( 72 | [i for i in highly_composite_numbers if i <= 3 * max_batch_size]) 73 | divisors = [i for i in range(1, window_size + 1) 74 | if window_size % i == 0] 75 | return [max([d for d in divisors if d <= bs]) for bs in batch_sizes] 76 | 77 | 78 | def pickles_to_torch(data_paths): 79 | print("Refining pickle data...") 80 | for data_path in tqdm(data_paths, ascii=True): 81 | examples = [] 82 | with open(data_path, 'rb') as f: 83 | while True: 84 | try: 85 | example = pickle.load(f) 86 | except EOFError: 87 | break 88 | examples.append(example) 89 | 90 | with open(data_path, 'wb') as f: 91 | torch.save(examples, f) 92 | -------------------------------------------------------------------------------- /dataset/lm.py: -------------------------------------------------------------------------------- 1 | from collections import Counter, OrderedDict 2 | import glob 3 | import io 4 | import os 5 | import pickle 6 | 7 | import torch 8 | from torchtext import data 9 | from tqdm import tqdm 10 | 11 | from dataset import common 12 | 13 | # pylint: disable=arguments-differ 14 | 15 | 16 | def split_tokenizer(x): 17 | return x.split() 18 | 19 | 20 | def read_examples(paths, fields, data_dir, mode, filter_pred, num_shard): 21 | data_path_fmt = data_dir + '/examples-' + mode + '-{}.pt' 22 | data_paths = [data_path_fmt.format(i) for i in range(num_shard)] 23 | writers = [open(data_path, 'wb') for data_path in data_paths] 24 | shard = 0 25 | 26 | for path in paths: 27 | print("Preprocessing {}".format(path)) 28 | 29 | with io.open(path, mode='r', encoding='utf-8') as trg_file: 30 | for trg_line in tqdm(trg_file, ascii=True): 31 | trg_line = trg_line.strip() 32 | if trg_line == '': 33 | continue 34 | 35 | example = data.Example.fromlist([trg_line], fields) 36 | if not filter_pred(example): 37 | continue 38 | 39 | pickle.dump(example, writers[shard]) 40 | shard = (shard + 1) % num_shard 41 | 42 | for writer in writers: 43 | writer.close() 44 | 45 | # Reload pickled objects, and save them again as a list. 46 | common.pickles_to_torch(data_paths) 47 | 48 | examples = torch.load(data_paths[0]) 49 | return examples, data_paths 50 | 51 | 52 | class LM1b(data.Dataset): 53 | urls = ["http://www.statmt.org/lm-benchmark/" 54 | "1-billion-word-language-modeling-benchmark-r13output.tar.gz"] 55 | name = 'lm1b' 56 | dirname = '' 57 | 58 | @staticmethod 59 | def sort_key(ex): 60 | return len(ex.trg) 61 | 62 | @classmethod 63 | def splits(cls, fields, data_dir, root='.data', **kwargs): 64 | if not isinstance(fields[0], (tuple, list)): 65 | fields = [('trg', fields[0])] 66 | 67 | filter_pred = kwargs['filter_pred'] 68 | 69 | expected_dir = os.path.join(root, cls.name) 70 | path = (expected_dir if os.path.exists(expected_dir) 71 | else cls.download(root)) 72 | 73 | lm_data_dir = "1-billion-word-language-modeling-benchmark-r13output" 74 | 75 | train_files = [ 76 | os.path.join(path, 77 | lm_data_dir, 78 | "training-monolingual.tokenized.shuffled", 79 | "news.en-%05d-of-00100" % i) for i in range(1, 100) 80 | ] 81 | train_examples, data_paths = \ 82 | read_examples(train_files, fields, data_dir, 'train', 83 | filter_pred, 100) 84 | 85 | val_files = [ 86 | os.path.join(path, 87 | lm_data_dir, 88 | "heldout-monolingual.tokenized.shuffled", 89 | "news.en.heldout-00000-of-00050") 90 | ] 91 | val_examples, _ = read_examples(val_files, fields, data_dir, 92 | 'val', filter_pred, 1) 93 | 94 | train_data = cls(train_examples, fields, **kwargs) 95 | val_data = cls(val_examples, fields, **kwargs) 96 | return (train_data, val_data, data_paths) 97 | 98 | 99 | def len_of_example(example): 100 | return len(example.trg) + 1 101 | 102 | 103 | def build_vocabs(trg_field, data_paths): 104 | trg_counter = Counter() 105 | for data_path in tqdm(data_paths, ascii=True): 106 | examples = torch.load(data_path) 107 | for x in examples: 108 | trg_counter.update(x.trg) 109 | 110 | specials = list(OrderedDict.fromkeys( 111 | tok for tok in [trg_field.unk_token, 112 | trg_field.pad_token, 113 | trg_field.init_token, 114 | trg_field.eos_token] 115 | if tok is not None)) 116 | trg_field.vocab = trg_field.vocab_cls(trg_counter, specials=specials, 117 | min_freq=300) 118 | 119 | 120 | def prepare(max_length, batch_size, device, opt, data_dir): 121 | pad = '' 122 | load_preprocessed = os.path.exists(data_dir + '/target.pt') 123 | 124 | def filter_pred(x): 125 | return len(x.trg) < max_length 126 | 127 | if load_preprocessed: 128 | print("Loading preprocessed data...") 129 | trg_field = torch.load(data_dir + '/target.pt')['field'] 130 | 131 | data_paths = glob.glob(data_dir + '/examples-train-*.pt') 132 | examples_train = torch.load(data_paths[0]) 133 | examples_val = torch.load(data_dir + '/examples-val-0.pt') 134 | 135 | fields = [('trg', trg_field)] 136 | train = LM1b(examples_train, fields, filter_pred=filter_pred) 137 | val = LM1b(examples_val, fields, filter_pred=filter_pred) 138 | else: 139 | trg_field = data.Field(tokenize=split_tokenizer, batch_first=True, 140 | pad_token=pad, lower=True, eos_token='') 141 | 142 | print("Loading data... (this may take a while)") 143 | train, val, data_paths = \ 144 | LM1b.splits(fields=(trg_field,), 145 | data_dir=data_dir, 146 | filter_pred=filter_pred) 147 | # fields = [('trg', trg_field)] 148 | # data_paths = glob.glob(data_dir + '/examples-train-*.pt') 149 | # examples_train = torch.load(data_paths[0]) 150 | # examples_val = torch.load(data_dir + '/examples-val-0.pt') 151 | # train = LM1b(examples_train, fields, filter_pred=filter_pred) 152 | # val = LM1b(examples_val, fields, filter_pred=filter_pred) 153 | 154 | print("Building vocabs... (this may take a while)") 155 | build_vocabs(trg_field, data_paths) 156 | 157 | print("Creating iterators...") 158 | train_iter, val_iter = common.BucketByLengthIterator.splits( 159 | (train, val), 160 | data_paths=data_paths, 161 | batch_size=batch_size, 162 | device=device, 163 | max_length=max_length, 164 | example_length_fn=len_of_example) 165 | 166 | opt.src_vocab_size = None 167 | opt.trg_vocab_size = len(trg_field.vocab) 168 | opt.src_pad_idx = None 169 | opt.trg_pad_idx = trg_field.vocab.stoi[pad] 170 | opt.has_inputs = False 171 | 172 | if not load_preprocessed: 173 | torch.save({'pad_idx': opt.trg_pad_idx, 'field': trg_field}, 174 | data_dir + '/target.pt') 175 | 176 | return train_iter, val_iter, opt 177 | -------------------------------------------------------------------------------- /dataset/problem.py: -------------------------------------------------------------------------------- 1 | 2 | def prepare(problem_set, data_dir, max_length, batch_size, device, opt): 3 | if problem_set not in ['wmt32k', 'lm1b']: 4 | raise Exception("only ['wmt32k', 'lm1b'] problem set supported.") 5 | 6 | setattr(opt, 'share_target_embedding', False) 7 | setattr(opt, 'has_inputs', True) 8 | 9 | if problem_set == 'wmt32k': 10 | from dataset import translation 11 | train_iter, val_iter, opt = \ 12 | translation.prepare(max_length, batch_size, device, opt, data_dir) 13 | elif problem_set == 'lm1b': 14 | from dataset import lm 15 | train_iter, val_iter, opt = \ 16 | lm.prepare(max_length, batch_size, device, opt, data_dir) 17 | 18 | return train_iter, val_iter, opt.src_vocab_size, opt.trg_vocab_size, opt 19 | -------------------------------------------------------------------------------- /dataset/translation.py: -------------------------------------------------------------------------------- 1 | from collections import Counter, OrderedDict 2 | import glob 3 | import io 4 | import os 5 | import pickle 6 | import re 7 | 8 | import torch 9 | from torchtext import data 10 | import spacy 11 | from tqdm import tqdm 12 | 13 | from dataset import common 14 | 15 | # pylint: disable=arguments-differ 16 | 17 | spacy_de = spacy.load('de') 18 | spacy_en = spacy.load('en') 19 | 20 | url = re.compile('(.*)') 21 | 22 | 23 | def tokenize_de(text): 24 | return [tok.text for tok in spacy_de.tokenizer(url.sub('@URL@', text))] 25 | 26 | 27 | def tokenize_en(text): 28 | return [tok.text for tok in spacy_en.tokenizer(url.sub('@URL@', text))] 29 | 30 | 31 | def read_examples(paths, exts, fields, data_dir, mode, filter_pred, num_shard): 32 | data_path_fmt = data_dir + '/examples-' + mode + '-{}.pt' 33 | data_paths = [data_path_fmt.format(i) for i in range(num_shard)] 34 | writers = [open(data_path, 'wb') for data_path in data_paths] 35 | shard = 0 36 | 37 | for path in paths: 38 | print("Preprocessing {}".format(path)) 39 | src_path, trg_path = tuple(path + x for x in exts) 40 | 41 | with io.open(src_path, mode='r', encoding='utf-8') as src_file, \ 42 | io.open(trg_path, mode='r', encoding='utf-8') as trg_file: 43 | for src_line, trg_line in tqdm(zip(src_file, trg_file), 44 | ascii=True): 45 | src_line, trg_line = src_line.strip(), trg_line.strip() 46 | if src_line == '' or trg_line == '': 47 | continue 48 | 49 | example = data.Example.fromlist( 50 | [src_line, trg_line], fields) 51 | if not filter_pred(example): 52 | continue 53 | 54 | pickle.dump(example, writers[shard]) 55 | shard = (shard + 1) % num_shard 56 | 57 | for writer in writers: 58 | writer.close() 59 | 60 | # Reload pickled objects, and save them again as a list. 61 | common.pickles_to_torch(data_paths) 62 | 63 | examples = torch.load(data_paths[0]) 64 | return examples, data_paths 65 | 66 | 67 | class WMT32k(data.Dataset): 68 | urls = ['http://data.statmt.org/wmt18/translation-task/' 69 | 'training-parallel-nc-v13.tgz', 70 | 'http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz', 71 | 'http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz', 72 | 'http://data.statmt.org/wmt17/translation-task/dev.tgz'] 73 | name = 'wmt32k' 74 | dirname = '' 75 | 76 | @staticmethod 77 | def sort_key(ex): 78 | return data.interleave_keys(len(ex.src), len(ex.trg)) 79 | 80 | @classmethod 81 | def splits(cls, exts, fields, data_dir, root='.data', **kwargs): 82 | if not isinstance(fields[0], (tuple, list)): 83 | fields = [('src', fields[0]), ('trg', fields[1])] 84 | 85 | filter_pred = kwargs['filter_pred'] 86 | 87 | expected_dir = os.path.join(root, cls.name) 88 | path = (expected_dir if os.path.exists(expected_dir) 89 | else cls.download(root)) 90 | 91 | train_files = ['training-parallel-nc-v13/news-commentary-v13.de-en', 92 | 'commoncrawl.de-en', 93 | 'training/europarl-v7.de-en'] 94 | train_files = map(lambda x: os.path.join(path, x), train_files) 95 | train_examples, data_paths = \ 96 | read_examples(train_files, exts, fields, data_dir, 'train', 97 | filter_pred, 100) 98 | 99 | val_files = [os.path.join(path, 'dev/newstest2013')] 100 | val_examples, _ = read_examples(val_files, exts, fields, data_dir, 101 | 'val', filter_pred, 1) 102 | 103 | train_data = cls(train_examples, fields, **kwargs) 104 | val_data = cls(val_examples, fields, **kwargs) 105 | return (train_data, val_data, data_paths) 106 | 107 | 108 | def len_of_example(example): 109 | return max(len(example.src) + 1, len(example.trg) + 1) 110 | 111 | 112 | def build_vocabs(src_field, trg_field, data_paths): 113 | src_counter = Counter() 114 | trg_counter = Counter() 115 | for data_path in tqdm(data_paths, ascii=True): 116 | examples = torch.load(data_path) 117 | for x in examples: 118 | src_counter.update(x.src) 119 | trg_counter.update(x.trg) 120 | 121 | specials = list(OrderedDict.fromkeys( 122 | tok for tok in [src_field.unk_token, 123 | src_field.pad_token, 124 | src_field.init_token, 125 | src_field.eos_token] 126 | if tok is not None)) 127 | src_field.vocab = src_field.vocab_cls(src_counter, specials=specials, 128 | min_freq=50) 129 | trg_field.vocab = trg_field.vocab_cls(trg_counter, specials=specials, 130 | min_freq=50) 131 | 132 | 133 | def prepare(max_length, batch_size, device, opt, data_dir): 134 | pad = '' 135 | load_preprocessed = os.path.exists(data_dir + '/source.pt') 136 | 137 | def filter_pred(x): 138 | return len(x.src) < max_length and len(x.trg) < max_length 139 | 140 | if load_preprocessed: 141 | print("Loading preprocessed data...") 142 | src_field = torch.load(data_dir + '/source.pt')['field'] 143 | trg_field = torch.load(data_dir + '/target.pt')['field'] 144 | 145 | data_paths = glob.glob(data_dir + '/examples-train-*.pt') 146 | examples_train = torch.load(data_paths[0]) 147 | examples_val = torch.load(data_dir + '/examples-val-0.pt') 148 | 149 | fields = [('src', src_field), ('trg', trg_field)] 150 | train = WMT32k(examples_train, fields, filter_pred=filter_pred) 151 | val = WMT32k(examples_val, fields, filter_pred=filter_pred) 152 | else: 153 | src_field = data.Field(tokenize=tokenize_de, batch_first=True, 154 | pad_token=pad, lower=True, eos_token='') 155 | trg_field = data.Field(tokenize=tokenize_en, batch_first=True, 156 | pad_token=pad, lower=True, eos_token='') 157 | 158 | print("Loading data... (this may take a while)") 159 | train, val, data_paths = \ 160 | WMT32k.splits(exts=('.de', '.en'), 161 | fields=(src_field, trg_field), 162 | data_dir=data_dir, 163 | filter_pred=filter_pred) 164 | 165 | print("Building vocabs... (this may take a while)") 166 | build_vocabs(src_field, trg_field, data_paths) 167 | 168 | print("Creating iterators...") 169 | train_iter, val_iter = common.BucketByLengthIterator.splits( 170 | (train, val), 171 | data_paths=data_paths, 172 | batch_size=batch_size, 173 | device=device, 174 | max_length=max_length, 175 | example_length_fn=len_of_example) 176 | 177 | opt.src_vocab_size = len(src_field.vocab) 178 | opt.trg_vocab_size = len(trg_field.vocab) 179 | opt.src_pad_idx = src_field.vocab.stoi[pad] 180 | opt.trg_pad_idx = trg_field.vocab.stoi[pad] 181 | 182 | if not load_preprocessed: 183 | torch.save({'pad_idx': opt.src_pad_idx, 'field': src_field}, 184 | data_dir + '/source.pt') 185 | torch.save({'pad_idx': opt.trg_pad_idx, 'field': trg_field}, 186 | data_dir + '/target.pt') 187 | 188 | return train_iter, val_iter, opt 189 | -------------------------------------------------------------------------------- /decoder.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from utils import utils 8 | 9 | # pylint: disable=not-callable 10 | 11 | 12 | def encode_inputs(sentence, model, src_data, beam_size, device): 13 | inputs = src_data['field'].preprocess(sentence) 14 | inputs.append(src_data['field'].eos_token) 15 | inputs = [inputs] 16 | inputs = src_data['field'].process(inputs, device=device) 17 | with torch.no_grad(): 18 | src_mask = utils.create_pad_mask(inputs, src_data['pad_idx']) 19 | enc_output = model.encode(inputs, src_mask) 20 | enc_output = enc_output.repeat(beam_size, 1, 1) 21 | return enc_output, src_mask 22 | 23 | 24 | def update_targets(targets, best_indices, idx, vocab_size): 25 | best_tensor_indices = torch.div(best_indices, vocab_size) 26 | best_token_indices = torch.fmod(best_indices, vocab_size) 27 | new_batch = torch.index_select(targets, 0, best_tensor_indices) 28 | new_batch[:, idx] = best_token_indices 29 | return new_batch 30 | 31 | 32 | def get_result_sentence(indices_history, trg_data, vocab_size): 33 | result = [] 34 | k = 0 35 | for best_indices in indices_history[::-1]: 36 | best_idx = best_indices[k] 37 | # TODO: get this vocab_size from target.pt? 38 | k = best_idx // vocab_size 39 | best_token_idx = best_idx % vocab_size 40 | best_token = trg_data['field'].vocab.itos[best_token_idx] 41 | result.append(best_token) 42 | return ' '.join(result[::-1]) 43 | 44 | 45 | def main(): 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument('--data_dir', type=str, required=True) 48 | parser.add_argument('--model_dir', type=str, required=True) 49 | parser.add_argument('--max_length', type=int, default=100) 50 | parser.add_argument('--beam_size', type=int, default=4) 51 | parser.add_argument('--alpha', type=float, default=0.6) 52 | parser.add_argument('--no_cuda', action='store_true') 53 | parser.add_argument('--translate', action='store_true') 54 | args = parser.parse_args() 55 | 56 | beam_size = args.beam_size 57 | 58 | # Load fields. 59 | if args.translate: 60 | src_data = torch.load(args.data_dir + '/source.pt') 61 | trg_data = torch.load(args.data_dir + '/target.pt') 62 | 63 | # Load a saved model. 64 | device = torch.device('cpu' if args.no_cuda else 'cuda') 65 | model = utils.load_checkpoint(args.model_dir, device) 66 | 67 | pads = torch.tensor([trg_data['pad_idx']] * beam_size, device=device) 68 | pads = pads.unsqueeze(-1) 69 | 70 | # We'll find a target sequence by beam search. 71 | scores_history = [torch.zeros((beam_size,), dtype=torch.float, 72 | device=device)] 73 | indices_history = [] 74 | cache = {} 75 | 76 | eos_idx = trg_data['field'].vocab.stoi[trg_data['field'].eos_token] 77 | 78 | if args.translate: 79 | sentence = input('Source? ') 80 | 81 | # Encoding inputs. 82 | if args.translate: 83 | start_time = time.time() 84 | enc_output, src_mask = encode_inputs(sentence, model, src_data, 85 | beam_size, device) 86 | targets = pads 87 | start_idx = 0 88 | else: 89 | enc_output, src_mask = None, None 90 | sentence = input('Target? ').split() 91 | for idx, _ in enumerate(sentence): 92 | sentence[idx] = trg_data['field'].vocab.stoi[sentence[idx]] 93 | sentence.append(trg_data['pad_idx']) 94 | targets = torch.tensor([sentence], device=device) 95 | start_idx = targets.size(1) - 1 96 | start_time = time.time() 97 | 98 | with torch.no_grad(): 99 | for idx in range(start_idx, args.max_length): 100 | if idx > start_idx: 101 | targets = torch.cat((targets, pads), dim=1) 102 | t_self_mask = utils.create_trg_self_mask(targets.size()[1], 103 | device=targets.device) 104 | 105 | t_mask = utils.create_pad_mask(targets, trg_data['pad_idx']) 106 | pred = model.decode(targets, enc_output, src_mask, 107 | t_self_mask, t_mask, cache) 108 | pred = pred[:, idx].squeeze(1) 109 | vocab_size = pred.size(1) 110 | 111 | pred = F.log_softmax(pred, dim=1) 112 | if idx == start_idx: 113 | scores = pred[0] 114 | else: 115 | scores = scores_history[-1].unsqueeze(1) + pred 116 | length_penalty = pow(((5. + idx + 1.) / 6.), args.alpha) 117 | scores = scores / length_penalty 118 | scores = scores.view(-1) 119 | 120 | best_scores, best_indices = scores.topk(beam_size, 0) 121 | scores_history.append(best_scores) 122 | indices_history.append(best_indices) 123 | 124 | # Stop searching when the best output of beam is EOS. 125 | if best_indices[0].item() % vocab_size == eos_idx: 126 | break 127 | 128 | targets = update_targets(targets, best_indices, idx, vocab_size) 129 | 130 | result = get_result_sentence(indices_history, trg_data, vocab_size) 131 | print("Result: {}".format(result)) 132 | 133 | print("Elapsed Time: {:.2f} sec".format(time.time() - start_time)) 134 | 135 | 136 | if __name__ == '__main__': 137 | main() 138 | -------------------------------------------------------------------------------- /model/fast_transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from tcop.masked_softmax import MaskedSoftmax 7 | 8 | from utils import utils 9 | from model.transformer import FeedForwardNetwork 10 | 11 | # pylint: disable=arguments-differ 12 | 13 | 14 | def initialize_weight(x): 15 | nn.init.xavier_uniform_(x.weight) 16 | if x.bias is not None: 17 | nn.init.constant_(x.bias, 0) 18 | 19 | 20 | class MultiHeadAttention(nn.Module): 21 | def __init__(self, hidden_size, dropout_rate, head_size=8): 22 | super(MultiHeadAttention, self).__init__() 23 | 24 | self.head_size = head_size 25 | 26 | self.att_size = att_size = hidden_size // head_size 27 | self.scale = att_size ** -0.5 28 | 29 | self.linear_q = nn.Linear(hidden_size, head_size * att_size, bias=False) 30 | self.linear_k = nn.Linear(hidden_size, head_size * att_size, bias=False) 31 | self.linear_v = nn.Linear(hidden_size, head_size * att_size, bias=False) 32 | initialize_weight(self.linear_q) 33 | initialize_weight(self.linear_k) 34 | initialize_weight(self.linear_v) 35 | 36 | self.att_dropout = nn.Dropout(dropout_rate) 37 | 38 | self.output_layer = nn.Linear(head_size * att_size, hidden_size, 39 | bias=False) 40 | initialize_weight(self.output_layer) 41 | 42 | def forward(self, q, k, v, mask, cache=None): 43 | orig_q_size = q.size() 44 | 45 | d_k = self.att_size 46 | d_v = self.att_size 47 | batch_size = q.size(0) 48 | 49 | # head_i = Attention(Q(W^Q)_i, K(W^K)_i, V(W^V)_i) 50 | q = self.linear_q(q).view(batch_size, -1, self.head_size, d_k) 51 | if cache is not None and 'encdec_k' in cache: 52 | k, v = cache['encdec_k'], cache['encdec_v'] 53 | else: 54 | k = self.linear_k(k).view(batch_size, -1, self.head_size, d_k) 55 | v = self.linear_v(v).view(batch_size, -1, self.head_size, d_v) 56 | 57 | if cache is not None: 58 | cache['encdec_k'], cache['encdec_v'] = k, v 59 | 60 | q = q.transpose(1, 2) # [b, h, q_len, d_k] 61 | v = v.transpose(1, 2) # [b, h, v_len, d_v] 62 | k = k.transpose(1, 2).transpose(2, 3) # [b, h, d_k, k_len] 63 | 64 | # Scaled Dot-Product Attention. 65 | # Attention(Q, K, V) = softmax((QK^T)/sqrt(d_k))V 66 | x = torch.matmul(q, k) # [b, h, q_len, k_len] 67 | x = MaskedSoftmax.apply(x, mask, self.scale) 68 | 69 | x = self.att_dropout(x) 70 | x = x.matmul(v) # [b, h, q_len, attn] 71 | 72 | x = x.transpose(1, 2).contiguous() # [b, q_len, h, attn] 73 | x = x.view(batch_size, -1, self.head_size * d_v) 74 | 75 | x = self.output_layer(x) 76 | 77 | assert x.size() == orig_q_size 78 | return x 79 | 80 | 81 | class EncoderLayer(nn.Module): 82 | def __init__(self, hidden_size, filter_size, dropout_rate): 83 | super(EncoderLayer, self).__init__() 84 | 85 | self.self_attention_norm = nn.LayerNorm(hidden_size, eps=1e-6) 86 | self.self_attention = MultiHeadAttention(hidden_size, dropout_rate) 87 | self.self_attention_dropout = nn.Dropout(dropout_rate) 88 | 89 | self.ffn_norm = nn.LayerNorm(hidden_size, eps=1e-6) 90 | self.ffn = FeedForwardNetwork(hidden_size, filter_size, dropout_rate) 91 | self.ffn_dropout = nn.Dropout(dropout_rate) 92 | 93 | def forward(self, x, mask): # pylint: disable=arguments-differ 94 | y = self.self_attention_norm(x) 95 | y = self.self_attention(y, y, y, mask) 96 | y = self.self_attention_dropout(y) 97 | x = x + y 98 | 99 | y = self.ffn_norm(x) 100 | y = self.ffn(y) 101 | y = self.ffn_dropout(y) 102 | x = x + y 103 | return x 104 | 105 | 106 | class DecoderLayer(nn.Module): 107 | def __init__(self, hidden_size, filter_size, dropout_rate): 108 | super(DecoderLayer, self).__init__() 109 | 110 | self.self_attention_norm = nn.LayerNorm(hidden_size, eps=1e-6) 111 | self.self_attention = MultiHeadAttention(hidden_size, dropout_rate) 112 | self.self_attention_dropout = nn.Dropout(dropout_rate) 113 | 114 | self.enc_dec_attention_norm = nn.LayerNorm(hidden_size, eps=1e-6) 115 | self.enc_dec_attention = MultiHeadAttention(hidden_size, dropout_rate) 116 | self.enc_dec_attention_dropout = nn.Dropout(dropout_rate) 117 | 118 | self.ffn_norm = nn.LayerNorm(hidden_size, eps=1e-6) 119 | self.ffn = FeedForwardNetwork(hidden_size, filter_size, dropout_rate) 120 | self.ffn_dropout = nn.Dropout(dropout_rate) 121 | 122 | def forward(self, x, enc_output, self_mask, i_mask, cache): 123 | y = self.self_attention_norm(x) 124 | y = self.self_attention(y, y, y, self_mask) 125 | y = self.self_attention_dropout(y) 126 | x = x + y 127 | 128 | if enc_output is not None: 129 | y = self.enc_dec_attention_norm(x) 130 | y = self.enc_dec_attention(y, enc_output, enc_output, i_mask, 131 | cache) 132 | y = self.enc_dec_attention_dropout(y) 133 | x = x + y 134 | 135 | y = self.ffn_norm(x) 136 | y = self.ffn(y) 137 | y = self.ffn_dropout(y) 138 | x = x + y 139 | return x 140 | 141 | 142 | class Encoder(nn.Module): 143 | def __init__(self, hidden_size, filter_size, dropout_rate, n_layers): 144 | super(Encoder, self).__init__() 145 | 146 | encoders = [EncoderLayer(hidden_size, filter_size, dropout_rate) 147 | for _ in range(n_layers)] 148 | self.layers = nn.ModuleList(encoders) 149 | 150 | self.last_norm = nn.LayerNorm(hidden_size, eps=1e-6) 151 | 152 | def forward(self, inputs, mask): 153 | encoder_output = inputs 154 | for enc_layer in self.layers: 155 | encoder_output = enc_layer(encoder_output, mask) 156 | return self.last_norm(encoder_output) 157 | 158 | 159 | class Decoder(nn.Module): 160 | def __init__(self, hidden_size, filter_size, dropout_rate, n_layers): 161 | super(Decoder, self).__init__() 162 | 163 | decoders = [DecoderLayer(hidden_size, filter_size, dropout_rate) 164 | for _ in range(n_layers)] 165 | self.layers = nn.ModuleList(decoders) 166 | 167 | self.last_norm = nn.LayerNorm(hidden_size, eps=1e-6) 168 | 169 | def forward(self, targets, enc_output, i_mask, t_self_mask, cache): 170 | decoder_output = targets 171 | for i, dec_layer in enumerate(self.layers): 172 | layer_cache = None 173 | if cache is not None: 174 | if i not in cache: 175 | cache[i] = {} 176 | layer_cache = cache[i] 177 | decoder_output = dec_layer(decoder_output, enc_output, 178 | t_self_mask, i_mask, layer_cache) 179 | return self.last_norm(decoder_output) 180 | 181 | 182 | class FastTransformer(nn.Module): 183 | def __init__(self, i_vocab_size, t_vocab_size, 184 | n_layers=6, 185 | hidden_size=512, 186 | filter_size=2048, 187 | dropout_rate=0.1, 188 | share_target_embedding=True, 189 | has_inputs=True, 190 | src_pad_idx=None, 191 | trg_pad_idx=None): 192 | super(FastTransformer, self).__init__() 193 | 194 | self.hidden_size = hidden_size 195 | self.emb_scale = hidden_size ** 0.5 196 | self.has_inputs = has_inputs 197 | self.src_pad_idx = src_pad_idx 198 | self.trg_pad_idx = trg_pad_idx 199 | 200 | self.t_vocab_embedding = nn.Embedding(t_vocab_size, hidden_size) 201 | nn.init.normal_(self.t_vocab_embedding.weight, mean=0, 202 | std=hidden_size**-0.5) 203 | self.t_emb_dropout = nn.Dropout(dropout_rate) 204 | self.decoder = Decoder(hidden_size, filter_size, 205 | dropout_rate, n_layers) 206 | 207 | if has_inputs: 208 | if not share_target_embedding: 209 | self.i_vocab_embedding = nn.Embedding(i_vocab_size, 210 | hidden_size) 211 | nn.init.normal_(self.i_vocab_embedding.weight, mean=0, 212 | std=hidden_size**-0.5) 213 | else: 214 | self.i_vocab_embedding = self.t_vocab_embedding 215 | 216 | self.i_emb_dropout = nn.Dropout(dropout_rate) 217 | 218 | self.encoder = Encoder(hidden_size, filter_size, 219 | dropout_rate, n_layers) 220 | 221 | # For positional encoding 222 | num_timescales = self.hidden_size // 2 223 | max_timescale = 10000.0 224 | min_timescale = 1.0 225 | log_timescale_increment = ( 226 | math.log(float(max_timescale) / float(min_timescale)) / 227 | max(num_timescales - 1, 1)) 228 | inv_timescales = min_timescale * torch.exp( 229 | torch.arange(num_timescales, dtype=torch.float32) * 230 | -log_timescale_increment) 231 | self.register_buffer('inv_timescales', inv_timescales) 232 | 233 | def forward(self, inputs, targets): 234 | enc_output, i_mask = None, None 235 | if self.has_inputs: 236 | i_mask = utils.create_pad_mask(inputs, self.src_pad_idx) 237 | enc_output = self.encode(inputs, i_mask) 238 | 239 | t_mask = utils.create_pad_mask(targets, self.trg_pad_idx) 240 | target_size = targets.size()[1] 241 | t_self_mask = utils.create_trg_self_mask(target_size, 242 | device=targets.device) 243 | return self.decode(targets, enc_output, i_mask, t_self_mask, t_mask) 244 | 245 | def encode(self, inputs, i_mask): 246 | # Input embedding 247 | input_embedded = self.i_vocab_embedding(inputs) 248 | input_embedded.masked_fill_(i_mask.squeeze(1).unsqueeze(-1), 0) 249 | input_embedded *= self.emb_scale 250 | input_embedded += self.get_position_encoding(inputs) 251 | input_embedded = self.i_emb_dropout(input_embedded) 252 | 253 | i_mask = i_mask.size(2) - i_mask.sum(dim=2, dtype=torch.int32) 254 | return self.encoder(input_embedded, i_mask) 255 | 256 | def decode(self, targets, enc_output, i_mask, t_self_mask, t_mask, 257 | cache=None): 258 | # target embedding 259 | target_embedded = self.t_vocab_embedding(targets) 260 | target_embedded.masked_fill_(t_mask.squeeze(1).unsqueeze(-1), 0) 261 | 262 | # Shifting 263 | target_embedded = target_embedded[:, :-1] 264 | target_embedded = F.pad(target_embedded, (0, 0, 1, 0)) 265 | 266 | target_embedded *= self.emb_scale 267 | target_embedded += self.get_position_encoding(targets) 268 | target_embedded = self.t_emb_dropout(target_embedded) 269 | 270 | # decoder 271 | if i_mask is not None: 272 | i_mask = i_mask.size(2) - i_mask.sum(dim=2, dtype=torch.int32) 273 | t_self_mask = \ 274 | t_self_mask.size(2) - t_self_mask.sum(dim=2, dtype=torch.int32) 275 | decoder_output = self.decoder(target_embedded, enc_output, i_mask, 276 | t_self_mask, cache) 277 | # linear 278 | output = torch.matmul(decoder_output, 279 | self.t_vocab_embedding.weight.transpose(0, 1)) 280 | 281 | return output 282 | 283 | def get_position_encoding(self, x): 284 | max_length = x.size()[1] 285 | position = torch.arange(max_length, dtype=torch.float32, 286 | device=x.device) 287 | scaled_time = position.unsqueeze(1) * self.inv_timescales.unsqueeze(0) 288 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 289 | dim=1) 290 | signal = F.pad(signal, (0, 0, 0, self.hidden_size % 2)) 291 | signal = signal.view(1, max_length, self.hidden_size) 292 | return signal 293 | -------------------------------------------------------------------------------- /model/transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from utils import utils 8 | 9 | # pylint: disable=arguments-differ 10 | 11 | 12 | def initialize_weight(x): 13 | nn.init.xavier_uniform_(x.weight) 14 | if x.bias is not None: 15 | nn.init.constant_(x.bias, 0) 16 | 17 | 18 | class FeedForwardNetwork(nn.Module): 19 | def __init__(self, hidden_size, filter_size, dropout_rate): 20 | super(FeedForwardNetwork, self).__init__() 21 | 22 | self.layer1 = nn.Linear(hidden_size, filter_size) 23 | self.relu = nn.ReLU() 24 | self.dropout = nn.Dropout(dropout_rate) 25 | self.layer2 = nn.Linear(filter_size, hidden_size) 26 | 27 | initialize_weight(self.layer1) 28 | initialize_weight(self.layer2) 29 | 30 | def forward(self, x): 31 | x = self.layer1(x) 32 | x = self.relu(x) 33 | x = self.dropout(x) 34 | x = self.layer2(x) 35 | return x 36 | 37 | 38 | class MultiHeadAttention(nn.Module): 39 | def __init__(self, hidden_size, dropout_rate, head_size=8): 40 | super(MultiHeadAttention, self).__init__() 41 | 42 | self.head_size = head_size 43 | 44 | self.att_size = att_size = hidden_size // head_size 45 | self.scale = att_size ** -0.5 46 | 47 | self.linear_q = nn.Linear(hidden_size, head_size * att_size, bias=False) 48 | self.linear_k = nn.Linear(hidden_size, head_size * att_size, bias=False) 49 | self.linear_v = nn.Linear(hidden_size, head_size * att_size, bias=False) 50 | initialize_weight(self.linear_q) 51 | initialize_weight(self.linear_k) 52 | initialize_weight(self.linear_v) 53 | 54 | self.att_dropout = nn.Dropout(dropout_rate) 55 | 56 | self.output_layer = nn.Linear(head_size * att_size, hidden_size, 57 | bias=False) 58 | initialize_weight(self.output_layer) 59 | 60 | def forward(self, q, k, v, mask, cache=None): 61 | orig_q_size = q.size() 62 | 63 | d_k = self.att_size 64 | d_v = self.att_size 65 | batch_size = q.size(0) 66 | 67 | # head_i = Attention(Q(W^Q)_i, K(W^K)_i, V(W^V)_i) 68 | q = self.linear_q(q).view(batch_size, -1, self.head_size, d_k) 69 | if cache is not None and 'encdec_k' in cache: 70 | k, v = cache['encdec_k'], cache['encdec_v'] 71 | else: 72 | k = self.linear_k(k).view(batch_size, -1, self.head_size, d_k) 73 | v = self.linear_v(v).view(batch_size, -1, self.head_size, d_v) 74 | 75 | if cache is not None: 76 | cache['encdec_k'], cache['encdec_v'] = k, v 77 | 78 | q = q.transpose(1, 2) # [b, h, q_len, d_k] 79 | v = v.transpose(1, 2) # [b, h, v_len, d_v] 80 | k = k.transpose(1, 2).transpose(2, 3) # [b, h, d_k, k_len] 81 | 82 | # Scaled Dot-Product Attention. 83 | # Attention(Q, K, V) = softmax((QK^T)/sqrt(d_k))V 84 | q.mul_(self.scale) 85 | x = torch.matmul(q, k) # [b, h, q_len, k_len] 86 | x.masked_fill_(mask.unsqueeze(1), -1e9) 87 | x = torch.softmax(x, dim=3) 88 | x = self.att_dropout(x) 89 | x = x.matmul(v) # [b, h, q_len, attn] 90 | 91 | x = x.transpose(1, 2).contiguous() # [b, q_len, h, attn] 92 | x = x.view(batch_size, -1, self.head_size * d_v) 93 | 94 | x = self.output_layer(x) 95 | 96 | assert x.size() == orig_q_size 97 | return x 98 | 99 | 100 | class EncoderLayer(nn.Module): 101 | def __init__(self, hidden_size, filter_size, dropout_rate): 102 | super(EncoderLayer, self).__init__() 103 | 104 | self.self_attention_norm = nn.LayerNorm(hidden_size, eps=1e-6) 105 | self.self_attention = MultiHeadAttention(hidden_size, dropout_rate) 106 | self.self_attention_dropout = nn.Dropout(dropout_rate) 107 | 108 | self.ffn_norm = nn.LayerNorm(hidden_size, eps=1e-6) 109 | self.ffn = FeedForwardNetwork(hidden_size, filter_size, dropout_rate) 110 | self.ffn_dropout = nn.Dropout(dropout_rate) 111 | 112 | def forward(self, x, mask): # pylint: disable=arguments-differ 113 | y = self.self_attention_norm(x) 114 | y = self.self_attention(y, y, y, mask) 115 | y = self.self_attention_dropout(y) 116 | x = x + y 117 | 118 | y = self.ffn_norm(x) 119 | y = self.ffn(y) 120 | y = self.ffn_dropout(y) 121 | x = x + y 122 | return x 123 | 124 | 125 | class DecoderLayer(nn.Module): 126 | def __init__(self, hidden_size, filter_size, dropout_rate): 127 | super(DecoderLayer, self).__init__() 128 | 129 | self.self_attention_norm = nn.LayerNorm(hidden_size, eps=1e-6) 130 | self.self_attention = MultiHeadAttention(hidden_size, dropout_rate) 131 | self.self_attention_dropout = nn.Dropout(dropout_rate) 132 | 133 | self.enc_dec_attention_norm = nn.LayerNorm(hidden_size, eps=1e-6) 134 | self.enc_dec_attention = MultiHeadAttention(hidden_size, dropout_rate) 135 | self.enc_dec_attention_dropout = nn.Dropout(dropout_rate) 136 | 137 | self.ffn_norm = nn.LayerNorm(hidden_size, eps=1e-6) 138 | self.ffn = FeedForwardNetwork(hidden_size, filter_size, dropout_rate) 139 | self.ffn_dropout = nn.Dropout(dropout_rate) 140 | 141 | def forward(self, x, enc_output, self_mask, i_mask, cache): 142 | y = self.self_attention_norm(x) 143 | y = self.self_attention(y, y, y, self_mask) 144 | y = self.self_attention_dropout(y) 145 | x = x + y 146 | 147 | if enc_output is not None: 148 | y = self.enc_dec_attention_norm(x) 149 | y = self.enc_dec_attention(y, enc_output, enc_output, i_mask, 150 | cache) 151 | y = self.enc_dec_attention_dropout(y) 152 | x = x + y 153 | 154 | y = self.ffn_norm(x) 155 | y = self.ffn(y) 156 | y = self.ffn_dropout(y) 157 | x = x + y 158 | return x 159 | 160 | 161 | class Encoder(nn.Module): 162 | def __init__(self, hidden_size, filter_size, dropout_rate, n_layers): 163 | super(Encoder, self).__init__() 164 | 165 | encoders = [EncoderLayer(hidden_size, filter_size, dropout_rate) 166 | for _ in range(n_layers)] 167 | self.layers = nn.ModuleList(encoders) 168 | 169 | self.last_norm = nn.LayerNorm(hidden_size, eps=1e-6) 170 | 171 | def forward(self, inputs, mask): 172 | encoder_output = inputs 173 | for enc_layer in self.layers: 174 | encoder_output = enc_layer(encoder_output, mask) 175 | return self.last_norm(encoder_output) 176 | 177 | 178 | class Decoder(nn.Module): 179 | def __init__(self, hidden_size, filter_size, dropout_rate, n_layers): 180 | super(Decoder, self).__init__() 181 | 182 | decoders = [DecoderLayer(hidden_size, filter_size, dropout_rate) 183 | for _ in range(n_layers)] 184 | self.layers = nn.ModuleList(decoders) 185 | 186 | self.last_norm = nn.LayerNorm(hidden_size, eps=1e-6) 187 | 188 | def forward(self, targets, enc_output, i_mask, t_self_mask, cache): 189 | decoder_output = targets 190 | for i, dec_layer in enumerate(self.layers): 191 | layer_cache = None 192 | if cache is not None: 193 | if i not in cache: 194 | cache[i] = {} 195 | layer_cache = cache[i] 196 | decoder_output = dec_layer(decoder_output, enc_output, 197 | t_self_mask, i_mask, layer_cache) 198 | return self.last_norm(decoder_output) 199 | 200 | 201 | class Transformer(nn.Module): 202 | def __init__(self, i_vocab_size, t_vocab_size, 203 | n_layers=6, 204 | hidden_size=512, 205 | filter_size=2048, 206 | dropout_rate=0.1, 207 | share_target_embedding=True, 208 | has_inputs=True, 209 | src_pad_idx=None, 210 | trg_pad_idx=None): 211 | super(Transformer, self).__init__() 212 | 213 | self.hidden_size = hidden_size 214 | self.emb_scale = hidden_size ** 0.5 215 | self.has_inputs = has_inputs 216 | self.src_pad_idx = src_pad_idx 217 | self.trg_pad_idx = trg_pad_idx 218 | 219 | self.t_vocab_embedding = nn.Embedding(t_vocab_size, hidden_size) 220 | nn.init.normal_(self.t_vocab_embedding.weight, mean=0, 221 | std=hidden_size**-0.5) 222 | self.t_emb_dropout = nn.Dropout(dropout_rate) 223 | self.decoder = Decoder(hidden_size, filter_size, 224 | dropout_rate, n_layers) 225 | 226 | if has_inputs: 227 | if not share_target_embedding: 228 | self.i_vocab_embedding = nn.Embedding(i_vocab_size, 229 | hidden_size) 230 | nn.init.normal_(self.i_vocab_embedding.weight, mean=0, 231 | std=hidden_size**-0.5) 232 | else: 233 | self.i_vocab_embedding = self.t_vocab_embedding 234 | 235 | self.i_emb_dropout = nn.Dropout(dropout_rate) 236 | 237 | self.encoder = Encoder(hidden_size, filter_size, 238 | dropout_rate, n_layers) 239 | 240 | # For positional encoding 241 | num_timescales = self.hidden_size // 2 242 | max_timescale = 10000.0 243 | min_timescale = 1.0 244 | log_timescale_increment = ( 245 | math.log(float(max_timescale) / float(min_timescale)) / 246 | max(num_timescales - 1, 1)) 247 | inv_timescales = min_timescale * torch.exp( 248 | torch.arange(num_timescales, dtype=torch.float32) * 249 | -log_timescale_increment) 250 | self.register_buffer('inv_timescales', inv_timescales) 251 | 252 | def forward(self, inputs, targets): 253 | enc_output, i_mask = None, None 254 | if self.has_inputs: 255 | i_mask = utils.create_pad_mask(inputs, self.src_pad_idx) 256 | enc_output = self.encode(inputs, i_mask) 257 | 258 | t_mask = utils.create_pad_mask(targets, self.trg_pad_idx) 259 | target_size = targets.size()[1] 260 | t_self_mask = utils.create_trg_self_mask(target_size, 261 | device=targets.device) 262 | return self.decode(targets, enc_output, i_mask, t_self_mask, t_mask) 263 | 264 | def encode(self, inputs, i_mask): 265 | # Input embedding 266 | input_embedded = self.i_vocab_embedding(inputs) 267 | input_embedded.masked_fill_(i_mask.squeeze(1).unsqueeze(-1), 0) 268 | input_embedded *= self.emb_scale 269 | input_embedded += self.get_position_encoding(inputs) 270 | input_embedded = self.i_emb_dropout(input_embedded) 271 | 272 | return self.encoder(input_embedded, i_mask) 273 | 274 | def decode(self, targets, enc_output, i_mask, t_self_mask, t_mask, 275 | cache=None): 276 | # target embedding 277 | target_embedded = self.t_vocab_embedding(targets) 278 | target_embedded.masked_fill_(t_mask.squeeze(1).unsqueeze(-1), 0) 279 | 280 | # Shifting 281 | target_embedded = target_embedded[:, :-1] 282 | target_embedded = F.pad(target_embedded, (0, 0, 1, 0)) 283 | 284 | target_embedded *= self.emb_scale 285 | target_embedded += self.get_position_encoding(targets) 286 | target_embedded = self.t_emb_dropout(target_embedded) 287 | 288 | # decoder 289 | decoder_output = self.decoder(target_embedded, enc_output, i_mask, 290 | t_self_mask, cache) 291 | # linear 292 | output = torch.matmul(decoder_output, 293 | self.t_vocab_embedding.weight.transpose(0, 1)) 294 | 295 | return output 296 | 297 | def get_position_encoding(self, x): 298 | max_length = x.size()[1] 299 | position = torch.arange(max_length, dtype=torch.float32, 300 | device=x.device) 301 | scaled_time = position.unsqueeze(1) * self.inv_timescales.unsqueeze(0) 302 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 303 | dim=1) 304 | signal = F.pad(signal, (0, 0, 0, self.hidden_size % 2)) 305 | signal = signal.view(1, max_length, self.hidden_size) 306 | return signal 307 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorboardX 2 | torchtext==0.3.1 3 | tqdm 4 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import torch 6 | from tensorboardX import SummaryWriter 7 | from tqdm import tqdm 8 | 9 | from dataset import problem 10 | from utils.optimizer import LRScheduler 11 | from utils import utils 12 | 13 | 14 | def summarize_train(writer, global_step, last_time, model, opt, 15 | inputs, targets, optimizer, loss, pred, ans): 16 | if opt.summary_grad: 17 | for name, param in model.named_parameters(): 18 | if not param.requires_grad: 19 | continue 20 | 21 | norm = torch.norm(param.grad.data.view(-1)) 22 | writer.add_scalar('gradient_norm/' + name, norm, 23 | global_step) 24 | 25 | writer.add_scalar('input_stats/batch_size', 26 | targets.size(0), global_step) 27 | 28 | if inputs is not None: 29 | writer.add_scalar('input_stats/input_length', 30 | inputs.size(1), global_step) 31 | i_nonpad = (inputs != opt.src_pad_idx).view(-1).type(torch.float32) 32 | writer.add_scalar('input_stats/inputs_nonpadding_frac', 33 | i_nonpad.mean(), global_step) 34 | 35 | writer.add_scalar('input_stats/target_length', 36 | targets.size(1), global_step) 37 | t_nonpad = (targets != opt.trg_pad_idx).view(-1).type(torch.float32) 38 | writer.add_scalar('input_stats/target_nonpadding_frac', 39 | t_nonpad.mean(), global_step) 40 | 41 | writer.add_scalar('optimizer/learning_rate', 42 | optimizer.learning_rate(), global_step) 43 | 44 | writer.add_scalar('loss', loss.item(), global_step) 45 | 46 | acc = utils.get_accuracy(pred, ans, opt.trg_pad_idx) 47 | writer.add_scalar('training/accuracy', 48 | acc, global_step) 49 | 50 | steps_per_sec = 100.0 / (time.time() - last_time) 51 | writer.add_scalar('global_step/sec', steps_per_sec, 52 | global_step) 53 | 54 | 55 | def train(train_data, model, opt, global_step, optimizer, t_vocab_size, 56 | label_smoothing, writer): 57 | model.train() 58 | last_time = time.time() 59 | pbar = tqdm(total=len(train_data.dataset), ascii=True) 60 | for batch in train_data: 61 | inputs = None 62 | if opt.has_inputs: 63 | inputs = batch.src 64 | 65 | targets = batch.trg 66 | pred = model(inputs, targets) 67 | 68 | pred = pred.view(-1, pred.size(-1)) 69 | ans = targets.view(-1) 70 | 71 | loss = utils.get_loss(pred, ans, t_vocab_size, 72 | label_smoothing, opt.trg_pad_idx) 73 | optimizer.zero_grad() 74 | loss.backward() 75 | optimizer.step() 76 | 77 | if global_step % 100 == 0: 78 | summarize_train(writer, global_step, last_time, model, opt, 79 | inputs, targets, optimizer, loss, pred, ans) 80 | last_time = time.time() 81 | 82 | pbar.set_description('[Loss: {:.4f}]'.format(loss.item())) 83 | 84 | global_step += 1 85 | pbar.update(targets.size(0)) 86 | 87 | pbar.close() 88 | train_data.reload_examples() 89 | return global_step 90 | 91 | 92 | def validation(validation_data, model, global_step, t_vocab_size, val_writer, 93 | opt): 94 | model.eval() 95 | total_loss = 0.0 96 | total_cnt = 0 97 | for batch in validation_data: 98 | inputs = None 99 | if opt.has_inputs: 100 | inputs = batch.src 101 | targets = batch.trg 102 | 103 | with torch.no_grad(): 104 | pred = model(inputs, targets) 105 | 106 | pred = pred.view(-1, pred.size(-1)) 107 | ans = targets.view(-1) 108 | loss = utils.get_loss(pred, ans, t_vocab_size, 0, 109 | opt.trg_pad_idx) 110 | total_loss += loss.item() * len(batch) 111 | total_cnt += len(batch) 112 | 113 | val_loss = total_loss / total_cnt 114 | print("Validation Loss", val_loss) 115 | val_writer.add_scalar('loss', val_loss, global_step) 116 | return val_loss 117 | 118 | 119 | def main(): 120 | parser = argparse.ArgumentParser() 121 | parser.add_argument('--problem', required=True) 122 | parser.add_argument('--train_step', type=int, default=200) 123 | parser.add_argument('--batch_size', type=int, default=4096) 124 | parser.add_argument('--max_length', type=int, default=100) 125 | parser.add_argument('--n_layers', type=int, default=6) 126 | parser.add_argument('--hidden_size', type=int, default=512) 127 | parser.add_argument('--filter_size', type=int, default=2048) 128 | parser.add_argument('--warmup', type=int, default=16000) 129 | parser.add_argument('--val_every', type=int, default=5) 130 | parser.add_argument('--dropout', type=float, default=0.1) 131 | parser.add_argument('--label_smoothing', type=float, default=0.1) 132 | parser.add_argument('--model', type=str, default='transformer') 133 | parser.add_argument('--output_dir', type=str, default='./output') 134 | parser.add_argument('--data_dir', type=str, default='./data') 135 | parser.add_argument('--no_cuda', action='store_true') 136 | parser.add_argument('--parallel', action='store_true') 137 | parser.add_argument('--summary_grad', action='store_true') 138 | opt = parser.parse_args() 139 | 140 | device = torch.device('cpu' if opt.no_cuda else 'cuda') 141 | 142 | if not os.path.exists(opt.output_dir + '/last/models'): 143 | os.makedirs(opt.output_dir + '/last/models') 144 | if not os.path.exists(opt.data_dir): 145 | os.makedirs(opt.data_dir) 146 | 147 | train_data, validation_data, i_vocab_size, t_vocab_size, opt = \ 148 | problem.prepare(opt.problem, opt.data_dir, opt.max_length, 149 | opt.batch_size, device, opt) 150 | if i_vocab_size is not None: 151 | print("# of vocabs (input):", i_vocab_size) 152 | print("# of vocabs (target):", t_vocab_size) 153 | 154 | if opt.model == 'transformer': 155 | from model.transformer import Transformer 156 | model_fn = Transformer 157 | elif opt.model == 'fast_transformer': 158 | from model.fast_transformer import FastTransformer 159 | model_fn = FastTransformer 160 | 161 | if os.path.exists(opt.output_dir + '/last/models/last_model.pt'): 162 | print("Load a checkpoint...") 163 | last_model_path = opt.output_dir + '/last/models' 164 | model, global_step = utils.load_checkpoint(last_model_path, device, 165 | is_eval=False) 166 | else: 167 | model = model_fn(i_vocab_size, t_vocab_size, 168 | n_layers=opt.n_layers, 169 | hidden_size=opt.hidden_size, 170 | filter_size=opt.filter_size, 171 | dropout_rate=opt.dropout, 172 | share_target_embedding=opt.share_target_embedding, 173 | has_inputs=opt.has_inputs, 174 | src_pad_idx=opt.src_pad_idx, 175 | trg_pad_idx=opt.trg_pad_idx) 176 | model = model.to(device=device) 177 | global_step = 0 178 | 179 | if opt.parallel: 180 | print("Use", torch.cuda.device_count(), "GPUs") 181 | model = torch.nn.DataParallel(model) 182 | 183 | num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 184 | print("# of parameters: {}".format(num_params)) 185 | 186 | optimizer = LRScheduler( 187 | filter(lambda x: x.requires_grad, model.parameters()), 188 | opt.hidden_size, opt.warmup, step=global_step) 189 | 190 | writer = SummaryWriter(opt.output_dir + '/last') 191 | val_writer = SummaryWriter(opt.output_dir + '/last/val') 192 | best_val_loss = float('inf') 193 | 194 | for t_step in range(opt.train_step): 195 | print("Epoch", t_step) 196 | start_epoch_time = time.time() 197 | global_step = train(train_data, model, opt, global_step, 198 | optimizer, t_vocab_size, opt.label_smoothing, 199 | writer) 200 | print("Epoch Time: {:.2f} sec".format(time.time() - start_epoch_time)) 201 | 202 | if t_step % opt.val_every != 0: 203 | continue 204 | 205 | val_loss = validation(validation_data, model, global_step, 206 | t_vocab_size, val_writer, opt) 207 | utils.save_checkpoint(model, opt.output_dir + '/last/models', 208 | global_step, val_loss < best_val_loss) 209 | best_val_loss = min(val_loss, best_val_loss) 210 | 211 | 212 | if __name__ == '__main__': 213 | main() 214 | -------------------------------------------------------------------------------- /utils/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | 3 | 4 | class LRScheduler: 5 | def __init__(self, parameters, hidden_size, warmup, step=0): 6 | self.constant = 2.0 * (hidden_size ** -0.5) 7 | self.cur_step = step 8 | self.warmup = warmup 9 | self.optimizer = optim.Adam(parameters, lr=self.learning_rate(), 10 | betas=(0.9, 0.997), eps=1e-09) 11 | 12 | def step(self): 13 | self.cur_step += 1 14 | rate = self.learning_rate() 15 | for p in self.optimizer.param_groups: 16 | p['lr'] = rate 17 | self.optimizer.step() 18 | 19 | def zero_grad(self): 20 | self.optimizer.zero_grad() 21 | 22 | def learning_rate(self): 23 | lr = self.constant 24 | lr *= min(1.0, self.cur_step / self.warmup) 25 | lr *= max(self.cur_step, self.warmup) ** -0.5 26 | return lr 27 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import math 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | def get_loss(pred, ans, vocab_size, label_smoothing, pad): 9 | # took this "normalizing" from tensor2tensor. We subtract it for 10 | # readability. This makes no difference on learning. 11 | confidence = 1.0 - label_smoothing 12 | low_confidence = (1.0 - confidence) / float(vocab_size - 1) 13 | normalizing = -( 14 | confidence * math.log(confidence) + float(vocab_size - 1) * 15 | low_confidence * math.log(low_confidence + 1e-20)) 16 | 17 | one_hot = torch.zeros_like(pred).scatter_(1, ans.unsqueeze(1), 1) 18 | one_hot = one_hot * confidence + (1 - one_hot) * low_confidence 19 | log_prob = F.log_softmax(pred, dim=1) 20 | 21 | xent = -(one_hot * log_prob).sum(dim=1) 22 | xent = xent.masked_select(ans != pad) 23 | loss = (xent - normalizing).mean() 24 | return loss 25 | 26 | 27 | def get_accuracy(pred, ans, pad): 28 | pred = pred.max(1)[1] 29 | n_correct = pred.eq(ans) 30 | n_correct = n_correct.masked_select(ans != pad) 31 | return n_correct.sum().item() / n_correct.size(0) 32 | 33 | 34 | def save_checkpoint(model, filepath, global_step, is_best): 35 | model_save_path = filepath + '/last_model.pt' 36 | torch.save(model, model_save_path) 37 | torch.save(global_step, filepath + '/global_step.pt') 38 | if is_best: 39 | best_save_path = filepath + '/best_model.pt' 40 | shutil.copyfile(model_save_path, best_save_path) 41 | 42 | 43 | def load_checkpoint(model_path, device, is_eval=True): 44 | if is_eval: 45 | model = torch.load(model_path + '/best_model.pt') 46 | model.eval() 47 | return model.to(device=device) 48 | 49 | model = torch.load(model_path + '/last_model.pt') 50 | global_step = torch.load(model_path + '/global_step.pt') 51 | return model.to(device=device), global_step 52 | 53 | 54 | def create_pad_mask(t, pad): 55 | mask = (t == pad).unsqueeze(-2) 56 | return mask 57 | 58 | 59 | def create_trg_self_mask(target_len, device=None): 60 | # Prevent leftward information flow in self-attention. 61 | ones = torch.ones(target_len, target_len, dtype=torch.uint8, 62 | device=device) 63 | t_self_mask = torch.triu(ones, diagonal=1).unsqueeze(0) 64 | 65 | return t_self_mask 66 | --------------------------------------------------------------------------------