├── .gitignore
├── LICENSE
├── README.md
├── beam.py
├── configs
└── example_config.json
├── data
└── example
│ └── raw
│ ├── src-test.txt
│ ├── src-train.txt
│ ├── src-val.txt
│ ├── tgt-train.txt
│ └── tgt-val.txt
├── datasets.py
├── dictionaries.py
├── embeddings.py
├── evaluate.py
├── evaluator.py
├── losses.py
├── metrics.py
├── models.py
├── optimizers.py
├── predict.py
├── predictors.py
├── prepare_datasets.py
├── train.py
├── trainer.py
└── utils
├── log.py
├── pad.py
└── pipe.py
/.gitignore:
--------------------------------------------------------------------------------
1 | checkpoints/
2 | logs/
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | .hypothesis/
51 | .pytest_cache/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 | db.sqlite3
61 |
62 | # Flask stuff:
63 | instance/
64 | .webassets-cache
65 |
66 | # Scrapy stuff:
67 | .scrapy
68 |
69 | # Sphinx documentation
70 | docs/_build/
71 |
72 | # PyBuilder
73 | target/
74 |
75 | # Jupyter Notebook
76 | .ipynb_checkpoints
77 |
78 | # pyenv
79 | .python-version
80 |
81 | # celery beat schedule file
82 | celerybeat-schedule
83 |
84 | # SageMath parsed files
85 | *.sage.py
86 |
87 | # Environments
88 | .env
89 | .venv
90 | env/
91 | venv/
92 | ENV/
93 | env.bak/
94 | venv.bak/
95 |
96 | # Spyder project settings
97 | .spyderproject
98 | .spyproject
99 |
100 | # Rope project settings
101 | .ropeproject
102 |
103 | # mkdocs documentation
104 | /site
105 |
106 | # mypy
107 | .mypy_cache/
108 |
109 | \.idea/
110 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Yongrae Jo
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Transformer-pytorch
2 | A PyTorch implementation of Transformer in "Attention is All You Need" (https://arxiv.org/abs/1706.03762)
3 |
4 | This repo focuses on clean, readable, and modular implementation of the paper.
5 |
6 |
7 |
8 | ## Requirements
9 | - Python 3.6+
10 | - [PyTorch 4.1+](http://pytorch.org/)
11 | - [NumPy](http://www.numpy.org/)
12 | - [NLTK](https://www.nltk.org/)
13 | - [tqdm](https://github.com/tqdm/tqdm)
14 |
15 | ## Usage
16 |
17 | ### Prepare datasets
18 | This repo comes with example data in `data/` directory. To begin, you will need to prepare datasets with given data as follows:
19 | ```
20 | $ python prepare_datasets.py --train_source=data/example/raw/src-train.txt --train_target=data/example/raw/tgt-train.txt --val_source=data/example/raw/src-val.txt --val_target=data/example/raw/tgt-val.txt --save_data_dir=data/example/processed
21 | ```
22 |
23 | The example data is brought from [OpenNMT-py](https://github.com/OpenNMT/OpenNMT-py).
24 | The data consists of parallel source (src) and target (tgt) data for training and validation.
25 | A data file contains one sentence per line with tokens separated by a space.
26 | Below are the provided example data files.
27 |
28 | - `src-train.txt`
29 | - `tgt-train.txt`
30 | - `src-val.txt`
31 | - `tgt-val.txt`
32 |
33 | ### Train model
34 | To train model, provide the train script with a path to processed data and save files as follows:
35 |
36 | ```
37 | $ python train.py --data_dir=data/example/processed --save_config=checkpoints/example_config.json --save_checkpoint=checkpoints/example_model.pth --save_log=logs/example.log
38 | ```
39 |
40 | This saves model config and checkpoints to given files, respectively.
41 | You can play around with hyperparameters of the model with command line arguments.
42 | For example, add `--epochs=300` to set the number of epochs to 300.
43 |
44 | ### Translate
45 | To translate a sentence in source language to target language:
46 | ```
47 | $ python predict.py --source="There is an imbalance here ." --config=checkpoints/example_config.json --checkpoint=checkpoints/example_model.pth
48 |
49 | Candidate 0 : Hier fehlt das Gleichgewicht .
50 | Candidate 1 : Hier fehlt das das Gleichgewicht .
51 | Candidate 2 : Hier fehlt das das das Gleichgewicht .
52 | ```
53 |
54 | It will give you translation candidates of the given source sentence.
55 | You can adjust the number of candidates with command line argument.
56 |
57 | ### Evaluate
58 | To calculate BLEU score of a trained model:
59 | ```
60 | $ python evaluate.py --save_result=logs/example_eval.txt --config=checkpoints/example_config.json --checkpoint=checkpoints/example_model.pth
61 |
62 | BLEU score : 0.0007947
63 | ```
64 |
65 | ## File description
66 | - `models.py` includes Transformer's encoder, decoder, and multi-head attention.
67 | - `embeddings.py` contains positional encoding.
68 | - `losses.py` contains label smoothing loss.
69 | - `optimizers.py` contains Noam optimizer.
70 | - `metrics.py` contains accuracy metric.
71 | - `beam.py` contains beam search.
72 | - `datasets.py` has code for loading and processing data.
73 | - `trainer.py` has code for training model.
74 | - `prepare_datasets.py` processes data.
75 | - `train.py` trains model.
76 | - `predict.py` translates given source sentence with a trained model.
77 | - `evaluate.py` calculates BLEU score of a trained model.
78 |
79 | ## Reference
80 | - [OpenNMT-py](https://github.com/OpenNMT/OpenNMT-py)
81 |
82 | ## Author
83 | [@dreamgonfly](https://github.com/dreamgonfly)
--------------------------------------------------------------------------------
/beam.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Beam:
5 |
6 | def __init__(self, beam_size=8, min_length=0, n_top=1, ranker=None,
7 | start_token_id=2, end_token_id=3):
8 | self.beam_size = beam_size
9 | self.min_length = min_length
10 | self.ranker = ranker
11 |
12 | self.end_token_id = end_token_id
13 | self.top_sentence_ended = False
14 |
15 | self.prev_ks = []
16 | self.next_ys = [torch.LongTensor(beam_size).fill_(start_token_id)] # remove padding
17 |
18 | self.current_scores = torch.FloatTensor(beam_size).zero_()
19 | self.all_scores = []
20 |
21 | # The attentions (matrix) for each time.
22 | self.all_attentions = []
23 |
24 | self.finished = []
25 |
26 |
27 |
28 | # Time and k pair for finished.
29 | self.finished = []
30 | self.n_top = n_top
31 |
32 | self.ranker = ranker
33 |
34 | def advance(self, next_log_probs, current_attention):
35 | # next_probs : beam_size X vocab_size
36 | # current_attention: (target_seq_len=1, beam_size, source_seq_len)
37 |
38 | vocabulary_size = next_log_probs.size(1)
39 | # current_beam_size = next_log_probs.size(0)
40 |
41 | current_length = len(self.next_ys)
42 | if current_length < self.min_length:
43 | for beam_index in range(len(next_log_probs)):
44 | next_log_probs[beam_index][self.end_token_id] = -1e10
45 |
46 | if len(self.prev_ks) > 0:
47 | beam_scores = next_log_probs + self.current_scores.unsqueeze(1).expand_as(next_log_probs)
48 | # Don't let EOS have children.
49 | last_y = self.next_ys[-1]
50 | for beam_index in range(last_y.size(0)):
51 | if last_y[beam_index] == self.end_token_id:
52 | beam_scores[beam_index] = -1e10 # -1e20 raises error when executing
53 | else:
54 | beam_scores = next_log_probs[0]
55 | flat_beam_scores = beam_scores.view(-1)
56 | top_scores, top_score_ids = flat_beam_scores.topk(k=self.beam_size, dim=0, largest=True, sorted=True)
57 |
58 | self.current_scores = top_scores
59 | self.all_scores.append(self.current_scores)
60 |
61 | prev_k = top_score_ids / vocabulary_size # (beam_size, )
62 | next_y = top_score_ids - prev_k * vocabulary_size # (beam_size, )
63 |
64 | self.prev_ks.append(prev_k)
65 | self.next_ys.append(next_y)
66 | # for RNN, dim=1 and for transformer, dim=0.
67 | prev_attention = current_attention.index_select(dim=0, index=prev_k) # (target_seq_len=1, beam_size, source_seq_len)
68 | self.all_attentions.append(prev_attention)
69 |
70 |
71 | for beam_index, last_token_id in enumerate(next_y):
72 | if last_token_id == self.end_token_id:
73 | # skip scoring
74 | self.finished.append((self.current_scores[beam_index], len(self.next_ys) - 1, beam_index))
75 |
76 | if next_y[0] == self.end_token_id:
77 | self.top_sentence_ended = True
78 |
79 | def get_current_state(self):
80 | "Get the outputs for the current timestep."
81 | return self.next_ys[-1]
82 |
83 | def get_current_origin(self):
84 | "Get the backpointers for the current timestep."
85 | return self.prev_ks[-1]
86 |
87 | def done(self):
88 | return self.top_sentence_ended and len(self.finished) >= self.n_top
89 |
90 | def get_hypothesis(self, timestep, k):
91 | hypothesis, attentions = [], []
92 | for j in range(len(self.prev_ks[:timestep]) - 1, -1, -1):
93 | hypothesis.append(self.next_ys[j + 1][k])
94 | # for RNN, [:, k, :], and for trnasformer, [k, :, :]
95 | attentions.append(self.all_attentions[j][k, :, :])
96 | k = self.prev_ks[j][k]
97 | attentions_tensor = torch.stack(attentions[::-1]).squeeze(1) # (timestep, source_seq_len)
98 | return hypothesis[::-1], attentions_tensor
99 |
100 | def sort_finished(self, minimum=None):
101 | if minimum is not None:
102 | i = 0
103 | # Add from beam until we have minimum outputs.
104 | while len(self.finished) < minimum:
105 | # global_scores = self.global_scorer.score(self, self.scores)
106 | # s = global_scores[i]
107 | s = self.current_scores[i]
108 | self.finished.append((s, len(self.next_ys) - 1, i))
109 | i += 1
110 |
111 | self.finished = sorted(self.finished, key=lambda a: a[0], reverse=True)
112 | scores = [sc for sc, _, _ in self.finished]
113 | ks = [(t, k) for _, t, k in self.finished]
114 | return scores, ks
--------------------------------------------------------------------------------
/configs/example_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_limit": null,
3 | "print_every": 1,
4 | "save_every": 1,
5 |
6 | "vocabulary_size": null,
7 | "share_dictionary": false,
8 | "positional_encoding": true,
9 |
10 | "d_model": 128,
11 | "layers_count": 1,
12 | "heads_count": 2,
13 | "d_ff": 128,
14 | "dropout_prob": 0.1,
15 |
16 | "label_smoothing": 0.1,
17 | "optimizer": "Noam",
18 | "lr": 0.001,
19 | "clip_grads": true,
20 |
21 | "batch_size": 10,
22 | "epochs": 10
23 | }
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, abspath, join, exists
2 | from os import makedirs
3 | from dictionaries import START_TOKEN, END_TOKEN
4 | UNK_INDEX = 1
5 |
6 | BASE_DIR = dirname(abspath(__file__))
7 |
8 |
9 | class TranslationDatasetOnTheFly:
10 |
11 | def __init__(self, phase, limit=None):
12 | assert phase in ('train', 'val'), "Dataset phase must be either 'train' or 'val'"
13 |
14 | self.limit = limit
15 |
16 | if phase == 'train':
17 | source_filepath = join(BASE_DIR, 'data', 'example', 'raw', 'src-train.txt')
18 | target_filepath = join(BASE_DIR, 'data', 'example', 'raw', 'tgt-train.txt')
19 | elif phase == 'val':
20 | source_filepath = join(BASE_DIR, 'data', 'example', 'raw', 'src-val.txt')
21 | target_filepath = join(BASE_DIR, 'data', 'example', 'raw', 'tgt-val.txt')
22 | else:
23 | raise NotImplementedError()
24 |
25 | with open(source_filepath) as source_file:
26 | self.source_data = source_file.readlines()
27 |
28 | with open(target_filepath) as target_filepath:
29 | self.target_data = target_filepath.readlines()
30 |
31 | def __getitem__(self, item):
32 | if self.limit is not None and item >= self.limit:
33 | raise IndexError()
34 |
35 | source = self.source_data[item].strip()
36 | target = self.target_data[item].strip()
37 | return source, target
38 |
39 | def __len__(self):
40 | if self.limit is None:
41 | return len(self.source_data)
42 | else:
43 | return self.limit
44 |
45 |
46 | class TranslationDataset:
47 |
48 | def __init__(self, data_dir, phase, limit=None):
49 | assert phase in ('train', 'val'), "Dataset phase must be either 'train' or 'val'"
50 |
51 | self.limit = limit
52 |
53 | self.data = []
54 | with open(join(data_dir, f'raw-{phase}.txt')) as file:
55 | for line in file:
56 | source, target = line.strip().split('\t')
57 | self.data.append((source, target))
58 |
59 | def __getitem__(self, item):
60 | if self.limit is not None and item >= self.limit:
61 | raise IndexError()
62 |
63 | return self.data[item]
64 |
65 | def __len__(self):
66 | if self.limit is None:
67 | return len(self.data)
68 | else:
69 | return self.limit
70 |
71 | @staticmethod
72 | def prepare(train_source, train_target, val_source, val_target, save_data_dir):
73 |
74 | if not exists(save_data_dir):
75 | makedirs(save_data_dir)
76 |
77 | for phase in ('train', 'val'):
78 |
79 | if phase == 'train':
80 | source_filepath = train_source
81 | target_filepath = train_target
82 | else:
83 | source_filepath = val_source
84 | target_filepath = val_target
85 |
86 | with open(source_filepath) as source_file:
87 | source_data = source_file.readlines()
88 |
89 | with open(target_filepath) as target_filepath:
90 | target_data = target_filepath.readlines()
91 |
92 | with open(join(save_data_dir, f'raw-{phase}.txt'), 'w') as file:
93 | for source_line, target_line in zip(source_data, target_data):
94 | source_line = source_line.strip()
95 | target_line = target_line.strip()
96 | line = f'{source_line}\t{target_line}\n'
97 | file.write(line)
98 |
99 |
100 | class TokenizedTranslationDatasetOnTheFly:
101 |
102 | def __init__(self, phase, limit=None):
103 |
104 | self.raw_dataset = TranslationDatasetOnTheFly(phase, limit)
105 |
106 | def __getitem__(self, item):
107 | raw_source, raw_target = self.raw_dataset[item]
108 | tokenized_source = raw_source.split()
109 | tokenized_target = raw_target.split()
110 | return tokenized_source, tokenized_target
111 |
112 | def __len__(self):
113 | return len(self.raw_dataset)
114 |
115 |
116 | class TokenizedTranslationDataset:
117 |
118 | def __init__(self, data_dir, phase, limit=None):
119 |
120 | self.raw_dataset = TranslationDataset(data_dir, phase, limit)
121 |
122 | def __getitem__(self, item):
123 | raw_source, raw_target = self.raw_dataset[item]
124 | tokenized_source = raw_source.split()
125 | tokenized_target = raw_target.split()
126 | return tokenized_source, tokenized_target
127 |
128 | def __len__(self):
129 | return len(self.raw_dataset)
130 |
131 |
132 | class InputTargetTranslationDatasetOnTheFly:
133 |
134 | def __init__(self, phase, limit=None):
135 | self.tokenized_dataset = TokenizedTranslationDatasetOnTheFly(phase, limit)
136 |
137 | def __getitem__(self, item):
138 | tokenized_source, tokenized_target = self.tokenized_dataset[item]
139 | full_target = [START_TOKEN] + tokenized_target + [END_TOKEN]
140 | inputs = full_target[:-1]
141 | targets = full_target[1:]
142 | return tokenized_source, inputs, targets
143 |
144 | def __len__(self):
145 | return len(self.tokenized_dataset)
146 |
147 |
148 | class InputTargetTranslationDataset:
149 |
150 | def __init__(self, data_dir, phase, limit=None):
151 | self.tokenized_dataset = TokenizedTranslationDataset(data_dir, phase, limit)
152 |
153 | def __getitem__(self, item):
154 | tokenized_source, tokenized_target = self.tokenized_dataset[item]
155 | full_target = [START_TOKEN] + tokenized_target + [END_TOKEN]
156 | inputs = full_target[:-1]
157 | targets = full_target[1:]
158 | return tokenized_source, inputs, targets
159 |
160 | def __len__(self):
161 | return len(self.tokenized_dataset)
162 |
163 |
164 | class IndexedInputTargetTranslationDatasetOnTheFly:
165 |
166 | def __init__(self, phase, source_dictionary, target_dictionary, limit=None):
167 |
168 | self.input_target_dataset = InputTargetTranslationDatasetOnTheFly(phase, limit)
169 | self.source_dictionary = source_dictionary
170 | self.target_dictionary = target_dictionary
171 |
172 | def __getitem__(self, item):
173 | source, inputs, targets = self.input_target_dataset[item]
174 | indexed_source = self.source_dictionary.index_sentence(source)
175 | indexed_inputs = self.target_dictionary.index_sentence(inputs)
176 | indexed_targets = self.target_dictionary.index_sentence(targets)
177 |
178 | return indexed_source, indexed_inputs, indexed_targets
179 |
180 | def __len__(self):
181 | return len(self.input_target_dataset)
182 |
183 | @staticmethod
184 | def preprocess(source_dictionary):
185 |
186 | def preprocess_function(source):
187 | source_tokens = source.strip().split()
188 | indexed_source = source_dictionary.index_sentence(source_tokens)
189 | return indexed_source
190 |
191 | return preprocess_function
192 |
193 |
194 | class IndexedInputTargetTranslationDataset:
195 |
196 | def __init__(self, data_dir, phase, vocabulary_size=None, limit=None):
197 |
198 | self.data = []
199 |
200 | unknownify = lambda index: index if index < vocabulary_size else UNK_INDEX
201 | with open(join(data_dir, f'indexed-{phase}.txt')) as file:
202 | for line in file:
203 | sources, inputs, targets = line.strip().split('\t')
204 | if vocabulary_size is not None:
205 | indexed_sources = [unknownify(int(index)) for index in sources.strip().split(' ')]
206 | indexed_inputs = [unknownify(int(index)) for index in inputs.strip().split(' ')]
207 | indexed_targets = [unknownify(int(index)) for index in targets.strip().split(' ')]
208 | else:
209 | indexed_sources = [int(index) for index in sources.strip().split(' ')]
210 | indexed_inputs = [int(index) for index in inputs.strip().split(' ')]
211 | indexed_targets = [int(index) for index in targets.strip().split(' ')]
212 | self.data.append((indexed_sources, indexed_inputs, indexed_targets))
213 | if limit is not None and len(self.data) >= limit:
214 | break
215 |
216 | self.vocabulary_size = vocabulary_size
217 | self.limit = limit
218 |
219 | def __getitem__(self, item):
220 | if self.limit is not None and item >= self.limit:
221 | raise IndexError()
222 |
223 | indexed_sources, indexed_inputs, indexed_targets = self.data[item]
224 | return indexed_sources, indexed_inputs, indexed_targets
225 |
226 | def __len__(self):
227 | if self.limit is None:
228 | return len(self.data)
229 | else:
230 | return self.limit
231 |
232 | @staticmethod
233 | def preprocess(source_dictionary):
234 |
235 | def preprocess_function(source):
236 | source_tokens = source.strip().split()
237 | indexed_source = source_dictionary.index_sentence(source_tokens)
238 | return indexed_source
239 |
240 | return preprocess_function
241 |
242 | @staticmethod
243 | def prepare(data_dir, source_dictionary, target_dictionary):
244 |
245 | join_indexes = lambda indexes: ' '.join(str(index) for index in indexes)
246 | for phase in ('train', 'val'):
247 | input_target_dataset = InputTargetTranslationDataset(data_dir, phase)
248 |
249 | with open(join(data_dir, f'indexed-{phase}.txt'), 'w') as file:
250 | for sources, inputs, targets in input_target_dataset:
251 | indexed_sources = join_indexes(source_dictionary.index_sentence(sources))
252 | indexed_inputs = join_indexes(target_dictionary.index_sentence(inputs))
253 | indexed_targets = join_indexes(target_dictionary.index_sentence(targets))
254 | file.write(f'{indexed_sources}\t{indexed_inputs}\t{indexed_targets}\n')
255 |
--------------------------------------------------------------------------------
/dictionaries.py:
--------------------------------------------------------------------------------
1 | from collections import Counter
2 | from os.path import dirname, abspath, join, exists
3 | from os import makedirs
4 |
5 | BASE_DIR = dirname(abspath(__file__))
6 |
7 | PAD_TOKEN = ''
8 | UNK_TOKEN = ''
9 | START_TOKEN = ''
10 | END_TOKEN = ''
11 |
12 |
13 | class IndexDictionary:
14 |
15 | def __init__(self, iterable=None, mode='shared', vocabulary_size=None):
16 |
17 | self.special_tokens = [PAD_TOKEN, UNK_TOKEN, START_TOKEN, END_TOKEN]
18 |
19 | # On-the-fly mode
20 | if iterable is not None:
21 |
22 | self.vocab_tokens, self.token_counts = self._build_vocabulary(iterable, vocabulary_size)
23 | self.token_index_dict = {token: index for index, token in enumerate(self.vocab_tokens)}
24 | self.vocabulary_size = len(self.vocab_tokens)
25 |
26 | self.mode = mode
27 |
28 | def token_to_index(self, token):
29 | try:
30 | return self.token_index_dict[token]
31 | except KeyError:
32 | return self.token_index_dict[UNK_TOKEN]
33 |
34 | def index_to_token(self, index):
35 | if index >= self.vocabulary_size:
36 | return self.vocab_tokens[UNK_TOKEN]
37 | else:
38 | return self.vocab_tokens[index]
39 |
40 | def index_sentence(self, sentence):
41 | return [self.token_to_index(token) for token in sentence]
42 |
43 | def tokenify_indexes(self, token_indexes):
44 | return [self.index_to_token(token_index) for token_index in token_indexes]
45 |
46 | def _build_vocabulary(self, iterable, vocabulary_size):
47 |
48 | counter = Counter()
49 | for token in iterable:
50 | counter[token] += 1
51 |
52 | if vocabulary_size is not None:
53 | most_commons = counter.most_common(vocabulary_size - len(self.special_tokens))
54 | frequent_tokens = [token for token, count in most_commons]
55 | vocab_tokens = self.special_tokens + frequent_tokens
56 | token_counts = [0] * len(self.special_tokens) + [count for token, count in most_commons]
57 | else:
58 | all_tokens = [token for token, count in counter.items()]
59 | vocab_tokens = self.special_tokens + all_tokens
60 | token_counts = [0] * len(self.special_tokens) + [count for token, count in counter.items()]
61 |
62 | return vocab_tokens, token_counts
63 |
64 | def save(self, data_dir):
65 |
66 | vocabulary_filepath = join(data_dir, f'vocabulary-{self.mode}.txt')
67 | with open(vocabulary_filepath, 'w') as file:
68 | for vocab_index, (vocab_token, count) in enumerate(zip(self.vocab_tokens, self.token_counts)):
69 | file.write(str(vocab_index) + '\t' + vocab_token + '\t' + str(count) + '\n')
70 |
71 | @classmethod
72 | def load(cls, data_dir, mode='shared', vocabulary_size=None):
73 | vocabulary_filepath = join(data_dir, f'vocabulary-{mode}.txt')
74 |
75 | vocab_tokens = {}
76 | token_counts = []
77 | with open(vocabulary_filepath) as file:
78 | for line in file:
79 | vocab_index, vocab_token, count = line.strip().split('\t')
80 | vocab_index = int(vocab_index)
81 | vocab_tokens[vocab_index] = vocab_token
82 | token_counts.append(int(count))
83 |
84 | if vocabulary_size is not None:
85 | vocab_tokens = {k: v for k, v in vocab_tokens.items() if k < vocabulary_size}
86 | token_counts = token_counts[:vocabulary_size]
87 |
88 | instance = cls(mode=mode)
89 | instance.vocab_tokens = vocab_tokens
90 | instance.token_counts = token_counts
91 | instance.token_index_dict = {token: index for index, token in vocab_tokens.items()}
92 | instance.vocabulary_size = len(vocab_tokens)
93 |
94 | return instance
95 |
96 |
--------------------------------------------------------------------------------
/embeddings.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import math
4 |
5 |
6 | class PositionalEncoding(nn.Module):
7 | """
8 | Implements the sinusoidal positional encoding for
9 | non-recurrent neural networks.
10 |
11 | Implementation based on "Attention Is All You Need"
12 | :cite:`DBLP:journals/corr/VaswaniSPUJGKP17`
13 |
14 | Args:
15 | dropout_prob (float): dropout parameter
16 | dim (int): embedding size
17 | """
18 |
19 | def __init__(self, num_embeddings, embedding_dim, dim, dropout_prob=0., padding_idx=0, max_len=5000):
20 | super(PositionalEncoding, self).__init__()
21 |
22 | pe = torch.zeros(max_len, dim)
23 | position = torch.arange(0, max_len).unsqueeze(1)
24 | div_term = torch.exp((torch.arange(0, dim, 2) *
25 | -(math.log(10000.0) / dim)).float())
26 | pe[:, 0::2] = torch.sin(position.float() * div_term)
27 | pe[:, 1::2] = torch.cos(position.float() * div_term)
28 | pe = pe.unsqueeze(0)
29 |
30 | self.num_embeddings = num_embeddings
31 | self.embedding_dim = embedding_dim
32 | self.embbedding = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
33 | self.weight = self.embbedding.weight
34 | self.register_buffer('pe', pe)
35 | self.dropout = nn.Dropout(p=dropout_prob)
36 | self.dim = dim
37 |
38 | def forward(self, x, step=None):
39 | x = self.embbedding(x)
40 | x = x * math.sqrt(self.dim)
41 | if step is None:
42 | x = x + self.pe[:, :x.size(1)]
43 | else:
44 | x = x + self.pe[:, step]
45 | x = self.dropout(x)
46 | return x
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | from evaluator import Evaluator
2 | from predictors import Predictor
3 | from models import build_model
4 | from datasets import TranslationDataset
5 | from datasets import IndexedInputTargetTranslationDataset
6 | from dictionaries import IndexDictionary
7 |
8 | from argparse import ArgumentParser
9 | import json
10 | from datetime import datetime
11 |
12 | parser = ArgumentParser(description='Predict translation')
13 | parser.add_argument('--save_result', type=str, default=None)
14 | parser.add_argument('--config', type=str, required=True)
15 | parser.add_argument('--checkpoint', type=str, required=True)
16 | parser.add_argument('--phase', type=str, default='val', choices=['train', 'val'])
17 |
18 | args = parser.parse_args()
19 | with open(args.config) as f:
20 | config = json.load(f)
21 |
22 | print('Constructing dictionaries...')
23 | source_dictionary = IndexDictionary.load(config['data_dir'], mode='source', vocabulary_size=config['vocabulary_size'])
24 | target_dictionary = IndexDictionary.load(config['data_dir'], mode='target', vocabulary_size=config['vocabulary_size'])
25 |
26 | print('Building model...')
27 | model = build_model(config, source_dictionary.vocabulary_size, target_dictionary.vocabulary_size)
28 |
29 | predictor = Predictor(
30 | preprocess=IndexedInputTargetTranslationDataset.preprocess(source_dictionary),
31 | postprocess=lambda x: ' '.join([token for token in target_dictionary.tokenify_indexes(x) if token != '']),
32 | model=model,
33 | checkpoint_filepath=args.checkpoint
34 | )
35 |
36 | timestamp = datetime.now()
37 | if args.save_result is None:
38 | eval_filepath = 'logs/eval-{config}-time={timestamp}.csv'.format(
39 | config=args.config.replace('/', '-'),
40 | timestamp=timestamp.strftime("%Y_%m_%d_%H_%M_%S"))
41 | else:
42 | eval_filepath = args.save_result
43 |
44 | evaluator = Evaluator(
45 | predictor=predictor,
46 | save_filepath=eval_filepath
47 | )
48 |
49 | print('Evaluating...')
50 | test_dataset = TranslationDataset(config['data_dir'], args.phase, limit=1000)
51 | bleu_score = evaluator.evaluate_dataset(test_dataset)
52 | print('Evaluation time :', datetime.now() - timestamp)
53 |
54 | print("BLEU score :", bleu_score)
55 |
56 |
57 |
--------------------------------------------------------------------------------
/evaluator.py:
--------------------------------------------------------------------------------
1 | from nltk.translate.bleu_score import sentence_bleu, corpus_bleu, SmoothingFunction
2 | from tqdm import tqdm
3 |
4 |
5 | class Evaluator:
6 |
7 | def __init__(self, predictor, save_filepath):
8 |
9 | self.predictor = predictor
10 | self.save_filepath = save_filepath
11 |
12 | def evaluate_dataset(self, test_dataset):
13 | tokenize = lambda x: x.split()
14 |
15 | predictions = []
16 | for source, target in tqdm(test_dataset):
17 | prediction = self.predictor.predict_one(source, num_candidates=1)[0]
18 | predictions.append(prediction)
19 |
20 | hypotheses = [tokenize(prediction) for prediction in predictions]
21 | list_of_references = [[tokenize(target)] for source, target in test_dataset]
22 | smoothing_function = SmoothingFunction()
23 |
24 | with open(self.save_filepath, 'w') as file:
25 | for (source, target), prediction, hypothesis, references in zip(test_dataset, predictions,
26 | hypotheses, list_of_references):
27 | sentence_bleu_score = sentence_bleu(references, hypothesis,
28 | smoothing_function=smoothing_function.method3)
29 | line = "{bleu_score}\t{source}\t{target}\t|\t{prediction}".format(
30 | bleu_score=sentence_bleu_score,
31 | source=source,
32 | target=target,
33 | prediction=prediction
34 | )
35 | file.write(line + '\n')
36 |
37 | bleu_score = corpus_bleu(list_of_references, hypotheses, smoothing_function=smoothing_function.method3)
38 |
39 | return bleu_score
40 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class TokenCrossEntropyLoss(nn.Module):
6 |
7 | def __init__(self, pad_index=0):
8 | super(TokenCrossEntropyLoss, self).__init__()
9 |
10 | self.pad_index = pad_index
11 | self.base_loss_function = nn.CrossEntropyLoss(reduction='sum', ignore_index=pad_index)
12 |
13 | def forward(self, outputs, targets):
14 | batch_size, seq_len, vocabulary_size = outputs.size()
15 |
16 | outputs_flat = outputs.view(batch_size * seq_len, vocabulary_size)
17 | targets_flat = targets.view(batch_size * seq_len)
18 |
19 | batch_loss = self.base_loss_function(outputs_flat, targets_flat)
20 |
21 | count = (targets != self.pad_index).sum().item()
22 |
23 | return batch_loss, count
24 |
25 |
26 | class LabelSmoothingLoss(nn.Module):
27 | """
28 | With label smoothing,
29 | KL-divergence between q_{smoothed ground truth prob.}(w)
30 | and p_{prob. computed by model}(w) is minimized.
31 | """
32 | def __init__(self, label_smoothing, vocabulary_size, pad_index=0):
33 | assert 0.0 < label_smoothing <= 1.0
34 |
35 | super(LabelSmoothingLoss, self).__init__()
36 |
37 | self.pad_index = pad_index
38 | self.log_softmax = nn.LogSoftmax(dim=-1)
39 | self.criterion = nn.KLDivLoss(reduction='sum')
40 |
41 | smoothing_value = label_smoothing / (vocabulary_size - 2) # exclude pad and true label
42 | smoothed_targets = torch.full((vocabulary_size,), smoothing_value)
43 | smoothed_targets[self.pad_index] = 0
44 | self.register_buffer('smoothed_targets', smoothed_targets.unsqueeze(0)) # (1, vocabulary_size)
45 |
46 | self.confidence = 1.0 - label_smoothing
47 |
48 | def forward(self, outputs, targets):
49 | """
50 | outputs (FloatTensor): (batch_size, seq_len, vocabulary_size)
51 | targets (LongTensor): (batch_size, seq_len)
52 | """
53 | batch_size, seq_len, vocabulary_size = outputs.size()
54 |
55 | outputs_log_softmax = self.log_softmax(outputs)
56 | outputs_flat = outputs_log_softmax.view(batch_size * seq_len, vocabulary_size)
57 | targets_flat = targets.view(batch_size * seq_len)
58 |
59 | smoothed_targets = self.smoothed_targets.repeat(targets_flat.size(0), 1)
60 | # smoothed_targets: (batch_size * seq_len, vocabulary_size)
61 |
62 | smoothed_targets.scatter_(1, targets_flat.unsqueeze(1), self.confidence)
63 | # smoothed_targets: (batch_size * seq_len, vocabulary_size)
64 |
65 | smoothed_targets.masked_fill_((targets_flat == self.pad_index).unsqueeze(1), 0)
66 | # masked_targets: (batch_size * seq_len, vocabulary_size)
67 |
68 | loss = self.criterion(outputs_flat, smoothed_targets)
69 | count = (targets != self.pad_index).sum().item()
70 |
71 | return loss, count
72 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 |
4 | class AccuracyMetric(nn.Module):
5 |
6 | def __init__(self, pad_index=0):
7 | super(AccuracyMetric, self).__init__()
8 |
9 | self.pad_index = pad_index
10 |
11 | def forward(self, outputs, targets):
12 |
13 | batch_size, seq_len, vocabulary_size = outputs.size()
14 |
15 | outputs = outputs.view(batch_size * seq_len, vocabulary_size)
16 | targets = targets.view(batch_size * seq_len)
17 |
18 | predicts = outputs.argmax(dim=1)
19 | corrects = predicts == targets
20 |
21 | corrects.masked_fill_((targets == self.pad_index), 0)
22 |
23 | correct_count = corrects.sum().item()
24 | count = (targets != self.pad_index).sum().item()
25 |
26 | return correct_count, count
27 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | from embeddings import PositionalEncoding
2 | from utils.pad import pad_masking, subsequent_masking
3 |
4 | import torch
5 | from torch import nn
6 | import numpy as np
7 | from collections import defaultdict
8 |
9 | PAD_TOKEN_ID = 0
10 |
11 |
12 | def build_model(config, source_vocabulary_size, target_vocabulary_size):
13 | if config['positional_encoding']:
14 | source_embedding = PositionalEncoding(
15 | num_embeddings=source_vocabulary_size,
16 | embedding_dim=config['d_model'],
17 | dim=config['d_model']) # why dim?
18 | target_embedding = PositionalEncoding(
19 | num_embeddings=target_vocabulary_size,
20 | embedding_dim=config['d_model'],
21 | dim=config['d_model']) # why dim?
22 | else:
23 | source_embedding = nn.Embedding(
24 | num_embeddings=source_vocabulary_size,
25 | embedding_dim=config['d_model'])
26 | target_embedding = nn.Embedding(
27 | num_embeddings=target_vocabulary_size,
28 | embedding_dim=config['d_model'])
29 |
30 | encoder = TransformerEncoder(
31 | layers_count=config['layers_count'],
32 | d_model=config['d_model'],
33 | heads_count=config['heads_count'],
34 | d_ff=config['d_ff'],
35 | dropout_prob=config['dropout_prob'],
36 | embedding=source_embedding)
37 |
38 | decoder = TransformerDecoder(
39 | layers_count=config['layers_count'],
40 | d_model=config['d_model'],
41 | heads_count=config['heads_count'],
42 | d_ff=config['d_ff'],
43 | dropout_prob=config['dropout_prob'],
44 | embedding=target_embedding)
45 |
46 | model = Transformer(encoder, decoder)
47 |
48 | return model
49 |
50 |
51 | class Transformer(nn.Module):
52 |
53 | def __init__(self, encoder, decoder):
54 | super(Transformer, self).__init__()
55 |
56 | self.encoder = encoder
57 | self.decoder = decoder
58 |
59 | def forward(self, sources, inputs):
60 | # sources : (batch_size, sources_len)
61 | # inputs : (batch_size, targets_len - 1)
62 | batch_size, sources_len = sources.size()
63 | batch_size, inputs_len = inputs.size()
64 |
65 | sources_mask = pad_masking(sources, sources_len)
66 | memory_mask = pad_masking(sources, inputs_len)
67 | inputs_mask = subsequent_masking(inputs) | pad_masking(inputs, inputs_len)
68 |
69 | memory = self.encoder(sources, sources_mask) # (batch_size, seq_len, d_model)
70 | outputs, state = self.decoder(inputs, memory, memory_mask, inputs_mask) # (batch_size, seq_len, d_model)
71 | return outputs
72 |
73 |
74 | class TransformerEncoder(nn.Module):
75 |
76 | def __init__(self, layers_count, d_model, heads_count, d_ff, dropout_prob, embedding):
77 | super(TransformerEncoder, self).__init__()
78 |
79 | self.d_model = d_model
80 | self.embedding = embedding
81 | self.encoder_layers = nn.ModuleList(
82 | [TransformerEncoderLayer(d_model, heads_count, d_ff, dropout_prob) for _ in range(layers_count)]
83 | )
84 |
85 | def forward(self, sources, mask):
86 | """
87 |
88 | args:
89 | sources: embedded_sequence, (batch_size, seq_len, embed_size)
90 | """
91 | sources = self.embedding(sources)
92 |
93 | for encoder_layer in self.encoder_layers:
94 | sources = encoder_layer(sources, mask)
95 |
96 | return sources
97 |
98 |
99 | class TransformerEncoderLayer(nn.Module):
100 |
101 | def __init__(self, d_model, heads_count, d_ff, dropout_prob):
102 | super(TransformerEncoderLayer, self).__init__()
103 |
104 | self.self_attention_layer = Sublayer(MultiHeadAttention(heads_count, d_model, dropout_prob), d_model)
105 | self.pointwise_feedforward_layer = Sublayer(PointwiseFeedForwardNetwork(d_ff, d_model, dropout_prob), d_model)
106 | self.dropout = nn.Dropout(dropout_prob)
107 |
108 | def forward(self, sources, sources_mask):
109 | # x: (batch_size, seq_len, d_model)
110 |
111 | sources = self.self_attention_layer(sources, sources, sources, sources_mask)
112 | sources = self.dropout(sources)
113 | sources = self.pointwise_feedforward_layer(sources)
114 |
115 | return sources
116 |
117 |
118 | class TransformerDecoder(nn.Module):
119 |
120 | def __init__(self, layers_count, d_model, heads_count, d_ff, dropout_prob, embedding):
121 | super(TransformerDecoder, self).__init__()
122 |
123 | self.d_model = d_model
124 | self.embedding = embedding
125 | self.decoder_layers = nn.ModuleList(
126 | [TransformerDecoderLayer(d_model, heads_count, d_ff, dropout_prob) for _ in range(layers_count)]
127 | )
128 | self.generator = nn.Linear(embedding.embedding_dim, embedding.num_embeddings)
129 | self.generator.weight = self.embedding.weight
130 |
131 | def forward(self, inputs, memory, memory_mask, inputs_mask=None, state=None):
132 | # inputs: (batch_size, seq_len - 1, d_model)
133 | # memory: (batch_size, seq_len, d_model)
134 |
135 | inputs = self.embedding(inputs)
136 | # if state is not None:
137 | # inputs = torch.cat([state.previous_inputs, inputs], dim=1)
138 | #
139 | # state.previous_inputs = inputs
140 |
141 | for layer_index, decoder_layer in enumerate(self.decoder_layers):
142 | if state is None:
143 | inputs = decoder_layer(inputs, memory, memory_mask, inputs_mask)
144 | else: # Use cache
145 | layer_cache = state.layer_caches[layer_index]
146 | # print('inputs_mask', inputs_mask)
147 | inputs = decoder_layer(inputs, memory, memory_mask, inputs_mask, layer_cache)
148 |
149 | state.update_state(
150 | layer_index=layer_index,
151 | layer_mode='self-attention',
152 | key_projected=decoder_layer.self_attention_layer.sublayer.key_projected,
153 | value_projected=decoder_layer.self_attention_layer.sublayer.value_projected,
154 | )
155 | state.update_state(
156 | layer_index=layer_index,
157 | layer_mode='memory-attention',
158 | key_projected=decoder_layer.memory_attention_layer.sublayer.key_projected,
159 | value_projected=decoder_layer.memory_attention_layer.sublayer.value_projected,
160 | )
161 |
162 | generated = self.generator(inputs) # (batch_size, seq_len, vocab_size)
163 | return generated, state
164 |
165 | def init_decoder_state(self, **args):
166 | return DecoderState()
167 |
168 |
169 | class TransformerDecoderLayer(nn.Module):
170 |
171 | def __init__(self, d_model, heads_count, d_ff, dropout_prob):
172 | super(TransformerDecoderLayer, self).__init__()
173 | self.self_attention_layer = Sublayer(MultiHeadAttention(heads_count, d_model, dropout_prob, mode='self-attention'), d_model)
174 | self.memory_attention_layer = Sublayer(MultiHeadAttention(heads_count, d_model, dropout_prob, mode='memory-attention'), d_model)
175 | self.pointwise_feedforward_layer = Sublayer(PointwiseFeedForwardNetwork(d_ff, d_model, dropout_prob), d_model)
176 |
177 | def forward(self, inputs, memory, memory_mask, inputs_mask, layer_cache=None):
178 | # print('self attention')
179 | # print('inputs_mask', inputs_mask)
180 | inputs = self.self_attention_layer(inputs, inputs, inputs, inputs_mask, layer_cache)
181 | # print('memory attention')
182 | inputs = self.memory_attention_layer(inputs, memory, memory, memory_mask, layer_cache)
183 | inputs = self.pointwise_feedforward_layer(inputs)
184 | return inputs
185 |
186 |
187 | class Sublayer(nn.Module):
188 |
189 | def __init__(self, sublayer, d_model):
190 | super(Sublayer, self).__init__()
191 |
192 | self.sublayer = sublayer
193 | self.layer_normalization = LayerNormalization(d_model)
194 |
195 | def forward(self, *args):
196 | x = args[0]
197 | x = self.sublayer(*args) + x
198 | return self.layer_normalization(x)
199 |
200 |
201 | class LayerNormalization(nn.Module):
202 |
203 | def __init__(self, features_count, epsilon=1e-6):
204 | super(LayerNormalization, self).__init__()
205 |
206 | self.gain = nn.Parameter(torch.ones(features_count))
207 | self.bias = nn.Parameter(torch.zeros(features_count))
208 | self.epsilon = epsilon
209 |
210 | def forward(self, x):
211 |
212 | mean = x.mean(dim=-1, keepdim=True)
213 | std = x.std(dim=-1, keepdim=True)
214 |
215 | return self.gain * (x - mean) / (std + self.epsilon) + self.bias
216 |
217 |
218 | class MultiHeadAttention(nn.Module):
219 |
220 | def __init__(self, heads_count, d_model, dropout_prob, mode='self-attention'):
221 | super(MultiHeadAttention, self).__init__()
222 |
223 | assert d_model % heads_count == 0
224 | assert mode in ('self-attention', 'memory-attention')
225 |
226 | self.d_head = d_model // heads_count
227 | self.heads_count = heads_count
228 | self.mode = mode
229 | self.query_projection = nn.Linear(d_model, heads_count * self.d_head)
230 | self.key_projection = nn.Linear(d_model, heads_count * self.d_head)
231 | self.value_projection = nn.Linear(d_model, heads_count * self.d_head)
232 | self.final_projection = nn.Linear(d_model, heads_count * self.d_head)
233 | self.dropout = nn.Dropout(dropout_prob)
234 | self.softmax = nn.Softmax(dim=3)
235 |
236 | self.attention = None
237 | # For cache
238 | self.key_projected = None
239 | self.value_projected = None
240 |
241 | def forward(self, query, key, value, mask=None, layer_cache=None):
242 | """
243 |
244 | Args:
245 | query: (batch_size, query_len, model_dim)
246 | key: (batch_size, key_len, model_dim)
247 | value: (batch_size, value_len, model_dim)
248 | mask: (batch_size, query_len, key_len)
249 | state: DecoderState
250 | """
251 | # print('attention mask', mask)
252 | batch_size, query_len, d_model = query.size()
253 |
254 | d_head = d_model // self.heads_count
255 |
256 | query_projected = self.query_projection(query)
257 | # print('query_projected', query_projected.shape)
258 | if layer_cache is None or layer_cache[self.mode] is None: # Don't use cache
259 | key_projected = self.key_projection(key)
260 | value_projected = self.value_projection(value)
261 | else: # Use cache
262 | if self.mode == 'self-attention':
263 | key_projected = self.key_projection(key)
264 | value_projected = self.value_projection(value)
265 |
266 | key_projected = torch.cat([key_projected, layer_cache[self.mode]['key_projected']], dim=1)
267 | value_projected = torch.cat([value_projected, layer_cache[self.mode]['value_projected']], dim=1)
268 | elif self.mode == 'memory-attention':
269 | key_projected = layer_cache[self.mode]['key_projected']
270 | value_projected = layer_cache[self.mode]['value_projected']
271 |
272 | # For cache
273 | self.key_projected = key_projected
274 | self.value_projected = value_projected
275 |
276 | batch_size, key_len, d_model = key_projected.size()
277 | batch_size, value_len, d_model = value_projected.size()
278 |
279 | query_heads = query_projected.view(batch_size, query_len, self.heads_count, d_head).transpose(1, 2) # (batch_size, heads_count, query_len, d_head)
280 | # print('query_heads', query_heads.shape)
281 | # print(batch_size, key_len, self.heads_count, d_head)
282 | # print(key_projected.shape)
283 | key_heads = key_projected.view(batch_size, key_len, self.heads_count, d_head).transpose(1, 2) # (batch_size, heads_count, key_len, d_head)
284 | value_heads = value_projected.view(batch_size, value_len, self.heads_count, d_head).transpose(1, 2) # (batch_size, heads_count, value_len, d_head)
285 |
286 | attention_weights = self.scaled_dot_product(query_heads, key_heads) # (batch_size, heads_count, query_len, key_len)
287 |
288 | if mask is not None:
289 | # print('mode', self.mode)
290 | # print('mask', mask.shape)
291 | # print('attention_weights', attention_weights.shape)
292 | mask_expanded = mask.unsqueeze(1).expand_as(attention_weights)
293 | attention_weights = attention_weights.masked_fill(mask_expanded, -1e18)
294 |
295 | self.attention = self.softmax(attention_weights) # Save attention to the object
296 | # print('attention_weights', attention_weights.shape)
297 | attention_dropped = self.dropout(self.attention)
298 | context_heads = torch.matmul(attention_dropped, value_heads) # (batch_size, heads_count, query_len, d_head)
299 | # print('context_heads', context_heads.shape)
300 | context_sequence = context_heads.transpose(1, 2).contiguous() # (batch_size, query_len, heads_count, d_head)
301 | context = context_sequence.view(batch_size, query_len, d_model) # (batch_size, query_len, d_model)
302 | final_output = self.final_projection(context)
303 | # print('final_output', final_output.shape)
304 |
305 | return final_output
306 |
307 | def scaled_dot_product(self, query_heads, key_heads):
308 | """
309 |
310 | Args:
311 | query_heads: (batch_size, heads_count, query_len, d_head)
312 | key_heads: (batch_size, heads_count, key_len, d_head)
313 | """
314 | key_heads_transposed = key_heads.transpose(2, 3)
315 | dot_product = torch.matmul(query_heads, key_heads_transposed) # (batch_size, heads_count, query_len, key_len)
316 | attention_weights = dot_product / np.sqrt(self.d_head)
317 | return attention_weights
318 |
319 |
320 | class PointwiseFeedForwardNetwork(nn.Module):
321 |
322 | def __init__(self, d_ff, d_model, dropout_prob):
323 | super(PointwiseFeedForwardNetwork, self).__init__()
324 |
325 | self.feed_forward = nn.Sequential(
326 | nn.Linear(d_model, d_ff),
327 | nn.Dropout(dropout_prob),
328 | nn.ReLU(),
329 | nn.Linear(d_ff, d_model),
330 | nn.Dropout(dropout_prob),
331 | )
332 |
333 | def forward(self, x):
334 | """
335 |
336 | Args:
337 | x: (batch_size, seq_len, d_model)
338 | """
339 | return self.feed_forward(x)
340 |
341 |
342 | class DecoderState:
343 |
344 | def __init__(self):
345 | self.previous_inputs = torch.tensor([])
346 | self.layer_caches = defaultdict(lambda: {'self-attention': None, 'memory-attention': None})
347 |
348 | def update_state(self, layer_index, layer_mode, key_projected, value_projected):
349 | self.layer_caches[layer_index][layer_mode] = {
350 | 'key_projected': key_projected,
351 | 'value_projected': value_projected
352 | }
353 |
354 | # def repeat_beam_size_times(self, beam_size): # memory만 repeat하면 되는데 state에 memory는 넣지 않기로 했다.
355 | # self.
356 | # self.src = self.src.data.repeat(beam_size, 1)
357 |
358 | def beam_update(self, positions):
359 | for layer_index in self.layer_caches:
360 | for mode in ('self-attention', 'memory-attention'):
361 | if self.layer_caches[layer_index][mode] is not None:
362 | for projection in self.layer_caches[layer_index][mode]:
363 | cache = self.layer_caches[layer_index][mode][projection]
364 | if cache is not None:
365 | cache.data.copy_(cache.data.index_select(0, positions))
366 |
--------------------------------------------------------------------------------
/optimizers.py:
--------------------------------------------------------------------------------
1 | from torch.optim import Adam
2 |
3 |
4 | class NoamOptimizer(Adam):
5 |
6 | def __init__(self, params, d_model, factor=2, warmup_steps=4000, betas=(0.9, 0.98), eps=1e-9):
7 | # self.optimizer = Adam(params, betas=betas, eps=eps)
8 | self.d_model = d_model
9 | self.warmup_steps = warmup_steps
10 | self.lr = 0
11 | self.step_num = 0
12 | self.factor = factor
13 |
14 | super(NoamOptimizer, self).__init__(params, betas=betas, eps=eps)
15 |
16 | def step(self, closure=None):
17 | self.step_num += 1
18 | self.lr = self.lrate()
19 | for group in self.param_groups:
20 | group['lr'] = self.lr
21 | super(NoamOptimizer, self).step()
22 |
23 | def lrate(self):
24 | return self.factor * self.d_model ** (-0.5) * min(self.step_num ** (-0.5), self.step_num * self.warmup_steps ** (-1.5))
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | from predictors import Predictor
2 | from models import build_model
3 | from datasets import IndexedInputTargetTranslationDataset
4 | from dictionaries import IndexDictionary
5 |
6 | from argparse import ArgumentParser
7 | import json
8 |
9 | parser = ArgumentParser(description='Predict translation')
10 | parser.add_argument('--source', type=str)
11 | parser.add_argument('--config', type=str, required=True)
12 | parser.add_argument('--checkpoint', type=str)
13 | parser.add_argument('--num_candidates', type=int, default=3)
14 |
15 | args = parser.parse_args()
16 | with open(args.config) as f:
17 | config = json.load(f)
18 |
19 | print('Constructing dictionaries...')
20 | source_dictionary = IndexDictionary.load(config['data_dir'], mode='source', vocabulary_size=config['vocabulary_size'])
21 | target_dictionary = IndexDictionary.load(config['data_dir'], mode='target', vocabulary_size=config['vocabulary_size'])
22 |
23 | print('Building model...')
24 | model = build_model(config, source_dictionary.vocabulary_size, target_dictionary.vocabulary_size)
25 |
26 | predictor = Predictor(
27 | preprocess=IndexedInputTargetTranslationDataset.preprocess(source_dictionary),
28 | postprocess=lambda x: ' '.join([token for token in target_dictionary.tokenify_indexes(x) if token != '']),
29 | model=model,
30 | checkpoint_filepath=args.checkpoint
31 | )
32 |
33 | for index, candidate in enumerate(predictor.predict_one(args.source, num_candidates=args.num_candidates)):
34 | print(f'Candidate {index} : {candidate}')
35 |
--------------------------------------------------------------------------------
/predictors.py:
--------------------------------------------------------------------------------
1 | from beam import Beam
2 | from utils.pad import pad_masking
3 |
4 | import torch
5 |
6 |
7 | class Predictor:
8 |
9 | def __init__(self, preprocess, postprocess, model, checkpoint_filepath, max_length=30, beam_size=8):
10 | self.preprocess = preprocess
11 | self.postprocess = postprocess
12 | self.model = model
13 | self.max_length = max_length
14 | self.beam_size = beam_size
15 |
16 | self.model.eval()
17 | checkpoint = torch.load(checkpoint_filepath, map_location='cpu')
18 | self.model.load_state_dict(checkpoint)
19 |
20 | def predict_one(self, source, num_candidates=5):
21 | source_preprocessed = self.preprocess(source)
22 | source_tensor = torch.tensor(source_preprocessed).unsqueeze(0) # why unsqueeze?
23 | length_tensor = torch.tensor(len(source_preprocessed)).unsqueeze(0)
24 |
25 | sources_mask = pad_masking(source_tensor, source_tensor.size(1))
26 | memory_mask = pad_masking(source_tensor, 1)
27 | memory = self.model.encoder(source_tensor, sources_mask)
28 |
29 | decoder_state = self.model.decoder.init_decoder_state()
30 | # print('decoder_state src', decoder_state.src.shape)
31 | # print('previous_input previous_input', decoder_state.previous_input)
32 | # print('previous_input previous_layer_inputs ', decoder_state.previous_layer_inputs)
33 |
34 |
35 | # Repeat beam_size times
36 | memory_beam = memory.detach().repeat(self.beam_size, 1, 1) # (beam_size, seq_len, hidden_size)
37 |
38 | beam = Beam(beam_size=self.beam_size, min_length=0, n_top=num_candidates, ranker=None)
39 |
40 | for _ in range(self.max_length):
41 |
42 | new_inputs = beam.get_current_state().unsqueeze(1) # (beam_size, seq_len=1)
43 | decoder_outputs, decoder_state = self.model.decoder(new_inputs, memory_beam,
44 | memory_mask,
45 | state=decoder_state)
46 | # decoder_outputs: (beam_size, target_seq_len=1, vocabulary_size)
47 | # attentions['std']: (target_seq_len=1, beam_size, source_seq_len)
48 |
49 | attention = self.model.decoder.decoder_layers[-1].memory_attention_layer.sublayer.attention
50 | beam.advance(decoder_outputs.squeeze(1), attention)
51 |
52 | beam_current_origin = beam.get_current_origin() # (beam_size, )
53 | decoder_state.beam_update(beam_current_origin)
54 |
55 | if beam.done():
56 | break
57 |
58 | scores, ks = beam.sort_finished(minimum=num_candidates)
59 | hypothesises, attentions = [], []
60 | for i, (times, k) in enumerate(ks[:num_candidates]):
61 | hypothesis, attention = beam.get_hypothesis(times, k)
62 | hypothesises.append(hypothesis)
63 | attentions.append(attention)
64 |
65 | self.attentions = attentions
66 | self.hypothesises = [[token.item() for token in h] for h in hypothesises]
67 | hs = [self.postprocess(h) for h in self.hypothesises]
68 | return list(reversed(hs))
--------------------------------------------------------------------------------
/prepare_datasets.py:
--------------------------------------------------------------------------------
1 | from datasets import TranslationDataset, TranslationDatasetOnTheFly
2 | from datasets import TokenizedTranslationDataset, TokenizedTranslationDatasetOnTheFly
3 | from datasets import InputTargetTranslationDataset, InputTargetTranslationDatasetOnTheFly
4 | from datasets import IndexedInputTargetTranslationDataset, IndexedInputTargetTranslationDatasetOnTheFly
5 | from dictionaries import IndexDictionary
6 | from utils.pipe import shared_tokens_generator, source_tokens_generator, target_tokens_generator
7 |
8 | from argparse import ArgumentParser
9 |
10 | parser = ArgumentParser('Prepare datasets')
11 | parser.add_argument('--train_source', type=str, default='data/example/raw/src-train.txt')
12 | parser.add_argument('--train_target', type=str, default='data/example/raw/tgt-train.txt')
13 | parser.add_argument('--val_source', type=str, default='data/example/raw/src-val.txt')
14 | parser.add_argument('--val_target', type=str, default='data/example/raw/tgt-val.txt')
15 | parser.add_argument('--save_data_dir', type=str, default='data/example/processed')
16 | parser.add_argument('--share_dictionary', type=bool, default=False)
17 |
18 | args = parser.parse_args()
19 |
20 | TranslationDataset.prepare(args.train_source, args.train_target, args.val_source, args.val_target, args.save_data_dir)
21 | translation_dataset = TranslationDataset(args.save_data_dir, 'train')
22 | translation_dataset_on_the_fly = TranslationDatasetOnTheFly('train')
23 | assert translation_dataset[0] == translation_dataset_on_the_fly[0]
24 |
25 | tokenized_dataset = TokenizedTranslationDataset(args.save_data_dir, 'train')
26 |
27 | if args.share_dictionary:
28 | source_generator = shared_tokens_generator(tokenized_dataset)
29 | source_dictionary = IndexDictionary(source_generator, mode='source')
30 | target_generator = shared_tokens_generator(tokenized_dataset)
31 | target_dictionary = IndexDictionary(target_generator, mode='target')
32 |
33 | source_dictionary.save(args.save_data_dir)
34 | target_dictionary.save(args.save_data_dir)
35 | else:
36 | source_generator = source_tokens_generator(tokenized_dataset)
37 | source_dictionary = IndexDictionary(source_generator, mode='source')
38 | target_generator = target_tokens_generator(tokenized_dataset)
39 | target_dictionary = IndexDictionary(target_generator, mode='target')
40 |
41 | source_dictionary.save(args.save_data_dir)
42 | target_dictionary.save(args.save_data_dir)
43 |
44 | source_dictionary = IndexDictionary.load(args.save_data_dir, mode='source')
45 | target_dictionary = IndexDictionary.load(args.save_data_dir, mode='target')
46 |
47 | IndexedInputTargetTranslationDataset.prepare(args.save_data_dir, source_dictionary, target_dictionary)
48 | indexed_translation_dataset = IndexedInputTargetTranslationDataset(args.save_data_dir, 'train')
49 | indexed_translation_dataset_on_the_fly = IndexedInputTargetTranslationDatasetOnTheFly('train', source_dictionary, target_dictionary)
50 | assert indexed_translation_dataset[0] == indexed_translation_dataset_on_the_fly[0]
51 |
52 | print('Done datasets preparation.')
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from models import build_model
2 | from datasets import IndexedInputTargetTranslationDataset
3 | from dictionaries import IndexDictionary
4 | from losses import TokenCrossEntropyLoss, LabelSmoothingLoss
5 | from metrics import AccuracyMetric
6 | from optimizers import NoamOptimizer
7 | from trainer import EpochSeq2SeqTrainer
8 | from utils.log import get_logger
9 | from utils.pipe import input_target_collate_fn
10 |
11 | import torch
12 | from torch.optim import Adam
13 | from torch.utils.data import DataLoader
14 | import numpy as np
15 |
16 | from argparse import ArgumentParser
17 | from datetime import datetime
18 | import json
19 | import random
20 |
21 | parser = ArgumentParser(description='Train Transformer')
22 | parser.add_argument('--config', type=str, default=None)
23 |
24 | parser.add_argument('--data_dir', type=str, default='data/example/processed')
25 | parser.add_argument('--save_config', type=str, default=None)
26 | parser.add_argument('--save_checkpoint', type=str, default=None)
27 | parser.add_argument('--save_log', type=str, default=None)
28 |
29 | parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
30 |
31 | parser.add_argument('--dataset_limit', type=int, default=None)
32 | parser.add_argument('--print_every', type=int, default=1)
33 | parser.add_argument('--save_every', type=int, default=1)
34 |
35 | parser.add_argument('--vocabulary_size', type=int, default=None)
36 | parser.add_argument('--positional_encoding', action='store_true')
37 |
38 | parser.add_argument('--d_model', type=int, default=128)
39 | parser.add_argument('--layers_count', type=int, default=1)
40 | parser.add_argument('--heads_count', type=int, default=2)
41 | parser.add_argument('--d_ff', type=int, default=128)
42 | parser.add_argument('--dropout_prob', type=float, default=0.1)
43 |
44 | parser.add_argument('--label_smoothing', type=float, default=0.1)
45 | parser.add_argument('--optimizer', type=str, default="Adam", choices=["Noam", "Adam"])
46 | parser.add_argument('--lr', type=float, default=0.001)
47 | parser.add_argument('--clip_grads', action='store_true')
48 |
49 | parser.add_argument('--batch_size', type=int, default=64)
50 | parser.add_argument('--epochs', type=int, default=100)
51 |
52 |
53 | def run_trainer(config):
54 | random.seed(0)
55 | np.random.seed(0)
56 | torch.manual_seed(0)
57 |
58 | run_name_format = (
59 | "d_model={d_model}-"
60 | "layers_count={layers_count}-"
61 | "heads_count={heads_count}-"
62 | "pe={positional_encoding}-"
63 | "optimizer={optimizer}-"
64 | "{timestamp}"
65 | )
66 |
67 | run_name = run_name_format.format(**config, timestamp=datetime.now().strftime("%Y_%m_%d_%H_%M_%S"))
68 |
69 | logger = get_logger(run_name, save_log=config['save_log'])
70 | logger.info(f'Run name : {run_name}')
71 | logger.info(config)
72 |
73 | logger.info('Constructing dictionaries...')
74 | source_dictionary = IndexDictionary.load(config['data_dir'], mode='source', vocabulary_size=config['vocabulary_size'])
75 | target_dictionary = IndexDictionary.load(config['data_dir'], mode='target', vocabulary_size=config['vocabulary_size'])
76 | logger.info(f'Source dictionary vocabulary : {source_dictionary.vocabulary_size} tokens')
77 | logger.info(f'Target dictionary vocabulary : {target_dictionary.vocabulary_size} tokens')
78 |
79 | logger.info('Building model...')
80 | model = build_model(config, source_dictionary.vocabulary_size, target_dictionary.vocabulary_size)
81 |
82 | logger.info(model)
83 | logger.info('Encoder : {parameters_count} parameters'.format(parameters_count=sum([p.nelement() for p in model.encoder.parameters()])))
84 | logger.info('Decoder : {parameters_count} parameters'.format(parameters_count=sum([p.nelement() for p in model.decoder.parameters()])))
85 | logger.info('Total : {parameters_count} parameters'.format(parameters_count=sum([p.nelement() for p in model.parameters()])))
86 |
87 | logger.info('Loading datasets...')
88 | train_dataset = IndexedInputTargetTranslationDataset(
89 | data_dir=config['data_dir'],
90 | phase='train',
91 | vocabulary_size=config['vocabulary_size'],
92 | limit=config['dataset_limit'])
93 |
94 | val_dataset = IndexedInputTargetTranslationDataset(
95 | data_dir=config['data_dir'],
96 | phase='val',
97 | vocabulary_size=config['vocabulary_size'],
98 | limit=config['dataset_limit'])
99 |
100 | train_dataloader = DataLoader(
101 | train_dataset,
102 | batch_size=config['batch_size'],
103 | shuffle=True,
104 | collate_fn=input_target_collate_fn)
105 |
106 | val_dataloader = DataLoader(
107 | val_dataset,
108 | batch_size=config['batch_size'],
109 | collate_fn=input_target_collate_fn)
110 |
111 | if config['label_smoothing'] > 0.0:
112 | loss_function = LabelSmoothingLoss(label_smoothing=config['label_smoothing'],
113 | vocabulary_size=target_dictionary.vocabulary_size)
114 | else:
115 | loss_function = TokenCrossEntropyLoss()
116 |
117 | accuracy_function = AccuracyMetric()
118 |
119 | if config['optimizer'] == 'Noam':
120 | optimizer = NoamOptimizer(model.parameters(), d_model=config['d_model'])
121 | elif config['optimizer'] == 'Adam':
122 | optimizer = Adam(model.parameters(), lr=config['lr'])
123 | else:
124 | raise NotImplementedError()
125 |
126 | logger.info('Start training...')
127 | trainer = EpochSeq2SeqTrainer(
128 | model=model,
129 | train_dataloader=train_dataloader,
130 | val_dataloader=val_dataloader,
131 | loss_function=loss_function,
132 | metric_function=accuracy_function,
133 | optimizer=optimizer,
134 | logger=logger,
135 | run_name=run_name,
136 | save_config=config['save_config'],
137 | save_checkpoint=config['save_checkpoint'],
138 | config=config
139 | )
140 |
141 | trainer.run(config['epochs'])
142 |
143 | return trainer
144 |
145 |
146 | if __name__ == '__main__':
147 |
148 | args = parser.parse_args()
149 |
150 | if args.config is not None:
151 | with open(args.config) as f:
152 | config = json.load(f)
153 |
154 | default_config = vars(args)
155 | for key, default_value in default_config.items():
156 | if key not in config:
157 | config[key] = default_value
158 | else:
159 | config = vars(args) # convert to dictionary
160 |
161 | run_trainer(config)
162 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from tqdm import tqdm
4 |
5 | from os.path import dirname, abspath, join, exists
6 | from os import makedirs
7 | from datetime import datetime
8 | import json
9 |
10 | PAD_INDEX = 0
11 |
12 | BASE_DIR = dirname(abspath(__file__))
13 |
14 |
15 | class EpochSeq2SeqTrainer:
16 |
17 | def __init__(self, model,
18 | train_dataloader, val_dataloader,
19 | loss_function, metric_function, optimizer,
20 | logger, run_name,
21 | save_config, save_checkpoint,
22 | config):
23 |
24 | self.config = config
25 | self.device = torch.device(self.config['device'])
26 |
27 | self.model = model.to(self.device)
28 | self.train_dataloader = train_dataloader
29 | self.val_dataloader = val_dataloader
30 |
31 | self.loss_function = loss_function.to(self.device)
32 | self.metric_function = metric_function
33 | self.optimizer = optimizer
34 | self.clip_grads = self.config['clip_grads']
35 |
36 | self.logger = logger
37 | self.checkpoint_dir = join(BASE_DIR, 'checkpoints', run_name)
38 |
39 | if not exists(self.checkpoint_dir):
40 | makedirs(self.checkpoint_dir)
41 |
42 | if save_config is None:
43 | config_filepath = join(self.checkpoint_dir, 'config.json')
44 | else:
45 | config_filepath = save_config
46 | with open(config_filepath, 'w') as config_file:
47 | json.dump(self.config, config_file)
48 |
49 | self.print_every = self.config['print_every']
50 | self.save_every = self.config['save_every']
51 |
52 | self.epoch = 0
53 | self.history = []
54 |
55 | self.start_time = datetime.now()
56 |
57 | self.best_val_metric = None
58 | self.best_checkpoint_filepath = None
59 |
60 | self.save_checkpoint = save_checkpoint
61 | self.save_format = 'epoch={epoch:0>3}-val_loss={val_loss:<.3}-val_metrics={val_metrics}.pth'
62 |
63 | self.log_format = (
64 | "Epoch: {epoch:>3} "
65 | "Progress: {progress:<.1%} "
66 | "Elapsed: {elapsed} "
67 | "Examples/second: {per_second:<.1} "
68 | "Train Loss: {train_loss:<.6} "
69 | "Val Loss: {val_loss:<.6} "
70 | "Train Metrics: {train_metrics} "
71 | "Val Metrics: {val_metrics} "
72 | "Learning rate: {current_lr:<.4} "
73 | )
74 |
75 | def run_epoch(self, dataloader, mode='train'):
76 | batch_losses = []
77 | batch_counts = []
78 | batch_metrics = []
79 | for sources, inputs, targets in tqdm(dataloader):
80 | sources, inputs, targets = sources.to(self.device), inputs.to(self.device), targets.to(self.device)
81 | outputs = self.model(sources, inputs)
82 |
83 | batch_loss, batch_count = self.loss_function(outputs, targets)
84 |
85 | if mode == 'train':
86 | self.optimizer.zero_grad()
87 | batch_loss.backward()
88 | if self.clip_grads:
89 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)
90 | self.optimizer.step()
91 |
92 | batch_losses.append(batch_loss.item())
93 | batch_counts.append(batch_count)
94 |
95 | batch_metric, batch_metric_count = self.metric_function(outputs, targets)
96 | batch_metrics.append(batch_metric)
97 |
98 | assert batch_count == batch_metric_count
99 |
100 | if self.epoch == 0: # for testing
101 | return float('inf'), [float('inf')]
102 |
103 | epoch_loss = sum(batch_losses) / sum(batch_counts)
104 | epoch_accuracy = sum(batch_metrics) / sum(batch_counts)
105 | epoch_perplexity = float(np.exp(epoch_loss))
106 | epoch_metrics = [epoch_perplexity, epoch_accuracy]
107 |
108 | return epoch_loss, epoch_metrics
109 |
110 | def run(self, epochs=10):
111 |
112 | for epoch in range(self.epoch, epochs + 1):
113 | self.epoch = epoch
114 |
115 | self.model.train()
116 |
117 | epoch_start_time = datetime.now()
118 | train_epoch_loss, train_epoch_metrics = self.run_epoch(self.train_dataloader, mode='train')
119 | epoch_end_time = datetime.now()
120 |
121 | self.model.eval()
122 |
123 | val_epoch_loss, val_epoch_metrics = self.run_epoch(self.val_dataloader, mode='val')
124 |
125 | if epoch % self.print_every == 0 and self.logger:
126 | per_second = len(self.train_dataloader.dataset) / ((epoch_end_time - epoch_start_time).seconds + 1)
127 | current_lr = self.optimizer.param_groups[0]['lr']
128 | log_message = self.log_format.format(epoch=epoch,
129 | progress=epoch / epochs,
130 | per_second=per_second,
131 | train_loss=train_epoch_loss,
132 | val_loss=val_epoch_loss,
133 | train_metrics=[round(metric, 4) for metric in train_epoch_metrics],
134 | val_metrics=[round(metric, 4) for metric in val_epoch_metrics],
135 | current_lr=current_lr,
136 | elapsed=self._elapsed_time()
137 | )
138 |
139 | self.logger.info(log_message)
140 |
141 | if epoch % self.save_every == 0:
142 | self._save_model(epoch, train_epoch_loss, val_epoch_loss, train_epoch_metrics, val_epoch_metrics)
143 |
144 | def _save_model(self, epoch, train_epoch_loss, val_epoch_loss, train_epoch_metrics, val_epoch_metrics):
145 |
146 | checkpoint_filename = self.save_format.format(
147 | epoch=epoch,
148 | val_loss=val_epoch_loss,
149 | val_metrics='-'.join(['{:<.3}'.format(v) for v in val_epoch_metrics])
150 | )
151 |
152 | if self.save_checkpoint is None:
153 | checkpoint_filepath = join(self.checkpoint_dir, checkpoint_filename)
154 | else:
155 | checkpoint_filepath = self.save_checkpoint
156 |
157 | save_state = {
158 | 'epoch': epoch,
159 | 'train_loss': train_epoch_loss,
160 | 'train_metrics': train_epoch_metrics,
161 | 'val_loss': val_epoch_loss,
162 | 'val_metrics': val_epoch_metrics,
163 | 'checkpoint': checkpoint_filepath,
164 | }
165 |
166 | if self.epoch > 0:
167 | torch.save(self.model.state_dict(), checkpoint_filepath)
168 | self.history.append(save_state)
169 |
170 | representative_val_metric = val_epoch_metrics[0]
171 | if self.best_val_metric is None or self.best_val_metric > representative_val_metric:
172 | self.best_val_metric = representative_val_metric
173 | self.val_loss_at_best = val_epoch_loss
174 | self.train_loss_at_best = train_epoch_loss
175 | self.train_metrics_at_best = train_epoch_metrics
176 | self.val_metrics_at_best = val_epoch_metrics
177 | self.best_checkpoint_filepath = checkpoint_filepath
178 |
179 | if self.logger:
180 | self.logger.info("Saved model to {}".format(checkpoint_filepath))
181 | self.logger.info("Current best model is {}".format(self.best_checkpoint_filepath))
182 |
183 | def _elapsed_time(self):
184 | now = datetime.now()
185 | elapsed = now - self.start_time
186 | return str(elapsed).split('.')[0] # remove milliseconds
187 |
--------------------------------------------------------------------------------
/utils/log.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, abspath, join, exists
2 | import os
3 | import logging
4 |
5 | BASE_DIR = dirname(dirname(abspath(__file__)))
6 |
7 |
8 | def get_logger(run_name, save_log=None):
9 | log_dir = join(BASE_DIR, 'logs')
10 | if not exists(log_dir):
11 | os.makedirs(log_dir)
12 |
13 | log_filename = f'{run_name}.log'
14 | if save_log is None:
15 | log_filepath = join(log_dir, log_filename)
16 | else:
17 | log_filepath = save_log
18 |
19 | logger = logging.getLogger(run_name)
20 |
21 | if not logger.handlers: # execute only if logger doesn't already exist
22 | file_handler = logging.FileHandler(log_filepath, 'w', 'utf-8')
23 | stream_handler = logging.StreamHandler(os.sys.stdout)
24 |
25 | formatter = logging.Formatter('[%(levelname)s] %(asctime)s > %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
26 |
27 | file_handler.setFormatter(formatter)
28 | stream_handler.setFormatter(formatter)
29 |
30 | logger.addHandler(file_handler)
31 | logger.addHandler(stream_handler)
32 | logger.setLevel(logging.INFO)
33 |
34 | return logger
--------------------------------------------------------------------------------
/utils/pad.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | PAD_TOKEN_INDEX = 0
5 |
6 |
7 | def pad_masking(x, target_len):
8 | # x: (batch_size, seq_len)
9 | batch_size, seq_len = x.size()
10 | padded_positions = x == PAD_TOKEN_INDEX # (batch_size, seq_len)
11 | pad_mask = padded_positions.unsqueeze(1).expand(batch_size, target_len, seq_len)
12 | return pad_mask
13 |
14 |
15 | def subsequent_masking(x):
16 | # x: (batch_size, seq_len - 1)
17 | batch_size, seq_len = x.size()
18 | subsequent_mask = np.triu(np.ones(shape=(seq_len, seq_len)), k=1).astype('uint8')
19 | subsequent_mask = torch.tensor(subsequent_mask).to(x.device)
20 | subsequent_mask = subsequent_mask.unsqueeze(0).expand(batch_size, seq_len, seq_len)
21 | return subsequent_mask
--------------------------------------------------------------------------------
/utils/pipe.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | PAD_INDEX = 0
4 |
5 |
6 | def input_target_collate_fn(batch):
7 | """merges a list of samples to form a mini-batch."""
8 |
9 | # indexed_sources = [sources for sources, inputs, targets in batch]
10 | # indexed_inputs = [inputs for sources, inputs, targets in batch]
11 | # indexed_targets = [targets for sources, inputs, targets in batch]
12 |
13 | sources_lengths = [len(sources) for sources, inputs, targets in batch]
14 | inputs_lengths = [len(inputs) for sources, inputs, targets in batch]
15 | targets_lengths = [len(targets) for sources, inputs, targets in batch]
16 |
17 | sources_max_length = max(sources_lengths)
18 | inputs_max_length = max(inputs_lengths)
19 | targets_max_length = max(targets_lengths)
20 |
21 | sources_padded = [sources + [PAD_INDEX] * (sources_max_length - len(sources)) for sources, inputs, targets in batch]
22 | inputs_padded = [inputs + [PAD_INDEX] * (inputs_max_length - len(inputs)) for sources, inputs, targets in batch]
23 | targets_padded = [targets + [PAD_INDEX] * (targets_max_length - len(targets)) for sources, inputs, targets in batch]
24 |
25 | sources_tensor = torch.tensor(sources_padded)
26 | inputs_tensor = torch.tensor(inputs_padded)
27 | targets_tensor = torch.tensor(targets_padded)
28 |
29 | # lengths = {
30 | # 'sources_lengths': torch.tensor(sources_lengths),
31 | # 'inputs_lengths': torch.tensor(inputs_lengths),
32 | # 'targets_lengths': torch.tensor(targets_lengths)
33 | # }
34 |
35 | return sources_tensor, inputs_tensor, targets_tensor
36 |
37 |
38 | def shared_tokens_generator(dataset):
39 | for source, target in dataset:
40 | for token in source:
41 | yield token
42 | for token in target:
43 | yield token
44 |
45 |
46 | def source_tokens_generator(dataset):
47 | for source, target in dataset:
48 | for token in source:
49 | yield token
50 |
51 |
52 | def target_tokens_generator(dataset):
53 | for source, target in dataset:
54 | for token in target:
55 | yield token
--------------------------------------------------------------------------------