├── .gitignore ├── LICENSE ├── README.md ├── beam.py ├── configs └── example_config.json ├── data └── example │ └── raw │ ├── src-test.txt │ ├── src-train.txt │ ├── src-val.txt │ ├── tgt-train.txt │ └── tgt-val.txt ├── datasets.py ├── dictionaries.py ├── embeddings.py ├── evaluate.py ├── evaluator.py ├── losses.py ├── metrics.py ├── models.py ├── optimizers.py ├── predict.py ├── predictors.py ├── prepare_datasets.py ├── train.py ├── trainer.py └── utils ├── log.py ├── pad.py └── pipe.py /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints/ 2 | logs/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | 109 | \.idea/ 110 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Yongrae Jo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Transformer-pytorch 2 | A PyTorch implementation of Transformer in "Attention is All You Need" (https://arxiv.org/abs/1706.03762) 3 | 4 | This repo focuses on clean, readable, and modular implementation of the paper. 5 | 6 | screen shot 2018-09-27 at 1 49 14 pm 7 | 8 | ## Requirements 9 | - Python 3.6+ 10 | - [PyTorch 4.1+](http://pytorch.org/) 11 | - [NumPy](http://www.numpy.org/) 12 | - [NLTK](https://www.nltk.org/) 13 | - [tqdm](https://github.com/tqdm/tqdm) 14 | 15 | ## Usage 16 | 17 | ### Prepare datasets 18 | This repo comes with example data in `data/` directory. To begin, you will need to prepare datasets with given data as follows: 19 | ``` 20 | $ python prepare_datasets.py --train_source=data/example/raw/src-train.txt --train_target=data/example/raw/tgt-train.txt --val_source=data/example/raw/src-val.txt --val_target=data/example/raw/tgt-val.txt --save_data_dir=data/example/processed 21 | ``` 22 | 23 | The example data is brought from [OpenNMT-py](https://github.com/OpenNMT/OpenNMT-py). 24 | The data consists of parallel source (src) and target (tgt) data for training and validation. 25 | A data file contains one sentence per line with tokens separated by a space. 26 | Below are the provided example data files. 27 | 28 | - `src-train.txt` 29 | - `tgt-train.txt` 30 | - `src-val.txt` 31 | - `tgt-val.txt` 32 | 33 | ### Train model 34 | To train model, provide the train script with a path to processed data and save files as follows: 35 | 36 | ``` 37 | $ python train.py --data_dir=data/example/processed --save_config=checkpoints/example_config.json --save_checkpoint=checkpoints/example_model.pth --save_log=logs/example.log 38 | ``` 39 | 40 | This saves model config and checkpoints to given files, respectively. 41 | You can play around with hyperparameters of the model with command line arguments. 42 | For example, add `--epochs=300` to set the number of epochs to 300. 43 | 44 | ### Translate 45 | To translate a sentence in source language to target language: 46 | ``` 47 | $ python predict.py --source="There is an imbalance here ." --config=checkpoints/example_config.json --checkpoint=checkpoints/example_model.pth 48 | 49 | Candidate 0 : Hier fehlt das Gleichgewicht . 50 | Candidate 1 : Hier fehlt das das Gleichgewicht . 51 | Candidate 2 : Hier fehlt das das das Gleichgewicht . 52 | ``` 53 | 54 | It will give you translation candidates of the given source sentence. 55 | You can adjust the number of candidates with command line argument. 56 | 57 | ### Evaluate 58 | To calculate BLEU score of a trained model: 59 | ``` 60 | $ python evaluate.py --save_result=logs/example_eval.txt --config=checkpoints/example_config.json --checkpoint=checkpoints/example_model.pth 61 | 62 | BLEU score : 0.0007947 63 | ``` 64 | 65 | ## File description 66 | - `models.py` includes Transformer's encoder, decoder, and multi-head attention. 67 | - `embeddings.py` contains positional encoding. 68 | - `losses.py` contains label smoothing loss. 69 | - `optimizers.py` contains Noam optimizer. 70 | - `metrics.py` contains accuracy metric. 71 | - `beam.py` contains beam search. 72 | - `datasets.py` has code for loading and processing data. 73 | - `trainer.py` has code for training model. 74 | - `prepare_datasets.py` processes data. 75 | - `train.py` trains model. 76 | - `predict.py` translates given source sentence with a trained model. 77 | - `evaluate.py` calculates BLEU score of a trained model. 78 | 79 | ## Reference 80 | - [OpenNMT-py](https://github.com/OpenNMT/OpenNMT-py) 81 | 82 | ## Author 83 | [@dreamgonfly](https://github.com/dreamgonfly) -------------------------------------------------------------------------------- /beam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Beam: 5 | 6 | def __init__(self, beam_size=8, min_length=0, n_top=1, ranker=None, 7 | start_token_id=2, end_token_id=3): 8 | self.beam_size = beam_size 9 | self.min_length = min_length 10 | self.ranker = ranker 11 | 12 | self.end_token_id = end_token_id 13 | self.top_sentence_ended = False 14 | 15 | self.prev_ks = [] 16 | self.next_ys = [torch.LongTensor(beam_size).fill_(start_token_id)] # remove padding 17 | 18 | self.current_scores = torch.FloatTensor(beam_size).zero_() 19 | self.all_scores = [] 20 | 21 | # The attentions (matrix) for each time. 22 | self.all_attentions = [] 23 | 24 | self.finished = [] 25 | 26 | 27 | 28 | # Time and k pair for finished. 29 | self.finished = [] 30 | self.n_top = n_top 31 | 32 | self.ranker = ranker 33 | 34 | def advance(self, next_log_probs, current_attention): 35 | # next_probs : beam_size X vocab_size 36 | # current_attention: (target_seq_len=1, beam_size, source_seq_len) 37 | 38 | vocabulary_size = next_log_probs.size(1) 39 | # current_beam_size = next_log_probs.size(0) 40 | 41 | current_length = len(self.next_ys) 42 | if current_length < self.min_length: 43 | for beam_index in range(len(next_log_probs)): 44 | next_log_probs[beam_index][self.end_token_id] = -1e10 45 | 46 | if len(self.prev_ks) > 0: 47 | beam_scores = next_log_probs + self.current_scores.unsqueeze(1).expand_as(next_log_probs) 48 | # Don't let EOS have children. 49 | last_y = self.next_ys[-1] 50 | for beam_index in range(last_y.size(0)): 51 | if last_y[beam_index] == self.end_token_id: 52 | beam_scores[beam_index] = -1e10 # -1e20 raises error when executing 53 | else: 54 | beam_scores = next_log_probs[0] 55 | flat_beam_scores = beam_scores.view(-1) 56 | top_scores, top_score_ids = flat_beam_scores.topk(k=self.beam_size, dim=0, largest=True, sorted=True) 57 | 58 | self.current_scores = top_scores 59 | self.all_scores.append(self.current_scores) 60 | 61 | prev_k = top_score_ids / vocabulary_size # (beam_size, ) 62 | next_y = top_score_ids - prev_k * vocabulary_size # (beam_size, ) 63 | 64 | self.prev_ks.append(prev_k) 65 | self.next_ys.append(next_y) 66 | # for RNN, dim=1 and for transformer, dim=0. 67 | prev_attention = current_attention.index_select(dim=0, index=prev_k) # (target_seq_len=1, beam_size, source_seq_len) 68 | self.all_attentions.append(prev_attention) 69 | 70 | 71 | for beam_index, last_token_id in enumerate(next_y): 72 | if last_token_id == self.end_token_id: 73 | # skip scoring 74 | self.finished.append((self.current_scores[beam_index], len(self.next_ys) - 1, beam_index)) 75 | 76 | if next_y[0] == self.end_token_id: 77 | self.top_sentence_ended = True 78 | 79 | def get_current_state(self): 80 | "Get the outputs for the current timestep." 81 | return self.next_ys[-1] 82 | 83 | def get_current_origin(self): 84 | "Get the backpointers for the current timestep." 85 | return self.prev_ks[-1] 86 | 87 | def done(self): 88 | return self.top_sentence_ended and len(self.finished) >= self.n_top 89 | 90 | def get_hypothesis(self, timestep, k): 91 | hypothesis, attentions = [], [] 92 | for j in range(len(self.prev_ks[:timestep]) - 1, -1, -1): 93 | hypothesis.append(self.next_ys[j + 1][k]) 94 | # for RNN, [:, k, :], and for trnasformer, [k, :, :] 95 | attentions.append(self.all_attentions[j][k, :, :]) 96 | k = self.prev_ks[j][k] 97 | attentions_tensor = torch.stack(attentions[::-1]).squeeze(1) # (timestep, source_seq_len) 98 | return hypothesis[::-1], attentions_tensor 99 | 100 | def sort_finished(self, minimum=None): 101 | if minimum is not None: 102 | i = 0 103 | # Add from beam until we have minimum outputs. 104 | while len(self.finished) < minimum: 105 | # global_scores = self.global_scorer.score(self, self.scores) 106 | # s = global_scores[i] 107 | s = self.current_scores[i] 108 | self.finished.append((s, len(self.next_ys) - 1, i)) 109 | i += 1 110 | 111 | self.finished = sorted(self.finished, key=lambda a: a[0], reverse=True) 112 | scores = [sc for sc, _, _ in self.finished] 113 | ks = [(t, k) for _, t, k in self.finished] 114 | return scores, ks -------------------------------------------------------------------------------- /configs/example_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_limit": null, 3 | "print_every": 1, 4 | "save_every": 1, 5 | 6 | "vocabulary_size": null, 7 | "share_dictionary": false, 8 | "positional_encoding": true, 9 | 10 | "d_model": 128, 11 | "layers_count": 1, 12 | "heads_count": 2, 13 | "d_ff": 128, 14 | "dropout_prob": 0.1, 15 | 16 | "label_smoothing": 0.1, 17 | "optimizer": "Noam", 18 | "lr": 0.001, 19 | "clip_grads": true, 20 | 21 | "batch_size": 10, 22 | "epochs": 10 23 | } -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, abspath, join, exists 2 | from os import makedirs 3 | from dictionaries import START_TOKEN, END_TOKEN 4 | UNK_INDEX = 1 5 | 6 | BASE_DIR = dirname(abspath(__file__)) 7 | 8 | 9 | class TranslationDatasetOnTheFly: 10 | 11 | def __init__(self, phase, limit=None): 12 | assert phase in ('train', 'val'), "Dataset phase must be either 'train' or 'val'" 13 | 14 | self.limit = limit 15 | 16 | if phase == 'train': 17 | source_filepath = join(BASE_DIR, 'data', 'example', 'raw', 'src-train.txt') 18 | target_filepath = join(BASE_DIR, 'data', 'example', 'raw', 'tgt-train.txt') 19 | elif phase == 'val': 20 | source_filepath = join(BASE_DIR, 'data', 'example', 'raw', 'src-val.txt') 21 | target_filepath = join(BASE_DIR, 'data', 'example', 'raw', 'tgt-val.txt') 22 | else: 23 | raise NotImplementedError() 24 | 25 | with open(source_filepath) as source_file: 26 | self.source_data = source_file.readlines() 27 | 28 | with open(target_filepath) as target_filepath: 29 | self.target_data = target_filepath.readlines() 30 | 31 | def __getitem__(self, item): 32 | if self.limit is not None and item >= self.limit: 33 | raise IndexError() 34 | 35 | source = self.source_data[item].strip() 36 | target = self.target_data[item].strip() 37 | return source, target 38 | 39 | def __len__(self): 40 | if self.limit is None: 41 | return len(self.source_data) 42 | else: 43 | return self.limit 44 | 45 | 46 | class TranslationDataset: 47 | 48 | def __init__(self, data_dir, phase, limit=None): 49 | assert phase in ('train', 'val'), "Dataset phase must be either 'train' or 'val'" 50 | 51 | self.limit = limit 52 | 53 | self.data = [] 54 | with open(join(data_dir, f'raw-{phase}.txt')) as file: 55 | for line in file: 56 | source, target = line.strip().split('\t') 57 | self.data.append((source, target)) 58 | 59 | def __getitem__(self, item): 60 | if self.limit is not None and item >= self.limit: 61 | raise IndexError() 62 | 63 | return self.data[item] 64 | 65 | def __len__(self): 66 | if self.limit is None: 67 | return len(self.data) 68 | else: 69 | return self.limit 70 | 71 | @staticmethod 72 | def prepare(train_source, train_target, val_source, val_target, save_data_dir): 73 | 74 | if not exists(save_data_dir): 75 | makedirs(save_data_dir) 76 | 77 | for phase in ('train', 'val'): 78 | 79 | if phase == 'train': 80 | source_filepath = train_source 81 | target_filepath = train_target 82 | else: 83 | source_filepath = val_source 84 | target_filepath = val_target 85 | 86 | with open(source_filepath) as source_file: 87 | source_data = source_file.readlines() 88 | 89 | with open(target_filepath) as target_filepath: 90 | target_data = target_filepath.readlines() 91 | 92 | with open(join(save_data_dir, f'raw-{phase}.txt'), 'w') as file: 93 | for source_line, target_line in zip(source_data, target_data): 94 | source_line = source_line.strip() 95 | target_line = target_line.strip() 96 | line = f'{source_line}\t{target_line}\n' 97 | file.write(line) 98 | 99 | 100 | class TokenizedTranslationDatasetOnTheFly: 101 | 102 | def __init__(self, phase, limit=None): 103 | 104 | self.raw_dataset = TranslationDatasetOnTheFly(phase, limit) 105 | 106 | def __getitem__(self, item): 107 | raw_source, raw_target = self.raw_dataset[item] 108 | tokenized_source = raw_source.split() 109 | tokenized_target = raw_target.split() 110 | return tokenized_source, tokenized_target 111 | 112 | def __len__(self): 113 | return len(self.raw_dataset) 114 | 115 | 116 | class TokenizedTranslationDataset: 117 | 118 | def __init__(self, data_dir, phase, limit=None): 119 | 120 | self.raw_dataset = TranslationDataset(data_dir, phase, limit) 121 | 122 | def __getitem__(self, item): 123 | raw_source, raw_target = self.raw_dataset[item] 124 | tokenized_source = raw_source.split() 125 | tokenized_target = raw_target.split() 126 | return tokenized_source, tokenized_target 127 | 128 | def __len__(self): 129 | return len(self.raw_dataset) 130 | 131 | 132 | class InputTargetTranslationDatasetOnTheFly: 133 | 134 | def __init__(self, phase, limit=None): 135 | self.tokenized_dataset = TokenizedTranslationDatasetOnTheFly(phase, limit) 136 | 137 | def __getitem__(self, item): 138 | tokenized_source, tokenized_target = self.tokenized_dataset[item] 139 | full_target = [START_TOKEN] + tokenized_target + [END_TOKEN] 140 | inputs = full_target[:-1] 141 | targets = full_target[1:] 142 | return tokenized_source, inputs, targets 143 | 144 | def __len__(self): 145 | return len(self.tokenized_dataset) 146 | 147 | 148 | class InputTargetTranslationDataset: 149 | 150 | def __init__(self, data_dir, phase, limit=None): 151 | self.tokenized_dataset = TokenizedTranslationDataset(data_dir, phase, limit) 152 | 153 | def __getitem__(self, item): 154 | tokenized_source, tokenized_target = self.tokenized_dataset[item] 155 | full_target = [START_TOKEN] + tokenized_target + [END_TOKEN] 156 | inputs = full_target[:-1] 157 | targets = full_target[1:] 158 | return tokenized_source, inputs, targets 159 | 160 | def __len__(self): 161 | return len(self.tokenized_dataset) 162 | 163 | 164 | class IndexedInputTargetTranslationDatasetOnTheFly: 165 | 166 | def __init__(self, phase, source_dictionary, target_dictionary, limit=None): 167 | 168 | self.input_target_dataset = InputTargetTranslationDatasetOnTheFly(phase, limit) 169 | self.source_dictionary = source_dictionary 170 | self.target_dictionary = target_dictionary 171 | 172 | def __getitem__(self, item): 173 | source, inputs, targets = self.input_target_dataset[item] 174 | indexed_source = self.source_dictionary.index_sentence(source) 175 | indexed_inputs = self.target_dictionary.index_sentence(inputs) 176 | indexed_targets = self.target_dictionary.index_sentence(targets) 177 | 178 | return indexed_source, indexed_inputs, indexed_targets 179 | 180 | def __len__(self): 181 | return len(self.input_target_dataset) 182 | 183 | @staticmethod 184 | def preprocess(source_dictionary): 185 | 186 | def preprocess_function(source): 187 | source_tokens = source.strip().split() 188 | indexed_source = source_dictionary.index_sentence(source_tokens) 189 | return indexed_source 190 | 191 | return preprocess_function 192 | 193 | 194 | class IndexedInputTargetTranslationDataset: 195 | 196 | def __init__(self, data_dir, phase, vocabulary_size=None, limit=None): 197 | 198 | self.data = [] 199 | 200 | unknownify = lambda index: index if index < vocabulary_size else UNK_INDEX 201 | with open(join(data_dir, f'indexed-{phase}.txt')) as file: 202 | for line in file: 203 | sources, inputs, targets = line.strip().split('\t') 204 | if vocabulary_size is not None: 205 | indexed_sources = [unknownify(int(index)) for index in sources.strip().split(' ')] 206 | indexed_inputs = [unknownify(int(index)) for index in inputs.strip().split(' ')] 207 | indexed_targets = [unknownify(int(index)) for index in targets.strip().split(' ')] 208 | else: 209 | indexed_sources = [int(index) for index in sources.strip().split(' ')] 210 | indexed_inputs = [int(index) for index in inputs.strip().split(' ')] 211 | indexed_targets = [int(index) for index in targets.strip().split(' ')] 212 | self.data.append((indexed_sources, indexed_inputs, indexed_targets)) 213 | if limit is not None and len(self.data) >= limit: 214 | break 215 | 216 | self.vocabulary_size = vocabulary_size 217 | self.limit = limit 218 | 219 | def __getitem__(self, item): 220 | if self.limit is not None and item >= self.limit: 221 | raise IndexError() 222 | 223 | indexed_sources, indexed_inputs, indexed_targets = self.data[item] 224 | return indexed_sources, indexed_inputs, indexed_targets 225 | 226 | def __len__(self): 227 | if self.limit is None: 228 | return len(self.data) 229 | else: 230 | return self.limit 231 | 232 | @staticmethod 233 | def preprocess(source_dictionary): 234 | 235 | def preprocess_function(source): 236 | source_tokens = source.strip().split() 237 | indexed_source = source_dictionary.index_sentence(source_tokens) 238 | return indexed_source 239 | 240 | return preprocess_function 241 | 242 | @staticmethod 243 | def prepare(data_dir, source_dictionary, target_dictionary): 244 | 245 | join_indexes = lambda indexes: ' '.join(str(index) for index in indexes) 246 | for phase in ('train', 'val'): 247 | input_target_dataset = InputTargetTranslationDataset(data_dir, phase) 248 | 249 | with open(join(data_dir, f'indexed-{phase}.txt'), 'w') as file: 250 | for sources, inputs, targets in input_target_dataset: 251 | indexed_sources = join_indexes(source_dictionary.index_sentence(sources)) 252 | indexed_inputs = join_indexes(target_dictionary.index_sentence(inputs)) 253 | indexed_targets = join_indexes(target_dictionary.index_sentence(targets)) 254 | file.write(f'{indexed_sources}\t{indexed_inputs}\t{indexed_targets}\n') 255 | -------------------------------------------------------------------------------- /dictionaries.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from os.path import dirname, abspath, join, exists 3 | from os import makedirs 4 | 5 | BASE_DIR = dirname(abspath(__file__)) 6 | 7 | PAD_TOKEN = '' 8 | UNK_TOKEN = '' 9 | START_TOKEN = '' 10 | END_TOKEN = '' 11 | 12 | 13 | class IndexDictionary: 14 | 15 | def __init__(self, iterable=None, mode='shared', vocabulary_size=None): 16 | 17 | self.special_tokens = [PAD_TOKEN, UNK_TOKEN, START_TOKEN, END_TOKEN] 18 | 19 | # On-the-fly mode 20 | if iterable is not None: 21 | 22 | self.vocab_tokens, self.token_counts = self._build_vocabulary(iterable, vocabulary_size) 23 | self.token_index_dict = {token: index for index, token in enumerate(self.vocab_tokens)} 24 | self.vocabulary_size = len(self.vocab_tokens) 25 | 26 | self.mode = mode 27 | 28 | def token_to_index(self, token): 29 | try: 30 | return self.token_index_dict[token] 31 | except KeyError: 32 | return self.token_index_dict[UNK_TOKEN] 33 | 34 | def index_to_token(self, index): 35 | if index >= self.vocabulary_size: 36 | return self.vocab_tokens[UNK_TOKEN] 37 | else: 38 | return self.vocab_tokens[index] 39 | 40 | def index_sentence(self, sentence): 41 | return [self.token_to_index(token) for token in sentence] 42 | 43 | def tokenify_indexes(self, token_indexes): 44 | return [self.index_to_token(token_index) for token_index in token_indexes] 45 | 46 | def _build_vocabulary(self, iterable, vocabulary_size): 47 | 48 | counter = Counter() 49 | for token in iterable: 50 | counter[token] += 1 51 | 52 | if vocabulary_size is not None: 53 | most_commons = counter.most_common(vocabulary_size - len(self.special_tokens)) 54 | frequent_tokens = [token for token, count in most_commons] 55 | vocab_tokens = self.special_tokens + frequent_tokens 56 | token_counts = [0] * len(self.special_tokens) + [count for token, count in most_commons] 57 | else: 58 | all_tokens = [token for token, count in counter.items()] 59 | vocab_tokens = self.special_tokens + all_tokens 60 | token_counts = [0] * len(self.special_tokens) + [count for token, count in counter.items()] 61 | 62 | return vocab_tokens, token_counts 63 | 64 | def save(self, data_dir): 65 | 66 | vocabulary_filepath = join(data_dir, f'vocabulary-{self.mode}.txt') 67 | with open(vocabulary_filepath, 'w') as file: 68 | for vocab_index, (vocab_token, count) in enumerate(zip(self.vocab_tokens, self.token_counts)): 69 | file.write(str(vocab_index) + '\t' + vocab_token + '\t' + str(count) + '\n') 70 | 71 | @classmethod 72 | def load(cls, data_dir, mode='shared', vocabulary_size=None): 73 | vocabulary_filepath = join(data_dir, f'vocabulary-{mode}.txt') 74 | 75 | vocab_tokens = {} 76 | token_counts = [] 77 | with open(vocabulary_filepath) as file: 78 | for line in file: 79 | vocab_index, vocab_token, count = line.strip().split('\t') 80 | vocab_index = int(vocab_index) 81 | vocab_tokens[vocab_index] = vocab_token 82 | token_counts.append(int(count)) 83 | 84 | if vocabulary_size is not None: 85 | vocab_tokens = {k: v for k, v in vocab_tokens.items() if k < vocabulary_size} 86 | token_counts = token_counts[:vocabulary_size] 87 | 88 | instance = cls(mode=mode) 89 | instance.vocab_tokens = vocab_tokens 90 | instance.token_counts = token_counts 91 | instance.token_index_dict = {token: index for index, token in vocab_tokens.items()} 92 | instance.vocabulary_size = len(vocab_tokens) 93 | 94 | return instance 95 | 96 | -------------------------------------------------------------------------------- /embeddings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | 5 | 6 | class PositionalEncoding(nn.Module): 7 | """ 8 | Implements the sinusoidal positional encoding for 9 | non-recurrent neural networks. 10 | 11 | Implementation based on "Attention Is All You Need" 12 | :cite:`DBLP:journals/corr/VaswaniSPUJGKP17` 13 | 14 | Args: 15 | dropout_prob (float): dropout parameter 16 | dim (int): embedding size 17 | """ 18 | 19 | def __init__(self, num_embeddings, embedding_dim, dim, dropout_prob=0., padding_idx=0, max_len=5000): 20 | super(PositionalEncoding, self).__init__() 21 | 22 | pe = torch.zeros(max_len, dim) 23 | position = torch.arange(0, max_len).unsqueeze(1) 24 | div_term = torch.exp((torch.arange(0, dim, 2) * 25 | -(math.log(10000.0) / dim)).float()) 26 | pe[:, 0::2] = torch.sin(position.float() * div_term) 27 | pe[:, 1::2] = torch.cos(position.float() * div_term) 28 | pe = pe.unsqueeze(0) 29 | 30 | self.num_embeddings = num_embeddings 31 | self.embedding_dim = embedding_dim 32 | self.embbedding = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) 33 | self.weight = self.embbedding.weight 34 | self.register_buffer('pe', pe) 35 | self.dropout = nn.Dropout(p=dropout_prob) 36 | self.dim = dim 37 | 38 | def forward(self, x, step=None): 39 | x = self.embbedding(x) 40 | x = x * math.sqrt(self.dim) 41 | if step is None: 42 | x = x + self.pe[:, :x.size(1)] 43 | else: 44 | x = x + self.pe[:, step] 45 | x = self.dropout(x) 46 | return x -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | from evaluator import Evaluator 2 | from predictors import Predictor 3 | from models import build_model 4 | from datasets import TranslationDataset 5 | from datasets import IndexedInputTargetTranslationDataset 6 | from dictionaries import IndexDictionary 7 | 8 | from argparse import ArgumentParser 9 | import json 10 | from datetime import datetime 11 | 12 | parser = ArgumentParser(description='Predict translation') 13 | parser.add_argument('--save_result', type=str, default=None) 14 | parser.add_argument('--config', type=str, required=True) 15 | parser.add_argument('--checkpoint', type=str, required=True) 16 | parser.add_argument('--phase', type=str, default='val', choices=['train', 'val']) 17 | 18 | args = parser.parse_args() 19 | with open(args.config) as f: 20 | config = json.load(f) 21 | 22 | print('Constructing dictionaries...') 23 | source_dictionary = IndexDictionary.load(config['data_dir'], mode='source', vocabulary_size=config['vocabulary_size']) 24 | target_dictionary = IndexDictionary.load(config['data_dir'], mode='target', vocabulary_size=config['vocabulary_size']) 25 | 26 | print('Building model...') 27 | model = build_model(config, source_dictionary.vocabulary_size, target_dictionary.vocabulary_size) 28 | 29 | predictor = Predictor( 30 | preprocess=IndexedInputTargetTranslationDataset.preprocess(source_dictionary), 31 | postprocess=lambda x: ' '.join([token for token in target_dictionary.tokenify_indexes(x) if token != '']), 32 | model=model, 33 | checkpoint_filepath=args.checkpoint 34 | ) 35 | 36 | timestamp = datetime.now() 37 | if args.save_result is None: 38 | eval_filepath = 'logs/eval-{config}-time={timestamp}.csv'.format( 39 | config=args.config.replace('/', '-'), 40 | timestamp=timestamp.strftime("%Y_%m_%d_%H_%M_%S")) 41 | else: 42 | eval_filepath = args.save_result 43 | 44 | evaluator = Evaluator( 45 | predictor=predictor, 46 | save_filepath=eval_filepath 47 | ) 48 | 49 | print('Evaluating...') 50 | test_dataset = TranslationDataset(config['data_dir'], args.phase, limit=1000) 51 | bleu_score = evaluator.evaluate_dataset(test_dataset) 52 | print('Evaluation time :', datetime.now() - timestamp) 53 | 54 | print("BLEU score :", bleu_score) 55 | 56 | 57 | -------------------------------------------------------------------------------- /evaluator.py: -------------------------------------------------------------------------------- 1 | from nltk.translate.bleu_score import sentence_bleu, corpus_bleu, SmoothingFunction 2 | from tqdm import tqdm 3 | 4 | 5 | class Evaluator: 6 | 7 | def __init__(self, predictor, save_filepath): 8 | 9 | self.predictor = predictor 10 | self.save_filepath = save_filepath 11 | 12 | def evaluate_dataset(self, test_dataset): 13 | tokenize = lambda x: x.split() 14 | 15 | predictions = [] 16 | for source, target in tqdm(test_dataset): 17 | prediction = self.predictor.predict_one(source, num_candidates=1)[0] 18 | predictions.append(prediction) 19 | 20 | hypotheses = [tokenize(prediction) for prediction in predictions] 21 | list_of_references = [[tokenize(target)] for source, target in test_dataset] 22 | smoothing_function = SmoothingFunction() 23 | 24 | with open(self.save_filepath, 'w') as file: 25 | for (source, target), prediction, hypothesis, references in zip(test_dataset, predictions, 26 | hypotheses, list_of_references): 27 | sentence_bleu_score = sentence_bleu(references, hypothesis, 28 | smoothing_function=smoothing_function.method3) 29 | line = "{bleu_score}\t{source}\t{target}\t|\t{prediction}".format( 30 | bleu_score=sentence_bleu_score, 31 | source=source, 32 | target=target, 33 | prediction=prediction 34 | ) 35 | file.write(line + '\n') 36 | 37 | bleu_score = corpus_bleu(list_of_references, hypotheses, smoothing_function=smoothing_function.method3) 38 | 39 | return bleu_score 40 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class TokenCrossEntropyLoss(nn.Module): 6 | 7 | def __init__(self, pad_index=0): 8 | super(TokenCrossEntropyLoss, self).__init__() 9 | 10 | self.pad_index = pad_index 11 | self.base_loss_function = nn.CrossEntropyLoss(reduction='sum', ignore_index=pad_index) 12 | 13 | def forward(self, outputs, targets): 14 | batch_size, seq_len, vocabulary_size = outputs.size() 15 | 16 | outputs_flat = outputs.view(batch_size * seq_len, vocabulary_size) 17 | targets_flat = targets.view(batch_size * seq_len) 18 | 19 | batch_loss = self.base_loss_function(outputs_flat, targets_flat) 20 | 21 | count = (targets != self.pad_index).sum().item() 22 | 23 | return batch_loss, count 24 | 25 | 26 | class LabelSmoothingLoss(nn.Module): 27 | """ 28 | With label smoothing, 29 | KL-divergence between q_{smoothed ground truth prob.}(w) 30 | and p_{prob. computed by model}(w) is minimized. 31 | """ 32 | def __init__(self, label_smoothing, vocabulary_size, pad_index=0): 33 | assert 0.0 < label_smoothing <= 1.0 34 | 35 | super(LabelSmoothingLoss, self).__init__() 36 | 37 | self.pad_index = pad_index 38 | self.log_softmax = nn.LogSoftmax(dim=-1) 39 | self.criterion = nn.KLDivLoss(reduction='sum') 40 | 41 | smoothing_value = label_smoothing / (vocabulary_size - 2) # exclude pad and true label 42 | smoothed_targets = torch.full((vocabulary_size,), smoothing_value) 43 | smoothed_targets[self.pad_index] = 0 44 | self.register_buffer('smoothed_targets', smoothed_targets.unsqueeze(0)) # (1, vocabulary_size) 45 | 46 | self.confidence = 1.0 - label_smoothing 47 | 48 | def forward(self, outputs, targets): 49 | """ 50 | outputs (FloatTensor): (batch_size, seq_len, vocabulary_size) 51 | targets (LongTensor): (batch_size, seq_len) 52 | """ 53 | batch_size, seq_len, vocabulary_size = outputs.size() 54 | 55 | outputs_log_softmax = self.log_softmax(outputs) 56 | outputs_flat = outputs_log_softmax.view(batch_size * seq_len, vocabulary_size) 57 | targets_flat = targets.view(batch_size * seq_len) 58 | 59 | smoothed_targets = self.smoothed_targets.repeat(targets_flat.size(0), 1) 60 | # smoothed_targets: (batch_size * seq_len, vocabulary_size) 61 | 62 | smoothed_targets.scatter_(1, targets_flat.unsqueeze(1), self.confidence) 63 | # smoothed_targets: (batch_size * seq_len, vocabulary_size) 64 | 65 | smoothed_targets.masked_fill_((targets_flat == self.pad_index).unsqueeze(1), 0) 66 | # masked_targets: (batch_size * seq_len, vocabulary_size) 67 | 68 | loss = self.criterion(outputs_flat, smoothed_targets) 69 | count = (targets != self.pad_index).sum().item() 70 | 71 | return loss, count 72 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class AccuracyMetric(nn.Module): 5 | 6 | def __init__(self, pad_index=0): 7 | super(AccuracyMetric, self).__init__() 8 | 9 | self.pad_index = pad_index 10 | 11 | def forward(self, outputs, targets): 12 | 13 | batch_size, seq_len, vocabulary_size = outputs.size() 14 | 15 | outputs = outputs.view(batch_size * seq_len, vocabulary_size) 16 | targets = targets.view(batch_size * seq_len) 17 | 18 | predicts = outputs.argmax(dim=1) 19 | corrects = predicts == targets 20 | 21 | corrects.masked_fill_((targets == self.pad_index), 0) 22 | 23 | correct_count = corrects.sum().item() 24 | count = (targets != self.pad_index).sum().item() 25 | 26 | return correct_count, count 27 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from embeddings import PositionalEncoding 2 | from utils.pad import pad_masking, subsequent_masking 3 | 4 | import torch 5 | from torch import nn 6 | import numpy as np 7 | from collections import defaultdict 8 | 9 | PAD_TOKEN_ID = 0 10 | 11 | 12 | def build_model(config, source_vocabulary_size, target_vocabulary_size): 13 | if config['positional_encoding']: 14 | source_embedding = PositionalEncoding( 15 | num_embeddings=source_vocabulary_size, 16 | embedding_dim=config['d_model'], 17 | dim=config['d_model']) # why dim? 18 | target_embedding = PositionalEncoding( 19 | num_embeddings=target_vocabulary_size, 20 | embedding_dim=config['d_model'], 21 | dim=config['d_model']) # why dim? 22 | else: 23 | source_embedding = nn.Embedding( 24 | num_embeddings=source_vocabulary_size, 25 | embedding_dim=config['d_model']) 26 | target_embedding = nn.Embedding( 27 | num_embeddings=target_vocabulary_size, 28 | embedding_dim=config['d_model']) 29 | 30 | encoder = TransformerEncoder( 31 | layers_count=config['layers_count'], 32 | d_model=config['d_model'], 33 | heads_count=config['heads_count'], 34 | d_ff=config['d_ff'], 35 | dropout_prob=config['dropout_prob'], 36 | embedding=source_embedding) 37 | 38 | decoder = TransformerDecoder( 39 | layers_count=config['layers_count'], 40 | d_model=config['d_model'], 41 | heads_count=config['heads_count'], 42 | d_ff=config['d_ff'], 43 | dropout_prob=config['dropout_prob'], 44 | embedding=target_embedding) 45 | 46 | model = Transformer(encoder, decoder) 47 | 48 | return model 49 | 50 | 51 | class Transformer(nn.Module): 52 | 53 | def __init__(self, encoder, decoder): 54 | super(Transformer, self).__init__() 55 | 56 | self.encoder = encoder 57 | self.decoder = decoder 58 | 59 | def forward(self, sources, inputs): 60 | # sources : (batch_size, sources_len) 61 | # inputs : (batch_size, targets_len - 1) 62 | batch_size, sources_len = sources.size() 63 | batch_size, inputs_len = inputs.size() 64 | 65 | sources_mask = pad_masking(sources, sources_len) 66 | memory_mask = pad_masking(sources, inputs_len) 67 | inputs_mask = subsequent_masking(inputs) | pad_masking(inputs, inputs_len) 68 | 69 | memory = self.encoder(sources, sources_mask) # (batch_size, seq_len, d_model) 70 | outputs, state = self.decoder(inputs, memory, memory_mask, inputs_mask) # (batch_size, seq_len, d_model) 71 | return outputs 72 | 73 | 74 | class TransformerEncoder(nn.Module): 75 | 76 | def __init__(self, layers_count, d_model, heads_count, d_ff, dropout_prob, embedding): 77 | super(TransformerEncoder, self).__init__() 78 | 79 | self.d_model = d_model 80 | self.embedding = embedding 81 | self.encoder_layers = nn.ModuleList( 82 | [TransformerEncoderLayer(d_model, heads_count, d_ff, dropout_prob) for _ in range(layers_count)] 83 | ) 84 | 85 | def forward(self, sources, mask): 86 | """ 87 | 88 | args: 89 | sources: embedded_sequence, (batch_size, seq_len, embed_size) 90 | """ 91 | sources = self.embedding(sources) 92 | 93 | for encoder_layer in self.encoder_layers: 94 | sources = encoder_layer(sources, mask) 95 | 96 | return sources 97 | 98 | 99 | class TransformerEncoderLayer(nn.Module): 100 | 101 | def __init__(self, d_model, heads_count, d_ff, dropout_prob): 102 | super(TransformerEncoderLayer, self).__init__() 103 | 104 | self.self_attention_layer = Sublayer(MultiHeadAttention(heads_count, d_model, dropout_prob), d_model) 105 | self.pointwise_feedforward_layer = Sublayer(PointwiseFeedForwardNetwork(d_ff, d_model, dropout_prob), d_model) 106 | self.dropout = nn.Dropout(dropout_prob) 107 | 108 | def forward(self, sources, sources_mask): 109 | # x: (batch_size, seq_len, d_model) 110 | 111 | sources = self.self_attention_layer(sources, sources, sources, sources_mask) 112 | sources = self.dropout(sources) 113 | sources = self.pointwise_feedforward_layer(sources) 114 | 115 | return sources 116 | 117 | 118 | class TransformerDecoder(nn.Module): 119 | 120 | def __init__(self, layers_count, d_model, heads_count, d_ff, dropout_prob, embedding): 121 | super(TransformerDecoder, self).__init__() 122 | 123 | self.d_model = d_model 124 | self.embedding = embedding 125 | self.decoder_layers = nn.ModuleList( 126 | [TransformerDecoderLayer(d_model, heads_count, d_ff, dropout_prob) for _ in range(layers_count)] 127 | ) 128 | self.generator = nn.Linear(embedding.embedding_dim, embedding.num_embeddings) 129 | self.generator.weight = self.embedding.weight 130 | 131 | def forward(self, inputs, memory, memory_mask, inputs_mask=None, state=None): 132 | # inputs: (batch_size, seq_len - 1, d_model) 133 | # memory: (batch_size, seq_len, d_model) 134 | 135 | inputs = self.embedding(inputs) 136 | # if state is not None: 137 | # inputs = torch.cat([state.previous_inputs, inputs], dim=1) 138 | # 139 | # state.previous_inputs = inputs 140 | 141 | for layer_index, decoder_layer in enumerate(self.decoder_layers): 142 | if state is None: 143 | inputs = decoder_layer(inputs, memory, memory_mask, inputs_mask) 144 | else: # Use cache 145 | layer_cache = state.layer_caches[layer_index] 146 | # print('inputs_mask', inputs_mask) 147 | inputs = decoder_layer(inputs, memory, memory_mask, inputs_mask, layer_cache) 148 | 149 | state.update_state( 150 | layer_index=layer_index, 151 | layer_mode='self-attention', 152 | key_projected=decoder_layer.self_attention_layer.sublayer.key_projected, 153 | value_projected=decoder_layer.self_attention_layer.sublayer.value_projected, 154 | ) 155 | state.update_state( 156 | layer_index=layer_index, 157 | layer_mode='memory-attention', 158 | key_projected=decoder_layer.memory_attention_layer.sublayer.key_projected, 159 | value_projected=decoder_layer.memory_attention_layer.sublayer.value_projected, 160 | ) 161 | 162 | generated = self.generator(inputs) # (batch_size, seq_len, vocab_size) 163 | return generated, state 164 | 165 | def init_decoder_state(self, **args): 166 | return DecoderState() 167 | 168 | 169 | class TransformerDecoderLayer(nn.Module): 170 | 171 | def __init__(self, d_model, heads_count, d_ff, dropout_prob): 172 | super(TransformerDecoderLayer, self).__init__() 173 | self.self_attention_layer = Sublayer(MultiHeadAttention(heads_count, d_model, dropout_prob, mode='self-attention'), d_model) 174 | self.memory_attention_layer = Sublayer(MultiHeadAttention(heads_count, d_model, dropout_prob, mode='memory-attention'), d_model) 175 | self.pointwise_feedforward_layer = Sublayer(PointwiseFeedForwardNetwork(d_ff, d_model, dropout_prob), d_model) 176 | 177 | def forward(self, inputs, memory, memory_mask, inputs_mask, layer_cache=None): 178 | # print('self attention') 179 | # print('inputs_mask', inputs_mask) 180 | inputs = self.self_attention_layer(inputs, inputs, inputs, inputs_mask, layer_cache) 181 | # print('memory attention') 182 | inputs = self.memory_attention_layer(inputs, memory, memory, memory_mask, layer_cache) 183 | inputs = self.pointwise_feedforward_layer(inputs) 184 | return inputs 185 | 186 | 187 | class Sublayer(nn.Module): 188 | 189 | def __init__(self, sublayer, d_model): 190 | super(Sublayer, self).__init__() 191 | 192 | self.sublayer = sublayer 193 | self.layer_normalization = LayerNormalization(d_model) 194 | 195 | def forward(self, *args): 196 | x = args[0] 197 | x = self.sublayer(*args) + x 198 | return self.layer_normalization(x) 199 | 200 | 201 | class LayerNormalization(nn.Module): 202 | 203 | def __init__(self, features_count, epsilon=1e-6): 204 | super(LayerNormalization, self).__init__() 205 | 206 | self.gain = nn.Parameter(torch.ones(features_count)) 207 | self.bias = nn.Parameter(torch.zeros(features_count)) 208 | self.epsilon = epsilon 209 | 210 | def forward(self, x): 211 | 212 | mean = x.mean(dim=-1, keepdim=True) 213 | std = x.std(dim=-1, keepdim=True) 214 | 215 | return self.gain * (x - mean) / (std + self.epsilon) + self.bias 216 | 217 | 218 | class MultiHeadAttention(nn.Module): 219 | 220 | def __init__(self, heads_count, d_model, dropout_prob, mode='self-attention'): 221 | super(MultiHeadAttention, self).__init__() 222 | 223 | assert d_model % heads_count == 0 224 | assert mode in ('self-attention', 'memory-attention') 225 | 226 | self.d_head = d_model // heads_count 227 | self.heads_count = heads_count 228 | self.mode = mode 229 | self.query_projection = nn.Linear(d_model, heads_count * self.d_head) 230 | self.key_projection = nn.Linear(d_model, heads_count * self.d_head) 231 | self.value_projection = nn.Linear(d_model, heads_count * self.d_head) 232 | self.final_projection = nn.Linear(d_model, heads_count * self.d_head) 233 | self.dropout = nn.Dropout(dropout_prob) 234 | self.softmax = nn.Softmax(dim=3) 235 | 236 | self.attention = None 237 | # For cache 238 | self.key_projected = None 239 | self.value_projected = None 240 | 241 | def forward(self, query, key, value, mask=None, layer_cache=None): 242 | """ 243 | 244 | Args: 245 | query: (batch_size, query_len, model_dim) 246 | key: (batch_size, key_len, model_dim) 247 | value: (batch_size, value_len, model_dim) 248 | mask: (batch_size, query_len, key_len) 249 | state: DecoderState 250 | """ 251 | # print('attention mask', mask) 252 | batch_size, query_len, d_model = query.size() 253 | 254 | d_head = d_model // self.heads_count 255 | 256 | query_projected = self.query_projection(query) 257 | # print('query_projected', query_projected.shape) 258 | if layer_cache is None or layer_cache[self.mode] is None: # Don't use cache 259 | key_projected = self.key_projection(key) 260 | value_projected = self.value_projection(value) 261 | else: # Use cache 262 | if self.mode == 'self-attention': 263 | key_projected = self.key_projection(key) 264 | value_projected = self.value_projection(value) 265 | 266 | key_projected = torch.cat([key_projected, layer_cache[self.mode]['key_projected']], dim=1) 267 | value_projected = torch.cat([value_projected, layer_cache[self.mode]['value_projected']], dim=1) 268 | elif self.mode == 'memory-attention': 269 | key_projected = layer_cache[self.mode]['key_projected'] 270 | value_projected = layer_cache[self.mode]['value_projected'] 271 | 272 | # For cache 273 | self.key_projected = key_projected 274 | self.value_projected = value_projected 275 | 276 | batch_size, key_len, d_model = key_projected.size() 277 | batch_size, value_len, d_model = value_projected.size() 278 | 279 | query_heads = query_projected.view(batch_size, query_len, self.heads_count, d_head).transpose(1, 2) # (batch_size, heads_count, query_len, d_head) 280 | # print('query_heads', query_heads.shape) 281 | # print(batch_size, key_len, self.heads_count, d_head) 282 | # print(key_projected.shape) 283 | key_heads = key_projected.view(batch_size, key_len, self.heads_count, d_head).transpose(1, 2) # (batch_size, heads_count, key_len, d_head) 284 | value_heads = value_projected.view(batch_size, value_len, self.heads_count, d_head).transpose(1, 2) # (batch_size, heads_count, value_len, d_head) 285 | 286 | attention_weights = self.scaled_dot_product(query_heads, key_heads) # (batch_size, heads_count, query_len, key_len) 287 | 288 | if mask is not None: 289 | # print('mode', self.mode) 290 | # print('mask', mask.shape) 291 | # print('attention_weights', attention_weights.shape) 292 | mask_expanded = mask.unsqueeze(1).expand_as(attention_weights) 293 | attention_weights = attention_weights.masked_fill(mask_expanded, -1e18) 294 | 295 | self.attention = self.softmax(attention_weights) # Save attention to the object 296 | # print('attention_weights', attention_weights.shape) 297 | attention_dropped = self.dropout(self.attention) 298 | context_heads = torch.matmul(attention_dropped, value_heads) # (batch_size, heads_count, query_len, d_head) 299 | # print('context_heads', context_heads.shape) 300 | context_sequence = context_heads.transpose(1, 2).contiguous() # (batch_size, query_len, heads_count, d_head) 301 | context = context_sequence.view(batch_size, query_len, d_model) # (batch_size, query_len, d_model) 302 | final_output = self.final_projection(context) 303 | # print('final_output', final_output.shape) 304 | 305 | return final_output 306 | 307 | def scaled_dot_product(self, query_heads, key_heads): 308 | """ 309 | 310 | Args: 311 | query_heads: (batch_size, heads_count, query_len, d_head) 312 | key_heads: (batch_size, heads_count, key_len, d_head) 313 | """ 314 | key_heads_transposed = key_heads.transpose(2, 3) 315 | dot_product = torch.matmul(query_heads, key_heads_transposed) # (batch_size, heads_count, query_len, key_len) 316 | attention_weights = dot_product / np.sqrt(self.d_head) 317 | return attention_weights 318 | 319 | 320 | class PointwiseFeedForwardNetwork(nn.Module): 321 | 322 | def __init__(self, d_ff, d_model, dropout_prob): 323 | super(PointwiseFeedForwardNetwork, self).__init__() 324 | 325 | self.feed_forward = nn.Sequential( 326 | nn.Linear(d_model, d_ff), 327 | nn.Dropout(dropout_prob), 328 | nn.ReLU(), 329 | nn.Linear(d_ff, d_model), 330 | nn.Dropout(dropout_prob), 331 | ) 332 | 333 | def forward(self, x): 334 | """ 335 | 336 | Args: 337 | x: (batch_size, seq_len, d_model) 338 | """ 339 | return self.feed_forward(x) 340 | 341 | 342 | class DecoderState: 343 | 344 | def __init__(self): 345 | self.previous_inputs = torch.tensor([]) 346 | self.layer_caches = defaultdict(lambda: {'self-attention': None, 'memory-attention': None}) 347 | 348 | def update_state(self, layer_index, layer_mode, key_projected, value_projected): 349 | self.layer_caches[layer_index][layer_mode] = { 350 | 'key_projected': key_projected, 351 | 'value_projected': value_projected 352 | } 353 | 354 | # def repeat_beam_size_times(self, beam_size): # memory만 repeat하면 되는데 state에 memory는 넣지 않기로 했다. 355 | # self. 356 | # self.src = self.src.data.repeat(beam_size, 1) 357 | 358 | def beam_update(self, positions): 359 | for layer_index in self.layer_caches: 360 | for mode in ('self-attention', 'memory-attention'): 361 | if self.layer_caches[layer_index][mode] is not None: 362 | for projection in self.layer_caches[layer_index][mode]: 363 | cache = self.layer_caches[layer_index][mode][projection] 364 | if cache is not None: 365 | cache.data.copy_(cache.data.index_select(0, positions)) 366 | -------------------------------------------------------------------------------- /optimizers.py: -------------------------------------------------------------------------------- 1 | from torch.optim import Adam 2 | 3 | 4 | class NoamOptimizer(Adam): 5 | 6 | def __init__(self, params, d_model, factor=2, warmup_steps=4000, betas=(0.9, 0.98), eps=1e-9): 7 | # self.optimizer = Adam(params, betas=betas, eps=eps) 8 | self.d_model = d_model 9 | self.warmup_steps = warmup_steps 10 | self.lr = 0 11 | self.step_num = 0 12 | self.factor = factor 13 | 14 | super(NoamOptimizer, self).__init__(params, betas=betas, eps=eps) 15 | 16 | def step(self, closure=None): 17 | self.step_num += 1 18 | self.lr = self.lrate() 19 | for group in self.param_groups: 20 | group['lr'] = self.lr 21 | super(NoamOptimizer, self).step() 22 | 23 | def lrate(self): 24 | return self.factor * self.d_model ** (-0.5) * min(self.step_num ** (-0.5), self.step_num * self.warmup_steps ** (-1.5)) -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | from predictors import Predictor 2 | from models import build_model 3 | from datasets import IndexedInputTargetTranslationDataset 4 | from dictionaries import IndexDictionary 5 | 6 | from argparse import ArgumentParser 7 | import json 8 | 9 | parser = ArgumentParser(description='Predict translation') 10 | parser.add_argument('--source', type=str) 11 | parser.add_argument('--config', type=str, required=True) 12 | parser.add_argument('--checkpoint', type=str) 13 | parser.add_argument('--num_candidates', type=int, default=3) 14 | 15 | args = parser.parse_args() 16 | with open(args.config) as f: 17 | config = json.load(f) 18 | 19 | print('Constructing dictionaries...') 20 | source_dictionary = IndexDictionary.load(config['data_dir'], mode='source', vocabulary_size=config['vocabulary_size']) 21 | target_dictionary = IndexDictionary.load(config['data_dir'], mode='target', vocabulary_size=config['vocabulary_size']) 22 | 23 | print('Building model...') 24 | model = build_model(config, source_dictionary.vocabulary_size, target_dictionary.vocabulary_size) 25 | 26 | predictor = Predictor( 27 | preprocess=IndexedInputTargetTranslationDataset.preprocess(source_dictionary), 28 | postprocess=lambda x: ' '.join([token for token in target_dictionary.tokenify_indexes(x) if token != '']), 29 | model=model, 30 | checkpoint_filepath=args.checkpoint 31 | ) 32 | 33 | for index, candidate in enumerate(predictor.predict_one(args.source, num_candidates=args.num_candidates)): 34 | print(f'Candidate {index} : {candidate}') 35 | -------------------------------------------------------------------------------- /predictors.py: -------------------------------------------------------------------------------- 1 | from beam import Beam 2 | from utils.pad import pad_masking 3 | 4 | import torch 5 | 6 | 7 | class Predictor: 8 | 9 | def __init__(self, preprocess, postprocess, model, checkpoint_filepath, max_length=30, beam_size=8): 10 | self.preprocess = preprocess 11 | self.postprocess = postprocess 12 | self.model = model 13 | self.max_length = max_length 14 | self.beam_size = beam_size 15 | 16 | self.model.eval() 17 | checkpoint = torch.load(checkpoint_filepath, map_location='cpu') 18 | self.model.load_state_dict(checkpoint) 19 | 20 | def predict_one(self, source, num_candidates=5): 21 | source_preprocessed = self.preprocess(source) 22 | source_tensor = torch.tensor(source_preprocessed).unsqueeze(0) # why unsqueeze? 23 | length_tensor = torch.tensor(len(source_preprocessed)).unsqueeze(0) 24 | 25 | sources_mask = pad_masking(source_tensor, source_tensor.size(1)) 26 | memory_mask = pad_masking(source_tensor, 1) 27 | memory = self.model.encoder(source_tensor, sources_mask) 28 | 29 | decoder_state = self.model.decoder.init_decoder_state() 30 | # print('decoder_state src', decoder_state.src.shape) 31 | # print('previous_input previous_input', decoder_state.previous_input) 32 | # print('previous_input previous_layer_inputs ', decoder_state.previous_layer_inputs) 33 | 34 | 35 | # Repeat beam_size times 36 | memory_beam = memory.detach().repeat(self.beam_size, 1, 1) # (beam_size, seq_len, hidden_size) 37 | 38 | beam = Beam(beam_size=self.beam_size, min_length=0, n_top=num_candidates, ranker=None) 39 | 40 | for _ in range(self.max_length): 41 | 42 | new_inputs = beam.get_current_state().unsqueeze(1) # (beam_size, seq_len=1) 43 | decoder_outputs, decoder_state = self.model.decoder(new_inputs, memory_beam, 44 | memory_mask, 45 | state=decoder_state) 46 | # decoder_outputs: (beam_size, target_seq_len=1, vocabulary_size) 47 | # attentions['std']: (target_seq_len=1, beam_size, source_seq_len) 48 | 49 | attention = self.model.decoder.decoder_layers[-1].memory_attention_layer.sublayer.attention 50 | beam.advance(decoder_outputs.squeeze(1), attention) 51 | 52 | beam_current_origin = beam.get_current_origin() # (beam_size, ) 53 | decoder_state.beam_update(beam_current_origin) 54 | 55 | if beam.done(): 56 | break 57 | 58 | scores, ks = beam.sort_finished(minimum=num_candidates) 59 | hypothesises, attentions = [], [] 60 | for i, (times, k) in enumerate(ks[:num_candidates]): 61 | hypothesis, attention = beam.get_hypothesis(times, k) 62 | hypothesises.append(hypothesis) 63 | attentions.append(attention) 64 | 65 | self.attentions = attentions 66 | self.hypothesises = [[token.item() for token in h] for h in hypothesises] 67 | hs = [self.postprocess(h) for h in self.hypothesises] 68 | return list(reversed(hs)) -------------------------------------------------------------------------------- /prepare_datasets.py: -------------------------------------------------------------------------------- 1 | from datasets import TranslationDataset, TranslationDatasetOnTheFly 2 | from datasets import TokenizedTranslationDataset, TokenizedTranslationDatasetOnTheFly 3 | from datasets import InputTargetTranslationDataset, InputTargetTranslationDatasetOnTheFly 4 | from datasets import IndexedInputTargetTranslationDataset, IndexedInputTargetTranslationDatasetOnTheFly 5 | from dictionaries import IndexDictionary 6 | from utils.pipe import shared_tokens_generator, source_tokens_generator, target_tokens_generator 7 | 8 | from argparse import ArgumentParser 9 | 10 | parser = ArgumentParser('Prepare datasets') 11 | parser.add_argument('--train_source', type=str, default='data/example/raw/src-train.txt') 12 | parser.add_argument('--train_target', type=str, default='data/example/raw/tgt-train.txt') 13 | parser.add_argument('--val_source', type=str, default='data/example/raw/src-val.txt') 14 | parser.add_argument('--val_target', type=str, default='data/example/raw/tgt-val.txt') 15 | parser.add_argument('--save_data_dir', type=str, default='data/example/processed') 16 | parser.add_argument('--share_dictionary', type=bool, default=False) 17 | 18 | args = parser.parse_args() 19 | 20 | TranslationDataset.prepare(args.train_source, args.train_target, args.val_source, args.val_target, args.save_data_dir) 21 | translation_dataset = TranslationDataset(args.save_data_dir, 'train') 22 | translation_dataset_on_the_fly = TranslationDatasetOnTheFly('train') 23 | assert translation_dataset[0] == translation_dataset_on_the_fly[0] 24 | 25 | tokenized_dataset = TokenizedTranslationDataset(args.save_data_dir, 'train') 26 | 27 | if args.share_dictionary: 28 | source_generator = shared_tokens_generator(tokenized_dataset) 29 | source_dictionary = IndexDictionary(source_generator, mode='source') 30 | target_generator = shared_tokens_generator(tokenized_dataset) 31 | target_dictionary = IndexDictionary(target_generator, mode='target') 32 | 33 | source_dictionary.save(args.save_data_dir) 34 | target_dictionary.save(args.save_data_dir) 35 | else: 36 | source_generator = source_tokens_generator(tokenized_dataset) 37 | source_dictionary = IndexDictionary(source_generator, mode='source') 38 | target_generator = target_tokens_generator(tokenized_dataset) 39 | target_dictionary = IndexDictionary(target_generator, mode='target') 40 | 41 | source_dictionary.save(args.save_data_dir) 42 | target_dictionary.save(args.save_data_dir) 43 | 44 | source_dictionary = IndexDictionary.load(args.save_data_dir, mode='source') 45 | target_dictionary = IndexDictionary.load(args.save_data_dir, mode='target') 46 | 47 | IndexedInputTargetTranslationDataset.prepare(args.save_data_dir, source_dictionary, target_dictionary) 48 | indexed_translation_dataset = IndexedInputTargetTranslationDataset(args.save_data_dir, 'train') 49 | indexed_translation_dataset_on_the_fly = IndexedInputTargetTranslationDatasetOnTheFly('train', source_dictionary, target_dictionary) 50 | assert indexed_translation_dataset[0] == indexed_translation_dataset_on_the_fly[0] 51 | 52 | print('Done datasets preparation.') -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from models import build_model 2 | from datasets import IndexedInputTargetTranslationDataset 3 | from dictionaries import IndexDictionary 4 | from losses import TokenCrossEntropyLoss, LabelSmoothingLoss 5 | from metrics import AccuracyMetric 6 | from optimizers import NoamOptimizer 7 | from trainer import EpochSeq2SeqTrainer 8 | from utils.log import get_logger 9 | from utils.pipe import input_target_collate_fn 10 | 11 | import torch 12 | from torch.optim import Adam 13 | from torch.utils.data import DataLoader 14 | import numpy as np 15 | 16 | from argparse import ArgumentParser 17 | from datetime import datetime 18 | import json 19 | import random 20 | 21 | parser = ArgumentParser(description='Train Transformer') 22 | parser.add_argument('--config', type=str, default=None) 23 | 24 | parser.add_argument('--data_dir', type=str, default='data/example/processed') 25 | parser.add_argument('--save_config', type=str, default=None) 26 | parser.add_argument('--save_checkpoint', type=str, default=None) 27 | parser.add_argument('--save_log', type=str, default=None) 28 | 29 | parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') 30 | 31 | parser.add_argument('--dataset_limit', type=int, default=None) 32 | parser.add_argument('--print_every', type=int, default=1) 33 | parser.add_argument('--save_every', type=int, default=1) 34 | 35 | parser.add_argument('--vocabulary_size', type=int, default=None) 36 | parser.add_argument('--positional_encoding', action='store_true') 37 | 38 | parser.add_argument('--d_model', type=int, default=128) 39 | parser.add_argument('--layers_count', type=int, default=1) 40 | parser.add_argument('--heads_count', type=int, default=2) 41 | parser.add_argument('--d_ff', type=int, default=128) 42 | parser.add_argument('--dropout_prob', type=float, default=0.1) 43 | 44 | parser.add_argument('--label_smoothing', type=float, default=0.1) 45 | parser.add_argument('--optimizer', type=str, default="Adam", choices=["Noam", "Adam"]) 46 | parser.add_argument('--lr', type=float, default=0.001) 47 | parser.add_argument('--clip_grads', action='store_true') 48 | 49 | parser.add_argument('--batch_size', type=int, default=64) 50 | parser.add_argument('--epochs', type=int, default=100) 51 | 52 | 53 | def run_trainer(config): 54 | random.seed(0) 55 | np.random.seed(0) 56 | torch.manual_seed(0) 57 | 58 | run_name_format = ( 59 | "d_model={d_model}-" 60 | "layers_count={layers_count}-" 61 | "heads_count={heads_count}-" 62 | "pe={positional_encoding}-" 63 | "optimizer={optimizer}-" 64 | "{timestamp}" 65 | ) 66 | 67 | run_name = run_name_format.format(**config, timestamp=datetime.now().strftime("%Y_%m_%d_%H_%M_%S")) 68 | 69 | logger = get_logger(run_name, save_log=config['save_log']) 70 | logger.info(f'Run name : {run_name}') 71 | logger.info(config) 72 | 73 | logger.info('Constructing dictionaries...') 74 | source_dictionary = IndexDictionary.load(config['data_dir'], mode='source', vocabulary_size=config['vocabulary_size']) 75 | target_dictionary = IndexDictionary.load(config['data_dir'], mode='target', vocabulary_size=config['vocabulary_size']) 76 | logger.info(f'Source dictionary vocabulary : {source_dictionary.vocabulary_size} tokens') 77 | logger.info(f'Target dictionary vocabulary : {target_dictionary.vocabulary_size} tokens') 78 | 79 | logger.info('Building model...') 80 | model = build_model(config, source_dictionary.vocabulary_size, target_dictionary.vocabulary_size) 81 | 82 | logger.info(model) 83 | logger.info('Encoder : {parameters_count} parameters'.format(parameters_count=sum([p.nelement() for p in model.encoder.parameters()]))) 84 | logger.info('Decoder : {parameters_count} parameters'.format(parameters_count=sum([p.nelement() for p in model.decoder.parameters()]))) 85 | logger.info('Total : {parameters_count} parameters'.format(parameters_count=sum([p.nelement() for p in model.parameters()]))) 86 | 87 | logger.info('Loading datasets...') 88 | train_dataset = IndexedInputTargetTranslationDataset( 89 | data_dir=config['data_dir'], 90 | phase='train', 91 | vocabulary_size=config['vocabulary_size'], 92 | limit=config['dataset_limit']) 93 | 94 | val_dataset = IndexedInputTargetTranslationDataset( 95 | data_dir=config['data_dir'], 96 | phase='val', 97 | vocabulary_size=config['vocabulary_size'], 98 | limit=config['dataset_limit']) 99 | 100 | train_dataloader = DataLoader( 101 | train_dataset, 102 | batch_size=config['batch_size'], 103 | shuffle=True, 104 | collate_fn=input_target_collate_fn) 105 | 106 | val_dataloader = DataLoader( 107 | val_dataset, 108 | batch_size=config['batch_size'], 109 | collate_fn=input_target_collate_fn) 110 | 111 | if config['label_smoothing'] > 0.0: 112 | loss_function = LabelSmoothingLoss(label_smoothing=config['label_smoothing'], 113 | vocabulary_size=target_dictionary.vocabulary_size) 114 | else: 115 | loss_function = TokenCrossEntropyLoss() 116 | 117 | accuracy_function = AccuracyMetric() 118 | 119 | if config['optimizer'] == 'Noam': 120 | optimizer = NoamOptimizer(model.parameters(), d_model=config['d_model']) 121 | elif config['optimizer'] == 'Adam': 122 | optimizer = Adam(model.parameters(), lr=config['lr']) 123 | else: 124 | raise NotImplementedError() 125 | 126 | logger.info('Start training...') 127 | trainer = EpochSeq2SeqTrainer( 128 | model=model, 129 | train_dataloader=train_dataloader, 130 | val_dataloader=val_dataloader, 131 | loss_function=loss_function, 132 | metric_function=accuracy_function, 133 | optimizer=optimizer, 134 | logger=logger, 135 | run_name=run_name, 136 | save_config=config['save_config'], 137 | save_checkpoint=config['save_checkpoint'], 138 | config=config 139 | ) 140 | 141 | trainer.run(config['epochs']) 142 | 143 | return trainer 144 | 145 | 146 | if __name__ == '__main__': 147 | 148 | args = parser.parse_args() 149 | 150 | if args.config is not None: 151 | with open(args.config) as f: 152 | config = json.load(f) 153 | 154 | default_config = vars(args) 155 | for key, default_value in default_config.items(): 156 | if key not in config: 157 | config[key] = default_value 158 | else: 159 | config = vars(args) # convert to dictionary 160 | 161 | run_trainer(config) 162 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tqdm import tqdm 4 | 5 | from os.path import dirname, abspath, join, exists 6 | from os import makedirs 7 | from datetime import datetime 8 | import json 9 | 10 | PAD_INDEX = 0 11 | 12 | BASE_DIR = dirname(abspath(__file__)) 13 | 14 | 15 | class EpochSeq2SeqTrainer: 16 | 17 | def __init__(self, model, 18 | train_dataloader, val_dataloader, 19 | loss_function, metric_function, optimizer, 20 | logger, run_name, 21 | save_config, save_checkpoint, 22 | config): 23 | 24 | self.config = config 25 | self.device = torch.device(self.config['device']) 26 | 27 | self.model = model.to(self.device) 28 | self.train_dataloader = train_dataloader 29 | self.val_dataloader = val_dataloader 30 | 31 | self.loss_function = loss_function.to(self.device) 32 | self.metric_function = metric_function 33 | self.optimizer = optimizer 34 | self.clip_grads = self.config['clip_grads'] 35 | 36 | self.logger = logger 37 | self.checkpoint_dir = join(BASE_DIR, 'checkpoints', run_name) 38 | 39 | if not exists(self.checkpoint_dir): 40 | makedirs(self.checkpoint_dir) 41 | 42 | if save_config is None: 43 | config_filepath = join(self.checkpoint_dir, 'config.json') 44 | else: 45 | config_filepath = save_config 46 | with open(config_filepath, 'w') as config_file: 47 | json.dump(self.config, config_file) 48 | 49 | self.print_every = self.config['print_every'] 50 | self.save_every = self.config['save_every'] 51 | 52 | self.epoch = 0 53 | self.history = [] 54 | 55 | self.start_time = datetime.now() 56 | 57 | self.best_val_metric = None 58 | self.best_checkpoint_filepath = None 59 | 60 | self.save_checkpoint = save_checkpoint 61 | self.save_format = 'epoch={epoch:0>3}-val_loss={val_loss:<.3}-val_metrics={val_metrics}.pth' 62 | 63 | self.log_format = ( 64 | "Epoch: {epoch:>3} " 65 | "Progress: {progress:<.1%} " 66 | "Elapsed: {elapsed} " 67 | "Examples/second: {per_second:<.1} " 68 | "Train Loss: {train_loss:<.6} " 69 | "Val Loss: {val_loss:<.6} " 70 | "Train Metrics: {train_metrics} " 71 | "Val Metrics: {val_metrics} " 72 | "Learning rate: {current_lr:<.4} " 73 | ) 74 | 75 | def run_epoch(self, dataloader, mode='train'): 76 | batch_losses = [] 77 | batch_counts = [] 78 | batch_metrics = [] 79 | for sources, inputs, targets in tqdm(dataloader): 80 | sources, inputs, targets = sources.to(self.device), inputs.to(self.device), targets.to(self.device) 81 | outputs = self.model(sources, inputs) 82 | 83 | batch_loss, batch_count = self.loss_function(outputs, targets) 84 | 85 | if mode == 'train': 86 | self.optimizer.zero_grad() 87 | batch_loss.backward() 88 | if self.clip_grads: 89 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1) 90 | self.optimizer.step() 91 | 92 | batch_losses.append(batch_loss.item()) 93 | batch_counts.append(batch_count) 94 | 95 | batch_metric, batch_metric_count = self.metric_function(outputs, targets) 96 | batch_metrics.append(batch_metric) 97 | 98 | assert batch_count == batch_metric_count 99 | 100 | if self.epoch == 0: # for testing 101 | return float('inf'), [float('inf')] 102 | 103 | epoch_loss = sum(batch_losses) / sum(batch_counts) 104 | epoch_accuracy = sum(batch_metrics) / sum(batch_counts) 105 | epoch_perplexity = float(np.exp(epoch_loss)) 106 | epoch_metrics = [epoch_perplexity, epoch_accuracy] 107 | 108 | return epoch_loss, epoch_metrics 109 | 110 | def run(self, epochs=10): 111 | 112 | for epoch in range(self.epoch, epochs + 1): 113 | self.epoch = epoch 114 | 115 | self.model.train() 116 | 117 | epoch_start_time = datetime.now() 118 | train_epoch_loss, train_epoch_metrics = self.run_epoch(self.train_dataloader, mode='train') 119 | epoch_end_time = datetime.now() 120 | 121 | self.model.eval() 122 | 123 | val_epoch_loss, val_epoch_metrics = self.run_epoch(self.val_dataloader, mode='val') 124 | 125 | if epoch % self.print_every == 0 and self.logger: 126 | per_second = len(self.train_dataloader.dataset) / ((epoch_end_time - epoch_start_time).seconds + 1) 127 | current_lr = self.optimizer.param_groups[0]['lr'] 128 | log_message = self.log_format.format(epoch=epoch, 129 | progress=epoch / epochs, 130 | per_second=per_second, 131 | train_loss=train_epoch_loss, 132 | val_loss=val_epoch_loss, 133 | train_metrics=[round(metric, 4) for metric in train_epoch_metrics], 134 | val_metrics=[round(metric, 4) for metric in val_epoch_metrics], 135 | current_lr=current_lr, 136 | elapsed=self._elapsed_time() 137 | ) 138 | 139 | self.logger.info(log_message) 140 | 141 | if epoch % self.save_every == 0: 142 | self._save_model(epoch, train_epoch_loss, val_epoch_loss, train_epoch_metrics, val_epoch_metrics) 143 | 144 | def _save_model(self, epoch, train_epoch_loss, val_epoch_loss, train_epoch_metrics, val_epoch_metrics): 145 | 146 | checkpoint_filename = self.save_format.format( 147 | epoch=epoch, 148 | val_loss=val_epoch_loss, 149 | val_metrics='-'.join(['{:<.3}'.format(v) for v in val_epoch_metrics]) 150 | ) 151 | 152 | if self.save_checkpoint is None: 153 | checkpoint_filepath = join(self.checkpoint_dir, checkpoint_filename) 154 | else: 155 | checkpoint_filepath = self.save_checkpoint 156 | 157 | save_state = { 158 | 'epoch': epoch, 159 | 'train_loss': train_epoch_loss, 160 | 'train_metrics': train_epoch_metrics, 161 | 'val_loss': val_epoch_loss, 162 | 'val_metrics': val_epoch_metrics, 163 | 'checkpoint': checkpoint_filepath, 164 | } 165 | 166 | if self.epoch > 0: 167 | torch.save(self.model.state_dict(), checkpoint_filepath) 168 | self.history.append(save_state) 169 | 170 | representative_val_metric = val_epoch_metrics[0] 171 | if self.best_val_metric is None or self.best_val_metric > representative_val_metric: 172 | self.best_val_metric = representative_val_metric 173 | self.val_loss_at_best = val_epoch_loss 174 | self.train_loss_at_best = train_epoch_loss 175 | self.train_metrics_at_best = train_epoch_metrics 176 | self.val_metrics_at_best = val_epoch_metrics 177 | self.best_checkpoint_filepath = checkpoint_filepath 178 | 179 | if self.logger: 180 | self.logger.info("Saved model to {}".format(checkpoint_filepath)) 181 | self.logger.info("Current best model is {}".format(self.best_checkpoint_filepath)) 182 | 183 | def _elapsed_time(self): 184 | now = datetime.now() 185 | elapsed = now - self.start_time 186 | return str(elapsed).split('.')[0] # remove milliseconds 187 | -------------------------------------------------------------------------------- /utils/log.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, abspath, join, exists 2 | import os 3 | import logging 4 | 5 | BASE_DIR = dirname(dirname(abspath(__file__))) 6 | 7 | 8 | def get_logger(run_name, save_log=None): 9 | log_dir = join(BASE_DIR, 'logs') 10 | if not exists(log_dir): 11 | os.makedirs(log_dir) 12 | 13 | log_filename = f'{run_name}.log' 14 | if save_log is None: 15 | log_filepath = join(log_dir, log_filename) 16 | else: 17 | log_filepath = save_log 18 | 19 | logger = logging.getLogger(run_name) 20 | 21 | if not logger.handlers: # execute only if logger doesn't already exist 22 | file_handler = logging.FileHandler(log_filepath, 'w', 'utf-8') 23 | stream_handler = logging.StreamHandler(os.sys.stdout) 24 | 25 | formatter = logging.Formatter('[%(levelname)s] %(asctime)s > %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 26 | 27 | file_handler.setFormatter(formatter) 28 | stream_handler.setFormatter(formatter) 29 | 30 | logger.addHandler(file_handler) 31 | logger.addHandler(stream_handler) 32 | logger.setLevel(logging.INFO) 33 | 34 | return logger -------------------------------------------------------------------------------- /utils/pad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | PAD_TOKEN_INDEX = 0 5 | 6 | 7 | def pad_masking(x, target_len): 8 | # x: (batch_size, seq_len) 9 | batch_size, seq_len = x.size() 10 | padded_positions = x == PAD_TOKEN_INDEX # (batch_size, seq_len) 11 | pad_mask = padded_positions.unsqueeze(1).expand(batch_size, target_len, seq_len) 12 | return pad_mask 13 | 14 | 15 | def subsequent_masking(x): 16 | # x: (batch_size, seq_len - 1) 17 | batch_size, seq_len = x.size() 18 | subsequent_mask = np.triu(np.ones(shape=(seq_len, seq_len)), k=1).astype('uint8') 19 | subsequent_mask = torch.tensor(subsequent_mask).to(x.device) 20 | subsequent_mask = subsequent_mask.unsqueeze(0).expand(batch_size, seq_len, seq_len) 21 | return subsequent_mask -------------------------------------------------------------------------------- /utils/pipe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | PAD_INDEX = 0 4 | 5 | 6 | def input_target_collate_fn(batch): 7 | """merges a list of samples to form a mini-batch.""" 8 | 9 | # indexed_sources = [sources for sources, inputs, targets in batch] 10 | # indexed_inputs = [inputs for sources, inputs, targets in batch] 11 | # indexed_targets = [targets for sources, inputs, targets in batch] 12 | 13 | sources_lengths = [len(sources) for sources, inputs, targets in batch] 14 | inputs_lengths = [len(inputs) for sources, inputs, targets in batch] 15 | targets_lengths = [len(targets) for sources, inputs, targets in batch] 16 | 17 | sources_max_length = max(sources_lengths) 18 | inputs_max_length = max(inputs_lengths) 19 | targets_max_length = max(targets_lengths) 20 | 21 | sources_padded = [sources + [PAD_INDEX] * (sources_max_length - len(sources)) for sources, inputs, targets in batch] 22 | inputs_padded = [inputs + [PAD_INDEX] * (inputs_max_length - len(inputs)) for sources, inputs, targets in batch] 23 | targets_padded = [targets + [PAD_INDEX] * (targets_max_length - len(targets)) for sources, inputs, targets in batch] 24 | 25 | sources_tensor = torch.tensor(sources_padded) 26 | inputs_tensor = torch.tensor(inputs_padded) 27 | targets_tensor = torch.tensor(targets_padded) 28 | 29 | # lengths = { 30 | # 'sources_lengths': torch.tensor(sources_lengths), 31 | # 'inputs_lengths': torch.tensor(inputs_lengths), 32 | # 'targets_lengths': torch.tensor(targets_lengths) 33 | # } 34 | 35 | return sources_tensor, inputs_tensor, targets_tensor 36 | 37 | 38 | def shared_tokens_generator(dataset): 39 | for source, target in dataset: 40 | for token in source: 41 | yield token 42 | for token in target: 43 | yield token 44 | 45 | 46 | def source_tokens_generator(dataset): 47 | for source, target in dataset: 48 | for token in source: 49 | yield token 50 | 51 | 52 | def target_tokens_generator(dataset): 53 | for source, target in dataset: 54 | for token in target: 55 | yield token --------------------------------------------------------------------------------