├── .gitignore ├── LICENSE ├── README.md ├── environment.yaml ├── examples ├── bert_pretraining.py ├── gpt_pretraining.py ├── overfitting_test.py └── speed_check.py ├── run-tests.sh ├── setup.py └── src ├── main └── python │ └── transformer │ ├── __init__.py │ ├── bert │ ├── __init__.py │ └── mlm_loss.py │ ├── decoder.py │ ├── enc_dec_base.py │ ├── encoder.py │ ├── feed_forward_layer.py │ ├── multi_head_attention.py │ ├── normalization.py │ ├── transformer.py │ ├── transformer_tools.py │ └── util.py └── test └── python └── transformer_test ├── __init__.py ├── decoder_test.py ├── encoder_test.py ├── feed_forward_layer_test.py ├── multi_head_attention_test.py ├── normalization_test.py ├── transformer_tools_test.py └── util_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | # 2-Clause BSD License 2 | # 3 | # Copyright (c) 2018, Patrick Hohenecker 4 | # All rights reserved. 5 | # 6 | # Redistribution and use in source and binary forms, with or without 7 | # modification, are permitted provided that the following conditions are met: 8 | # 9 | # 1. Redistributions of source code must retain the above copyright notice, this 10 | # list of conditions and the following disclaimer. 11 | # 2. Redistributions in binary form must reproduce the above copyright notice, 12 | # this list of conditions and the following disclaimer in the documentation 13 | # and/or other materials provided with the distribution. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 19 | # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | 26 | # author: Patrick Hohenecker 27 | # version: 2018.1 28 | # date: Aug 21, 2018 29 | 30 | 31 | ### Python 32 | __pycache__/** 33 | *.py[cdo] 34 | *$py.class 35 | 36 | ### Distribution / Packaging 37 | .Python 38 | build/ 39 | develop-eggs/ 40 | dist/ 41 | downloads/ 42 | eggs/ 43 | .eggs/ 44 | lib/ 45 | lib64/ 46 | parts/ 47 | sdist/ 48 | var/ 49 | wheels/ 50 | *.egg-info/ 51 | .installed.cfg 52 | *.egg 53 | MANIFEST 54 | 55 | ### Mac 56 | .DS_Store 57 | ._* 58 | 59 | ### IntelliJ IDEA / PyCharm 60 | .idea 61 | *.iml 62 | *.iws 63 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2-Clause BSD License 2 | 3 | Copyright (c) 2018, Patrick Hohenecker 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 19 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | pytorch-transformer 2 | =================== 3 | 4 | 5 | This repository provides a PyTorch implementation of the *Transformer* model that has been introduced in the paper 6 | *Attention Is All You Need* (Vaswani et al. 2017). 7 | 8 | 9 | Installation 10 | ------------ 11 | 12 | The easiest way to install this package is via pip: 13 | 14 | ```bash 15 | pip install git+https://github.com/phohenecker/pytorch-transformer 16 | ``` 17 | 18 | 19 | Usage 20 | ----- 21 | 22 | ```python 23 | import transformer 24 | model = transformer.Transformer(...) 25 | ``` 26 | 27 | ##### 1. Computing Predictions given a Target Sequence 28 | 29 | This is the default behaviour of a 30 | [`Transformer`](src/main/python/transformer/transformer.py), 31 | and is implemented in its 32 | [`forward`](src/main/python/transformer/transformer.py#L205) 33 | method: 34 | ```python 35 | predictions = model(input_seq, target_seq) 36 | ``` 37 | 38 | 39 | ##### 2. Evaluating the Probability of a Target Sequence 40 | 41 | The probability of an output sequence given an input sequence under an already trained model can be evaluated by means 42 | of the function 43 | [`eval_probability`](src/main/python/transformer/transformer_tools.py#L46): 44 | ```python 45 | probabilities = transformer.eval_probability(model, input_seq, target_seq, pad_index=...) 46 | ``` 47 | 48 | ##### 3. Sampling an Output Sequence 49 | 50 | Sampling a random output given an input sequence under the distribution computed by a model is realized by the function 51 | [`sample_output`](src/main/python/transformer/transformer_tools.py#L115): 52 | 53 | ```python 54 | output_seq = transformer.sample_output(model, input_seq, eos_index, pad_index, max_len) 55 | ``` 56 | 57 | 58 | Pretraining Encoders with BERT 59 | ------------------------------ 60 | 61 | For pretraining the encoder part of the transformer 62 | (i.e.,[`transformer.Encoder`](src/main/python/transformer/encoder.py)) 63 | with BERT (Devlin et al., 2018), the class [`MLMLoss`](src/main/python/transformer/bert/mlm_loss.py) provides an 64 | implementation of the masked language-model loss function. 65 | A full example of how to implement pretraining with BERT can be found in 66 | [`examples/bert_pretraining.py`](examples/bert_pretraining.py). 67 | 68 | 69 | References 70 | ---------- 71 | 72 | > Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L., Polosukhin, I. (2017). 73 | > Attention Is All You Need. 74 | > Preprint at http://arxiv.org/abs/1706.03762. 75 | 76 | > Devlin, J., Chang, M.-W., Lee, K., & Toutanova, K. (2018). 77 | > BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. 78 | > Preprint at http://arxiv.org/abs/1810.04805. 79 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | # 2-Clause BSD License 2 | # 3 | # Copyright (c) 2018, Patrick Hohenecker 4 | # All rights reserved. 5 | # 6 | # Redistribution and use in source and binary forms, with or without 7 | # modification, are permitted provided that the following conditions are met: 8 | # 9 | # 1. Redistributions of source code must retain the above copyright notice, this 10 | # list of conditions and the following disclaimer. 11 | # 2. Redistributions in binary form must reproduce the above copyright notice, 12 | # this list of conditions and the following disclaimer in the documentation 13 | # and/or other materials provided with the distribution. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 19 | # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | 26 | # author: Patrick Hohenecker 27 | # version: 2018.1 28 | # date: Aug 21, 2018 29 | 30 | 31 | name: pytorch-transformer 32 | channels: 33 | - pytorch 34 | dependencies: 35 | - python>=3.6 36 | - cython 37 | - numpy>=1.15.0 38 | - pytorch>=0.4.1 39 | - pip: 40 | - gensim>=3.7.2 41 | - insanity>=2017.1 42 | - torchtestcase>=2018.2 43 | -------------------------------------------------------------------------------- /examples/bert_pretraining.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | """An example of how to pretrain a transformer encoder with BERT.""" 5 | 6 | 7 | import collections 8 | import itertools 9 | import typing 10 | 11 | import gensim.models.word2vec as word2vec 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | import torch.optim as optim 16 | 17 | import transformer 18 | import transformer.bert as bert 19 | 20 | 21 | __author__ = "Patrick Hohenecker" 22 | __copyright__ = ( 23 | "Copyright (c) 2019, Patrick Hohenecker\n" 24 | "All rights reserved.\n" 25 | "\n" 26 | "Redistribution and use in source and binary forms, with or without\n" 27 | "modification, are permitted provided that the following conditions are met:\n" 28 | "\n" 29 | "1. Redistributions of source code must retain the above copyright notice, this\n" 30 | " list of conditions and the following disclaimer.\n" 31 | "2. Redistributions in binary form must reproduce the above copyright notice,\n" 32 | " this list of conditions and the following disclaimer in the documentation\n" 33 | " and/or other materials provided with the distribution.\n" 34 | "\n" 35 | "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n" 36 | "ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n" 37 | "WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n" 38 | "DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n" 39 | "ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n" 40 | "(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n" 41 | "LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n" 42 | "ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n" 43 | "(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n" 44 | "SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE." 45 | ) 46 | __license__ = "BSD-2-Clause" 47 | __version__ = "2019.1" 48 | __date__ = "23 Apr 2019" 49 | __maintainer__ = "Patrick Hohenecker" 50 | __email__ = "mail@paho.at" 51 | __status__ = "Development" 52 | 53 | 54 | # ==================================================================================================================== # 55 | # C O N S T A N T S # 56 | # ==================================================================================================================== # 57 | 58 | 59 | Token = collections.namedtuple("Token", ["index", "word"]) 60 | """This is used to store index-word pairs.""" 61 | 62 | DATA = [ 63 | "where the streets have no name", 64 | "we ' re still building then burning down love", 65 | "burning down love", 66 | "and when i go there , i go there with you", 67 | "it ' s all i can do" 68 | ] 69 | """list[str]: The already preprocessed training data.""" 70 | 71 | 72 | # SPECIAL TOKENS ##################################################################################################### 73 | 74 | SOS = Token(0, "") 75 | """The start-of-sequence token.""" 76 | 77 | EOS = Token(1, "") 78 | """The end-of-sequence token.""" 79 | 80 | PAD = Token(2, "") 81 | """The padding token.""" 82 | 83 | MASK = Token(3, "") 84 | """The mask token.""" 85 | 86 | 87 | # MODEL CONFIG ####################################################################################################### 88 | 89 | DIMENSIONS = (256, 32, 32) 90 | """tuple[int]: A tuple of d_model, d_k, d_v.""" 91 | 92 | DROPOUT_RATE = 0.1 93 | """float: The used dropout rate.""" 94 | 95 | EMBEDDING_SIZE = DIMENSIONS[0] 96 | """int: The used embedding size.""" 97 | 98 | NUM_LAYERS = 6 99 | """int: The number of layers in the trained transformer encoder.""" 100 | 101 | 102 | # TRAINING DETAILS ################################################################################################### 103 | 104 | GPU = False # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< SET THIS TO True, IF YOU ARE USING A MACHINE WITH A GPU! 105 | """bool: Indicates whether to make use of a GPU.""" 106 | 107 | LEARNING_RATE = 0.0001 108 | """float: The used learning rate.""" 109 | 110 | NUM_EPOCHS = 500 111 | """int: The total number of training epochs.""" 112 | 113 | NUM_HEADS = 6 114 | """int: The number of attention heads to use.""" 115 | 116 | 117 | # ==================================================================================================================== # 118 | # H E L P E R F U N C T I O N S # 119 | # ==================================================================================================================== # 120 | 121 | 122 | def prepare_data() -> typing.Tuple[typing.List[typing.List[str]], collections.OrderedDict]: 123 | """Preprocesses the training data, and creates the vocabulary. 124 | 125 | Returns: 126 | list[list[str]]: The training data as list of samples, each of which is a list of words. 127 | collections.OrderedDict: The vocabulary as an ``OrderedDict`` from words to indices. 128 | """ 129 | 130 | # gather all words that appear in the data 131 | all_words = set() 132 | for sample in DATA: 133 | all_words.update(sample.split(" ")) 134 | 135 | # create the vocabulary 136 | vocab = collections.OrderedDict( 137 | [ 138 | (SOS.word, SOS.index), 139 | (EOS.word, EOS.index), 140 | (PAD.word, PAD.index), 141 | (MASK.word, MASK.index) 142 | ] 143 | ) 144 | for idx, word in enumerate(sorted(all_words)): 145 | vocab[word] = idx + 4 146 | 147 | # split, add ..., and pad the dataset 148 | data = [[SOS.word] + sample.split(" ") + [EOS.word] for sample in DATA] 149 | max_len = max(len(sample) for sample in data) 150 | data = [sample + ([PAD.word] * (max_len - len(sample))) for sample in data] 151 | 152 | return data, vocab 153 | 154 | 155 | # ==================================================================================================================== # 156 | # M A I N # 157 | # ==================================================================================================================== # 158 | 159 | 160 | def main(): 161 | 162 | # fetch the training data 163 | data, vocab = prepare_data() 164 | 165 | # create the word embeddings with word2vec and positional embeddings 166 | emb_model = word2vec.Word2Vec( 167 | sentences=data, 168 | size=EMBEDDING_SIZE, 169 | min_count=1 170 | ) 171 | for word in vocab.keys(): 172 | if word not in emb_model.wv: 173 | emb_model.wv[word] = np.zeros((EMBEDDING_SIZE,)) 174 | word_emb_mat = nn.Parameter( 175 | data=torch.FloatTensor([emb_model[word] for word in vocab.keys()]), 176 | requires_grad=False 177 | ) 178 | word_emb = nn.Embedding(len(vocab), EMBEDDING_SIZE) 179 | word_emb.weight = word_emb_mat 180 | pos_emb = nn.Embedding(len(data[0]), EMBEDDING_SIZE) 181 | pos_emb.weight.require_grad = True 182 | 183 | # turn the dataset into a tensor of word indices 184 | data = torch.LongTensor([[vocab[word] for word in sample] for sample in data]) 185 | 186 | # create the encoder, the pretraining loss, and the optimizer 187 | encoder = transformer.Encoder( 188 | NUM_LAYERS, # num_layers 189 | NUM_HEADS, # num_heads 190 | *DIMENSIONS, # dim_model / dim_keys / dim_values 191 | DROPOUT_RATE, # residual_dropout 192 | DROPOUT_RATE, # attention_dropout 193 | PAD.index # pad_index 194 | ) 195 | loss = bert.MLMLoss( 196 | encoder, 197 | word_emb, 198 | pos_emb, 199 | MASK.index 200 | ) 201 | optimizer = optim.Adam( 202 | itertools.chain(encoder.parameters(), loss.parameters()), 203 | lr=LEARNING_RATE 204 | ) 205 | 206 | # move to GPU, if possible 207 | if GPU: 208 | data = data.cuda() 209 | encoder.cuda() 210 | loss.cuda() # -> also moves embeddings to the GPU 211 | 212 | # pretrain the encoder 213 | for epoch in range(NUM_EPOCHS): 214 | 215 | # compute the loss 216 | optimizer.zero_grad() 217 | current_loss = loss(data) 218 | print("EPOCH", epoch + 1, ": LOSS =", current_loss.item()) 219 | 220 | # update the model 221 | current_loss.backward() 222 | optimizer.step() 223 | 224 | 225 | if __name__ == "__main__": 226 | main() 227 | -------------------------------------------------------------------------------- /examples/gpt_pretraining.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | """An example of how to pretrain a transformer encoder GPT-style.""" 5 | 6 | 7 | import collections 8 | import itertools 9 | import typing 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | 15 | import transformer 16 | import transformer.util as util 17 | 18 | 19 | __author__ = "Patrick Hohenecker" 20 | __copyright__ = ( 21 | "Copyright (c) 2019, Patrick Hohenecker\n" 22 | "All rights reserved.\n" 23 | "\n" 24 | "Redistribution and use in source and binary forms, with or without\n" 25 | "modification, are permitted provided that the following conditions are met:\n" 26 | "\n" 27 | "1. Redistributions of source code must retain the above copyright notice, this\n" 28 | " list of conditions and the following disclaimer.\n" 29 | "2. Redistributions in binary form must reproduce the above copyright notice,\n" 30 | " this list of conditions and the following disclaimer in the documentation\n" 31 | " and/or other materials provided with the distribution.\n" 32 | "\n" 33 | "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n" 34 | "ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n" 35 | "WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n" 36 | "DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n" 37 | "ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n" 38 | "(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n" 39 | "LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n" 40 | "ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n" 41 | "(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n" 42 | "SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE." 43 | ) 44 | __license__ = "BSD-2-Clause" 45 | __version__ = "2019.1" 46 | __date__ = "13 Jul 2019" 47 | __maintainer__ = "Patrick Hohenecker" 48 | __email__ = "mail@paho.at" 49 | __status__ = "Development" 50 | 51 | 52 | # ==================================================================================================================== # 53 | # C O N S T A N T S # 54 | # ==================================================================================================================== # 55 | 56 | 57 | Token = collections.namedtuple("Token", ["index", "word"]) 58 | """This is used to store index-word pairs.""" 59 | 60 | DATA = [ 61 | "where the streets have no name", 62 | "we ' re still building then burning down love", 63 | "burning down love", 64 | "and when i go there , i go there with you", 65 | "it ' s all i can do" 66 | ] 67 | """list[str]: The already preprocessed training data.""" 68 | 69 | # SPECIAL TOKENS ##################################################################################################### 70 | 71 | SOS = Token(0, "") 72 | """The start-of-sequence token.""" 73 | 74 | EOS = Token(1, "") 75 | """The end-of-sequence token.""" 76 | 77 | PAD = Token(2, "") 78 | """The padding token.""" 79 | 80 | MASK = Token(3, "") 81 | """The mask token.""" 82 | 83 | # MODEL CONFIG ####################################################################################################### 84 | 85 | DIMENSIONS = (256, 32, 32) 86 | """tuple[int]: A tuple of d_model, d_k, d_v.""" 87 | 88 | DROPOUT_RATE = 0 89 | """float: The used dropout rate.""" 90 | 91 | EMBEDDING_SIZE = DIMENSIONS[0] 92 | """int: The used embedding size.""" 93 | 94 | NUM_LAYERS = 3 95 | """int: The number of layers in the trained transformer encoder.""" 96 | 97 | # TRAINING DETAILS ################################################################################################### 98 | 99 | GPU = False # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< SET THIS TO True, IF YOU ARE USING A MACHINE WITH A GPU! 100 | """bool: Indicates whether to make use of a GPU.""" 101 | 102 | LEARNING_RATE = 0.0001 103 | """float: The used learning rate.""" 104 | 105 | NUM_EPOCHS = 500 106 | """int: The total number of training epochs.""" 107 | 108 | NUM_HEADS = 6 109 | """int: The number of attention heads to use.""" 110 | 111 | 112 | # ==================================================================================================================== # 113 | # H E L P E R F U N C T I O N S # 114 | # ==================================================================================================================== # 115 | 116 | 117 | def prepare_data() -> typing.Tuple[typing.List[typing.List[str]], collections.OrderedDict]: 118 | """Preprocesses the training data, and creates the vocabulary. 119 | 120 | Returns: 121 | list[list[str]]: The training data as list of samples, each of which is a list of words. 122 | collections.OrderedDict: The vocabulary as an ``OrderedDict`` from words to indices. 123 | """ 124 | 125 | # gather all words that appear in the data 126 | all_words = set() 127 | for sample in DATA: 128 | all_words.update(sample.split(" ")) 129 | 130 | # create the vocabulary 131 | vocab = collections.OrderedDict( 132 | [ 133 | (SOS.word, SOS.index), 134 | (EOS.word, EOS.index), 135 | (PAD.word, PAD.index), 136 | (MASK.word, MASK.index) 137 | ] 138 | ) 139 | for idx, word in enumerate(sorted(all_words)): 140 | vocab[word] = idx + 4 141 | 142 | # split, add ..., and pad the dataset 143 | data = [[SOS.word] + sample.split(" ") + [EOS.word] for sample in DATA] 144 | max_len = max(len(sample) for sample in data) 145 | data = [sample + ([PAD.word] * (max_len - len(sample))) for sample in data] 146 | 147 | return data, vocab 148 | 149 | 150 | # ==================================================================================================================== # 151 | # M A I N # 152 | # ==================================================================================================================== # 153 | 154 | 155 | def main(): 156 | 157 | # fetch the training data 158 | data, vocab = prepare_data() 159 | 160 | # create the word embeddings and positional embeddings (we learn both of them) 161 | word_emb = nn.Embedding(len(vocab), EMBEDDING_SIZE) 162 | pos_emb = nn.Embedding(len(data[0]), EMBEDDING_SIZE) 163 | 164 | # turn the dataset into a tensor of word indices 165 | data = torch.LongTensor([[vocab[word] for word in sample] for sample in data]) 166 | 167 | # create the encoder, the pretraining loss, and the optimizer 168 | encoder = transformer.Encoder( 169 | NUM_LAYERS, # num_layers 170 | NUM_HEADS, # num_heads 171 | *DIMENSIONS, # dim_model / dim_keys / dim_values 172 | DROPOUT_RATE, # residual_dropout 173 | DROPOUT_RATE, # attention_dropout 174 | PAD.index # pad_index 175 | ) 176 | loss = nn.CrossEntropyLoss() 177 | optimizer = optim.Adam( 178 | itertools.chain(encoder.parameters(), word_emb.parameters(), pos_emb.parameters()), 179 | lr=LEARNING_RATE 180 | ) 181 | 182 | # move to GPU, if possible 183 | if GPU: 184 | data = data.cuda() 185 | encoder.cuda() 186 | word_emb.cuda() 187 | pos_emb.cuda() 188 | 189 | # create a mask that ensures that no future steps can be used 190 | mask = util.create_shifted_output_mask(data)[:, :-1, :-1] # -> cut off final time step, which is never an input 191 | 192 | # create a tensor of indices, which is used to retrieve the according positional embeddings below 193 | index_seq = data.new(range(data.size(1) - 1)).unsqueeze(0).expand(data.size(0), -1) 194 | 195 | # pretrain the encoder 196 | for epoch in range(NUM_EPOCHS): 197 | 198 | # embed input sequence + add positional embeddings 199 | input_seq = word_emb(data[:, :-1]) + pos_emb(index_seq) 200 | 201 | # encode the input sequence 202 | enc = encoder(input_seq, mask) 203 | 204 | # compute (unnormalized) next-word predictions from the encoded input sequences 205 | logits = enc.matmul(word_emb.weight.transpose(0, 1)) 206 | 207 | # compute the loss 208 | optimizer.zero_grad() 209 | current_loss = loss(logits.view(-1, logits.size(-1)), data[:, 1:].contiguous().view(-1)) 210 | print(f"EPOCH {epoch + 1:>3}: LOSS = {current_loss.item()}") 211 | 212 | # update the model 213 | current_loss.backward() 214 | optimizer.step() 215 | 216 | # evaluate the probabilities of the training samples 217 | encoder.eval() 218 | input_seq = word_emb(data[:, :-1]) + pos_emb(index_seq) 219 | enc = encoder(input_seq, mask) 220 | log_probs = torch.log_softmax(enc.matmul(word_emb.weight.transpose(0, 1)), 2) 221 | sample_probs = [] 222 | for sample_idx, sample_log_probs in enumerate(log_probs): 223 | sample_data = data[sample_idx][1:].unsqueeze(1) 224 | sample_log_probs = sample_log_probs.gather(1, sample_data) * (sample_data != PAD.index).float() 225 | sample_probs.append(sample_log_probs.sum().exp().item()) 226 | print("\nSAMPLE PROBABILITIES:") 227 | for p in sample_probs: 228 | print("*", p) 229 | 230 | 231 | if __name__ == "__main__": 232 | main() 233 | -------------------------------------------------------------------------------- /examples/overfitting_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | """An implementation of the overfitting test for the Transformer model. 5 | 6 | A simple test, which often signifies bugs in the implementation of a model, is the overfitting test. To that end, the 7 | considered model is trained and evaluated on the same tiny dataset, which it should be able to overfit easily. 8 | Therefore, the final model should yield very high probabilities for the desired target values. If this is not the case, 9 | however, then there is probably something wrong with the tested model and/or its implementation. 10 | 11 | In this module, we test our implementation of the Transformer model on a super-simple translation task from German to 12 | English. To that end, the considered corpus consists of 5 short and already pre-processed sentences, and is specified in 13 | this file (see below). 14 | """ 15 | 16 | 17 | import collections 18 | import itertools 19 | import typing 20 | 21 | import torch 22 | 23 | import transformer 24 | 25 | from torch import nn 26 | from torch import optim 27 | 28 | 29 | __author__ = "Patrick Hohenecker" 30 | __copyright__ = ( 31 | "Copyright (c) 2018, Patrick Hohenecker\n" 32 | "All rights reserved.\n" 33 | "\n" 34 | "Redistribution and use in source and binary forms, with or without\n" 35 | "modification, are permitted provided that the following conditions are met:\n" 36 | "\n" 37 | "1. Redistributions of source code must retain the above copyright notice, this\n" 38 | " list of conditions and the following disclaimer.\n" 39 | "2. Redistributions in binary form must reproduce the above copyright notice,\n" 40 | " this list of conditions and the following disclaimer in the documentation\n" 41 | " and/or other materials provided with the distribution.\n" 42 | "\n" 43 | "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n" 44 | "ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n" 45 | "WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n" 46 | "DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n" 47 | "ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n" 48 | "(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n" 49 | "LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n" 50 | "ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n" 51 | "(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n" 52 | "SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE." 53 | ) 54 | __license__ = "BSD-2-Clause" 55 | __version__ = "2018.1" 56 | __date__ = "Aug 29, 2018" 57 | __maintainer__ = "Patrick Hohenecker" 58 | __email__ = "mail@paho.at" 59 | __status__ = "Development" 60 | 61 | 62 | Token = collections.namedtuple("Token", ["index", "word"]) 63 | """This is used to store index-word pairs.""" 64 | 65 | 66 | # ==================================================================================================================== # 67 | # C O N S T A N T S # 68 | # ==================================================================================================================== # 69 | 70 | 71 | # PARALLEL DATA ###################################################################################################### 72 | 73 | DATA_GERMAN = [ 74 | "Alle warten auf das Licht .", 75 | "Fürchtet euch , fürchtet euch nicht .", 76 | "Die Sonne scheint mir aus den Augen .", 77 | "Sie wird heut ' Nacht nicht untergehen .", 78 | "Und die Welt zählt laut bis 10 ." 79 | ] 80 | 81 | DATA_ENGLISH = [ 82 | "Everyone is waiting for the light .", 83 | "Be afraid , do not be afraid .", 84 | "The sun is shining out of my eyes .", 85 | "It will not go down tonight .", 86 | "And the world counts up to 10 loudly ." 87 | ] 88 | 89 | 90 | # SPECIAL TOKENS ##################################################################################################### 91 | 92 | SOS = Token(0, "") 93 | """str: The start-of-sequence token.""" 94 | 95 | EOS = Token(1, "") 96 | """str: The end-of-sequence token.""" 97 | 98 | PAD = Token(2, "") 99 | """str: The padding token.""" 100 | 101 | 102 | # MODEL CONFIG ####################################################################################################### 103 | 104 | EMBEDDING_SIZE = 300 105 | """int: The used embedding size.""" 106 | 107 | GPU = False # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< SET THIS TO True, IF YOU ARE USING A MACHINE WITH A GPU! 108 | """bool: Indicates whether to make use of a GPU.""" 109 | 110 | NUM_EPOCHS = 200 111 | """int: The total number of training epochs.""" 112 | 113 | 114 | # ==================================================================================================================== # 115 | # H E L P E R F U N C T I O N S # 116 | # ==================================================================================================================== # 117 | 118 | 119 | def eval_model(model: transformer.Transformer, input_seq: torch.LongTensor, target_seq: torch.LongTensor) -> None: 120 | """Evaluates the the provided model on the given data, and prints the probabilities of the desired translations. 121 | 122 | Args: 123 | model (:class:`transformer.Transformer`): The model to evaluate. 124 | input_seq (torch.LongTensor): The input sequences, as (batch-size x max-input-seq-len) tensor. 125 | target_seq (torch.LongTensor): The target sequences, as (batch-size x max-target-seq-len) tensor. 126 | """ 127 | probs = transformer.eval_probability(model, input_seq, target_seq, pad_index=PAD.index).detach().numpy().tolist() 128 | 129 | print("sample " + ("{} " * len(probs)).format(*range(len(probs)))) 130 | print("probability " + ("{:.6f} " * len(probs)).format(*probs)) 131 | 132 | 133 | def fetch_vocab() -> typing.Tuple[typing.List[str], typing.Dict[str, int]]: 134 | """Determines the vocabulary, and provides mappings from indices to words and vice versa. 135 | 136 | Returns: 137 | tuple: A pair of mappings, index-to-word and word-to-index. 138 | """ 139 | # gather all (lower-cased) words that appear in the data 140 | all_words = set() 141 | for sentence in itertools.chain(DATA_GERMAN, DATA_ENGLISH): 142 | all_words.update(word.lower() for word in sentence.split(" ")) 143 | 144 | # create mapping from index to word 145 | idx_to_word = [SOS.word, EOS.word, PAD.word] + list(sorted(all_words)) 146 | 147 | # create mapping from word to index 148 | word_to_idx = {word: idx for idx, word in enumerate(idx_to_word)} 149 | 150 | return idx_to_word, word_to_idx 151 | 152 | 153 | def prepare_data(word_to_idx: typing.Dict[str, int]) -> typing.Tuple[torch.LongTensor, torch.LongTensor]: 154 | """Prepares the data as PyTorch ``LongTensor``s. 155 | 156 | Args: 157 | word_to_idx (dict[str, int]): A dictionary that maps words to indices in the vocabulary. 158 | 159 | Returns: 160 | tuple: A pair of ``LongTensor``s, the first representing the input and the second the target sequence. 161 | """ 162 | # break sentences into word tokens 163 | german = [] 164 | for sentence in DATA_GERMAN: 165 | german.append([SOS.word] + sentence.split(" ") + [EOS.word]) 166 | english = [] 167 | for sentence in DATA_ENGLISH: 168 | english.append([SOS.word] + sentence.split(" ") + [EOS.word]) 169 | 170 | # pad all sentences to equal length 171 | len_german = max(len(sentence) for sentence in german) 172 | for sentence in german: 173 | sentence.extend([PAD.word] * (len_german - len(sentence))) 174 | len_english = max(len(sentence) for sentence in english) 175 | for sentence in english: 176 | sentence.extend([PAD.word] * (len_english - len(sentence))) 177 | 178 | # map words to indices in the vocabulary 179 | german = [[word_to_idx[word.lower()] for word in sentence] for sentence in german] 180 | english = [[word_to_idx[word.lower()] for word in sentence] for sentence in english] 181 | 182 | # create according LongTensors 183 | german = torch.LongTensor(german) 184 | english = torch.LongTensor(english) 185 | 186 | return german, english 187 | 188 | 189 | # ==================================================================================================================== # 190 | # M A I N # 191 | # ==================================================================================================================== # 192 | 193 | 194 | def main(): 195 | # fetch vocabulary + prepare data 196 | idx_to_word, word_to_idx = fetch_vocab() 197 | input_seq, target_seq = prepare_data(word_to_idx) 198 | 199 | # create embeddings to use 200 | emb = nn.Embedding(len(idx_to_word), EMBEDDING_SIZE) 201 | emb.reset_parameters() 202 | 203 | # create transformer model 204 | model = transformer.Transformer( 205 | emb, 206 | PAD.index, 207 | emb.num_embeddings, 208 | max_seq_len=max(input_seq.size(1), target_seq.size(1)) 209 | ) 210 | 211 | # create an optimizer for training the model + a X-entropy loss 212 | optimizer = optim.Adam((param for param in model.parameters() if param.requires_grad), lr=0.0001) 213 | loss = nn.CrossEntropyLoss() 214 | 215 | print("Initial Probabilities of Translations:") 216 | print("--------------------------------------") 217 | eval_model(model, input_seq, target_seq) 218 | print() 219 | 220 | # move model + data on the GPU (if possible) 221 | if GPU: 222 | model.cuda() 223 | input_seq = input_seq.cuda() 224 | target_seq = target_seq.cuda() 225 | 226 | # train the model 227 | for epoch in range(NUM_EPOCHS): 228 | print("training epoch {}...".format(epoch + 1), end=" ") 229 | 230 | predictions = model(input_seq, target_seq) 231 | optimizer.zero_grad() 232 | current_loss = loss( 233 | predictions.view(predictions.size(0) * predictions.size(1), predictions.size(2)), 234 | target_seq.view(-1) 235 | ) 236 | current_loss.backward() 237 | optimizer.step() 238 | 239 | print("OK (loss: {:.6f})".format(current_loss.item())) 240 | 241 | # put model in evaluation mode 242 | model.eval() 243 | 244 | print() 245 | print("Final Probabilities of Translations:") 246 | print("------------------------------------") 247 | eval_model(model, input_seq, target_seq) 248 | 249 | # randomly sample outputs from the input sequences based on the probabilities computed by the trained model 250 | sampled_output = transformer.sample_output(model, input_seq, EOS.index, PAD.index, target_seq.size(1)) 251 | 252 | print() 253 | print("Sampled Outputs:") 254 | print("----------------") 255 | for sample_idx in range(input_seq.size(0)): 256 | for token_idx in range(input_seq.size(1)): 257 | print(idx_to_word[input_seq[sample_idx, token_idx].item()], end=" ") 258 | print(" => ", end=" ") 259 | for token_idx in range(sampled_output.size(1)): 260 | print(idx_to_word[sampled_output[sample_idx, token_idx].item()], end=" ") 261 | print() 262 | 263 | 264 | if __name__ == "__main__": 265 | main() 266 | -------------------------------------------------------------------------------- /examples/speed_check.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | """A performance comparison of the transformer and a traditional recurrent attention-based model. 5 | 6 | This module measures how long it takes to process one training batch of a (random) sequence-to-sequence task. The 7 | architecture of the recurrent model that the transformer is compared with uses additive attention and GRUs for both 8 | encoder and decoder. 9 | """ 10 | 11 | 12 | import time 13 | 14 | import numpy as np 15 | import torch 16 | 17 | import transformer 18 | 19 | from torch import nn 20 | 21 | 22 | __author__ = "Patrick Hohenecker" 23 | __copyright__ = ( 24 | "Copyright (c) 2018, Patrick Hohenecker\n" 25 | "All rights reserved.\n" 26 | "\n" 27 | "Redistribution and use in source and binary forms, with or without\n" 28 | "modification, are permitted provided that the following conditions are met:\n" 29 | "\n" 30 | "1. Redistributions of source code must retain the above copyright notice, this\n" 31 | " list of conditions and the following disclaimer.\n" 32 | "2. Redistributions in binary form must reproduce the above copyright notice,\n" 33 | " this list of conditions and the following disclaimer in the documentation\n" 34 | " and/or other materials provided with the distribution.\n" 35 | "\n" 36 | "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n" 37 | "ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n" 38 | "WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n" 39 | "DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n" 40 | "ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n" 41 | "(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n" 42 | "LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n" 43 | "ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n" 44 | "(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n" 45 | "SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE." 46 | ) 47 | __license__ = "BSD-2-Clause" 48 | __version__ = "2018.1" 49 | __date__ = "Oct 30, 2018" 50 | __maintainer__ = "Patrick Hohenecker" 51 | __email__ = "mail@paho.at" 52 | __status__ = "Development" 53 | 54 | 55 | BATCH_SIZE = 128 56 | """int: The size of the generated batch of input sequences.""" 57 | 58 | EMBEDDING_SIZE = 100 59 | """int: The size of each token of the input sequences.""" 60 | 61 | GPU = False # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< SET THIS TO True, IF YOU ARE USING A MACHINE WITH A GPU! 62 | """bool: Indicates whether to make use of a GPU.""" 63 | 64 | HIDDEN_SIZE = 300 65 | """int: The size of any hidden layers.""" 66 | 67 | INPUT_LEN = 100 68 | """int: The length of the (randomly) generated input sequence.""" 69 | 70 | NUM_RUNS = 5 71 | """int: The total number of times that each model is ran and the according execution time tracked.""" 72 | 73 | VOCAB_SIZE = 10000 74 | """int: The size of the used vocabulary.""" 75 | 76 | 77 | # ==================================================================================================================== # 78 | # R E C U R R E N T M O D E L # 79 | # ==================================================================================================================== # 80 | 81 | 82 | class EncDecWithAttn(nn.Module): 83 | """A very simple attention-based encoder-decoder model. 84 | 85 | Here is a schematic of the implemented model: 86 | 87 | word-1 \ 88 | word-2 | +-----------------+ +--------------------+ +----------------+ 89 | word-3 > -> | Encoder (BiGRU) | -> | Decoder (GRU+attn) | -> | linear+softmax | -> distribution over vocabulary 90 | word-4 | +-----------------+ +--------------------+ +----------------+ 91 | ... / 92 | """ 93 | 94 | def __init__(self, emb: nn.Embedding, hidden_size: int): 95 | super().__init__() 96 | self._emb = emb 97 | 98 | # create the decoder 99 | self._encoder = nn.GRU( 100 | emb.embedding_dim, # input_size 101 | hidden_size, # hidden_size, 102 | bidirectional=True 103 | ) 104 | self._encoder_init = nn.Parameter(torch.FloatTensor(2, hidden_size)) 105 | 106 | # create the decoder 107 | self._decoder = nn.GRU( 108 | 2 * emb.embedding_dim + 2 * hidden_size, # input_size 109 | hidden_size # hidden_size 110 | ) 111 | self._decoder_init_hidden = nn.Parameter(torch.FloatTensor(1, hidden_size)) 112 | self._decoder_init_output = nn.Parameter(torch.FloatTensor(1, emb.embedding_dim)) 113 | 114 | # add an additional feed-forward layer to be used on top of the decoder 115 | self._output_proj = nn.Sequential( 116 | nn.Linear(hidden_size, emb.num_embeddings), 117 | nn.Softmax(dim=1) 118 | ) 119 | 120 | # create module for computing the attention scores 121 | self._attn = nn.Sequential( 122 | nn.Linear(3 * hidden_size, hidden_size, bias=False), 123 | nn.Tanh(), 124 | nn.Linear(hidden_size, 1, bias=False), 125 | nn.Softmax(dim=1) 126 | ) 127 | 128 | self.reset_parameters() 129 | 130 | def forward(self, input_seq: torch.LongTensor, target_seq) -> torch.FloatTensor: 131 | # embed + encode input sequence 132 | input_seq = self._emb(input_seq) 133 | enc_seq, _ = self._encoder( 134 | input_seq, 135 | self._encoder_init.expand(input_seq.size(1), *self._encoder_init.size()).transpose(0, 1).contiguous() 136 | ) 137 | 138 | all_outputs = [] # used the store the outputs for all time steps 139 | 140 | # these are used to store the decoder's last state as well as last output produced 141 | last_hidden = self._decoder_init_hidden \ 142 | .expand(input_seq.size(1), *self._decoder_init_hidden.size()) \ 143 | .transpose(0, 1) 144 | last_hidden = last_hidden.contiguous() 145 | last_output = self._decoder_init_output.expand(input_seq.size(1), self._decoder_init_output.size(1)) 146 | last_output = last_output.contiguous() 147 | 148 | # iterate over the input sequence token-by-token, and run the decoder 149 | for idx, token in enumerate(input_seq): 150 | 151 | # run attention to compute a glimpse of the encoded input sequence 152 | attn_scores = torch.cat( 153 | [enc_seq, last_hidden.expand(enc_seq.size(0), last_hidden.size(1), last_hidden.size(2))], 154 | dim=2 155 | ) 156 | attn_scores = self._attn(attn_scores) 157 | glimpse = (enc_seq * attn_scores).sum(dim=0) 158 | 159 | # add a 0-th time-dimension to all inputs of the decoder 160 | token = token.unsqueeze(0) 161 | glimpse = glimpse.unsqueeze(0) 162 | last_output = last_output.unsqueeze(0) 163 | 164 | # run the decoder + softmax on top 165 | _, last_hidden = self._decoder(torch.cat([token, glimpse, last_output], dim=2), last_hidden) 166 | last_output = self._output_proj(last_hidden.squeeze(0)) 167 | all_outputs.append(last_output) 168 | 169 | # fill in target output 170 | last_output = target_seq[idx] 171 | last_output = self._emb(last_output) 172 | 173 | return torch.stack(all_outputs) 174 | 175 | def reset_parameters(self): 176 | """Resets all tunable parameters of the module.""" 177 | self._encoder.reset_parameters() 178 | self._decoder.reset_parameters() 179 | self._attn[0].reset_parameters() 180 | self._attn[2].reset_parameters() 181 | self._output_proj[0].reset_parameters() 182 | nn.init.normal_(self._encoder_init, std=0.1) 183 | nn.init.normal_(self._decoder_init_hidden, std=0.1) 184 | nn.init.normal_(self._decoder_init_output, std=0.1) 185 | 186 | 187 | # ==================================================================================================================== # 188 | # M A I N # 189 | # ==================================================================================================================== # 190 | 191 | 192 | def main(): 193 | # create an embedding matrix + randomly sample input as well as target sequence 194 | emb = nn.Embedding(VOCAB_SIZE, EMBEDDING_SIZE) 195 | input_seq = torch.from_numpy(np.random.randint(1, VOCAB_SIZE - 1, (INPUT_LEN, BATCH_SIZE))) 196 | target_seq = torch.from_numpy(np.random.randint(1, VOCAB_SIZE - 1, (INPUT_LEN, BATCH_SIZE))) 197 | # -> we assume that index 0 is the token 198 | 199 | # create the models being compared 200 | recurrent_model = EncDecWithAttn(emb, HIDDEN_SIZE) 201 | transformer_model = transformer.Transformer( 202 | emb, # text_emb 203 | 0, # pad_index 204 | emb.num_embeddings, # output_size 205 | max_seq_len=INPUT_LEN, 206 | dim_model=HIDDEN_SIZE, 207 | num_layers=1 208 | ) 209 | 210 | # move everything to the GPU, if possible 211 | if GPU: 212 | input_seq = input_seq.cuda() 213 | target_seq = target_seq.cuda() 214 | emb.cuda() 215 | recurrent_model.cuda() 216 | transformer_model.cuda() 217 | 218 | # measure how long it takes the recurrent model to process the data 219 | print("Testing the recurrent attention-based encoder-decoder model...") 220 | times = [] 221 | for idx in range(NUM_RUNS): 222 | start = time.time() 223 | recurrent_model(input_seq, target_seq) 224 | times.append(time.time() - start) 225 | print("Run {} finished in {:.3f}s".format(idx + 1, times[-1])) 226 | print("Avg. duration: {:.3f}s\n".format(np.mean(times))) 227 | 228 | # flip the first two dimensions of the data, as the transformer expects the first dimension to be the batch 229 | input_seq = input_seq.transpose(0, 1) 230 | target_seq = target_seq.transpose(0, 1) 231 | 232 | # measure how long it takes the transformer model to process the data 233 | print("Testing the transformer model...") 234 | times = [] 235 | for idx in range(NUM_RUNS): 236 | start = time.time() 237 | transformer_model(input_seq, target_seq) 238 | times.append(time.time() - start) 239 | print("Run {} finished in {:.3f}s".format(idx + 1, times[-1])) 240 | print("Avg. duration: {:.3f}s".format(np.mean(times))) 241 | 242 | 243 | if __name__ == "__main__": 244 | main() 245 | -------------------------------------------------------------------------------- /run-tests.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # 2-Clause BSD License 4 | # 5 | # Copyright (c) 2018, Patrick Hohenecker 6 | # All rights reserved. 7 | # 8 | # Redistribution and use in source and binary forms, with or without 9 | # modification, are permitted provided that the following conditions are met: 10 | # 11 | # 1. Redistributions of source code must retain the above copyright notice, this 12 | # list of conditions and the following disclaimer. 13 | # 2. Redistributions in binary form must reproduce the above copyright notice, 14 | # this list of conditions and the following disclaimer in the documentation 15 | # and/or other materials provided with the distribution. 16 | # 17 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 18 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 19 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 20 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 21 | # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 22 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 23 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 24 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | 28 | # author: Patrick Hohenecker 29 | # version: 2018.1 30 | # date: Aug 21, 2018 31 | 32 | 33 | export PYTHONPATH=`pwd`/src/main/python:${PYTHONPATH} 34 | python3 -m unittest discover -s src/test/python -p "*_test.py" 35 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | from distutils.core import setup 6 | 7 | 8 | __author__ = "Patrick Hohenecker" 9 | __copyright__ = ( 10 | "Copyright (c) 2018, Patrick Hohenecker\n" 11 | "All rights reserved.\n" 12 | "\n" 13 | "Redistribution and use in source and binary forms, with or without\n" 14 | "modification, are permitted provided that the following conditions are met:\n" 15 | "\n" 16 | "1. Redistributions of source code must retain the above copyright notice, this\n" 17 | " list of conditions and the following disclaimer.\n" 18 | "2. Redistributions in binary form must reproduce the above copyright notice,\n" 19 | " this list of conditions and the following disclaimer in the documentation\n" 20 | " and/or other materials provided with the distribution.\n" 21 | "\n" 22 | "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n" 23 | "ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n" 24 | "WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n" 25 | "DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n" 26 | "ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n" 27 | "(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n" 28 | "LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n" 29 | "ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n" 30 | "(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n" 31 | "SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE." 32 | ) 33 | __license__ = "BSD-2-Clause" 34 | __version__ = "2018.1" 35 | __date__ = "Aug 21, 2018" 36 | __maintainer__ = "Patrick Hohenecker" 37 | __email__ = "mail@paho.at" 38 | __status__ = "Development" 39 | 40 | 41 | # read the long description from the read me file 42 | long_description = open("README.md").read() 43 | 44 | 45 | setup( 46 | author="Patrick Hohenecker", 47 | author_email="mail@paho.at", 48 | classifiers=[ 49 | "Programming Language :: Python :: 3.6", 50 | "Programming Language :: Python :: 3.7" 51 | ], 52 | copyright="Copyright (c) 2018 Patrick Hohenecker", 53 | data_files=[ 54 | (".", ["LICENSE", "README.md"]) 55 | ], 56 | description="A PyTorch implementation of the Transformer model from \"Attention Is All You Need\".", 57 | install_requires=[ 58 | "insanity>=2017.1", 59 | "numpy>=1.15.0", 60 | "torch>=0.4.1" 61 | ], 62 | license="BSD-2-Clause", 63 | long_description=long_description, 64 | name="transformer", 65 | package_dir={"": "src/main/python"}, 66 | packages=[ 67 | "transformer", 68 | "transformer.bert" 69 | ], 70 | python_requires=">=3.6", 71 | url="https://github.com/phohenecker/pytorch-transformer", 72 | version="2018.1" 73 | ) 74 | -------------------------------------------------------------------------------- /src/main/python/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """A PyTorch implementation of the Transformer model from "Attention Is All You Need".""" 4 | 5 | 6 | import transformer.bert as bert 7 | 8 | from transformer.encoder import Encoder 9 | from transformer.transformer import Transformer 10 | from transformer.transformer_tools import * 11 | 12 | 13 | __author__ = "Patrick Hohenecker" 14 | __copyright__ = ( 15 | "Copyright (c) 2018, Patrick Hohenecker\n" 16 | "All rights reserved.\n" 17 | "\n" 18 | "Redistribution and use in source and binary forms, with or without\n" 19 | "modification, are permitted provided that the following conditions are met:\n" 20 | "\n" 21 | "1. Redistributions of source code must retain the above copyright notice, this\n" 22 | " list of conditions and the following disclaimer.\n" 23 | "2. Redistributions in binary form must reproduce the above copyright notice,\n" 24 | " this list of conditions and the following disclaimer in the documentation\n" 25 | " and/or other materials provided with the distribution.\n" 26 | "\n" 27 | "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n" 28 | "ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n" 29 | "WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n" 30 | "DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n" 31 | "ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n" 32 | "(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n" 33 | "LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n" 34 | "ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n" 35 | "(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n" 36 | "SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE." 37 | ) 38 | __license__ = "BSD-2-Clause" 39 | __version__ = "2018.1" 40 | __date__ = "Aug 21, 2018" 41 | __maintainer__ = "Patrick Hohenecker" 42 | __email__ = "mail@paho.at" 43 | __status__ = "Development" 44 | -------------------------------------------------------------------------------- /src/main/python/transformer/bert/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """This package implements modules for pretraining the encoder part of the transformer via BERT.""" 4 | 5 | 6 | from transformer.bert.mlm_loss import MLMLoss 7 | 8 | 9 | __author__ = "Patrick Hohenecker" 10 | __copyright__ = ( 11 | "Copyright (c) 2019, Patrick Hohenecker\n" 12 | "All rights reserved.\n" 13 | "\n" 14 | "Redistribution and use in source and binary forms, with or without\n" 15 | "modification, are permitted provided that the following conditions are met:\n" 16 | "\n" 17 | "1. Redistributions of source code must retain the above copyright notice, this\n" 18 | " list of conditions and the following disclaimer.\n" 19 | "2. Redistributions in binary form must reproduce the above copyright notice,\n" 20 | " this list of conditions and the following disclaimer in the documentation\n" 21 | " and/or other materials provided with the distribution.\n" 22 | "\n" 23 | "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n" 24 | "ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n" 25 | "WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n" 26 | "DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n" 27 | "ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n" 28 | "(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n" 29 | "LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n" 30 | "ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n" 31 | "(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n" 32 | "SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE." 33 | ) 34 | __license__ = "BSD-2-Clause" 35 | __version__ = "2019.1" 36 | __date__ = "23 Apr 2019" 37 | __maintainer__ = "Patrick Hohenecker" 38 | __email__ = "mail@paho.at" 39 | __status__ = "Development" 40 | -------------------------------------------------------------------------------- /src/main/python/transformer/bert/mlm_loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import math 5 | import numbers 6 | import random 7 | 8 | import insanity 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | 13 | import transformer.encoder as encoder 14 | import transformer.util as util 15 | 16 | 17 | __author__ = "Patrick Hohenecker" 18 | __copyright__ = ( 19 | "Copyright (c) 2019, Patrick Hohenecker\n" 20 | "All rights reserved.\n" 21 | "\n" 22 | "Redistribution and use in source and binary forms, with or without\n" 23 | "modification, are permitted provided that the following conditions are met:\n" 24 | "\n" 25 | "1. Redistributions of source code must retain the above copyright notice, this\n" 26 | " list of conditions and the following disclaimer.\n" 27 | "2. Redistributions in binary form must reproduce the above copyright notice,\n" 28 | " this list of conditions and the following disclaimer in the documentation\n" 29 | " and/or other materials provided with the distribution.\n" 30 | "\n" 31 | "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n" 32 | "ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n" 33 | "WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n" 34 | "DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n" 35 | "ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n" 36 | "(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n" 37 | "LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n" 38 | "ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n" 39 | "(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n" 40 | "SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE." 41 | ) 42 | __license__ = "BSD-2-Clause" 43 | __version__ = "2019.1" 44 | __date__ = "23 Apr 2019" 45 | __maintainer__ = "Patrick Hohenecker" 46 | __email__ = "mail@paho.at" 47 | __status__ = "Development" 48 | 49 | 50 | class MLMLoss(nn.Module): 51 | """The masked language-model (MLM) loss function for pretraining a transformer encoder. 52 | 53 | Unlike other loss functions, an ``MLMLoss`` has trainable parameters, which are part of a linear layer with a 54 | softmax on top (cf. :attr:`output_layer`) that is used for predicting masked/obliterated tokens. These have to be 55 | optimized together with the parameters of the pretrained encoder. 56 | """ 57 | 58 | def __init__( 59 | self, 60 | model: encoder.Encoder, 61 | word_emb: nn.Embedding, 62 | pos_emb: nn.Embedding, 63 | mask_index: int, 64 | prediction_rate: numbers.Real = 0.15, 65 | mask_rate: numbers.Real = 0.8, 66 | random_rate: numbers.Real = 0.1 67 | ): 68 | """Creates a new instance of ``BERTLoss`. 69 | 70 | Args: 71 | model (encoder.Encoder): The encoder model being pretrained. 72 | word_emb (nn.Embedding): The used word embeddings. 73 | pos_emb (nn.Embedding): The used positional embeddings. 74 | mask_index (int): The index of the mask token. 75 | prediction_rate (numbers.Real, optional): The percentage of tokens in each training sequence that 76 | predictions are computed for, which is set to ``0.8``, by default. 77 | mask_rate (numbers.Real, optional): Among all tokens that predictions are computed for, the percentage of 78 | tokens that are replaced with the mask token, as specified by ``mask_index``. This is set to ``0.8``, by 79 | default. 80 | random_rate (numbers.Real, optional): Among all tokens that predictions are computed for, the percentage of 81 | tokens that are randomly replaced with other tokens. This is set to ``0.1``, by default. 82 | """ 83 | super().__init__() 84 | 85 | # sanitize args 86 | insanity.sanitize_type("model", model, encoder.Encoder) 87 | insanity.sanitize_type("word_emb", word_emb, nn.Embedding) 88 | insanity.sanitize_type("pos_emb", word_emb, nn.Embedding) 89 | if pos_emb.embedding_dim != word_emb.embedding_dim: 90 | raise ValueError(" is not compatible with !") 91 | insanity.sanitize_type("mask_index", mask_index, int) 92 | if mask_index < 0 or mask_index >= word_emb.num_embeddings: 93 | raise ValueError("The does not exist in !") 94 | insanity.sanitize_type("prediction_rate", prediction_rate, numbers.Real) 95 | prediction_rate = float(prediction_rate) 96 | insanity.sanitize_range("prediction_rate", prediction_rate, minimum=0, maximum=1) 97 | insanity.sanitize_type("mask_rate", mask_rate, numbers.Real) 98 | mask_rate = float(mask_rate) 99 | insanity.sanitize_range("mask_rate", mask_rate, minimum=0, maximum=1) 100 | insanity.sanitize_type("random_rate", random_rate, numbers.Real) 101 | random_rate = float(random_rate) 102 | insanity.sanitize_range("random_rate", random_rate, minimum=0, maximum=1) 103 | if mask_rate + random_rate > 1: 104 | raise ValueError(" + has to be at most 1!") 105 | 106 | # store args 107 | self._mask_index = mask_index 108 | self._mask_rate = mask_rate 109 | self._model = model 110 | self._pad_index = model.pad_index 111 | self._pos_emb = pos_emb 112 | self._prediction_rate = prediction_rate 113 | self._random_rate = random_rate 114 | self._word_emb = word_emb 115 | 116 | # create an output layer, which is trained together with the model, for predicting masked tokens 117 | self._output_layer = nn.Sequential( 118 | nn.Linear(self._word_emb.embedding_dim, self._word_emb.num_embeddings), 119 | nn.Softmax(dim=1) 120 | ) 121 | 122 | # create the used loss function 123 | self._loss = nn.CrossEntropyLoss() 124 | 125 | # PROPERTIES ##################################################################################################### 126 | 127 | @property 128 | def output_layer(self) -> nn.Sequential: 129 | """nn.Sequential: A linear layer with a softmax on top, which is used for predicting masked/obliterated tokens. 130 | """ 131 | return self._output_layer 132 | 133 | # METHODS ######################################################################################################## 134 | 135 | def forward(self, batch: torch.LongTensor) -> torch.FloatTensor: 136 | """Computes the loss function. 137 | 138 | Args: 139 | batch (torch.LongTensor): A batch of training data, as (batch-size x max-seq-len)-tensor. 140 | 141 | Returns: 142 | torch.FloatTensor: The computed loss. 143 | """ 144 | # sanitize args 145 | insanity.sanitize_type("batch", batch, torch.Tensor) 146 | if batch.dtype != torch.int64: 147 | raise TypeError(" has to be a LongTensor!") 148 | if batch.dim() != 2: 149 | raise ValueError(" has to be a 2d tensor!") 150 | 151 | # create the padding mask to use 152 | padding_mask = util.create_padding_mask(batch, self._pad_index) 153 | 154 | # create a tensor of indices, which is used to retrieve the according positional embeddings below 155 | index_seq = batch.new(range(batch.size(1))).unsqueeze(0).expand(batch.size(0), -1) 156 | 157 | # compute the sequence lengths for all samples in the batch 158 | seq_len = (batch != self._pad_index).sum(dim=1).cpu().numpy().tolist() 159 | 160 | # randomly choose the tokens to compute predictions for 161 | pred_mask = padding_mask.new(*batch.size()).zero_().long() # all tokens being predicted 162 | mask_mask = padding_mask.new(*batch.size()).zero_().long() # token replaced with 163 | random_mask = padding_mask.new(*batch.size()).zero_().long() # tokens replace with random tokens 164 | for sample_idx, sample_len in enumerate(seq_len): # iterate over all samples in the batch 165 | 166 | # determine how many tokens to computed predictions for 167 | num_pred = int(math.ceil(sample_len * self._prediction_rate)) # num of tokens predictions are computed for 168 | num_mask = int(math.floor(num_pred * self._mask_rate)) # num of tokens replaced with 169 | num_random = int(math.ceil(num_pred * self._random_rate)) # num of tokens randomly replaced 170 | 171 | # randomly select indices to compute predictions for 172 | pred_indices = list(range(sample_len)) 173 | random.shuffle(pred_indices) 174 | pred_indices = pred_indices[:num_pred] 175 | 176 | # prepare the -mask 177 | for token_idx in pred_indices[:num_mask]: 178 | pred_mask[sample_idx, token_idx] = 1 179 | mask_mask[sample_idx, token_idx] = 1 180 | 181 | # prepare the random-mask 182 | for token_idx in pred_indices[num_mask:(num_mask + num_random)]: 183 | pred_mask[sample_idx, token_idx] = 1 184 | random_mask[sample_idx, token_idx] = 1 185 | 186 | # remaining tokens that predictions are computed for are left untouched 187 | for token_idx in pred_indices[(num_mask + num_random):]: 188 | pred_mask[sample_idx, token_idx] = 1 189 | 190 | # replace predicted tokens in the batch appropriately 191 | masked_batch = ( 192 | batch * (1 - mask_mask) * (1 - random_mask) + 193 | mask_mask * batch.new(*batch.size()).fill_(self._mask_index) + 194 | random_mask * (batch.new(*batch.size()).double().uniform_() * self._word_emb.num_embeddings).long() 195 | ) 196 | 197 | # embed the batch 198 | masked_batch = self._word_emb(masked_batch) + self._pos_emb(index_seq) 199 | 200 | # encode sequence in the batch using BERT 201 | enc = self._model(masked_batch, padding_mask) 202 | 203 | # turn encodings, the target token indices (that we seek to predict), and the prediction mask, into matrices, 204 | # such that each row corresponds with one token 205 | enc = enc.view(enc.size(0) * enc.size(1), enc.size(2)) 206 | target = batch.view(-1) 207 | pred_mask = pred_mask.view(-1) 208 | 209 | # turn the prediction mask into a tensor of indices (to select below) 210 | pred_mask = pred_mask.new(np.where(pred_mask.detach().cpu().numpy())[0]) 211 | 212 | # fetch embeddings and target values of those tokens that are being predicted 213 | enc = enc.index_select(0, pred_mask) 214 | target = target.index_select(0, pred_mask) 215 | 216 | # compute predictions for each encoded token + the according loss 217 | pred = self._output_layer(enc) 218 | loss = self._loss(pred, target) 219 | 220 | return loss 221 | 222 | def reset_parameters(self) -> None: 223 | """Resets the loss' tunable parameters that are being trained to predict masked/obliterated tokens. 224 | 225 | Notice that this function does **not** reset the used embeddings. 226 | """ 227 | self._output_layer[0].reset_parameters() 228 | -------------------------------------------------------------------------------- /src/main/python/transformer/decoder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import torch 5 | 6 | from torch import nn 7 | 8 | from transformer import enc_dec_base 9 | from transformer import feed_forward_layer as ffl 10 | from transformer import multi_head_attention as mha 11 | from transformer import normalization 12 | from transformer import util 13 | 14 | 15 | __author__ = "Patrick Hohenecker" 16 | __copyright__ = ( 17 | "Copyright (c) 2018, Patrick Hohenecker\n" 18 | "All rights reserved.\n" 19 | "\n" 20 | "Redistribution and use in source and binary forms, with or without\n" 21 | "modification, are permitted provided that the following conditions are met:\n" 22 | "\n" 23 | "1. Redistributions of source code must retain the above copyright notice, this\n" 24 | " list of conditions and the following disclaimer.\n" 25 | "2. Redistributions in binary form must reproduce the above copyright notice,\n" 26 | " this list of conditions and the following disclaimer in the documentation\n" 27 | " and/or other materials provided with the distribution.\n" 28 | "\n" 29 | "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n" 30 | "ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n" 31 | "WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n" 32 | "DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n" 33 | "ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n" 34 | "(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n" 35 | "LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n" 36 | "ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n" 37 | "(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n" 38 | "SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE." 39 | ) 40 | __license__ = "BSD-2-Clause" 41 | __version__ = "2018.1" 42 | __date__ = "Aug 21, 2018" 43 | __maintainer__ = "Patrick Hohenecker" 44 | __email__ = "mail@paho.at" 45 | __status__ = "Development" 46 | 47 | 48 | # ==================================================================================================================== # 49 | # CLASS D E C O D E R # 50 | # ==================================================================================================================== # 51 | 52 | 53 | class Decoder(nn.Module, enc_dec_base.EncDecBase): 54 | """The decoder that is used in the Transformer model.""" 55 | 56 | def __init__(self, *args, **kwargs): 57 | nn.Module.__init__(self) 58 | enc_dec_base.EncDecBase.__init__(self, *args, **kwargs) 59 | 60 | self._layers = nn.ModuleList([_DecoderLayer(self) for _ in range(self._num_layers)]) 61 | 62 | # METHODS ######################################################################################################## 63 | 64 | def forward( 65 | self, 66 | in_sequence: torch.FloatTensor, 67 | out_sequence: torch.FloatTensor, 68 | padding_mask: torch.ByteTensor=None 69 | ) -> torch.FloatTensor: 70 | """Runs the decoder. 71 | 72 | Args: 73 | in_sequence (torch.FloatTensor): The input sequence as (batch-size x in-seq-len x dim_model)-tensor. 74 | out_sequence (torch.FloatTensor): The output sequence as (batch-size x out-seq-len x dim_model)-tensor. 75 | padding_mask (torch.ByteTensor, optional): Optionally, a padding mask as 76 | (batch-size x in-seq-len x in-seq-len)-tensor. To that end, ``1``s indicate those positions that are 77 | part of the according sequence, and ``0``s mark padding tokens. 78 | 79 | Returns: 80 | FloatTensor: The computed output as (batch_size x out-seq-len x dim_model)-tensor. 81 | """ 82 | assert in_sequence.dim() == 3 83 | assert in_sequence.size(2) == self._dim_model 84 | assert out_sequence.dim() == 3 85 | assert out_sequence.size(0) == in_sequence.size(0) 86 | assert out_sequence.size(2) == self._dim_model 87 | if padding_mask is not None: 88 | assert padding_mask.dim() == 3 89 | assert padding_mask.size(0) == in_sequence.size(0) 90 | assert padding_mask.size(1) == in_sequence.size(1) 91 | assert padding_mask.size(2) == in_sequence.size(1) 92 | 93 | # create shifted output mask 94 | shifted_output_mask = util.create_shifted_output_mask(out_sequence) 95 | 96 | # shift provided target output to the right 97 | out_sequence = util.shift_output_sequence(out_sequence) 98 | 99 | # apply all layers to the input 100 | for layer in self._layers: 101 | out_sequence = layer(in_sequence, out_sequence, padding_mask, shifted_output_mask) 102 | 103 | # provide the created output 104 | return out_sequence 105 | 106 | def reset_parameters(self) -> None: 107 | for l in self._layers: 108 | l.reset_parameters() 109 | 110 | 111 | # ==================================================================================================================== # 112 | # CLASS _ D E C O D E R L A Y E R # 113 | # ==================================================================================================================== # 114 | 115 | 116 | class _DecoderLayer(nn.Module): 117 | """One layer of the decoder. 118 | 119 | Attributes: 120 | attn_1 (:class:`mha.MultiHeadAttention`): The attention mechanism that is used to read from the output sequence. 121 | attn_2 (:class:`mha.MultiHeadAttention`): The encoder-decoder attention mechanism. 122 | feed_forward (:class:`ffl.FeedForwardLayer`): The feed-forward layer on top of the attention mechanisms. 123 | """ 124 | 125 | def __init__(self, parent: Decoder): 126 | """Creates a new instance of ``_DecoderLayer``. 127 | 128 | Args: 129 | parent (Decoder): The decoder that the layers is created for. 130 | """ 131 | super().__init__() 132 | self.attn_1 = mha.MultiHeadAttention( 133 | parent.num_heads, 134 | parent.dim_model, 135 | parent.dim_keys, 136 | parent.dim_values, 137 | parent.attention_dropout 138 | ) 139 | self.attn_2 = mha.MultiHeadAttention( 140 | parent.num_heads, 141 | parent.dim_model, 142 | parent.dim_keys, 143 | parent.dim_values, 144 | parent.attention_dropout 145 | ) 146 | self.feed_forward = ffl.FeedForwardLayer(parent.dim_model) 147 | self.norm = normalization.Normalization() 148 | self.dropout = nn.Dropout(parent.residual_dropout) 149 | 150 | # METHODS ######################################################################################################## 151 | 152 | def forward( 153 | self, 154 | in_sequence: torch.FloatTensor, 155 | out_sequence: torch.FloatTensor, 156 | padding_mask: torch.ByteTensor, 157 | shifted_output_mask: torch.ByteTensor 158 | ) -> torch.FloatTensor: 159 | """Runs the layer. 160 | 161 | Args: 162 | in_sequence (torch.FloatTensor): The input sequence as (batch-size x in-seq-len x dim-model)-tensor. 163 | out_sequence (torch.FloatTensor): The output sequence as (batch-size x out-seq-len x dim-model)-tensor. 164 | padding_mask (torch.ByteTensor): A padding mask as (batch-size x in-seq-len x in-seq-len)-tensor or 165 | ``None`` if no mask is used. 166 | shifted_output_mask (torch.ByteTensor): The shifted-output mask as 167 | (batch-size x out-seq-len x in-seq-len)-tensor. 168 | 169 | Returns: 170 | FloatTensor: The computed outputs as (batch-size x out-seq-len x dim-model)-tensor. 171 | """ 172 | # prepare mask for enc-dec attention 173 | if padding_mask is not None: 174 | if in_sequence.size(1) < out_sequence.size(1): 175 | padding_mask = padding_mask[:, :1, :].repeat(1, out_sequence.size(1), 1) 176 | elif in_sequence.size(1) > out_sequence.size(1): 177 | padding_mask = padding_mask[:, :out_sequence.size(1), :] 178 | 179 | # compute attention sub-layer 1 180 | out_sequence = self.norm( 181 | self.dropout( 182 | self.attn_1(out_sequence, out_sequence, out_sequence, mask=shifted_output_mask) 183 | ) + out_sequence 184 | ) 185 | 186 | # compute attention sub-layer 2 187 | out_sequence = self.norm( 188 | self.dropout( 189 | self.attn_2(out_sequence, in_sequence, in_sequence, mask=padding_mask) 190 | ) + out_sequence 191 | ) 192 | 193 | # compute feed-forward sub-layer 194 | out_sequence = self.norm(self.dropout(self.feed_forward(out_sequence)) + out_sequence) 195 | 196 | return out_sequence 197 | 198 | def reset_parameters(self) -> None: 199 | """Resets all trainable parameters of the module.""" 200 | self.attn_1.reset_parameters() 201 | self.attn_2.reset_parameters() 202 | self.feed_forward.reset_parameters() 203 | -------------------------------------------------------------------------------- /src/main/python/transformer/enc_dec_base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import numbers 5 | 6 | 7 | __author__ = "Patrick Hohenecker" 8 | __copyright__ = ( 9 | "Copyright (c) 2018, Patrick Hohenecker\n" 10 | "All rights reserved.\n" 11 | "\n" 12 | "Redistribution and use in source and binary forms, with or without\n" 13 | "modification, are permitted provided that the following conditions are met:\n" 14 | "\n" 15 | "1. Redistributions of source code must retain the above copyright notice, this\n" 16 | " list of conditions and the following disclaimer.\n" 17 | "2. Redistributions in binary form must reproduce the above copyright notice,\n" 18 | " this list of conditions and the following disclaimer in the documentation\n" 19 | " and/or other materials provided with the distribution.\n" 20 | "\n" 21 | "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n" 22 | "ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n" 23 | "WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n" 24 | "DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n" 25 | "ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n" 26 | "(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n" 27 | "LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n" 28 | "ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n" 29 | "(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n" 30 | "SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE." 31 | ) 32 | __license__ = "BSD-2-Clause" 33 | __version__ = "2018.1" 34 | __date__ = "Aug 21, 2018" 35 | __maintainer__ = "Patrick Hohenecker" 36 | __email__ = "mail@paho.at" 37 | __status__ = "Development" 38 | 39 | 40 | class EncDecBase(object): 41 | """A base class that implements common functionality of the encoder and decoder parts of the Transformer model.""" 42 | 43 | def __init__( 44 | self, 45 | num_layers: int, 46 | num_heads: int, 47 | dim_model: int, 48 | dim_keys: int, 49 | dim_values: int, 50 | residual_dropout: numbers.Real, 51 | attention_dropout: numbers.Real, 52 | pad_index: int 53 | ): 54 | """Creates a new instance of ``EncDecBase``. 55 | 56 | Args: 57 | num_layers (int): The number of to use. 58 | num_heads (int): The number of attention heads to use. 59 | dim_model (int): The dimension to use for all layers. This is called d_model, in the paper. 60 | dim_keys (int): The size of the keys provided to the attention mechanism. This is called d_k, in the paper. 61 | dim_values (int): The size of the values provided to the attention mechanism. This is called d_v, in the 62 | paper. 63 | residual_dropout (numbers.Real): The dropout probability for residual connections (before they are added to 64 | the the sublayer output). 65 | attention_dropout (numbers.Real): The dropout probability for values provided by the attention mechanism. 66 | pad_index (int): The index that indicates a padding token in the input sequence. 67 | """ 68 | super().__init__() 69 | 70 | # define attributes 71 | self._attention_dropout = None 72 | self._dim_keys = None 73 | self._dim_model = None 74 | self._dim_values = None 75 | self._num_heads = None 76 | self._num_layers = None 77 | self._pad_index = None 78 | self._residual_dropout = None 79 | 80 | # specify properties 81 | self.attention_dropout = attention_dropout 82 | self.dim_keys = dim_keys 83 | self.dim_model = dim_model 84 | self.dim_values = dim_values 85 | self.num_heads = num_heads 86 | self.num_layers = num_layers 87 | self.pad_index = pad_index 88 | self.residual_dropout = residual_dropout 89 | 90 | # PROPERTIES ##################################################################################################### 91 | 92 | @property 93 | def attention_dropout(self) -> float: 94 | """float: The dropout probability for residual connections (before they are added to the the sublayer output). 95 | """ 96 | return self._attention_dropout 97 | 98 | @attention_dropout.setter 99 | def attention_dropout(self, attention_dropout: numbers.Real): 100 | self._sanitize_probability("attention_dropout", attention_dropout) 101 | self._attention_dropout = float(attention_dropout) 102 | 103 | @property 104 | def dim_keys(self) -> int: 105 | """int: The size of the keys provided to the attention mechanism. 106 | 107 | This value is called d_k, in "Attention Is All You Need". 108 | """ 109 | return self._dim_keys 110 | 111 | @dim_keys.setter 112 | def dim_keys(self, dim_keys: int) -> None: 113 | self._sanitize_pos_int("dim_keys", dim_keys) 114 | self._dim_keys = dim_keys 115 | 116 | @property 117 | def dim_model(self) -> int: 118 | """int: The dimension to use for all layers. 119 | 120 | This value is called d_model, in "Attention Is All You Need". 121 | """ 122 | return self._dim_model 123 | 124 | @dim_model.setter 125 | def dim_model(self, dim_model: int) -> None: 126 | self._sanitize_pos_int("dim_model", dim_model) 127 | self._dim_model = dim_model 128 | 129 | @property 130 | def dim_values(self) -> int: 131 | """int: The size of the values provided to the attention mechanism. 132 | 133 | This value is called d_v, in "Attention Is All You Need". 134 | """ 135 | return self._dim_values 136 | 137 | @dim_values.setter 138 | def dim_values(self, dim_values: int) -> None: 139 | self._sanitize_pos_int("dim_values", dim_values) 140 | self._dim_values = dim_values 141 | 142 | @property 143 | def num_heads(self) -> int: 144 | """int: The number of attention heads used by the implemented module.""" 145 | return self._num_heads 146 | 147 | @num_heads.setter 148 | def num_heads(self, num_heads: int) -> None: 149 | self._sanitize_pos_int("num_heads", num_heads) 150 | self._num_heads = num_heads 151 | 152 | @property 153 | def num_layers(self) -> int: 154 | """int: The number of layers used by the implemented module.""" 155 | return self._num_layers 156 | 157 | @num_layers.setter 158 | def num_layers(self, num_layers: int) -> None: 159 | self._sanitize_pos_int("num_layers", num_layers) 160 | self._num_layers = num_layers 161 | 162 | @property 163 | def pad_index(self) -> int: 164 | """int: The index that indicates a padding token in the input sequence.""" 165 | return self._pad_index 166 | 167 | @pad_index.setter 168 | def pad_index(self, pad_index: int) -> None: 169 | if not isinstance(pad_index, int): 170 | raise TypeError(" has to be an integer!") 171 | if pad_index < 0: 172 | raise ValueError(" has to be non-negative!") 173 | self._pad_index = pad_index 174 | 175 | @property 176 | def residual_dropout(self) -> float: 177 | """float: The dropout probability for values provided by the attention mechanism.""" 178 | return self._residual_dropout 179 | 180 | @residual_dropout.setter 181 | def residual_dropout(self, residual_dropout: numbers.Real): 182 | self._sanitize_probability("residual_dropout", residual_dropout) 183 | self._residual_dropout = float(residual_dropout) 184 | 185 | # METHODS ######################################################################################################## 186 | 187 | @staticmethod 188 | def _sanitize_pos_int(arg_name: str, arg_value) -> None: 189 | """Ensures that the provided arg is a positive integer. 190 | 191 | Args: 192 | arg_name (str): The name of the arg being sanitized. 193 | arg_value: The value being sanitized. 194 | 195 | Raises: 196 | TypeError: If ``arg_value`` is not an ``int``. 197 | ValueError: If ``arg_value`` is not a positive number. 198 | """ 199 | if not isinstance(arg_value, int): 200 | raise TypeError("<{}> has to be an integer!".format(arg_name)) 201 | if arg_value < 1: 202 | raise ValueError("<{}> has to be > 0!".format(arg_name)) 203 | 204 | @staticmethod 205 | def _sanitize_probability(arg_name: str, arg_value): 206 | """Ensures that the provided arg is a probability. 207 | 208 | Args: 209 | arg_name (str): The name of the arg being sanitized. 210 | arg_value: The value being sanitized. 211 | 212 | Raises: 213 | TypeError: If ``arg_value`` is not a ``numbers.Real``. 214 | ValueError: If ``arg_value`` is not in [0, 1]. 215 | """ 216 | if not isinstance(arg_value, numbers.Real): 217 | raise TypeError("<{}> has to be a real number!".format(arg_name)) 218 | if arg_value < 0 or float(arg_value) > 1: 219 | raise ValueError("<{}> has to be in [0, 1]!".format(arg_name)) 220 | -------------------------------------------------------------------------------- /src/main/python/transformer/encoder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import torch 5 | 6 | from torch import nn 7 | 8 | from transformer import enc_dec_base 9 | from transformer import feed_forward_layer as ffl 10 | from transformer import multi_head_attention as mha 11 | from transformer import normalization 12 | 13 | 14 | __author__ = "Patrick Hohenecker" 15 | __copyright__ = ( 16 | "Copyright (c) 2018, Patrick Hohenecker\n" 17 | "All rights reserved.\n" 18 | "\n" 19 | "Redistribution and use in source and binary forms, with or without\n" 20 | "modification, are permitted provided that the following conditions are met:\n" 21 | "\n" 22 | "1. Redistributions of source code must retain the above copyright notice, this\n" 23 | " list of conditions and the following disclaimer.\n" 24 | "2. Redistributions in binary form must reproduce the above copyright notice,\n" 25 | " this list of conditions and the following disclaimer in the documentation\n" 26 | " and/or other materials provided with the distribution.\n" 27 | "\n" 28 | "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n" 29 | "ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n" 30 | "WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n" 31 | "DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n" 32 | "ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n" 33 | "(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n" 34 | "LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n" 35 | "ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n" 36 | "(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n" 37 | "SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE." 38 | ) 39 | __license__ = "BSD-2-Clause" 40 | __version__ = "2018.1" 41 | __date__ = "Aug 21, 2018" 42 | __maintainer__ = "Patrick Hohenecker" 43 | __email__ = "mail@paho.at" 44 | __status__ = "Development" 45 | 46 | 47 | # ==================================================================================================================== # 48 | # CLASS E N C O D E R # 49 | # ==================================================================================================================== # 50 | 51 | 52 | class Encoder(nn.Module, enc_dec_base.EncDecBase): 53 | """The encoder that is used in the Transformer model.""" 54 | 55 | def __init__(self, *args, **kwargs): 56 | nn.Module.__init__(self) 57 | enc_dec_base.EncDecBase.__init__(self, *args, **kwargs) 58 | 59 | self._layers = nn.ModuleList([_EncoderLayer(self) for _ in range(self._num_layers)]) 60 | 61 | # METHODS ######################################################################################################## 62 | 63 | def forward(self, sequence: torch.FloatTensor, padding_mask: torch.ByteTensor=None) -> torch.FloatTensor: 64 | """Runs the encoder. 65 | 66 | Args: 67 | sequence (torch.FloatTensor): The input sequence as (batch-size x seq-len x dim-model)-tensor. 68 | padding_mask (torch.ByteTensor, optional): Optionally, a padding mask as 69 | (batch-size x in-seq-len x in-seq-len)-tensor. To that end, ``1``s indicate those positions that are 70 | part of the according sequence, and ``0``s mark padding tokens. 71 | 72 | Returns: 73 | FloatTensor: The encoded sequence as (batch_size x seq_len x dim_model)-tensor. 74 | """ 75 | assert sequence.dim() == 3 76 | assert sequence.size(2) == self._dim_model 77 | 78 | # apply all layers to the input 79 | for layer in self._layers: 80 | sequence = layer(sequence, padding_mask) 81 | 82 | # provide the final sequence 83 | return sequence 84 | 85 | def reset_parameters(self) -> None: 86 | for l in self._layers: 87 | l.reset_parameters() 88 | 89 | 90 | # ==================================================================================================================== # 91 | # CLASS _ E N C O D E R L A Y E R # 92 | # ==================================================================================================================== # 93 | 94 | 95 | class _EncoderLayer(nn.Module): 96 | """One layer of the encoder. 97 | 98 | Attributes: 99 | attn: (:class:`mha.MultiHeadAttention`): The attention mechanism that is used to read the input sequence. 100 | feed_forward (:class:`ffl.FeedForwardLayer`): The feed-forward layer on top of the attention mechanism. 101 | """ 102 | 103 | def __init__(self, parent: Encoder): 104 | """Creates a new instance of ``_EncoderLayer``. 105 | 106 | Args: 107 | parent (Encoder): The encoder that the layers is created for. 108 | """ 109 | super().__init__() 110 | self.attn = mha.MultiHeadAttention( 111 | parent.num_heads, 112 | parent.dim_model, 113 | parent.dim_keys, 114 | parent.dim_values, 115 | parent.attention_dropout 116 | ) 117 | self.feed_forward = ffl.FeedForwardLayer(parent.dim_model) 118 | self.norm = normalization.Normalization() 119 | self.dropout = nn.Dropout(parent.residual_dropout) 120 | 121 | # METHODS ######################################################################################################## 122 | 123 | def forward(self, sequence: torch.FloatTensor, padding_mask: torch.ByteTensor) -> torch.FloatTensor: 124 | """Runs the layer. 125 | 126 | Args: 127 | sequence (torch.FloatTensor): The input sequence as (batch_size x seq_len x dim_model)-tensor. 128 | padding_mask (torch.ByteTensor): The padding mask as (batch_size x seq_len x seq_len)-tensor or ``None`` if 129 | no mask is used. 130 | 131 | Returns: 132 | torch.FloatTensor: The encoded sequence as (batch_size x seq_len x dim_model)-tensor. 133 | """ 134 | # compute attention sub-layer 135 | sequence = self.norm(self.dropout(self.attn(sequence, sequence, sequence, mask=padding_mask)) + sequence) 136 | 137 | # compute feed-forward sub-layer 138 | sequence = self.norm(self.dropout(self.feed_forward(sequence)) + sequence) 139 | 140 | return sequence 141 | 142 | def reset_parameters(self) -> None: 143 | """Resets all trainable parameters of the module.""" 144 | self.attn.reset_parameters() 145 | self.feed_forward.reset_parameters() 146 | -------------------------------------------------------------------------------- /src/main/python/transformer/feed_forward_layer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import torch 5 | 6 | from torch import nn 7 | from torch.nn import functional 8 | 9 | 10 | __author__ = "Patrick Hohenecker" 11 | __copyright__ = ( 12 | "Copyright (c) 2018, Patrick Hohenecker\n" 13 | "All rights reserved.\n" 14 | "\n" 15 | "Redistribution and use in source and binary forms, with or without\n" 16 | "modification, are permitted provided that the following conditions are met:\n" 17 | "\n" 18 | "1. Redistributions of source code must retain the above copyright notice, this\n" 19 | " list of conditions and the following disclaimer.\n" 20 | "2. Redistributions in binary form must reproduce the above copyright notice,\n" 21 | " this list of conditions and the following disclaimer in the documentation\n" 22 | " and/or other materials provided with the distribution.\n" 23 | "\n" 24 | "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n" 25 | "ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n" 26 | "WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n" 27 | "DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n" 28 | "ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n" 29 | "(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n" 30 | "LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n" 31 | "ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n" 32 | "(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n" 33 | "SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE." 34 | ) 35 | __license__ = "BSD-2-Clause" 36 | __version__ = "2018.1" 37 | __date__ = "Aug 22, 2018" 38 | __maintainer__ = "Patrick Hohenecker" 39 | __email__ = "mail@paho.at" 40 | __status__ = "Development" 41 | 42 | 43 | class FeedForwardLayer(nn.Module): 44 | """A sublayer that computes a 1-hidden-layer multi-layer perceptron for each token in a sequences.""" 45 | 46 | def __init__(self, dim_model: int): 47 | """Creates a new instance of ``FeedForwardLayer``. 48 | 49 | Args: 50 | dim_model (int): The dimension of all tokens in the input sequence. This is called d_model, in the paper. 51 | """ 52 | super().__init__() 53 | 54 | # sanitize args 55 | if not isinstance(dim_model, int): 56 | raise TypeError(" has to be an integer!") 57 | if dim_model < 1: 58 | raise ValueError(" has to be a positive number!") 59 | 60 | # store arg 61 | self._dim_model = dim_model 62 | 63 | # create layers 64 | self._layer_1 = nn.Conv1d(self._dim_model, self._dim_model, 1) 65 | self._layer_2 = nn.Conv1d(self._dim_model, self._dim_model, 1) 66 | 67 | # PROPERTIES ##################################################################################################### 68 | 69 | @property 70 | def dim_model(self) -> int: 71 | """int: The dimension of all tokens in the input sequence. 72 | 73 | This is called d_model, in the paper. 74 | """ 75 | return self._dim_model 76 | 77 | @property 78 | def layer_1(self) -> nn.Conv1d: 79 | """nn.Conv1d: The first linear layer (before the ReLU non-linearity is applied).""" 80 | return self._layer_1 81 | 82 | @property 83 | def layer_2(self) -> nn.Conv1d: 84 | """nn.Conv1d: The second linear layer.""" 85 | return self._layer_2 86 | 87 | # METHODS ######################################################################################################## 88 | 89 | def forward(self, sequence: torch.FloatTensor) -> torch.FloatTensor: 90 | """Runs the feed-forward layer. 91 | 92 | Args: 93 | sequence (torch.FloatTensor): The input sequence given as (batch_size x seq_len x dim_model)-tensor. 94 | 95 | Returns: 96 | torch.FloatTensor: The computed values as (batch_size x seq_len x dim_model)-tensor. 97 | """ 98 | assert sequence.dim() == 3 99 | assert sequence.size(2) == self._dim_model 100 | 101 | sequence = functional.relu(self._layer_1(sequence.transpose(1, 2))) 102 | sequence = self._layer_2(sequence).transpose(1, 2) 103 | 104 | return sequence 105 | 106 | def reset_parameters(self): 107 | """Resets all trainable parameters of the module.""" 108 | self._layer_1.reset_parameters() 109 | self._layer_2.reset_parameters() 110 | -------------------------------------------------------------------------------- /src/main/python/transformer/multi_head_attention.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import typing 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from torch import nn 10 | from torch.nn import init 11 | 12 | 13 | __author__ = "Patrick Hohenecker" 14 | __copyright__ = ( 15 | "Copyright (c) 2018, Patrick Hohenecker\n" 16 | "All rights reserved.\n" 17 | "\n" 18 | "Redistribution and use in source and binary forms, with or without\n" 19 | "modification, are permitted provided that the following conditions are met:\n" 20 | "\n" 21 | "1. Redistributions of source code must retain the above copyright notice, this\n" 22 | " list of conditions and the following disclaimer.\n" 23 | "2. Redistributions in binary form must reproduce the above copyright notice,\n" 24 | " this list of conditions and the following disclaimer in the documentation\n" 25 | " and/or other materials provided with the distribution.\n" 26 | "\n" 27 | "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n" 28 | "ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n" 29 | "WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n" 30 | "DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n" 31 | "ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n" 32 | "(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n" 33 | "LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n" 34 | "ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n" 35 | "(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n" 36 | "SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE." 37 | ) 38 | __license__ = "BSD-2-Clause" 39 | __version__ = "2018.1" 40 | __date__ = "Aug 21, 2018" 41 | __maintainer__ = "Patrick Hohenecker" 42 | __email__ = "mail@paho.at" 43 | __status__ = "Development" 44 | 45 | 46 | class MultiHeadAttention(nn.Module): 47 | """A multi-head scaled dot-product attention mechanism as it is used in *Attention Is All You Need*.""" 48 | 49 | def __init__(self, num_heads: int, dim_model: int, dim_keys: int, dim_values: int, dropout_rate: float): 50 | """Creates a new instance of ``MultiHeadAttention``. 51 | 52 | Notice: 53 | This constructor does not sanitize any parameters, which means that this has to be taken care of beforehand. 54 | 55 | Args: 56 | num_heads (int): The number of attention heads to use. 57 | dim_model (int): The dimension used for all layers in the model that the ``MultiHeadAttention`` belongs to. 58 | dim_keys (int): The target size to project keys to. 59 | dim_values (int): The target size to project values to. 60 | dropout_rate (float): The dropout probability to use. 61 | """ 62 | super().__init__() 63 | 64 | # store all of the provided args 65 | self.dim_keys = dim_keys 66 | self.dim_model = dim_model 67 | self.dim_values = dim_values 68 | self.dropout_rate = dropout_rate 69 | self.num_heads = num_heads 70 | 71 | # create projections for inputs 72 | self.query_projection = nn.Parameter(torch.empty(self.num_heads, self.dim_model, self.dim_keys)) 73 | self.key_projection = nn.Parameter(torch.empty(self.num_heads, self.dim_model, self.dim_keys)) 74 | self.value_projection = nn.Parameter(torch.empty(self.num_heads, self.dim_model, self.dim_values)) 75 | 76 | # create output projection 77 | self.output_projection = nn.Parameter(torch.empty(self.num_heads * self.dim_values, self.dim_model)) 78 | 79 | # create softmax and dropout layers 80 | self.dropout = nn.Dropout(self.dropout_rate) 81 | self.softmax = nn.Softmax(dim=3) 82 | 83 | # initialize all parameters 84 | self.reset_parameters() 85 | 86 | # METHODS ######################################################################################################## 87 | 88 | def _apply_attention( 89 | self, 90 | queries: torch.FloatTensor, 91 | keys: torch.FloatTensor, 92 | values: torch.FloatTensor, 93 | mask: typing.Optional[torch.ByteTensor] 94 | ) -> torch.Tensor: 95 | """The actual attention mechanism. 96 | 97 | Args: 98 | queries (torch.FloatTensor): The queries as (batch_size x num_heads x Q x dim_keys)-tensor. 99 | keys (torch.FloatTensor): The keys as (batch_size x num_heads x KV x dim_keys)-tensor. 100 | values (torch.FloatTensor): The values as (batch_size x num_heads x KV x dim_values)-tensor. 101 | mask (torch.ByteTensor): An optional binary mask that indicates which key-value pairs to consider for each 102 | of the queries. If provided, then this has to be a (batch_size x Q x KV)-tensor. 103 | 104 | Returns: 105 | torch.FloatTensor: The computed "attended" values as (batch_size x num_heads x Q x dim_values)-tensor. If 106 | the ``mask`` specifies that none of the key-value pairs shall be used for any of the queries, then the 107 | according attended value is set to ``0``. 108 | """ 109 | # compute inputs to the softmax 110 | attn = queries.matmul(keys.transpose(2, 3)) / np.sqrt(self.dim_keys) # compute (Q * K^T) / sqrt(d_k) 111 | # -> (batch_size x num_heads x Q x KV) 112 | 113 | # apply the mask (if provided) 114 | if mask is not None: 115 | 116 | # check whether the mask excludes all of the entries 117 | if mask.sum().item() == 0: 118 | return torch.zeros(queries.size()) 119 | 120 | # expand mask to cover all heads 121 | mask = mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1) 122 | 123 | # determine which token masks are all-0 124 | non_zero_parts = (mask.sum(dim=-1) != 0).unsqueeze(-1).expand(*mask.size()) 125 | 126 | # remove the all-0 parts from the original mask 127 | mask = 1 - (1 - mask) * non_zero_parts 128 | 129 | # apply mask 130 | attn.masked_fill_(1 - mask, -np.inf) 131 | 132 | # compute attention scores 133 | attn = self.softmax(attn) 134 | 135 | # apply all-0 parts of the masks 136 | attn = attn * non_zero_parts.float() 137 | else: 138 | # compute attention scores 139 | attn = self.softmax(attn) 140 | 141 | # apply dropout 142 | attn = self.dropout(attn) 143 | 144 | # compute attended value 145 | return attn.matmul(values) # -> (batch_size x num_heads x Q x dim_values) 146 | 147 | def _project_inputs( 148 | self, 149 | queries: torch.FloatTensor, 150 | keys: torch.FloatTensor, 151 | values: torch.FloatTensor 152 | ) -> typing.Tuple[ 153 | torch.Tensor, 154 | torch.Tensor, 155 | torch.Tensor 156 | ]: 157 | """Projects all inputs provided to the attention mechanism to the needed sizes. 158 | 159 | This means that queries and keys are projected from ``dim_model`` to ``dim_keys``, and values from ``dim_model`` 160 | to ``dim_values``. 161 | 162 | Args: 163 | queries (torch.FloatTensor): The queries as (batch_size x Q x dim_model)-tensor. 164 | keys (torch.FloatTensor): The keys as (batch_size x KV x dim_model)-tensor. 165 | values (torch.FloatTensor): The values as (batch_size x KV x dim_model)-tensor. 166 | 167 | Returns: 168 | tuple: A triple of ``FloatTensor``s, consisting of the projected queries, keys, and values. 169 | """ 170 | # for each of the attention heads, project inputs to the needed dimensions 171 | queries = queries.unsqueeze(1).matmul(self.query_projection) # -> (batch_size x num_heads x Q x dim_keys) 172 | keys = keys.unsqueeze(1).matmul(self.key_projection) # -> (batch_size x num_heads x KV x dim_keys) 173 | values = values.unsqueeze(1).matmul(self.value_projection) # -> (batch_size x num_heads x KV x dim_values) 174 | 175 | return queries, keys, values 176 | 177 | def _project_output(self, attn_values: torch.FloatTensor) -> torch.FloatTensor: 178 | """Projects the "attended" values of all heads to the required output size. 179 | 180 | Args: 181 | attn_values (torch.FloatTensor): The attended values as (batch_size x num_heads x Q x dim_values)-tensor. 182 | 183 | Returns: 184 | torch.FloatTensor: The computed output as (batch_size x Q x dim_model)-tensor. 185 | """ 186 | # concatenate the values retrieved from all heads 187 | batch_size = attn_values.size(0) 188 | num_queries = attn_values.size(2) 189 | attn_values = attn_values.transpose(1, 2).reshape(batch_size, num_queries, -1) 190 | # -> (batch_size x Q x (num_heads * dim_values)) 191 | 192 | return attn_values.matmul(self.output_projection) # -> (batch-size x Q x dim_model) 193 | 194 | def forward( 195 | self, 196 | queries: torch.FloatTensor, 197 | keys: torch.FloatTensor, 198 | values: torch.FloatTensor, 199 | mask: torch.ByteTensor=None 200 | ) -> torch.Tensor: 201 | """Runs the attention mechanism. 202 | 203 | Args: 204 | queries (torch.FloatTensor): The queries as (batch_size x Q x dim_model)-tensor. 205 | keys (torch.FloatTensor): The keys as (batch_size x KV x dim_model)-tensor. 206 | values (torch.FloatTensor): The values as (batch_size x KV x dim_model)-tensor. 207 | mask (torch.ByteTensor, optional): An optional binary mask that indicates which key-value pairs to consider 208 | for each of the queries. If provided, then this has to be a (batch_size x Q x KV)-tensor. 209 | 210 | Returns: 211 | torch.FloatTensor: The values computed by the attention mechanism as (batch_size x Q x dim_model)-tensor. 212 | """ 213 | assert isinstance(queries, torch.FloatTensor) or isinstance(queries, torch.cuda.FloatTensor) 214 | assert isinstance(keys, torch.FloatTensor) or isinstance(keys, torch.cuda.FloatTensor) 215 | assert isinstance(values, torch.FloatTensor) or isinstance(values, torch.cuda.FloatTensor) 216 | assert queries.dim() == 3 217 | assert keys.dim() == 3 218 | assert values.dim() == 3 219 | assert queries.size(0) == keys.size(0) 220 | assert queries.size(0) == values.size(0) 221 | assert queries.size(2) == keys.size(2) 222 | assert keys.size(1) == values.size(1) 223 | if mask is not None: 224 | assert isinstance(mask, torch.ByteTensor) or isinstance(mask, torch.cuda.ByteTensor) 225 | assert mask.dim() == 3 226 | assert queries.size(0) == mask.size(0) 227 | assert queries.size(1) == mask.size(1) 228 | assert keys.size(1) == mask.size(2) 229 | 230 | # for each of the attention heads, project inputs to the needed dimensions 231 | queries, keys, values = self._project_inputs(queries, keys, values) 232 | 233 | # compute attention value 234 | attn_values = self._apply_attention(queries, keys, values, mask) 235 | 236 | # project retrieved values to needed dimensions 237 | return self._project_output(attn_values) 238 | 239 | def reset_parameters(self): 240 | """Resets all trainable parameters of the module.""" 241 | init.xavier_normal_(self.query_projection) 242 | init.xavier_normal_(self.key_projection) 243 | init.xavier_normal_(self.value_projection) 244 | init.xavier_normal_(self.output_projection) 245 | -------------------------------------------------------------------------------- /src/main/python/transformer/normalization.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import numbers 5 | 6 | import torch 7 | 8 | from torch import nn 9 | 10 | 11 | __author__ = "Patrick Hohenecker" 12 | __copyright__ = ( 13 | "Copyright (c) 2018, Patrick Hohenecker\n" 14 | "All rights reserved.\n" 15 | "\n" 16 | "Redistribution and use in source and binary forms, with or without\n" 17 | "modification, are permitted provided that the following conditions are met:\n" 18 | "\n" 19 | "1. Redistributions of source code must retain the above copyright notice, this\n" 20 | " list of conditions and the following disclaimer.\n" 21 | "2. Redistributions in binary form must reproduce the above copyright notice,\n" 22 | " this list of conditions and the following disclaimer in the documentation\n" 23 | " and/or other materials provided with the distribution.\n" 24 | "\n" 25 | "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n" 26 | "ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n" 27 | "WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n" 28 | "DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n" 29 | "ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n" 30 | "(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n" 31 | "LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n" 32 | "ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n" 33 | "(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n" 34 | "SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE." 35 | ) 36 | __license__ = "BSD-2-Clause" 37 | __version__ = "2018.1" 38 | __date__ = "Aug 21, 2018" 39 | __maintainer__ = "Patrick Hohenecker" 40 | __email__ = "mail@paho.at" 41 | __status__ = "Development" 42 | 43 | 44 | class Normalization(nn.Module): 45 | """A normalization layer.""" 46 | 47 | def __init__(self, eps: numbers.Real=1e-15): 48 | """Creates a new instance of ``Normalization``. 49 | 50 | Args: 51 | eps (numbers.Real, optional): A tiny number to be added to the standard deviation before re-scaling the 52 | centered values. This prevents divide-by-0 errors. By default, this is set to ``1e-15``. 53 | """ 54 | super().__init__() 55 | 56 | self._eps = None 57 | self.eps = float(eps) 58 | 59 | # PROPERTIES ##################################################################################################### 60 | 61 | @property 62 | def eps(self) -> float: 63 | """float: A tiny number that is added to the standard deviation before re-scaling the centered values. 64 | 65 | This prevents divide-by-0 errors. By default, this is set to ``1e-15``. 66 | """ 67 | return self._eps 68 | 69 | @eps.setter 70 | def eps(self, eps: numbers.Real) -> None: 71 | if not isinstance(eps, numbers.Real): 72 | raise TypeError(" has to be a real number!") 73 | self._eps = float(eps) 74 | 75 | # METHODS ######################################################################################################## 76 | 77 | def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: 78 | """Runs the normalization layer. 79 | 80 | Args: 81 | x (torch.FloatTensor): A tensor to be normalized. To that end, ``x`` is interpreted as a batch of values 82 | where normalization is applied over the last of its dimensions. 83 | 84 | Returns: 85 | torch.FloatTensor: The normalized tensor. 86 | """ 87 | mean = torch.mean(x, dim=-1, keepdim=True) 88 | std = torch.std(x, dim=-1, keepdim=True) 89 | 90 | return (x - mean) / (std + self._eps) 91 | -------------------------------------------------------------------------------- /src/main/python/transformer/transformer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import numbers 5 | 6 | import torch 7 | 8 | from torch import nn 9 | from torch.nn import functional 10 | 11 | from transformer import decoder as dec 12 | from transformer import encoder as enc 13 | from transformer import util 14 | 15 | 16 | __author__ = "Patrick Hohenecker" 17 | __copyright__ = ( 18 | "Copyright (c) 2018, Patrick Hohenecker\n" 19 | "All rights reserved.\n" 20 | "\n" 21 | "Redistribution and use in source and binary forms, with or without\n" 22 | "modification, are permitted provided that the following conditions are met:\n" 23 | "\n" 24 | "1. Redistributions of source code must retain the above copyright notice, this\n" 25 | " list of conditions and the following disclaimer.\n" 26 | "2. Redistributions in binary form must reproduce the above copyright notice,\n" 27 | " this list of conditions and the following disclaimer in the documentation\n" 28 | " and/or other materials provided with the distribution.\n" 29 | "\n" 30 | "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n" 31 | "ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n" 32 | "WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n" 33 | "DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n" 34 | "ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n" 35 | "(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n" 36 | "LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n" 37 | "ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n" 38 | "(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n" 39 | "SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE." 40 | ) 41 | __license__ = "BSD-2-Clause" 42 | __version__ = "2018.1" 43 | __date__ = "Aug 21, 2018" 44 | __maintainer__ = "Patrick Hohenecker" 45 | __email__ = "mail@paho.at" 46 | __status__ = "Development" 47 | 48 | 49 | class Transformer(nn.Module): 50 | """The Transformer model that was introduced in *Attention Is All You Need*.""" 51 | 52 | def __init__( 53 | self, 54 | word_emb: nn.Embedding, 55 | pad_index: int, 56 | output_size: int, 57 | positional_emb: nn.Embedding=None, 58 | max_seq_len: int=None, 59 | num_layers: int=6, 60 | num_heads: int=8, 61 | dim_model: int=512, 62 | dim_keys: int=64, 63 | dim_values: int=64, 64 | residual_dropout: numbers.Real=0.1, 65 | attention_dropout: numbers.Real=0.1 66 | ): 67 | """Creates a new instance of ``Transformer``. 68 | 69 | Args: 70 | word_emb (nn.Embedding): The word embeddings to use. 71 | pad_index (int): The index that indicates that a token in an input sequence is just padding. 72 | output_size (int): The size, i.e., the number of dimensions, of the output to provide. 73 | positional_emb (nn.Embedding, optional): The positional embeddings to use. 74 | max_seq_len (int, optional): The maximum length of any input or output sequences. This is used to generate 75 | positional embeddings, if ``positional_emb`` is not provided. 76 | num_layers (int): The number of to use. 77 | num_heads (int): The number of attention heads to use. 78 | dim_model (int): The dimension to use for all layers. This is called d_model, in the paper. 79 | dim_keys (int): The size of the keys provided to the attention mechanism. This is called d_k, in the paper. 80 | dim_values (int): The size of the values provided to the attention mechanism. This is called d_v, in the 81 | paper. 82 | residual_dropout (numbers.Real): The dropout probability for residual connections (before they are added to 83 | the the sublayer output). 84 | attention_dropout (numbers.Real): The dropout probability for values provided by the attention mechanism. 85 | """ 86 | super().__init__() 87 | 88 | # sanitize args 89 | if not isinstance(word_emb, nn.Embedding): 90 | raise TypeError(" has to be an instance of torch.nn.Embedding!") 91 | if not isinstance(output_size, int): 92 | raise TypeError("The has to be an integer!") 93 | if output_size < 1: 94 | raise ValueError("The has to be a positive number!") 95 | if positional_emb is not None: 96 | if not isinstance(positional_emb, nn.Embedding): 97 | raise TypeError(" has to be an instance of torch.nn.Embedding!") 98 | if word_emb.embedding_dim != positional_emb.embedding_dim: 99 | raise ValueError(" and have to use the same embedding size!") 100 | if max_seq_len is not None: 101 | if not isinstance(max_seq_len, int): 102 | raise TypeError("The has to be an integer!") 103 | if max_seq_len < 1: 104 | raise ValueError(" has to be a positive number!") 105 | elif positional_emb is not None and max_seq_len > positional_emb.num_embeddings: 106 | raise ValueError(" cannot be greater than the number of embeddings in !") 107 | elif positional_emb is None: 108 | raise ValueError("At least one of the args and has to be provided!") 109 | 110 | # store output_size and pad_index 111 | self._output_size = output_size 112 | self._pad_index = pad_index 113 | self._word_emb = word_emb 114 | 115 | # create encoder and decoder 116 | # (these are created first, because they sanitize all of the provided args) 117 | self._encoder = enc.Encoder( 118 | num_layers, 119 | num_heads, 120 | dim_model, 121 | dim_keys, 122 | dim_values, 123 | residual_dropout, 124 | attention_dropout, 125 | pad_index 126 | ) 127 | self._decoder = dec.Decoder( 128 | num_layers, 129 | num_heads, 130 | dim_model, 131 | dim_keys, 132 | dim_values, 133 | residual_dropout, 134 | attention_dropout, 135 | pad_index 136 | ) 137 | 138 | # store embeddings 139 | if positional_emb is None: 140 | self._positional_emb = util.create_positional_emb(max_seq_len, word_emb.embedding_dim, dim_model) 141 | else: 142 | self._positional_emb = positional_emb 143 | 144 | # figure out the maximum sequence length 145 | self._max_seq_len = self._positional_emb.num_embeddings 146 | 147 | # create linear projections for input (word embeddings) and output 148 | self._input_projection = nn.Linear(self._word_emb.embedding_dim, dim_model) 149 | self._output_projection = nn.Linear(dim_model, self._output_size) 150 | 151 | # PROPERTIES ##################################################################################################### 152 | 153 | @property 154 | def decoder(self) -> dec.Decoder: 155 | """:class:`dec.Decoder`: The decoder part of the Transformer.""" 156 | return self._decoder 157 | 158 | @property 159 | def embedding_dim(self) -> int: 160 | """int: The used embedding size.""" 161 | return self._word_emb.embedding_dim 162 | 163 | @property 164 | def encoder(self) -> enc.Encoder: 165 | """:class:`enc.Encoder`: The encoder part of the Transformer.""" 166 | return self._encoder 167 | 168 | @property 169 | def input_projection(self) -> nn.Linear: 170 | """nn.Linear: The linear projection between input and encoder.""" 171 | return self._input_projection 172 | 173 | @property 174 | def max_seq_len(self) -> int: 175 | """int: The maximum length that any input sequence may have.""" 176 | return self._max_seq_len 177 | 178 | @property 179 | def output_projection(self) -> nn.Linear: 180 | """nn.Linear: The linear projection between decoder and output.""" 181 | return self._output_projection 182 | 183 | @property 184 | def output_size(self) -> int: 185 | """int: The size of the output provided by the ``Transformer``.""" 186 | return self._output_size 187 | 188 | @property 189 | def pad_index(self) -> int: 190 | """int: The index that indicates that a token in an input sequence is just padding.""" 191 | return self._pad_index 192 | 193 | @property 194 | def positional_emb(self): 195 | """nn.Embedding: The used positional embeddings.""" 196 | return self._positional_emb 197 | 198 | @property 199 | def word_emb(self) -> nn.Embedding: 200 | """nn.Embedding: The used word embeddings.""" 201 | return self._word_emb 202 | 203 | # METHODS ######################################################################################################## 204 | 205 | def forward(self, input_seq: torch.LongTensor, target: torch.LongTensor) -> torch.FloatTensor: 206 | """Runs the Transformer. 207 | 208 | The Transformer expects both an input as well as a target sequence to be provided, and yields a probability 209 | distribution over all possible output tokens for each position in the target sequence. 210 | 211 | Args: 212 | input_seq (torch.LongTensor): The input sequence as (batch-size x input-seq-len)-tensor. 213 | target (torch.LongTensor): The target sequence as (batch-size x target-seq-len)-tensor. 214 | 215 | Returns: 216 | torch.FloatTensor: The computed probabilities for each position in ``target`` as a 217 | (batch-size x target-seq-len x output-size)-tensor. 218 | """ 219 | # sanitize args 220 | if not isinstance(input_seq, torch.LongTensor) and not isinstance(input_seq, torch.cuda.LongTensor): 221 | raise TypeError(" has to be a LongTensor!") 222 | if input_seq.dim() != 2: 223 | raise ValueError(" has to have 2 dimensions!") 224 | if not isinstance(target, torch.LongTensor) and not isinstance(target, torch.cuda.LongTensor): 225 | raise TypeError(" has to be a LongTensor!") 226 | if target.dim() != 2: 227 | raise ValueError(" has to have 2 dimensions!") 228 | 229 | # create a tensor of indices, which is used to retrieve the according positional embeddings below 230 | index_seq = input_seq.new(range(input_seq.size(1))).unsqueeze(0).expand(input_seq.size(0), -1) 231 | 232 | # create padding mask for input 233 | padding_mask = util.create_padding_mask(input_seq, self._pad_index) 234 | 235 | # embed the provided input 236 | input_seq = self._word_emb(input_seq) + self._positional_emb(index_seq) 237 | 238 | # project input to the needed size 239 | input_seq = self._input_projection(input_seq) 240 | 241 | # run the encoder 242 | input_seq = self._encoder(input_seq, padding_mask=padding_mask) 243 | 244 | # create a tensor of indices, which is used to retrieve the positional embeddings for the targets below 245 | index_seq = target.new(range(target.size(1))).unsqueeze(0).expand(target.size(0), -1) 246 | 247 | # embed the provided targets 248 | target = self._word_emb(target) + self._positional_emb(index_seq) 249 | 250 | # project target to the needed size 251 | target = self._input_projection(target) 252 | 253 | # run the decoder 254 | output = self._decoder(input_seq, target, padding_mask=padding_mask) 255 | 256 | # project output to the needed size 257 | output = self._output_projection(output) 258 | 259 | # compute softmax 260 | return functional.softmax(output, dim=2) 261 | 262 | def reset_parameters(self) -> None: 263 | """Resets all trainable parameters of the module.""" 264 | self._encoder.reset_parameters() 265 | self._decoder.reset_parameters() 266 | self._input_projection.reset_parameters() 267 | self._output_projection.reset_parameters() 268 | -------------------------------------------------------------------------------- /src/main/python/transformer/transformer_tools.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """This module contains utility functions for working with the Transformer model.""" 4 | 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from transformer import transformer 10 | from transformer import util 11 | 12 | 13 | __author__ = "Patrick Hohenecker" 14 | __copyright__ = ( 15 | "Copyright (c) 2018, Patrick Hohenecker\n" 16 | "All rights reserved.\n" 17 | "\n" 18 | "Redistribution and use in source and binary forms, with or without\n" 19 | "modification, are permitted provided that the following conditions are met:\n" 20 | "\n" 21 | "1. Redistributions of source code must retain the above copyright notice, this\n" 22 | " list of conditions and the following disclaimer.\n" 23 | "2. Redistributions in binary form must reproduce the above copyright notice,\n" 24 | " this list of conditions and the following disclaimer in the documentation\n" 25 | " and/or other materials provided with the distribution.\n" 26 | "\n" 27 | "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n" 28 | "ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n" 29 | "WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n" 30 | "DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n" 31 | "ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n" 32 | "(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n" 33 | "LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n" 34 | "ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n" 35 | "(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n" 36 | "SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE." 37 | ) 38 | __license__ = "BSD-2-Clause" 39 | __version__ = "2018.1" 40 | __date__ = "Aug 30, 2018" 41 | __maintainer__ = "Patrick Hohenecker" 42 | __email__ = "mail@paho.at" 43 | __status__ = "Development" 44 | 45 | 46 | def eval_probability( 47 | model: transformer.Transformer, 48 | input_seq: torch.LongTensor, 49 | target_seq: torch.LongTensor, 50 | pad_index: int=None 51 | ) -> torch.FloatTensor: 52 | """Computes the probability that the provided model computes a target sequence given an input sequence. 53 | 54 | Args: 55 | model (:class:`transformer.Transformer`): The model to use. 56 | input_seq (torch.LongTensor): The input sequence to be provided to the model. This has to be a 57 | (batch-size x input-seq-len)-tensor. 58 | target_seq (torch.LongTensor): The target sequence whose probability is being evaluated. This has to be a 59 | (batch-size x target-seq-len)-tensor. 60 | pad_index (int, optional): The index that indicates a padding token in a sequence. If ``target_seq`` is padded, 61 | then the ``pad_index`` has to be provided in order to allow for computing the probabilities for relevant 62 | parts of the target sequence only. 63 | 64 | Returns: 65 | torch.FloatTensor: A 1D-tensor of size (batch-size), which contains one probability for each sample in 66 | ``input_seq`` and ``target_seq``, respectively. 67 | """ 68 | if not isinstance(model, transformer.Transformer): 69 | raise TypeError("The has to be a transformer.Transformer!") 70 | if not isinstance(input_seq, torch.LongTensor) and not isinstance(input_seq, torch.cuda.LongTensor): 71 | raise TypeError("The has to be a LongTensor!") 72 | if input_seq.dim() != 2: 73 | raise ValueError(" has to be a 2D-tensor!") 74 | if input_seq.is_cuda: 75 | if not isinstance(target_seq, torch.cuda.LongTensor): 76 | raise TypeError("The has to be of the same type as , i.e., cuda.LongTensor!") 77 | elif not isinstance(target_seq, torch.LongTensor): 78 | raise TypeError("The has to be of the same type as , i.e., LongTensor!") 79 | if target_seq.dim() != 2: 80 | raise ValueError(" has to be a 2D-tensor!") 81 | if input_seq.size(0) != target_seq.size(0): 82 | raise ValueError(" and use different batch sizes!") 83 | if pad_index is not None and not isinstance(pad_index, int): 84 | raise TypeError("The , if provided, has to be an integer!") 85 | 86 | batch_size = input_seq.size(0) 87 | max_seq_len = input_seq.size(1) 88 | 89 | # put model in evaluation mode 90 | original_mode = model.training # store original mode (train/eval) to be restored eventually 91 | model.eval() 92 | 93 | # run the model to compute the needed probabilities 94 | predictions = model(input_seq, target_seq) 95 | 96 | # determine the lengths of the target sequences 97 | if pad_index is not None: 98 | mask = util.create_padding_mask(target_seq, pad_index)[:, 0, :] 99 | seq_len = mask.sum(dim=1).cpu().numpy().tolist() 100 | else: 101 | seq_len = (np.ones(batch_size, dtype=np.long) * max_seq_len).tolist() 102 | 103 | # compute the probabilities for each of the provided samples 104 | sample_probs = torch.ones(batch_size) 105 | for sample_idx in range(batch_size): # iterate over each sample 106 | for token_idx in range(seq_len[sample_idx]): # iterate over each position in the output sequence 107 | sample_probs[sample_idx] *= predictions[sample_idx, token_idx, target_seq[sample_idx, token_idx]].item() 108 | 109 | # restore original mode of the model 110 | model.train(mode=original_mode) 111 | 112 | return sample_probs 113 | 114 | 115 | def sample_output( 116 | model: transformer.Transformer, 117 | input_seq: torch.LongTensor, 118 | eos_index: int, 119 | pad_index: int, 120 | max_len: int 121 | ) -> torch.LongTensor: 122 | """Samples an output sequence based on the provided input. 123 | 124 | Args: 125 | model (:class:`transformer.Transformer`): The model to use. 126 | input_seq (torch.LongTensor): The input sequence to be provided to the model. This has to be a 127 | (batch-size x input-seq-len)-tensor. 128 | eos_index (int): The index that indicates the end of a sequence. 129 | pad_index (int): The index that indicates a padding token in a sequence. 130 | max_len (int): The maximum length of the generated output. 131 | 132 | Returns: 133 | torch.LongTensor: The generated output sequence as (batch-size x output-seq-len)-tensor. 134 | """ 135 | # sanitize args 136 | if not isinstance(model, transformer.Transformer): 137 | raise TypeError("The has to be a transformer.Transformer!") 138 | if not isinstance(input_seq, torch.LongTensor) and not isinstance(input_seq, torch.cuda.LongTensor): 139 | raise TypeError("The has to be a LongTensor!") 140 | if input_seq.dim() != 2: 141 | raise ValueError(" has to be a matrix!") 142 | if not isinstance(eos_index, int): 143 | raise TypeError("The has to be an integer!") 144 | if eos_index < 0 or eos_index >= model.output_size: 145 | raise ValueError("The is not a legal index in the vocabulary used by !") 146 | if not isinstance(pad_index, int): 147 | raise TypeError("The has to be an integer!") 148 | if pad_index < 0 or pad_index >= model.output_size: 149 | raise ValueError("The is not a legal index in the vocabulary used by !") 150 | if max_len is not None: 151 | if not isinstance(max_len, int): 152 | raise TypeError(" has to be an integer!") 153 | if max_len < 1: 154 | raise ValueError(" has to be > 0!") 155 | 156 | original_mode = model.training # the original mode (train/eval) of the provided model 157 | batch_size = input_seq.size(0) # number of samples in the provided input sequence 158 | 159 | # put model in evaluation mode 160 | model.eval() 161 | 162 | output_seq = [] # used to store the generated outputs for each position 163 | finished = [False] * batch_size 164 | 165 | for _ in range(max_len): 166 | 167 | # prepare the target to provide to the model 168 | # this is the current output with an additional final entry that is supposed to be predicted next 169 | # (which is why the concrete value does not matter) 170 | current_target = torch.cat(output_seq + [input_seq.new(batch_size, 1).zero_()], dim=1) 171 | 172 | # run the model 173 | probs = model(input_seq, current_target)[:, -1, :] 174 | 175 | # sample next output form the computed probabilities 176 | output = torch.multinomial(probs, 1) 177 | 178 | # determine which samples have been finished, and replace sampled output with padding for those that are already 179 | for sample_idx in range(batch_size): 180 | if finished[sample_idx]: 181 | output[sample_idx, 0] = pad_index 182 | elif output[sample_idx, 0].item() == eos_index: 183 | finished[sample_idx] = True 184 | 185 | # store created output 186 | output_seq.append(output) 187 | 188 | # check whether generation has been finished 189 | if all(finished): 190 | break 191 | 192 | # restore original mode of the model 193 | model.train(mode=original_mode) 194 | 195 | return torch.cat(output_seq, dim=1) 196 | -------------------------------------------------------------------------------- /src/main/python/transformer/util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """This module provides various utility functions.""" 4 | 5 | 6 | import itertools 7 | import numbers 8 | 9 | import numpy as np 10 | import torch 11 | 12 | from torch import nn 13 | 14 | 15 | __author__ = "Patrick Hohenecker" 16 | __copyright__ = ( 17 | "Copyright (c) 2018, Patrick Hohenecker\n" 18 | "All rights reserved.\n" 19 | "\n" 20 | "Redistribution and use in source and binary forms, with or without\n" 21 | "modification, are permitted provided that the following conditions are met:\n" 22 | "\n" 23 | "1. Redistributions of source code must retain the above copyright notice, this\n" 24 | " list of conditions and the following disclaimer.\n" 25 | "2. Redistributions in binary form must reproduce the above copyright notice,\n" 26 | " this list of conditions and the following disclaimer in the documentation\n" 27 | " and/or other materials provided with the distribution.\n" 28 | "\n" 29 | "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n" 30 | "ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n" 31 | "WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n" 32 | "DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n" 33 | "ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n" 34 | "(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n" 35 | "LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n" 36 | "ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n" 37 | "(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n" 38 | "SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE." 39 | ) 40 | __license__ = "BSD-2-Clause" 41 | __version__ = "2018.1" 42 | __date__ = "Aug 29, 2018" 43 | __maintainer__ = "Patrick Hohenecker" 44 | __email__ = "mail@paho.at" 45 | __status__ = "Development" 46 | 47 | 48 | def create_padding_mask(seq: torch.LongTensor, pad_index: int) -> torch.ByteTensor: 49 | """Creates a mask for the provided sequence that indicates which of the tokens are actual data and which are just 50 | padding. 51 | 52 | Args: 53 | seq (torch.LongTensor): The input sequences that the padding mask is created for. ``seq`` has to be a 54 | ``LongTensor`` of shape (batch-size x seq-len). 55 | pad_index (int): The index that indicates a padding token. 56 | 57 | Returns: 58 | torch.ByteTensor: A binary mask where ``1``s represent tokens that belong to the actual sequences and ``0``s 59 | indicate padding. The provided mask has the shape (batch-len x seq-len x seq-len). 60 | """ 61 | # sanitize args 62 | if not isinstance(seq, torch.LongTensor) and not isinstance(seq, torch.cuda.LongTensor): 63 | raise TypeError(" has to be a LongTensor!") 64 | if seq.dim() != 2: 65 | raise ValueError(" has to be a 2-dimensional tensor!") 66 | if not isinstance(pad_index, int): 67 | raise TypeError(" has to be an int!") 68 | 69 | seq_len = seq.size(1) 70 | 71 | return (seq != pad_index).unsqueeze(1).expand(-1, seq_len, -1) 72 | 73 | 74 | def create_positional_emb(max_seq_len: int, embedding_size: int, dim_model: int) -> nn.Embedding: 75 | """Creates positional embeddings. 76 | 77 | Args: 78 | max_seq_len (int): The maximum length of any input sequence, which corresponds with the total number of 79 | embedding vectors needed. 80 | embedding_size (int): The size of the embeddings to create. 81 | dim_model (int): The default layer size used in the model. 82 | 83 | Returns: 84 | nn.Embedding: The created positional embeddings. 85 | """ 86 | emb_matrix = ( 87 | [ 88 | np.sin(np.array(range(max_seq_len), dtype=np.float32) / (10000 ** (i / dim_model))), 89 | np.cos(np.array(range(max_seq_len), dtype=np.float32) / (10000 ** (i / dim_model))) 90 | ] 91 | for i in range(0, embedding_size, 2) 92 | ) 93 | emb_matrix = np.stack(itertools.chain(*emb_matrix)).T 94 | 95 | # if max_seq_len is an odd number, than the last entry of the embedding matrix has to be removed again 96 | if emb_matrix.shape[0] > max_seq_len: 97 | emb_matrix = emb_matrix[:-1] 98 | 99 | return nn.Embedding.from_pretrained(torch.from_numpy(emb_matrix)) 100 | 101 | 102 | def create_shifted_output_mask(seq: torch.Tensor) -> torch.ByteTensor: 103 | """Creates a mask that prevents the decoder to attend future outputs. 104 | 105 | For each sample in the provided batch, the created mask is a square matrix that contains one row for every 106 | position in the output sequence. Each of these rows indicates those parts of the sequence that may be considered in 107 | order to compute the respective output, i.e., those output values that have been computed earlier. 108 | 109 | Args: 110 | seq (torch.Tensor): The output sequence that the padding is mask is created for. ``seq`` has to be a tensor of 111 | shape (batch-size x seq-len x ...), i.e., it has to have at least two dimensions. 112 | 113 | Returns: 114 | torch.ByteTensor: A binary mask where ``1``s represent tokens that should be considered for the respective 115 | position and ``0``s indicate future outputs. The provided mask has shape (batch-size x seq-len x seq-len). 116 | """ 117 | # sanitize args 118 | if not isinstance(seq, torch.Tensor): 119 | raise TypeError(" has to be a Tensor!") 120 | if seq.dim() < 2: 121 | raise ValueError(" has to be at least a 2-dimensional tensor!") 122 | 123 | batch_size = seq.size(0) 124 | seq_len = seq.size(1) 125 | 126 | # create a mask for one sample 127 | mask = 1 - seq.new(seq_len, seq_len).fill_(1).triu(diagonal=1).byte() 128 | 129 | # copy the mask for all samples in the batch 130 | mask = mask.unsqueeze(0).expand(batch_size, -1, -1) 131 | 132 | return mask 133 | 134 | 135 | def shift_output_sequence(seq: torch.Tensor, zero_range: numbers.Real=1e-22) -> torch.Tensor: 136 | """Shifts the provided output sequence one position to the right. 137 | 138 | To shift the sequence, this function truncates the last element of and prepends a zero-entry to every sample of 139 | the provided batch. However, to prevent ``nan`` values in the gradients of tensors created by means of 140 | ``torch.std``, the prepended tensors are not actually set to 0, but sampled uniformly from a tiny interval around 0, 141 | which may be adjusted via the arg ``zero_range``. 142 | 143 | Args: 144 | seq (torch.Tensor): The sequence to shift as (batch-size x seq-length x dim-model)-tensor. 145 | zero_range (numbers.Real, optional): Specifies the range to sample zero-entries from as closed interval 146 | [``zero_range``, ``-zero_range``]. 147 | 148 | Returns: 149 | torch.Tensor: The shifted sequence, which, just like ``seq``, is a (batch-size x seq-length x dim-model)-tensor. 150 | """ 151 | # sanitize args 152 | if not isinstance(seq, torch.Tensor): 153 | raise TypeError(" has to be a tensor!") 154 | if seq.dim() != 3: 155 | raise ValueError("Expected to be 3D, but {} dimensions were encountered!".format(seq.dim())) 156 | if not isinstance(zero_range, numbers.Real): 157 | raise TypeError("The has to be a real number!") 158 | zero_range = float(zero_range) 159 | if zero_range <= 0: 160 | raise ValueError("The has to be a positive number!") 161 | 162 | return torch.cat( 163 | [ 164 | seq.new(seq.size(0), 1, seq.size(2)).uniform_(-zero_range, zero_range), 165 | seq[:, :-1, :] 166 | ], 167 | dim=1 168 | ) 169 | -------------------------------------------------------------------------------- /src/test/python/transformer_test/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """The unit tests for the package :mod:`transformer`.""" 4 | 5 | 6 | __author__ = "Patrick Hohenecker" 7 | __copyright__ = ( 8 | "Copyright (c) 2018, Patrick Hohenecker\n" 9 | "All rights reserved.\n" 10 | "\n" 11 | "Redistribution and use in source and binary forms, with or without\n" 12 | "modification, are permitted provided that the following conditions are met:\n" 13 | "\n" 14 | "1. Redistributions of source code must retain the above copyright notice, this\n" 15 | " list of conditions and the following disclaimer.\n" 16 | "2. Redistributions in binary form must reproduce the above copyright notice,\n" 17 | " this list of conditions and the following disclaimer in the documentation\n" 18 | " and/or other materials provided with the distribution.\n" 19 | "\n" 20 | "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n" 21 | "ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n" 22 | "WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n" 23 | "DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n" 24 | "ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n" 25 | "(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n" 26 | "LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n" 27 | "ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n" 28 | "(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n" 29 | "SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE." 30 | ) 31 | __license__ = "BSD-2-Clause" 32 | __version__ = "2018.1" 33 | __date__ = "Aug 21, 2018" 34 | __maintainer__ = "Patrick Hohenecker" 35 | __email__ = "mail@paho.at" 36 | __status__ = "Development" 37 | -------------------------------------------------------------------------------- /src/test/python/transformer_test/decoder_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import torch 5 | import torchtestcase as ttc 6 | 7 | from transformer import decoder 8 | from transformer import util 9 | 10 | 11 | __author__ = "Patrick Hohenecker" 12 | __copyright__ = ( 13 | "Copyright (c) 2018, Patrick Hohenecker\n" 14 | "All rights reserved.\n" 15 | "\n" 16 | "Redistribution and use in source and binary forms, with or without\n" 17 | "modification, are permitted provided that the following conditions are met:\n" 18 | "\n" 19 | "1. Redistributions of source code must retain the above copyright notice, this\n" 20 | " list of conditions and the following disclaimer.\n" 21 | "2. Redistributions in binary form must reproduce the above copyright notice,\n" 22 | " this list of conditions and the following disclaimer in the documentation\n" 23 | " and/or other materials provided with the distribution.\n" 24 | "\n" 25 | "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n" 26 | "ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n" 27 | "WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n" 28 | "DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n" 29 | "ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n" 30 | "(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n" 31 | "LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n" 32 | "ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n" 33 | "(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n" 34 | "SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE." 35 | ) 36 | __license__ = "BSD-2-Clause" 37 | __version__ = "2018.1" 38 | __date__ = "Aug 29, 2018" 39 | __maintainer__ = "Patrick Hohenecker" 40 | __email__ = "mail@paho.at" 41 | __status__ = "Development" 42 | 43 | 44 | class DecoderTest(ttc.TorchTestCase): 45 | 46 | TOLERANCE = 1e-4 47 | """float: The tolerance used for tensor equality assertions.""" 48 | 49 | def setUp(self): 50 | self.num_layers = 2 51 | self.num_heads = 4 52 | self.dim_model = 5 53 | self.dim_keys = 3 54 | self.dim_values = 3 55 | self.residual_dropout = 0.1 56 | self.attention_dropout = 0.2 57 | self.pad_index = 0 58 | 59 | def test_forward(self): 60 | dec = decoder.Decoder( 61 | self.num_layers, 62 | self.num_heads, 63 | self.dim_model, 64 | self.dim_keys, 65 | self.dim_values, 66 | 0, # residual dropout 67 | 0, # attention dropout 68 | self.pad_index 69 | ) 70 | 71 | # create test data 72 | input_seq = torch.FloatTensor( 73 | [ 74 | [ 75 | [1, 2, 3, 4, 5], 76 | [11, 22, 33, 44, 55], 77 | [111, 222, 333, 444, 555] 78 | ], 79 | [ 80 | [6, 7, 8, 9, 0], 81 | [66, 77, 88, 99, 00], 82 | [666, 777, 888, 999, 000] 83 | ] 84 | ] 85 | ) 86 | padding_mask = torch.ones(2, 3, 3).byte() 87 | output_seq = torch.FloatTensor( 88 | [ 89 | [ 90 | [0.1, 0.2, 0.3, 0.4, 0.5], 91 | [0.11, 0.22, 0.33, 0.44, 0.55], 92 | [0.111, 0.222, 0.333, 0.444, 0.555], 93 | [0.1111, 0.2222, 0.3333, 0.4444, 0.5555] 94 | ], 95 | [ 96 | [0.6, 0.7, 0.8, 0.9, 0.0], 97 | [0.66, 0.77, 0.88, 0.99, 0.00], 98 | [0.666, 0.777, 0.888, 0.999, 0.000], 99 | [0.6666, 0.7777, 0.8888, 0.9999, 0.0000] 100 | ] 101 | ] 102 | ) 103 | output_seq_2 = output_seq.clone() 104 | output_seq_2[0, -1] = torch.FloatTensor([0.6, 0.7, 0.8, 0.9, 1.0]) 105 | output_seq_2[1, -1] = torch.FloatTensor([0.66, 0.77, 0.88, 0.99, 11.0]) 106 | shifted_output_mask = util.create_shifted_output_mask(output_seq) 107 | 108 | # NOTICE: 109 | # output_seq and output_seq_1 differ at the last time step only 110 | 111 | # compute target 112 | target = util.shift_output_sequence(output_seq) 113 | for layer in dec._layers: 114 | target = layer(input_seq, target, padding_mask, shifted_output_mask) 115 | 116 | # run the decoder on both output sequences 117 | dec_seq = dec(input_seq, output_seq, padding_mask) 118 | dec_seq_2 = dec(input_seq, output_seq_2, padding_mask) 119 | 120 | # CHECK: the provided output has the same shape as the input 121 | self.assertEqual(output_seq.size(), dec_seq.size()) 122 | 123 | # CHECK: the encoder computes the expected targets 124 | self.eps = self.TOLERANCE 125 | self.assertEqual(target, dec_seq) 126 | self.assertEqual(target, dec_seq_2) 127 | # -> both of the computed values have to be equal, as the provided target output sequences differ at the last 128 | # time step only, which is not considered because of the performed shift 129 | 130 | def test_init(self): 131 | dec = decoder.Decoder( 132 | self.num_layers, 133 | self.num_heads, 134 | self.dim_model, 135 | self.dim_keys, 136 | self.dim_values, 137 | self.residual_dropout, 138 | self.attention_dropout, 139 | self.pad_index 140 | ) 141 | 142 | # CHECK: the correct number of layers was created 143 | self.assertEqual(self.num_layers, len(dec._layers)) 144 | 145 | for layer in dec._layers: # iterate over all layers in the encoder 146 | 147 | # CHECK: the first attention mechanism was configured correctly 148 | self.assertEqual(self.dim_keys, layer.attn_1.dim_keys) 149 | self.assertEqual(self.dim_model, layer.attn_1.dim_model) 150 | self.assertEqual(self.dim_values, layer.attn_1.dim_values) 151 | self.assertEqual(self.attention_dropout, layer.attn_1.dropout_rate) 152 | self.assertEqual(self.num_heads, layer.attn_1.num_heads) 153 | 154 | # CHECK: the second attention mechanism was configured correctly 155 | self.assertEqual(self.dim_keys, layer.attn_2.dim_keys) 156 | self.assertEqual(self.dim_model, layer.attn_2.dim_model) 157 | self.assertEqual(self.dim_values, layer.attn_2.dim_values) 158 | self.assertEqual(self.attention_dropout, layer.attn_2.dropout_rate) 159 | self.assertEqual(self.num_heads, layer.attn_2.num_heads) 160 | 161 | # CHECK: the feed-forward layer was configured correctly 162 | self.assertEqual(self.dim_model, layer.feed_forward.dim_model) 163 | 164 | # CHECK: the dropout mechanism uses the correct dropout rate 165 | self.assertEqual(self.residual_dropout, layer.dropout.p) 166 | -------------------------------------------------------------------------------- /src/test/python/transformer_test/encoder_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import torch 5 | import torchtestcase as ttc 6 | 7 | from transformer import encoder 8 | 9 | 10 | __author__ = "Patrick Hohenecker" 11 | __copyright__ = ( 12 | "Copyright (c) 2018, Patrick Hohenecker\n" 13 | "All rights reserved.\n" 14 | "\n" 15 | "Redistribution and use in source and binary forms, with or without\n" 16 | "modification, are permitted provided that the following conditions are met:\n" 17 | "\n" 18 | "1. Redistributions of source code must retain the above copyright notice, this\n" 19 | " list of conditions and the following disclaimer.\n" 20 | "2. Redistributions in binary form must reproduce the above copyright notice,\n" 21 | " this list of conditions and the following disclaimer in the documentation\n" 22 | " and/or other materials provided with the distribution.\n" 23 | "\n" 24 | "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n" 25 | "ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n" 26 | "WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n" 27 | "DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n" 28 | "ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n" 29 | "(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n" 30 | "LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n" 31 | "ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n" 32 | "(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n" 33 | "SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE." 34 | ) 35 | __license__ = "BSD-2-Clause" 36 | __version__ = "2018.1" 37 | __date__ = "Aug 29, 2018" 38 | __maintainer__ = "Patrick Hohenecker" 39 | __email__ = "mail@paho.at" 40 | __status__ = "Development" 41 | 42 | 43 | class EncoderTest(ttc.TorchTestCase): 44 | 45 | TOLERANCE = 1e-5 46 | 47 | def setUp(self): 48 | self.num_layers = 2 49 | self.num_heads = 4 50 | self.dim_model = 5 51 | self.dim_keys = 3 52 | self.dim_values = 3 53 | self.residual_dropout = 0.1 54 | self.attention_dropout = 0.2 55 | self.pad_index = 0 56 | 57 | def test_forward(self): 58 | enc = encoder.Encoder( 59 | self.num_layers, 60 | self.num_heads, 61 | self.dim_model, 62 | self.dim_keys, 63 | self.dim_values, 64 | 0, # residual dropout 65 | 0, # attention dropout, 66 | self.pad_index 67 | ) 68 | 69 | # create test data 70 | input_seq = torch.FloatTensor( 71 | [ 72 | [ 73 | [1, 2, 3, 4, 5], 74 | [11, 22, 33, 44, 55], 75 | [111, 222, 333, 444, 555] 76 | ], 77 | [ 78 | [6, 7, 8, 9, 0], 79 | [66, 77, 88, 99, 00], 80 | [666, 777, 888, 999, 000] 81 | ] 82 | ] 83 | ) 84 | padding_mask = torch.ones(2, 3, 3).byte() 85 | 86 | # compute target 87 | target = input_seq 88 | for layer in enc._layers: 89 | target = layer(target, padding_mask) 90 | target = target.detach() 91 | 92 | # run the encoder 93 | enc_seq = enc(input_seq).detach() 94 | 95 | # CHECK: the provided output has the same shape as the input 96 | self.assertEqual(input_seq.size(), enc_seq.size()) 97 | 98 | # CHECK: the encoder computes the expected target 99 | self.assertLessEqual( 100 | (target - enc_seq).abs(), 101 | torch.ones(target.size()) * self.TOLERANCE 102 | ) 103 | 104 | def test_init(self): 105 | enc = encoder.Encoder( 106 | self.num_layers, 107 | self.num_heads, 108 | self.dim_model, 109 | self.dim_keys, 110 | self.dim_values, 111 | self.residual_dropout, 112 | self.attention_dropout, 113 | self.pad_index 114 | ) 115 | 116 | # CHECK: the correct number of layers was created 117 | self.assertEqual(self.num_layers, len(enc._layers)) 118 | 119 | for layer in enc._layers: # iterate over all layers in the encoder 120 | 121 | # CHECK: the attention mechanism was configured correctly 122 | self.assertEqual(self.dim_keys, layer.attn.dim_keys) 123 | self.assertEqual(self.dim_model, layer.attn.dim_model) 124 | self.assertEqual(self.dim_values, layer.attn.dim_values) 125 | self.assertEqual(self.attention_dropout, layer.attn.dropout_rate) 126 | self.assertEqual(self.num_heads, layer.attn.num_heads) 127 | 128 | # CHECK: the feed-forward layer was configured correctly 129 | self.assertEqual(self.dim_model, layer.feed_forward.dim_model) 130 | 131 | # CHECK: the dropout mechanism uses the correct dropout rate 132 | self.assertEqual(self.residual_dropout, layer.dropout.p) 133 | -------------------------------------------------------------------------------- /src/test/python/transformer_test/feed_forward_layer_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import torch 5 | import torchtestcase as ttc 6 | 7 | from torch.nn import functional 8 | 9 | from transformer import feed_forward_layer as ffl 10 | 11 | 12 | __author__ = "Patrick Hohenecker" 13 | __copyright__ = ( 14 | "Copyright (c) 2018, Patrick Hohenecker\n" 15 | "All rights reserved.\n" 16 | "\n" 17 | "Redistribution and use in source and binary forms, with or without\n" 18 | "modification, are permitted provided that the following conditions are met:\n" 19 | "\n" 20 | "1. Redistributions of source code must retain the above copyright notice, this\n" 21 | " list of conditions and the following disclaimer.\n" 22 | "2. Redistributions in binary form must reproduce the above copyright notice,\n" 23 | " this list of conditions and the following disclaimer in the documentation\n" 24 | " and/or other materials provided with the distribution.\n" 25 | "\n" 26 | "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n" 27 | "ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n" 28 | "WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n" 29 | "DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n" 30 | "ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n" 31 | "(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n" 32 | "LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n" 33 | "ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n" 34 | "(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n" 35 | "SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE." 36 | ) 37 | __license__ = "BSD-2-Clause" 38 | __version__ = "2018.1" 39 | __date__ = "Aug 23, 2018" 40 | __maintainer__ = "Patrick Hohenecker" 41 | __email__ = "mail@paho.at" 42 | __status__ = "Development" 43 | 44 | 45 | class FeedForwardLayerTest(ttc.TorchTestCase): 46 | 47 | TOLERANCE = 1e-5 48 | 49 | def test_forward(self): 50 | dim_model = 5 51 | 52 | # create model for testing 53 | ff_layer = ffl.FeedForwardLayer(dim_model) 54 | 55 | # create test data 56 | input_seq = torch.FloatTensor( 57 | [ 58 | [ 59 | [1, 2, 3, 4, 5], 60 | [6, 7, 8, 9, 0] 61 | ], 62 | [ 63 | [11, 22, 33, 44, 55], 64 | [66, 77, 88, 99, 00] 65 | ] 66 | ] 67 | ) 68 | 69 | # fetch parameters of the feed-forward layer 70 | weight_1 = ff_layer.layer_1.weight.data 71 | bias_1 = ff_layer.layer_1.bias.data.data 72 | weight_2 = ff_layer.layer_2.weight.data 73 | bias_2 = ff_layer.layer_2.bias.data.data 74 | 75 | # CHECK: the parameters are as expected 76 | self.assertEqual((dim_model, dim_model, 1), weight_1.size()) 77 | self.assertEqual((dim_model,), bias_1.size()) 78 | self.assertEqual((dim_model, dim_model, 1), weight_2.size()) 79 | self.assertEqual((dim_model,), bias_2.size()) 80 | 81 | # turn all parameters into matrices 82 | weight_1 = weight_1.squeeze() 83 | bias_1 = bias_1.unsqueeze(1) 84 | weight_2 = weight_2.squeeze() 85 | bias_2 = bias_2.unsqueeze(1) 86 | 87 | # run the feed-forward layer 88 | output = ff_layer(input_seq).data 89 | 90 | # CHECK: the output has the same shape as the input 91 | self.assertEqual(input_seq.size(), output.size()) 92 | 93 | for sample_idx in range(input_seq.size(0)): # iterate over all samples in the batch 94 | for token_idx in range(input_seq.size(1)): # iterate over all tokens in the input sequences 95 | 96 | # compute target value 97 | target = input_seq[sample_idx, token_idx].unsqueeze(1) 98 | target = weight_1.matmul(target) + bias_1 99 | target = functional.relu(target) 100 | target = weight_2.matmul(target) + bias_2 101 | target = target.squeeze() 102 | 103 | # CHECK: the corresponding token has been processed correctly 104 | self.assertLessEqual( 105 | (target - output[sample_idx, token_idx]).abs(), 106 | torch.ones(target.size()) * self.TOLERANCE 107 | ) 108 | -------------------------------------------------------------------------------- /src/test/python/transformer_test/multi_head_attention_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import numpy as np 5 | import torch 6 | import torchtestcase as ttc 7 | 8 | from torch.nn import functional 9 | 10 | from transformer import multi_head_attention as mha 11 | 12 | 13 | __author__ = "Patrick Hohenecker" 14 | __copyright__ = ( 15 | "Copyright (c) 2018, Patrick Hohenecker\n" 16 | "All rights reserved.\n" 17 | "\n" 18 | "Redistribution and use in source and binary forms, with or without\n" 19 | "modification, are permitted provided that the following conditions are met:\n" 20 | "\n" 21 | "1. Redistributions of source code must retain the above copyright notice, this\n" 22 | " list of conditions and the following disclaimer.\n" 23 | "2. Redistributions in binary form must reproduce the above copyright notice,\n" 24 | " this list of conditions and the following disclaimer in the documentation\n" 25 | " and/or other materials provided with the distribution.\n" 26 | "\n" 27 | "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n" 28 | "ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n" 29 | "WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n" 30 | "DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n" 31 | "ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n" 32 | "(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n" 33 | "LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n" 34 | "ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n" 35 | "(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n" 36 | "SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE." 37 | ) 38 | __license__ = "BSD-2-Clause" 39 | __version__ = "2018.1" 40 | __date__ = "Aug 22, 2018" 41 | __maintainer__ = "Patrick Hohenecker" 42 | __email__ = "mail@paho.at" 43 | __status__ = "Development" 44 | 45 | 46 | class MultiHeadAttentionTest(ttc.TorchTestCase): 47 | 48 | TOLERANCE = 1e-5 49 | 50 | def setUp(self): 51 | self.eps = self.TOLERANCE 52 | np.seterr(all="raise") 53 | 54 | self.num_heads = 2 55 | self.dim_model = 4 56 | self.dim_keys = 2 57 | self.dim_values = 3 58 | self.batch_size = 2 59 | self.seq_len = 3 60 | 61 | # create attention mechanism for testing 62 | self.attn = mha.MultiHeadAttention(self.num_heads, self.dim_model, self.dim_keys, self.dim_values, 0) 63 | 64 | # create dummy data 65 | self.queries_0 = np.array( 66 | [ 67 | [0, 0, 0, 0], 68 | [4, 4, 4, 4], 69 | [8, 8, 8, 8] 70 | ], 71 | dtype=np.float32 72 | ) 73 | self.keys_0 = 2 * self.queries_0 74 | self.values_0 = 3 * self.queries_0 75 | self.queries_1 = np.array( 76 | [ 77 | [0, 1, 2, 3], 78 | [4, 5, 6, 7], 79 | [8, 9, 10, 11] 80 | ], 81 | dtype=np.float32 82 | ) 83 | self.keys_1 = 2 * self.queries_1 84 | self.values_1 = 3 * self.queries_1 85 | 86 | # create tensors to provided as input data to the attention mechanism 87 | self.in_queries = torch.stack( 88 | [ 89 | torch.from_numpy(self.queries_0), 90 | torch.from_numpy(self.queries_1) 91 | ] 92 | ) 93 | self.in_keys = torch.stack( 94 | [ 95 | torch.from_numpy(self.keys_0), 96 | torch.from_numpy(self.keys_1) 97 | ] 98 | ) 99 | self.in_values = torch.stack( 100 | [ 101 | torch.from_numpy(self.values_0), 102 | torch.from_numpy(self.values_1) 103 | ] 104 | ) 105 | 106 | def assertArrayEqualsTensor(self, a: np.ndarray, t: torch.Tensor): 107 | if np.abs(a - t.detach().numpy()).max() > self.TOLERANCE: 108 | raise AssertionError("The values are different!") 109 | 110 | def test_apply_attention(self): 111 | # project queries, keys, and values to the needed dimensions 112 | in_queries, in_keys, in_values = self.attn._project_inputs(self.in_queries, self.in_keys, self.in_values) 113 | 114 | # CHECK: ensure that inputs have dimensions as expected 115 | self.assertEqual((self.batch_size, self.num_heads, self.seq_len, self.dim_keys), in_queries.size()) 116 | self.assertEqual((self.batch_size, self.num_heads, self.seq_len, self.dim_keys), in_keys.size()) 117 | self.assertEqual((self.batch_size, self.num_heads, self.seq_len, self.dim_values), in_values.size()) 118 | 119 | # compute attended values 120 | attn_values = self.attn._apply_attention(in_queries, in_keys, in_values, None) 121 | 122 | # CHECK: the retrieved tensor has the correct shape 123 | self.assertEqual((self.batch_size, self.num_heads, self.seq_len, self.dim_values), attn_values.size()) 124 | 125 | for sample_idx in range(self.batch_size): # iterate over all samples 126 | for head_idx in range(self.num_heads): # iterate over all heads 127 | 128 | # compute the attention scores for the current head 129 | attn_scores = torch.matmul( 130 | in_queries[sample_idx][head_idx], 131 | in_keys[sample_idx][head_idx].transpose(0, 1) 132 | ) 133 | attn_scores /= np.sqrt(self.dim_keys) 134 | attn_scores = functional.softmax(attn_scores, dim=1) 135 | 136 | # compute attended values for the current head 137 | target_attn_values = torch.matmul(attn_scores, in_values[sample_idx][head_idx]) 138 | 139 | # CHECK: the retrieved attention values are correct 140 | self.assertEqual(target_attn_values, attn_values[sample_idx][head_idx]) 141 | 142 | # recompute attended values with 1-mask 143 | attn_values_2 = self.attn._apply_attention( 144 | in_queries, 145 | in_keys, 146 | in_values, 147 | torch.ones(self.batch_size, self.in_queries.size(1), self.in_keys.size(1)).byte() 148 | ) 149 | 150 | # CHECK: providing the mask did not change the attended values 151 | self.assertEqual(attn_values, attn_values_2) 152 | 153 | # create "short" keys/values 154 | _, short_in_keys, short_in_values = self.attn._project_inputs( 155 | self.in_queries, 156 | self.in_keys[:, :2, :], 157 | self.in_values[:, :2, :] 158 | ) 159 | 160 | # compute attended values for the short inputs 161 | short_attn_values = self.attn._apply_attention(in_queries, short_in_keys, short_in_values, None).detach() 162 | 163 | # compute short attended values using a mask rather than short inputs 164 | short_attn_values_2 = self.attn._apply_attention( 165 | in_queries, 166 | in_keys, 167 | in_values, 168 | torch.ByteTensor( 169 | [ 170 | [ # sample 0 171 | [1, 1, 0], # query 0 172 | [1, 1, 0], # query 1 173 | [1, 1, 0] # query 2 174 | ], 175 | [ # sample 1 176 | [1, 1, 0], # query 0 177 | [1, 1, 0], # query 1 178 | [1, 1, 0] # query 2 179 | ] 180 | ] 181 | ) 182 | ).detach() 183 | 184 | # CHECK: attention over short values yielded the same values as using the mask 185 | self.eps = self.TOLERANCE 186 | self.assertEqual(short_attn_values[:, 0], short_attn_values_2[:, 0]) 187 | 188 | # CHECK: if the mask is all 0, then the retrieved values are 0 as well 189 | self.eps = 0 190 | self.assertEqual( 191 | torch.zeros(in_queries.size()), 192 | self.attn._apply_attention( 193 | in_queries, 194 | in_keys, 195 | in_values, 196 | torch.zeros(in_queries.size(0), in_queries.size(1), in_keys.size(1)).byte() 197 | ) 198 | ) 199 | 200 | def test_project_inputs(self): 201 | # fetch projection matrices of the first head 202 | query_projection_0 = self.attn.query_projection[0].detach().numpy() 203 | key_projection_0 = self.attn.key_projection[0].detach().numpy() 204 | value_projection_0 = self.attn.value_projection[0].detach().numpy() 205 | 206 | # fetch projection matrices of the second head 207 | query_projection_1 = self.attn.query_projection[1].detach().numpy() 208 | key_projection_1 = self.attn.key_projection[1].detach().numpy() 209 | value_projection_1 = self.attn.value_projection[1].detach().numpy() 210 | 211 | # CHECK: ensure that inputs have dimensions as expected 212 | self.assertEqual((self.batch_size, self.seq_len, self.dim_model), self.in_queries.size()) 213 | self.assertEqual((self.batch_size, self.seq_len, self.dim_model), self.in_keys.size()) 214 | self.assertEqual((self.batch_size, self.seq_len, self.dim_model), self.in_values.size()) 215 | 216 | # run input projection 217 | proj_queries, proj_keys, proj_values = self.attn._project_inputs(self.in_queries, self.in_keys, self.in_values) 218 | 219 | # CHECK: the projected values have the correct shapes 220 | self.assertEqual((self.batch_size, self.num_heads, self.seq_len, self.dim_keys), proj_queries.size()) 221 | self.assertEqual((self.batch_size, self.num_heads, self.seq_len, self.dim_keys), proj_keys.size()) 222 | self.assertEqual((self.batch_size, self.num_heads, self.seq_len, self.dim_values), proj_values.size()) 223 | 224 | # CHECK: queries are projected correctly 225 | self.assertArrayEqualsTensor(np.matmul(self.queries_0, query_projection_0), proj_queries[0][0]) 226 | self.assertArrayEqualsTensor(np.matmul(self.queries_0, query_projection_1), proj_queries[0][1]) 227 | self.assertArrayEqualsTensor(np.matmul(self.queries_1, query_projection_0), proj_queries[1][0]) 228 | self.assertArrayEqualsTensor(np.matmul(self.queries_1, query_projection_1), proj_queries[1][1]) 229 | 230 | # CHECK: keys are projected correctly 231 | self.assertArrayEqualsTensor(np.matmul(self.keys_0, key_projection_0), proj_keys[0][0]) 232 | self.assertArrayEqualsTensor(np.matmul(self.keys_0, key_projection_1), proj_keys[0][1]) 233 | self.assertArrayEqualsTensor(np.matmul(self.keys_1, key_projection_0), proj_keys[1][0]) 234 | self.assertArrayEqualsTensor(np.matmul(self.keys_1, key_projection_1), proj_keys[1][1]) 235 | 236 | # CHECK: values are projected correctly 237 | self.assertArrayEqualsTensor(np.matmul(self.values_0, value_projection_0), proj_values[0][0]) 238 | self.assertArrayEqualsTensor(np.matmul(self.values_0, value_projection_1), proj_values[0][1]) 239 | self.assertArrayEqualsTensor(np.matmul(self.values_1, value_projection_0), proj_values[1][0]) 240 | self.assertArrayEqualsTensor(np.matmul(self.values_1, value_projection_1), proj_values[1][1]) 241 | 242 | def test_project_output(self): 243 | # fetch projection matrix 244 | output_projection = self.attn.output_projection 245 | 246 | # compute attention values for all queries 247 | attn_values = self.attn._apply_attention( 248 | *self.attn._project_inputs(self.in_queries, self.in_keys, self.in_values), 249 | None 250 | ) 251 | 252 | # CHECK: ensure that attention values have the correct shape 253 | self.assertEqual((self.batch_size, self.num_heads, self.seq_len, self.dim_values), attn_values.size()) 254 | 255 | # run output projection 256 | output = self.attn._project_output(attn_values) 257 | 258 | # CHECK: ensure that the output has the expected shape 259 | self.assertEqual((self.batch_size, self.seq_len, self.dim_model), output.size()) 260 | 261 | for sample_idx in range(self.batch_size): # iterate over all samples 262 | for query_idx in range(self.seq_len): # iterate over all queries 263 | 264 | # concatenate the values retrieved by the single heads (as row vector) 265 | concat_values = torch.cat( 266 | [ 267 | attn_values[sample_idx][0][query_idx], 268 | attn_values[sample_idx][1][query_idx] 269 | ] 270 | ).unsqueeze(0) 271 | 272 | # project concatenated values 273 | target_output = torch.matmul(concat_values, output_projection).squeeze() 274 | 275 | # CHECK: the retrieved output is correct 276 | self.eps = self.TOLERANCE 277 | self.assertEqual(target_output, output[sample_idx][query_idx]) 278 | -------------------------------------------------------------------------------- /src/test/python/transformer_test/normalization_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import torch 5 | import torchtestcase as ttc 6 | 7 | from transformer import normalization 8 | 9 | 10 | __author__ = "Patrick Hohenecker" 11 | __copyright__ = ( 12 | "Copyright (c) 2018, Patrick Hohenecker\n" 13 | "All rights reserved.\n" 14 | "\n" 15 | "Redistribution and use in source and binary forms, with or without\n" 16 | "modification, are permitted provided that the following conditions are met:\n" 17 | "\n" 18 | "1. Redistributions of source code must retain the above copyright notice, this\n" 19 | " list of conditions and the following disclaimer.\n" 20 | "2. Redistributions in binary form must reproduce the above copyright notice,\n" 21 | " this list of conditions and the following disclaimer in the documentation\n" 22 | " and/or other materials provided with the distribution.\n" 23 | "\n" 24 | "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n" 25 | "ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n" 26 | "WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n" 27 | "DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n" 28 | "ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n" 29 | "(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n" 30 | "LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n" 31 | "ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n" 32 | "(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n" 33 | "SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE." 34 | ) 35 | __license__ = "BSD-2-Clause" 36 | __version__ = "2018.1" 37 | __date__ = "Aug 29, 2018" 38 | __maintainer__ = "Patrick Hohenecker" 39 | __email__ = "mail@paho.at" 40 | __status__ = "Development" 41 | 42 | 43 | class NormalizationTest(ttc.TorchTestCase): 44 | 45 | def test_forward(self): 46 | # create test data 47 | data = torch.FloatTensor( 48 | [ 49 | [ 50 | [1, 2, 3], 51 | [11, 22, 33] 52 | ], 53 | [ 54 | [111, 222, 333], 55 | [1111, 2222, 3333] 56 | ] 57 | ] 58 | ) 59 | 60 | # create layer used for testing 61 | norm = normalization.Normalization() 62 | 63 | # run layer to normalize the data 64 | norm_data = norm(data) 65 | 66 | for sample_idx in range(data.size(0)): # iterate over all samples in the batch 67 | for token_idx in range(data.size(1)): # iterate over all tokens in the sequences 68 | 69 | # normalize current token 70 | norm_token = data[sample_idx, token_idx] 71 | norm_token = norm_token - torch.mean(norm_token) 72 | norm_token = norm_token / (torch.std(norm_token) + norm.eps) 73 | 74 | # CHECK: the data has been normed correctly 75 | self.assertLessEqual( 76 | (norm_token - norm_data[sample_idx, token_idx]).abs(), 77 | torch.ones(norm_token.size()) * norm.eps 78 | ) 79 | 80 | # run layer on 0-data 81 | norm_data = norm(torch.zeros(1, 2, 3)) 82 | 83 | # CHECK: the data is approximately zero 84 | self.assertLessEqual(norm_data.max().item(), norm.eps) 85 | -------------------------------------------------------------------------------- /src/test/python/transformer_test/transformer_tools_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import torch 5 | import torchtestcase as ttc 6 | 7 | from unittest import mock 8 | 9 | from torch import nn 10 | 11 | from transformer import transformer 12 | from transformer import transformer_tools as tt 13 | 14 | 15 | __author__ = "Patrick Hohenecker" 16 | __copyright__ = ( 17 | "Copyright (c) 2018, Patrick Hohenecker\n" 18 | "All rights reserved.\n" 19 | "\n" 20 | "Redistribution and use in source and binary forms, with or without\n" 21 | "modification, are permitted provided that the following conditions are met:\n" 22 | "\n" 23 | "1. Redistributions of source code must retain the above copyright notice, this\n" 24 | " list of conditions and the following disclaimer.\n" 25 | "2. Redistributions in binary form must reproduce the above copyright notice,\n" 26 | " this list of conditions and the following disclaimer in the documentation\n" 27 | " and/or other materials provided with the distribution.\n" 28 | "\n" 29 | "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n" 30 | "ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n" 31 | "WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n" 32 | "DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n" 33 | "ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n" 34 | "(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n" 35 | "LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n" 36 | "ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n" 37 | "(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n" 38 | "SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE." 39 | ) 40 | __license__ = "BSD-2-Clause" 41 | __version__ = "2018.1" 42 | __date__ = "Aug 30, 2018" 43 | __maintainer__ = "Patrick Hohenecker" 44 | __email__ = "mail@paho.at" 45 | __status__ = "Development" 46 | 47 | 48 | class TransformerToolsTest(ttc.TorchTestCase): 49 | 50 | TOLERANCE = 1e-6 51 | 52 | def test_eval_probability(self): 53 | pad_index = 0 54 | 55 | # create the test data 56 | input_seq = torch.LongTensor( # the actual values are irrelevant, since the model's forward method is mocked 57 | [ 58 | [1, 1, 1, 1, 0, 0], 59 | [1, 1, 1, 1, 1, 0], 60 | [1, 1, 0, 0, 0, 0] 61 | ] 62 | ) 63 | target_seq = torch.LongTensor( 64 | [ 65 | [1, 3, 0], 66 | [4, 2, 1], 67 | [2, 0, 0] 68 | ] 69 | ) 70 | mock_probs = torch.FloatTensor( 71 | [ 72 | [ 73 | [0.2, 0.2, 0.2, 0.2, 0.2], 74 | [0.05, 0.05, 0.05, 0.8, 0.05], 75 | [0.4, 0.3, 0.2, 0.05, 0.05] 76 | ], 77 | [ 78 | [0.2, 0.2, 0.2, 0.2, 0.2], 79 | [0.0, 0.0, 1.0, 0.0, 0.0], 80 | [0.1, 0.6, 0.05, 0.05, 0.2] 81 | ], 82 | [ 83 | [0.2, 0.2, 0.2, 0.2, 0.2], 84 | [0.0, 0.1, 0.2, 0.3, 0.4], 85 | [0.4, 0.3, 0.2, 0.1, 0.0] 86 | ] 87 | ] 88 | ) 89 | target_probs = torch.FloatTensor([0.16, 0.12, 0.2]) 90 | 91 | # prepare model 92 | model = transformer.Transformer( 93 | nn.Embedding(5, 5), 94 | pad_index, 95 | 10, 96 | max_seq_len=10 97 | ) 98 | model.forward = mock.MagicMock(return_value=mock_probs) 99 | 100 | # evaluate the probabilities of the data as predicted by the model 101 | probs = tt.eval_probability(model, input_seq, target_seq, pad_index=pad_index) 102 | 103 | # CHECK: the probabilities are as expected 104 | self.assertLessEqual( 105 | (target_probs - probs).abs(), 106 | self.TOLERANCE 107 | ) 108 | 109 | def test_sample_output(self): 110 | eos_index = 0 111 | pad_index = 1 112 | 113 | # prepare mock outputs of the model 114 | outputs = [ 115 | torch.FloatTensor( 116 | [ # time step 0 117 | [ # + sample 0 118 | [0.0, 0.0, 1.0, 0.0, 0.0] # | + output 0 -> 2 119 | ], # | 120 | [ # + sample 1 121 | [0.0, 0.0, 0.0, 1.0, 0.0] # | + output 0 -> 3 122 | ] 123 | ] 124 | ), 125 | torch.FloatTensor( 126 | [ # time step 1 127 | [ # + sample 0 128 | [0.0, 0.0, 1.0, 0.0, 0.0], # | + output 0 129 | [1.0, 0.0, 0.0, 0.0, 0.0] # | + output 1 -> 0 = EOS 130 | ], # | 131 | [ # + sample 1 132 | [0.0, 0.0, 0.0, 1.0, 0.0], # | + output 0 133 | [0.0, 0.0, 0.0, 0.0, 1.0] # | + output 1 -> 4 134 | ] 135 | ] 136 | ), 137 | torch.FloatTensor( 138 | [ # time step 2 139 | [ # + sample 0 140 | [0.0, 0.0, 1.0, 0.0, 0.0], # | + output 0 141 | [1.0, 0.0, 0.0, 0.0, 0.0], # | + output 1 142 | [0.0, 0.0, 1.0, 0.0, 0.0] # | + output 2 -> IRRELEVANT 143 | ], # | 144 | [ # + sample 1 145 | [0.0, 0.0, 0.0, 1.0, 0.0], # | + output 0 146 | [0.0, 0.0, 0.0, 0.0, 1.0], # | + output 1 147 | [1.0, 0.0, 0.0, 0.0, 0.0] # | + output 2 -> 0 = EOS 148 | ] 149 | ] 150 | ) 151 | ] 152 | 153 | def forward_patch(_, trans_target_seq) -> torch.FloatTensor: 154 | return outputs[trans_target_seq.size(1) - 1] 155 | 156 | # prepare the input sequence 157 | input_seq = torch.LongTensor( 158 | [ 159 | [2, 2, 2, 2, 2, 2], 160 | [3, 3, 3, 3, 3, 3] 161 | ] 162 | ) 163 | 164 | # prepare the target output sequence 165 | target = torch.LongTensor( 166 | [ 167 | [2, eos_index, pad_index], 168 | [3, 4, eos_index] 169 | ] 170 | ) 171 | 172 | # prepare model 173 | model = transformer.Transformer( 174 | nn.Embedding(5, 5), # word_emb 175 | pad_index, # pad_index 176 | 5, # output_size 177 | max_seq_len=100 178 | ) 179 | 180 | # patch the model to return the probabilities defined above 181 | with mock.patch("transformer.Transformer.forward", mock.Mock(side_effect=forward_patch)): 182 | 183 | # generate an output sequence 184 | output_seq = tt.sample_output(model, input_seq, eos_index, pad_index, 100) 185 | 186 | # CHECK: The generated sequence tensor has the expected shape 187 | self.assertEqual(target.shape, output_seq.shape) 188 | 189 | # CHECK: The generated sequences are as expected 190 | self.assertEqual(target, output_seq) 191 | -------------------------------------------------------------------------------- /src/test/python/transformer_test/util_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import numpy as np 5 | import torch 6 | import torchtestcase as ttc 7 | 8 | from torch import nn 9 | 10 | from transformer import util 11 | 12 | 13 | __author__ = "Patrick Hohenecker" 14 | __copyright__ = ( 15 | "Copyright (c) 2018, Patrick Hohenecker\n" 16 | "All rights reserved.\n" 17 | "\n" 18 | "Redistribution and use in source and binary forms, with or without\n" 19 | "modification, are permitted provided that the following conditions are met:\n" 20 | "\n" 21 | "1. Redistributions of source code must retain the above copyright notice, this\n" 22 | " list of conditions and the following disclaimer.\n" 23 | "2. Redistributions in binary form must reproduce the above copyright notice,\n" 24 | " this list of conditions and the following disclaimer in the documentation\n" 25 | " and/or other materials provided with the distribution.\n" 26 | "\n" 27 | "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n" 28 | "ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n" 29 | "WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n" 30 | "DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n" 31 | "ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n" 32 | "(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n" 33 | "LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n" 34 | "ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n" 35 | "(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n" 36 | "SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE." 37 | ) 38 | __license__ = "BSD-2-Clause" 39 | __version__ = "2018.1" 40 | __date__ = "Aug 29, 2018" 41 | __maintainer__ = "Patrick Hohenecker" 42 | __email__ = "mail@paho.at" 43 | __status__ = "Development" 44 | 45 | 46 | class UtilTest(ttc.TorchTestCase): 47 | 48 | TOLERANCE = 1e-5 49 | 50 | def test_create_padding_mask(self): 51 | # create test data 52 | seq = torch.LongTensor( 53 | [ 54 | [4, 2, 3, 5, 1, 0], 55 | [1, 4, 0, 0, 0, 0], 56 | [6, 3, 2, 4, 5, 1], 57 | [5, 3, 2, 3, 0, 0], 58 | [0, 0, 0, 0, 0, 0] 59 | ] 60 | ) 61 | target_mask = torch.ByteTensor( 62 | [ 63 | [ 64 | [1, 1, 1, 1, 1, 0], 65 | [1, 1, 1, 1, 1, 0], 66 | [1, 1, 1, 1, 1, 0], 67 | [1, 1, 1, 1, 1, 0], 68 | [1, 1, 1, 1, 1, 0], 69 | [1, 1, 1, 1, 1, 0] 70 | ], 71 | [ 72 | [1, 1, 0, 0, 0, 0], 73 | [1, 1, 0, 0, 0, 0], 74 | [1, 1, 0, 0, 0, 0], 75 | [1, 1, 0, 0, 0, 0], 76 | [1, 1, 0, 0, 0, 0], 77 | [1, 1, 0, 0, 0, 0] 78 | ], 79 | [ 80 | [1, 1, 1, 1, 1, 1], 81 | [1, 1, 1, 1, 1, 1], 82 | [1, 1, 1, 1, 1, 1], 83 | [1, 1, 1, 1, 1, 1], 84 | [1, 1, 1, 1, 1, 1], 85 | [1, 1, 1, 1, 1, 1] 86 | ], 87 | [ 88 | [1, 1, 1, 1, 0, 0], 89 | [1, 1, 1, 1, 0, 0], 90 | [1, 1, 1, 1, 0, 0], 91 | [1, 1, 1, 1, 0, 0], 92 | [1, 1, 1, 1, 0, 0], 93 | [1, 1, 1, 1, 0, 0] 94 | ], 95 | [ 96 | [0, 0, 0, 0, 0, 0], 97 | [0, 0, 0, 0, 0, 0], 98 | [0, 0, 0, 0, 0, 0], 99 | [0, 0, 0, 0, 0, 0], 100 | [0, 0, 0, 0, 0, 0], 101 | [0, 0, 0, 0, 0, 0] 102 | ] 103 | ] 104 | ) 105 | 106 | # create the padding mask 107 | mask = util.create_padding_mask(seq, 0) 108 | 109 | # CHECK: the retrieved mask is a ByteTensor 110 | self.assertIsInstance(mask, torch.ByteTensor) 111 | 112 | # CHECK: the mask has the correct shape 113 | self.assertEqual((seq.size(0), seq.size(1), seq.size(1)), mask.size()) 114 | 115 | # CHECK: the mask contains the correct values 116 | self.assertEqual(target_mask, mask) 117 | 118 | def test_create_positional_emb(self): 119 | dim_model = 512 120 | embedding_size = 300 121 | max_seq_len = 100 122 | 123 | # create positional embeddings 124 | pos_emb = util.create_positional_emb(max_seq_len, embedding_size, dim_model) 125 | 126 | # CHECK: the retrieved instance is an embedding of correct size 127 | self.assertIsInstance(pos_emb, nn.Embedding) 128 | self.assertEqual(max_seq_len, pos_emb.num_embeddings) 129 | self.assertEqual(embedding_size, pos_emb.embedding_dim) 130 | 131 | for pos in range(max_seq_len): # iterate over the embeddings for all time steps 132 | 133 | # fetch embedding vector for the current index 134 | emb_vec = pos_emb(torch.ones(1, dtype=torch.long) * pos).squeeze() 135 | 136 | for i in range(embedding_size): # iterate over all values of the current time step 137 | 138 | # fetch value for c 139 | emb_val = emb_vec[i].item() 140 | 141 | # CHECK: the considered value is the one expected 142 | if i % 2 == 0: 143 | self.assertLessEqual( 144 | np.abs(np.sin(pos / (10000 ** (i / dim_model))) - emb_val), 145 | self.TOLERANCE 146 | ) 147 | else: 148 | self.assertLessEqual( 149 | np.abs(np.cos(pos / (10000 ** ((i - 1) / dim_model))) - emb_val), 150 | self.TOLERANCE 151 | ) 152 | 153 | def test_create_shifted_output_mask(self): 154 | # create test data 155 | seq = torch.ones(2, 4).long() 156 | target_mask = torch.ByteTensor( 157 | [ 158 | [ 159 | [1, 0, 0, 0], 160 | [1, 1, 0, 0], 161 | [1, 1, 1, 0], 162 | [1, 1, 1, 1] 163 | ], 164 | [ 165 | [1, 0, 0, 0], 166 | [1, 1, 0, 0], 167 | [1, 1, 1, 0], 168 | [1, 1, 1, 1] 169 | ] 170 | ] 171 | ) 172 | 173 | # create the mask 174 | mask = util.create_shifted_output_mask(seq) 175 | 176 | # CHECK: the retrieved mask is a ByteTensor 177 | self.assertIsInstance(mask, torch.ByteTensor) 178 | 179 | # CHECK: the mask has the correct shape 180 | self.assertEqual((seq.size(0), seq.size(1), seq.size(1)), mask.size()) 181 | 182 | # CHECK: the mask contains the correct values 183 | self.assertEqual(target_mask, mask) 184 | 185 | def test_shift_output_sequence(self): 186 | # create test data 187 | seq = torch.FloatTensor( 188 | [ 189 | [ 190 | [1, 2, 3, 4, 5], 191 | [11, 22, 33, 44, 55], 192 | [111, 222, 333, 444, 555] 193 | ], 194 | [ 195 | [6, 7, 8, 9, 0], 196 | [66, 77, 88, 99, 00], 197 | [666, 777, 888, 999, 000] 198 | ] 199 | ] 200 | ) 201 | target = torch.FloatTensor( 202 | [ 203 | [ 204 | [0, 0, 0, 0, 0], 205 | [1, 2, 3, 4, 5], 206 | [11, 22, 33, 44, 55] 207 | ], 208 | [ 209 | [0, 0, 0, 0, 0], 210 | [6, 7, 8, 9, 0], 211 | [66, 77, 88, 99, 00] 212 | ] 213 | ] 214 | ) 215 | 216 | # shift the sequences 217 | shifted_seq = util.shift_output_sequence(seq) 218 | 219 | # CHECK: the sequence has been shifted correctly 220 | self.eps = 1e-22 221 | self.assertEqual(target, shifted_seq) 222 | --------------------------------------------------------------------------------