├── src ├── __init__.py ├── modules │ ├── __init__.py │ ├── smoothing.py │ ├── ffn_test.py │ ├── embedding_test.py │ ├── smoothing_test.py │ ├── ffn.py │ └── embedding.py ├── models │ ├── encoder_test.py │ ├── decoder_test.py │ ├── encoder.py │ ├── decoder.py │ └── transformer.py └── utils │ ├── hooks.py │ └── __init__.py ├── requirements.txt ├── hyperparams └── config.yaml ├── .gitignore ├── data.py ├── train.py └── README.md /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .embedding import * 2 | from .ffn import * -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-ignite==0.1.2 2 | torch==1.0.1.post2 3 | numpy==1.22.0 4 | torchtext==0.3.1 5 | six==1.10.0 -------------------------------------------------------------------------------- /src/modules/smoothing.py: -------------------------------------------------------------------------------- 1 | def label_smoothing(labels, eps=0.1): 2 | C = labels.size(1) 3 | return ((1 - eps) * labels) + (eps / C) 4 | 5 | -------------------------------------------------------------------------------- /src/models/encoder_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .encoder import TransformerEncoder 4 | 5 | 6 | def test_encoder(): 7 | encoder = TransformerEncoder(512, 8, 4) 8 | X = torch.randn(3, 5, 512) 9 | enc = encoder(X) 10 | 11 | assert enc.size() == X.size() 12 | -------------------------------------------------------------------------------- /src/models/decoder_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .decoder import TransformerDecoder 4 | 5 | 6 | def test_decoder(): 7 | decoder = TransformerDecoder(512, 8, 6) 8 | query = torch.randn(3, 10, 512, requires_grad=False) 9 | key = torch.randn(3, 5, 512, requires_grad=False) 10 | 11 | result = decoder(query, key) 12 | assert result.size() == query.size() 13 | -------------------------------------------------------------------------------- /src/modules/ffn_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | 4 | from modules.ffn import PositionWiseFFN 5 | 6 | 7 | class PositionWiseFFNTest(unittest.TestCase): 8 | def test_ffn(self): 9 | inputs = torch.randn(3, 5, 512) 10 | num_units = [2048, 512] 11 | pwffn = PositionWiseFFN(inputs.size(-1), num_units) 12 | result = pwffn(inputs) 13 | 14 | self.assertEqual(result.size(), inputs.size()) 15 | 16 | 17 | if __name__ == "__main__": 18 | unittest.main() 19 | -------------------------------------------------------------------------------- /src/modules/embedding_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from modules.embedding import TransformerEmbedding 5 | 6 | 7 | def test_embedding(): 8 | vocab_size = 100 9 | max_length = 150 10 | embedding_size = 300 11 | 12 | sequence = torch.tensor(np.random.randint(0, 100, size=(3, 5)), requires_grad=False) 13 | embedding = TransformerEmbedding(vocab_size, max_length, embedding_size, 1) 14 | 15 | result = embedding(sequence) 16 | assert result.size() == (3, 5, embedding_size) 17 | -------------------------------------------------------------------------------- /src/modules/smoothing_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | 4 | from hypothesis import given 5 | from hypothesis.strategies import integers 6 | 7 | from modules.embedding import OneHotEmbedding 8 | from modules.smoothing import label_smoothing 9 | 10 | 11 | class SmoothingTest(unittest.TestCase): 12 | @given(integers(1, 1000)) 13 | def test_label_smoothing(self, C): 14 | onehot = OneHotEmbedding(1000) 15 | smoothed_label = label_smoothing(onehot(torch.LongTensor([C]))) 16 | 17 | test = smoothed_label.sum(1).data.numpy()[0] 18 | self.assertAlmostEqual(round(test), 1.0) 19 | -------------------------------------------------------------------------------- /src/utils/hooks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pathlib import Path 4 | 5 | 6 | def restore_checkpoint_hook(model, model_path, logger=print): 7 | def restore_checkpoint(engine): 8 | try: 9 | model_file = Path(model_path) 10 | if model_file.exists(): 11 | logger("Start restore model...") 12 | model.load_state_dict(torch.load(model_path)) 13 | logger("Finish restore model!") 14 | else: 15 | logger("Model not found, skip restoring model") 16 | except Exception as e: 17 | logger("Something wrong while restoring the model: %s" % str(e)) 18 | 19 | return restore_checkpoint 20 | -------------------------------------------------------------------------------- /src/models/encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from ..modules.ffn import PositionWiseFFN 4 | 5 | 6 | class _Layer(nn.Module): 7 | def __init__(self, dim, num_head): 8 | super(_Layer, self).__init__() 9 | 10 | self.attn = nn.MultiheadAttention(dim, num_head) 11 | self.pffn = PositionWiseFFN(dim) 12 | 13 | def forward(self, src): 14 | out = self.attn(src, src, src)[0] 15 | out = self.pffn(out) 16 | 17 | return out 18 | 19 | 20 | class TransformerEncoder(nn.Module): 21 | def __init__(self, dim, num_head, num_layers): 22 | super().__init__() 23 | 24 | self.layer = nn.Sequential(*[_Layer(dim, num_head) for _ in range(num_layers)]) 25 | 26 | def forward(self, src): 27 | return self.layer(src) 28 | -------------------------------------------------------------------------------- /hyperparams/config.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | train: "/Users/adityakurniawan/Workspace/open-source/pytorch-transformer/wmt14/train" 3 | eval: "/Users/adityakurniawan/Workspace/open-source/pytorch-transformer/wmt14/eval" 4 | test: "/Users/adityakurniawan/Workspace/open-source/pytorch-transformer/wmt14/test" 5 | source_ext: "en" 6 | target_ext: "de" 7 | batch_size: 32 8 | source_min_freq: 10 9 | target_min_freq: 10 10 | source_max_freq: 10000 11 | target_max_freq: 10000 12 | training: 13 | epochs: 10 14 | learning_rate: 0.0004 15 | max_len: 100 16 | decay_step: 500 17 | decay_percent: 0.1 18 | val_log: 100 19 | checkpoint: "./checkpoints" 20 | model: 21 | encoder_emb_size: 512 22 | decoder_emb_size: 512 23 | enc_dim: 512 24 | enc_num_head: 8 25 | enc_num_layer: 1 26 | dec_dim: 512 27 | dec_num_head: 8 28 | dec_num_layer: 1 29 | -------------------------------------------------------------------------------- /src/modules/ffn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class PositionWiseFFN(nn.Module): 5 | def __init__(self, feature_size, num_units=2048, dropout=0.1): 6 | super(PositionWiseFFN, self).__init__() 7 | self._dropout = dropout 8 | self.ffn = nn.Sequential( 9 | nn.LayerNorm(feature_size), 10 | nn.Linear(feature_size, num_units), 11 | nn.ReLU(), 12 | nn.Dropout(dropout), 13 | nn.Linear(num_units, feature_size), 14 | ) 15 | self.ln = nn.LayerNorm(feature_size) 16 | 17 | def forward(self, X): 18 | ffn = self.ffn(X) 19 | # residual network 20 | ffn += X 21 | ffn = self.ln(ffn) 22 | 23 | return ffn 24 | 25 | def init_weight(self): 26 | for idx in range(len(self.ffn)): 27 | if hasattr(self.ffn[idx], "weight"): 28 | nn.init.uniform_(self.ffn[idx].weight, -0.1, 0.1) 29 | -------------------------------------------------------------------------------- /src/models/decoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from ..modules.ffn import PositionWiseFFN 4 | 5 | 6 | class _Layer(nn.Module): 7 | def __init__(self, dim, num_head): 8 | super(_Layer, self).__init__() 9 | 10 | self.self_attn = nn.MultiheadAttention(dim, num_head) 11 | self.lookbehind_attn = nn.MultiheadAttention(dim, num_head) 12 | self.pffn = PositionWiseFFN(dim) 13 | 14 | self.ln1 = nn.LayerNorm(dim) 15 | self.ln2 = nn.LayerNorm(dim) 16 | 17 | def forward(self, src, tgt): 18 | out = self.self_attn(tgt, tgt, tgt)[0] 19 | out += tgt 20 | out = self.ln1(out) 21 | look_out = self.lookbehind_attn(tgt, src, src)[0] 22 | out += look_out 23 | out = self.ln2(out) 24 | out = self.pffn(out) 25 | 26 | return out 27 | 28 | 29 | class TransformerDecoder(nn.Module): 30 | def __init__(self, dim, num_head, num_layers): 31 | super().__init__() 32 | 33 | self.layers = nn.ModuleList([_Layer(dim, num_head) for _ in range(num_layers)]) 34 | 35 | def forward(self, src, tgt): 36 | out = tgt 37 | for layer in self.layers: 38 | out = layer(out, src) 39 | 40 | return out 41 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | 4 | def print_current_prediction(predictions, targets, vocab): 5 | n_sample = 100 if len(targets) >= 100 else len(targets) 6 | result_str = "" 7 | bold_code = "\033[1m" 8 | end_bold_code = "\033[0m" 9 | result_str += "Current state of the model\n" 10 | result_str += ("=" * 100) + "\n" 11 | rand_ids = random.sample(range(len(targets)), n_sample) 12 | pred_sample = [predictions[i] for i in rand_ids] 13 | trg_sample = [targets[i] for i in rand_ids] 14 | for this_idx, (pred, trg) in enumerate(zip(pred_sample, trg_sample)): 15 | vocab_mapper = lambda x: vocab.itos[x] 16 | preds = list(map(vocab_mapper, pred)) 17 | trgs = list(map(vocab_mapper, trg)) 18 | result_str += "{}Prediction{}: {}\n".format( 19 | bold_code, end_bold_code, " ".join(preds[: len(trgs)]) 20 | ) 21 | result_str += "{}Target{}: {}\n".format( 22 | bold_code, end_bold_code, " ".join(trgs) 23 | ) 24 | result_str += "{}Difference of length{}: {}\n\n".format( 25 | bold_code, end_bold_code, abs(len(preds) - len(trgs)) 26 | ) 27 | if this_idx < len(pred_sample) - 1: 28 | result_str += "\n" 29 | result_str += "=" * 100 30 | print(result_str) 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # pytorch 2 | checkpoints 3 | 4 | # torchtext 5 | .data 6 | 7 | # Visual studio code 8 | .vscode 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | .static_storage/ 65 | .media/ 66 | local_settings.py 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # celery beat schedule file 88 | celerybeat-schedule 89 | 90 | # SageMath parsed files 91 | *.sage.py 92 | 93 | # Environments 94 | .env 95 | .venv 96 | env/ 97 | venv/ 98 | ENV/ 99 | env.bak/ 100 | venv.bak/ 101 | 102 | # Spyder project settings 103 | .spyderproject 104 | .spyproject 105 | 106 | # Rope project settings 107 | .ropeproject 108 | 109 | # mkdocs documentation 110 | /site 111 | 112 | # mypy 113 | .mypy_cache/ 114 | wmt14 115 | outputs 116 | -------------------------------------------------------------------------------- /src/modules/embedding.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class TransformerEmbedding(nn.Module): 8 | def __init__(self, word_embedding, positional_embedding=None): 9 | super(TransformerEmbedding, self).__init__() 10 | # add + 3 for pad, unk and eos/bos 11 | self.word_embedding = word_embedding 12 | # self.word_embedding = nn.Embedding(vocab_size + 3, 13 | # embedding_size, 14 | # padding_idx=padding_idx) 15 | 16 | self.positional_embedding = positional_embedding 17 | 18 | def forward(self, X): 19 | out = self.word_embedding(X) 20 | if self.positional_embedding: 21 | out = self.positional_embedding(out) 22 | 23 | return out 24 | 25 | 26 | class PositionalEncoding(nn.Module): 27 | def __init__(self, d_model, dropout=0.1, max_len=5000): 28 | super().__init__() 29 | 30 | self.dropout = nn.Dropout(p=dropout) 31 | 32 | pe = torch.zeros(max_len, d_model) 33 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 34 | div_term = torch.exp( 35 | torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) 36 | ) 37 | pe[:, 0::2] = torch.sin(position * div_term) 38 | pe[:, 1::2] = torch.cos(position * div_term) 39 | pe = pe.unsqueeze(0).transpose(0, 1) 40 | self.register_buffer("pe", pe) 41 | 42 | def forward(self, x): 43 | x = x + self.pe[: x.size(0), :] 44 | return self.dropout(x) 45 | 46 | 47 | class OneHotEmbedding(nn.Module): 48 | def __init__(self, num_class): 49 | super(OneHotEmbedding, self).__init__() 50 | self.embed = nn.Embedding(num_class, num_class) 51 | self.embed.weight.data = self._build_onehot(num_class) 52 | # to prevent the weight getting trained 53 | self.embed.weight.requires_grad = False 54 | 55 | def _build_onehot(self, num_class): 56 | onehot = torch.eye(num_class) 57 | return onehot 58 | 59 | def forward(self, x): 60 | return self.embed(x) 61 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchtext import data 4 | from torchtext import datasets 5 | 6 | 7 | def create_dataset(cfg, logger=print): 8 | 9 | if torch.cuda.is_available(): 10 | device_data = "cuda" 11 | else: 12 | device_data = "cpu" 13 | 14 | source = data.Field(batch_first=False, lower=True, init_token="") 15 | target = data.Field(batch_first=False, lower=True, eos_token="") 16 | 17 | fields = (source, target) 18 | if cfg.train and cfg.eval and cfg.test: 19 | exts = ("." + cfg.source_ext, "." + cfg.target_ext) 20 | train_path = cfg.train 21 | val_path = cfg.eval 22 | test_path = cfg.test 23 | 24 | train, val, test = get_mt_datasets(exts, fields, train_path, val_path, 25 | test_path) 26 | else: 27 | logger("neither train_path or val_path were defined. " 28 | "using WMT14 dataset as a fallback") 29 | train, val, test = get_wmt_dataset((".de", ".en"), fields) 30 | 31 | source.build_vocab(train.src, 32 | min_freq=cfg.source_min_freq, 33 | max_size=cfg.source_max_freq) 34 | target.build_vocab(train.trg, 35 | min_freq=cfg.target_min_freq, 36 | max_size=cfg.target_max_freq) 37 | 38 | train_iter, val_iter, test_iter = data.BucketIterator.splits( 39 | (train, val, test), 40 | batch_size=cfg.batch_size, 41 | repeat=False, 42 | shuffle=True, 43 | device=device_data) 44 | 45 | return train_iter, val_iter, test_iter, source.vocab, target.vocab 46 | 47 | 48 | def get_wmt_dataset(exts, fields): 49 | train, val, test = datasets.WMT14.splits(exts=exts, fields=fields) 50 | return train, val, test 51 | 52 | 53 | def get_mt_datasets(exts, fields, train_path, val_path, test_path): 54 | train = datasets.TranslationDataset(path=train_path, 55 | exts=exts, 56 | fields=fields) 57 | val = datasets.TranslationDataset(path=val_path, exts=exts, fields=fields) 58 | test = datasets.TranslationDataset(path=test_path, 59 | exts=exts, 60 | fields=fields) 61 | return train, val, test 62 | -------------------------------------------------------------------------------- /src/models/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .encoder import TransformerEncoder 5 | from .decoder import TransformerDecoder 6 | from ..modules.embedding import PositionalEncoding, TransformerEmbedding 7 | 8 | 9 | class Transformer(nn.Module): 10 | def __init__( 11 | self, 12 | max_length, 13 | enc_vocab, 14 | dec_vocab, 15 | enc_emb_size, 16 | dec_emb_size, 17 | enc_dim, 18 | enc_num_head, 19 | enc_num_layer, 20 | dec_dim, 21 | dec_num_head, 22 | dec_num_layer, 23 | dropout_rate=0.1, 24 | ): 25 | super(Transformer, self).__init__() 26 | enc_vocab_size = len(enc_vocab.itos) 27 | dec_vocab_size = len(dec_vocab.itos) 28 | 29 | word_enc_embedding = nn.Embedding( 30 | enc_vocab_size + 3, enc_emb_size, padding_idx=enc_vocab.stoi[""] 31 | ) 32 | pos_encoder = PositionalEncoding(enc_emb_size) 33 | word_dec_embedding = nn.Embedding( 34 | dec_vocab_size + 3, dec_emb_size, padding_idx=dec_vocab.stoi[""] 35 | ) 36 | pos_decoder = PositionalEncoding(dec_emb_size) 37 | self.encoder_embedding = TransformerEmbedding(word_enc_embedding, pos_encoder) 38 | self.decoder_embedding = TransformerEmbedding(word_dec_embedding, pos_decoder) 39 | 40 | self.encoder = TransformerEncoder(enc_dim, enc_num_head, enc_num_layer) 41 | self.decoder = TransformerDecoder(dec_dim, dec_num_head, dec_num_layer) 42 | 43 | self.logits_layer = nn.Linear(in_features=dec_dim, out_features=dec_vocab_size) 44 | self.softmax = nn.Softmax(dim=-1) 45 | 46 | def forward(self, enc_input, dec_input): 47 | enc_embed = self.encoder_embedding(enc_input) 48 | encoder_result = self.encoder(enc_embed) 49 | 50 | dec_embed = self.decoder_embedding(dec_input) 51 | decoder_result = self.decoder(dec_embed, encoder_result) 52 | 53 | logits = self.logits_layer(decoder_result.reshape(-1, decoder_result.size(-1))) 54 | softmax = self.softmax(logits) 55 | softmax = softmax.reshape(decoder_result.size(0), decoder_result.size(1), -1) 56 | 57 | return softmax, logits 58 | 59 | def generate_square_subsequent_mask(self, sz): 60 | """Generate a square mask for the sequence. 61 | The masked positions are filled with float('-inf'). 62 | Unmasked positions are filled with float(0.0). 63 | """ 64 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) 65 | mask = ( 66 | mask.float() 67 | .masked_fill(mask == 0, float("-inf")) 68 | .masked_fill(mask == 1, float(0.0)) 69 | ) 70 | return mask 71 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import pytorch_lightning as pl 6 | 7 | from operator import iconcat 8 | from functools import reduce 9 | from omegaconf import DictConfig, OmegaConf 10 | from pytorch_lightning import Trainer, seed_everything 11 | # from src.utils.hooks import (validation_result_hook, restore_checkpoint_hook) 12 | from src.utils import print_current_prediction 13 | from torch.optim.lr_scheduler import StepLR 14 | 15 | from src.models.transformer import Transformer 16 | from data import create_dataset 17 | 18 | seed_everything(42) 19 | 20 | 21 | class MachineTranslationModel(pl.LightningModule): 22 | def __init__(self, hparams, source_vocab, target_vocab): 23 | super().__init__() 24 | 25 | self.hparams = hparams 26 | self.model = Transformer( 27 | max_length=hparams["training"]["max_len"], 28 | enc_vocab=source_vocab, 29 | dec_vocab=target_vocab, 30 | enc_emb_size=hparams["model"]["encoder_emb_size"], 31 | dec_emb_size=hparams["model"]["decoder_emb_size"], 32 | enc_dim=hparams["model"]["enc_dim"], 33 | enc_num_head=hparams["model"]["enc_num_head"], 34 | enc_num_layer=hparams["model"]["enc_num_layer"], 35 | dec_dim=hparams["model"]["dec_dim"], 36 | dec_num_head=hparams["model"]["dec_num_head"], 37 | dec_num_layer=hparams["model"]["dec_num_layer"]) 38 | self.criterion = nn.CrossEntropyLoss() 39 | 40 | self._source_vocab = source_vocab 41 | self._target_vocab = target_vocab 42 | 43 | def forward(self, batch): 44 | return self.model(batch.src, batch.trg) 45 | 46 | def training_step(self, batch, batch_idx): 47 | _, logits = self(batch) 48 | 49 | flattened_target = batch.trg.view(-1) 50 | loss = self.criterion(logits, flattened_target) 51 | 52 | tensorboard_logs = {'train_loss': loss.item()} 53 | 54 | return {'loss': loss, 'log': tensorboard_logs} 55 | 56 | def validation_step(self, batch, batch_idx): 57 | with torch.no_grad(): 58 | probs, logits = self(batch) 59 | 60 | flattened_target = batch.trg.view(-1) 61 | loss = self.criterion(logits, flattened_target) 62 | 63 | preds = probs.transpose(0, 1).argmax(-1).tolist() 64 | targets = batch.trg.t().tolist() 65 | 66 | return {"loss": loss, "predictions": preds, "targets": targets} 67 | 68 | def test_step(self, batch, batch_idx): 69 | with torch.no_grad(): 70 | probs, logits = self.model(batch.src, batch.trg) 71 | 72 | flattened_target = batch.trg.view(-1) 73 | loss = self.criterion(logits, flattened_target) 74 | 75 | preds = probs.transpose(0, 1).argmax(-1).tolist() 76 | targets = batch.trg.t().tolist() 77 | 78 | return { 79 | "loss": loss.item(), 80 | "predictions": preds, 81 | "targets": targets 82 | } 83 | 84 | def validation_epoch_end(self, outputs): 85 | avg_loss = torch.stack([x["loss"] for x in outputs]).mean() 86 | predictions = reduce(iconcat, [x["predictions"] for x in outputs]) 87 | targets = reduce(iconcat, [x["targets"] for x in outputs]) 88 | 89 | print_current_prediction(predictions, targets, self._target_vocab) 90 | 91 | tensorboard_logs = {"avg_val_loss": avg_loss} 92 | return {"val_loss": avg_loss, "log": tensorboard_logs} 93 | 94 | def test_epoch_end(self, outputs): 95 | avg_loss = torch.stack([x["loss"] for x in outputs]).mean() 96 | predictions = reduce(iconcat, [x["predictions"] for x in outputs]) 97 | targets = reduce(iconcat, [x["targets"] for x in outputs]) 98 | 99 | print_current_prediction(predictions, targets, self._target_vocab) 100 | 101 | tensorboard_logs = {"avg_test_loss": avg_loss} 102 | return {"test_loss": avg_loss, "log": tensorboard_logs} 103 | 104 | def configure_optimizers(self): 105 | optimizer = optim.Adam(self.parameters(), 106 | lr=self.hparams["training"]["learning_rate"]) 107 | scheduler = StepLR(optimizer, 108 | step_size=self.hparams["training"]["decay_step"], 109 | gamma=self.hparams["training"]["decay_percent"]) 110 | return [optimizer], [scheduler] 111 | 112 | 113 | @hydra.main(config_path="hyperparams/config.yaml") 114 | def run(cfg: DictConfig): 115 | train_iter, val_iter, test_iter, source_vocab, target_vocab = create_dataset( 116 | cfg.dataset) 117 | model = MachineTranslationModel(OmegaConf.to_container(cfg, resolve=True), 118 | source_vocab, target_vocab) 119 | trainer = Trainer() 120 | trainer.fit(model, train_iter, val_iter) 121 | trainer.test(test_iter) 122 | 123 | 124 | if __name__ == '__main__': 125 | run() 126 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-transformer 2 | 3 | **WARNING! Still in the middle of refactoring, please revert to bb88f0a for the latest working version of this repo** 4 | 5 | Implementation of "Attention is All You Need" paper with PyTorch. 6 | Installed components are: 7 | 8 | 1. Multi-Head Attention 9 | 2. Positional Encoding with sinusodial 10 | 3. Position Wise FFN 11 | 4. Label Smoothing (unfortunately still can't use this because PyTorch has no support for loss calculation with logits yet :( ) 12 | 13 | ## Requirements 14 | 15 | * [pytorch](http://pytorch.org/) 16 | * [torchtext](https://github.com/pytorch/text/) 17 | * [ignite](https://github.com/pytorch/ignite/) 18 | 19 | ## How to use? 20 | 21 | ### Training 22 | 23 | You can run the training with simply `python train.py`. If you want to add other arguments. Please see the example below 24 | 25 | ```bash 26 | python train.py --batch_size 3 --source_train_path wmt14/train.en --target_train_path wmt14/train.de --source_val_path wmt14/eval.en --target_val_path wmt14/eval.de 27 | ``` 28 | 29 | ### Arguments 30 | 31 | ```bash 32 | --batch_size BATCH_SIZE 33 | Number of batch in single iteration 34 | --source_train_path SOURCE_TRAIN_PATH 35 | Path for source training data. Ex: data/train.en 36 | --target_train_path TARGET_TRAIN_PATH 37 | Path for target training data. Ex: data/train.de 38 | --source_val_path SOURCE_VAL_PATH 39 | Path for source validation data. Ex: data/val.en 40 | --target_val_path TARGET_VAL_PATH 41 | Path for target validation data. Ex: data/val.de 42 | --epochs EPOCHS Number of epochs 43 | --learning_rate LEARNING_RATE 44 | Learning rate size 45 | --max_len MAX_LEN Maximum allowed sentence length 46 | --enc_max_vocab ENC_MAX_VOCAB 47 | Maximum vocabs for encoder 48 | --dec_max_vocab DEC_MAX_VOCAB 49 | Maximum vocabs for decoder 50 | --encoder_units ENCODER_UNITS 51 | Number of encoder units for every layers. Separable by 52 | commas 53 | --decoder_units DECODER_UNITS 54 | Number of decoder units for every layers. Separable by 55 | commas 56 | --encoder_emb_size ENCODER_EMB_SIZE 57 | Size of encoder's embedding 58 | --decoder_emb_size DECODER_EMB_SIZE 59 | Size of decoder's embedding 60 | --log_interval LOG_INTERVAL 61 | Print loss for every N steps 62 | --save_interval SAVE_INTERVAL 63 | Save model for every N steps 64 | --compare_interval COMPARE_INTERVAL 65 | Compare current prediction with its true label for 66 | every N steps 67 | --decay_step DECAY_STEP 68 | Learning rate will decay after N step 69 | --decay_percent DECAY_PERCENT 70 | Percent of decreased in learning rate decay 71 | --model_dir MODEL_DIR 72 | Location to save the model 73 | ``` 74 | 75 | ## Results 76 | 77 | With the following command 78 | 79 | ```bash 80 | python train.py --batch_size 16 --max_len 500 --decoder_units 512,512,512 --save_interval 500 --enc_max_vocab 37000 --dec_max_vocab 37000 --decay_step 4000 --epochs 5 81 | ``` 82 | 83 | I was able to get result as the following 84 | 85 | ```bash 86 | Train loss: 2.17e-02 87 | Validation loss: 0 88 | 89 | Validation result 90 | ================= 91 | Prediction: zwei anlagen so nah bei@@ einander : absicht oder schil@@ d@@ bürger@@ st@@ reich ? 92 | Target: zwei anlagen so nah bei@@ einander : absicht oder schil@@ d@@ bürger@@ st@@ reich ? 93 | Difference of length: 0 94 | 95 | 96 | Prediction: al@@ fred ab@@ el , der das grundstück derzeit verwaltet , hatte die ver@@ schön@@ erungs@@ aktion mit seinem kollegen rein@@ hard dom@@ ke von der b@@ i ab@@ gesprochen . 97 | Target: al@@ fred ab@@ el , der das grundstück derzeit verwaltet , hatte die ver@@ schön@@ erungs@@ aktion mit seinem kollegen rein@@ hard dom@@ ke von der b@@ i ab@@ gesprochen . 98 | Difference of length: 0 99 | 100 | 101 | Prediction: andere komit@@ e@@ em@@ it@@ gli@@ eder sagten , es gebe nur ver@@ einzel@@ te berichte von pi@@ loten , die für eine stör@@ ung von flugzeug@@ systemen durch die geräte sp@@ rä@@ chen , und die meisten davon seien sehr alt . 102 | Target: andere komit@@ e@@ em@@ it@@ gli@@ eder sagten , es gebe nur ver@@ einzel@@ te berichte von pi@@ loten , die für eine stör@@ ung von flugzeug@@ systemen durch die geräte sp@@ rä@@ chen , und die meisten davon seien sehr alt . 103 | Difference of length: 0 104 | 105 | 106 | Prediction: der mann aus lamp@@ recht@@ sh@@ au@@ sen wollte an der außen@@ fass@@ ade eines gas@@ th@@ auses einen def@@ ekten heiz@@ ungs@@ füh@@ ler aus@@ wechseln . 107 | Target: der mann aus lamp@@ recht@@ sh@@ au@@ sen wollte an der außen@@ fass@@ ade eines gas@@ th@@ auses einen def@@ ekten heiz@@ ungs@@ füh@@ ler aus@@ wechseln . 108 | Difference of length: 0 109 | 110 | 111 | Prediction: falls allerdings in den kommenden monaten keine neuen bestellungen bekannt gegeben werden , dann erwarten wir , dass der markt dem programm gegenüber skep@@ tischer wird . 112 | Target: falls allerdings in den kommenden monaten keine neuen bestellungen bekannt gegeben werden , dann erwarten wir , dass der markt dem programm gegenüber skep@@ tischer wird . 113 | Difference of length: 0 114 | 115 | 116 | Prediction: es recht@@ fertigt nicht die ergebnisse eines berichts , in dem es heißt , die anstrengungen des wei@@ ßen hauses zur ru@@ hi@@ g@@ stellung der medien seien die „ aggres@@ si@@ v@@ sten ... seit der ni@@ x@@ on-@@ regierung “ . 117 | Target: es recht@@ fertigt nicht die ergebnisse eines berichts , in dem es heißt , die anstrengungen des wei@@ ßen hauses zur ru@@ hi@@ g@@ stellung der medien seien die „ aggres@@ si@@ v@@ sten ... seit der ni@@ x@@ on-@@ regierung “ . 118 | Difference of length: 0 119 | 120 | 121 | Prediction: " der schutz möglichst vieler be@@ bau@@ ter grund@@ stücke ist das ziel " , so h@@ äuß@@ ler . 122 | Target: " der schutz möglichst vieler be@@ bau@@ ter grund@@ stücke ist das ziel " , so h@@ äuß@@ ler . 123 | Difference of length: 0 124 | 125 | 126 | Prediction: das kon@@ sist@@ orium im nächsten jahr sei deshalb bedeut@@ sam , weil es das erste seit der wahl von fran@@ z@@ is@@ kus im märz diesen jahres sei , so val@@ ero . 127 | Target: das kon@@ sist@@ orium im nächsten jahr sei deshalb bedeut@@ sam , weil es das erste seit der wahl von fran@@ z@@ is@@ kus im märz diesen jahres sei , so val@@ ero . 128 | Difference of length: 0 129 | 130 | 131 | Prediction: samsung , hu@@ a@@ wei und ht@@ c stellen hand@@ ys her , die mit goo@@ g@@ les betriebssystem andro@@ id arbeiten , das in schar@@ f@@ em wettbewerb zu den mobil@@ produkten von apple und microsoft steht . 132 | Target: samsung , hu@@ a@@ wei und ht@@ c stellen hand@@ ys her , die mit goo@@ g@@ les betriebssystem andro@@ id arbeiten , das in schar@@ f@@ em wettbewerb zu den mobil@@ produkten von apple und microsoft steht . 133 | Difference of length: 0 134 | 135 | 136 | Prediction: etw@@ a gegen 14 : 15 uhr am mittwoch sah ein spaziergän@@ ger , der seinen hund aus@@ führte , die gest@@ rand@@ ete ru@@ by auf dem 15 meter hohen absatz im stein@@ bruch . 137 | Target: etw@@ a gegen 14 : 15 uhr am mittwoch sah ein spaziergän@@ ger , der seinen hund aus@@ führte , die gest@@ rand@@ ete ru@@ by auf dem 15 meter hohen absatz im stein@@ bruch . 138 | Difference of length: 0 139 | 140 | 141 | Prediction: c@@ ook sagte : „ nach den erhöh@@ ungen bei der st@@ emp@@ el@@ steuer auf hoch@@ prei@@ si@@ ge wohnungen und der einführung der damit verbundenen gesetzgebung gegen ein um@@ gehen kann man schwer@@ lich behaupten , hochwertige immobilien seien zu niedrig best@@ euert , ungeachtet der auswirkungen des ver@@ alt@@ eten gemein@@ dest@@ euer@@ systems . “ 142 | Target: c@@ ook sagte : „ nach den erhöh@@ ungen bei der st@@ emp@@ el@@ steuer auf hoch@@ prei@@ si@@ ge wohnungen und der einführung der damit verbundenen gesetzgebung gegen ein um@@ gehen kann man schwer@@ lich behaupten , hochwertige immobilien seien zu niedrig best@@ euert , ungeachtet der auswirkungen des ver@@ alt@@ eten gemein@@ dest@@ euer@@ systems . “ 143 | Difference of length: 0 144 | 145 | 146 | Prediction: ohne die unterstützung des einzigen anderen herstell@@ ers großer moderner j@@ ets sagen experten , dass der ruf nach einem neuen branchen@@ standard vermutlich ver@@ pu@@ ffen werde , aber von der welle von 7@@ 7@@ 7@@ x-@@ verk@@ äu@@ fen ab@@ lenken könnte . 147 | Target: ohne die unterstützung des einzigen anderen herstell@@ ers großer moderner j@@ ets sagen experten , dass der ruf nach einem neuen branchen@@ standard vermutlich ver@@ pu@@ ffen werde , aber von der welle von 7@@ 7@@ 7@@ x-@@ verk@@ äu@@ fen ab@@ lenken könnte . 148 | Difference of length: 0 149 | 150 | 151 | Prediction: " bildung ist ein wichtiger standor@@ t@@ faktor " , unter@@ strich clau@@ dia st@@ eh@@ le , direkt@@ or@@ in der hans@@ -@@ th@@ om@@ a-@@ schule , die das vern@@ etz@@ te schul@@ projekt bildungs@@ zentrum hoch@@ schwar@@ zw@@ ald vor@@ stellte . 152 | Target: " bildung ist ein wichtiger standor@@ t@@ faktor " , unter@@ strich clau@@ dia st@@ eh@@ le , direkt@@ or@@ in der hans@@ -@@ th@@ om@@ a-@@ schule , die das vern@@ etz@@ te schul@@ projekt bildungs@@ zentrum hoch@@ schwar@@ zw@@ ald vor@@ stellte . 153 | Difference of length: 0 154 | 155 | 156 | Prediction: es bestand die möglichkeit , dass sie sehr schwer verletzt war oder sch@@ lim@@ mer@@ es . 157 | Target: es bestand die möglichkeit , dass sie sehr schwer verletzt war oder sch@@ lim@@ mer@@ es . 158 | Difference of length: 0 159 | 160 | 161 | Prediction: zu dem unfall war es nach angaben der polizei gekommen , als ein 26 jahre alter mann am donn@@ erst@@ ag@@ abend , gegen 22 uhr , mit einem dam@@ en@@ fahr@@ rad ordnungs@@ widri@@ g auf dem linken geh@@ weg vom bahn@@ hof@@ platz in richtung markt@@ stätte unterwegs war . 162 | Target: zu dem unfall war es nach angaben der polizei gekommen , als ein 26 jahre alter mann am donn@@ erst@@ ag@@ abend , gegen 22 uhr , mit einem dam@@ en@@ fahr@@ rad ordnungs@@ widri@@ g auf dem linken geh@@ weg vom bahn@@ hof@@ platz in richtung markt@@ stätte unterwegs war . 163 | Difference of length: 0 164 | 165 | 166 | Prediction: " dieser aufwand ist nun weg " , freut sich ma@@ ier . 167 | Target: " dieser aufwand ist nun weg " , freut sich ma@@ ier . 168 | Difference of length: 0 169 | 170 | 171 | Prediction: thomas op@@ per@@ mann , der abgeordnete , der den für den geheim@@ dienst zuständigen parlamentarischen ausschuss leitet , erklärte , man solle die gelegenheit ergreifen , snow@@ den als zeu@@ ge anzu@@ hören , wenn dies möglich sei , „ ohne ihn zu gefährden und die beziehungen zu den usa völlig zu ru@@ in@@ ieren “ . 172 | Target: thomas op@@ per@@ mann , der abgeordnete , der den für den geheim@@ dienst zuständigen parlamentarischen ausschuss leitet , erklärte , man solle die gelegenheit ergreifen , snow@@ den als zeu@@ ge anzu@@ hören , wenn dies möglich sei , „ ohne ihn zu gefährden und die beziehungen zu den usa völlig zu ru@@ in@@ ieren “ . 173 | Difference of length: 0 174 | 175 | 176 | Prediction: r@@ eine pflanzen@@ mar@@ gar@@ ine sei eine gute alternative zu but@@ ter , jo@@ gh@@ urt lasse durch so@@ ja@@ jo@@ gh@@ urt ersetzen . 177 | Target: r@@ eine pflanzen@@ mar@@ gar@@ ine sei eine gute alternative zu but@@ ter , jo@@ gh@@ urt lasse durch so@@ ja@@ jo@@ gh@@ urt ersetzen . 178 | Difference of length: 0 179 | 180 | 181 | Prediction: der handel am nas@@ da@@ q op@@ tions market wurde am frei@@ tag@@ nach@@ mittag deutscher zeit unterbrochen . 182 | Target: der handel am nas@@ da@@ q op@@ tions market wurde am frei@@ tag@@ nach@@ mittag deutscher zeit unterbrochen . 183 | Difference of length: 0 184 | 185 | 186 | Prediction: ö@@ z@@ dem@@ ir will j@@ azz@@ ausbildung in stuttgart erhalten 187 | Target: ö@@ z@@ dem@@ ir will j@@ azz@@ ausbildung in stuttgart erhalten 188 | Difference of length: 0 189 | ``` 190 | 191 | I'm still not really sure why the validation can have such a good accuracy, will need to drill down and debug the model further. --------------------------------------------------------------------------------