├── .gitignore ├── README.md ├── config.py ├── data.py ├── main.py ├── models ├── __init__.py ├── block │ ├── decoder_block.py │ └── encoder_block.py ├── build_model.py ├── embedding │ ├── positional_encoding.py │ ├── token_embedding.py │ └── transformer_embedding.py ├── layer │ ├── multi_head_attention_layer.py │ ├── position_wise_feed_forward_layer.py │ └── residual_connection_layer.py └── model │ ├── decoder.py │ ├── encoder.py │ └── transformer.py ├── prepare.sh ├── requirements.txt └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .data 3 | *.pt 4 | *.out 5 | checkpoint 6 | .DS_Store 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # transformer_pytorch 2 | Transformer(Attention Is All You Need) Implementation in Pytorch. 3 | 4 | A detailed implementation description is provided in the post below. 5 | #### [Blog Post](https://cpm0722.github.io/pytorch-implementation/transformer) 6 | 7 | --- 8 | 9 | #### Install 10 | ```bash 11 | bash prepare.sh 12 | ``` 13 | 14 | #### Run Train ([Multi30k](https://github.com/multi30k/dataset)) 15 | ```bash 16 | python3 main.py 17 | ``` 18 | 19 | --- 20 | 21 | #### Reference 22 | 23 | - ##### [Attention Is All You Need](https://arxiv.org/pdf/1706.03762.pdf) 24 | - ##### [Harvard NLP](http://nlp.seas.harvard.edu/2018/04/03/attention.html) 25 | - ##### [hyunwoongko/transformer](https://github.com/hyunwoongko/transformer) 26 | - ##### [WikiDocs](https://wikidocs.net/31379) 27 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author : Hansu Kim(@cpm0722) 3 | @when : 2022-08-21 4 | @github : https://github.com/cpm0722 5 | @homepage : https://cpm0722.github.io 6 | """ 7 | 8 | import torch 9 | 10 | DEVICE = torch.device('cuda:0') 11 | CHECKPOINT_DIR = "./checkpoint" 12 | 13 | N_EPOCH = 1000 14 | 15 | BATCH_SIZE = 2048 16 | NUM_WORKERS = 8 17 | 18 | LEARNING_RATE = 1e-5 19 | WEIGHT_DECAY = 5e-4 20 | ADAM_EPS = 5e-9 21 | SCHEDULER_FACTOR = 0.9 22 | SCHEDULER_PATIENCE = 10 23 | 24 | WARM_UP_STEP = 100 25 | 26 | DROPOUT_RATE = 0.1 27 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author : Hansu Kim(@cpm0722) 3 | @when : 2022-08-21 4 | @github : https://github.com/cpm0722 5 | @homepage : https://cpm0722.github.io 6 | """ 7 | 8 | import os 9 | import torch 10 | from torch.utils.data import DataLoader 11 | from torchtext.vocab import build_vocab_from_iterator 12 | import torchtext.transforms as T 13 | 14 | from utils import save_pkl, load_pkl 15 | 16 | 17 | class Multi30k(): 18 | 19 | def __init__(self, 20 | lang=("en", "de"), 21 | max_seq_len=256, 22 | unk_idx=0, 23 | pad_idx=1, 24 | sos_idx=2, 25 | eos_idx=3, 26 | vocab_min_freq=2): 27 | 28 | self.dataset_name = "multi30k" 29 | self.lang_src, self.lang_tgt = lang 30 | self.max_seq_len = max_seq_len 31 | self.unk_idx = unk_idx 32 | self.pad_idx = pad_idx 33 | self.sos_idx = sos_idx 34 | self.eos_idx = eos_idx 35 | self.unk = "" 36 | self.pad = "" 37 | self.sos = "" 38 | self.eos = "" 39 | self.specials={ 40 | self.unk: self.unk_idx, 41 | self.pad: self.pad_idx, 42 | self.sos: self.sos_idx, 43 | self.eos: self.eos_idx 44 | } 45 | self.vocab_min_freq = vocab_min_freq 46 | 47 | self.tokenizer_src = self.build_tokenizer(self.lang_src) 48 | self.tokenizer_tgt = self.build_tokenizer(self.lang_tgt) 49 | 50 | self.train = None 51 | self.valid = None 52 | self.test = None 53 | self.build_dataset() 54 | 55 | self.vocab_src = None 56 | self.vocab_tgt = None 57 | self.build_vocab() 58 | 59 | self.transform_src = None 60 | self.transform_tgt = None 61 | self.build_transform() 62 | 63 | 64 | def build_dataset(self, raw_dir="raw", cache_dir=".data"): 65 | cache_dir = os.path.join(cache_dir, self.dataset_name) 66 | raw_dir = os.path.join(cache_dir, raw_dir) 67 | os.makedirs(raw_dir, exist_ok=True) 68 | 69 | train_file = os.path.join(cache_dir, "train.pkl") 70 | valid_file = os.path.join(cache_dir, "valid.pkl") 71 | test_file = os.path.join(cache_dir, "test.pkl") 72 | 73 | if os.path.exists(train_file): 74 | self.train = load_pkl(train_file) 75 | else: 76 | with open(os.path.join(raw_dir, "train.en"), "r") as f: 77 | train_en = [text.rstrip() for text in f] 78 | with open(os.path.join(raw_dir, "train.de"), "r") as f: 79 | train_de = [text.rstrip() for text in f] 80 | self.train = [(en, de) for en, de in zip(train_en, train_de)] 81 | save_pkl(self.train , train_file) 82 | 83 | if os.path.exists(valid_file): 84 | self.valid = load_pkl(valid_file) 85 | else: 86 | with open(os.path.join(raw_dir, "val.en"), "r") as f: 87 | valid_en = [text.rstrip() for text in f] 88 | with open(os.path.join(raw_dir, "val.de"), "r") as f: 89 | valid_de = [text.rstrip() for text in f] 90 | self.valid = [(en, de) for en, de in zip(valid_en, valid_de)] 91 | save_pkl(self.valid, valid_file) 92 | 93 | if os.path.exists(test_file): 94 | self.test = load_pkl(test_file) 95 | else: 96 | with open(os.path.join(raw_dir, "test_2016_flickr.en"), "r") as f: 97 | test_en = [text.rstrip() for text in f] 98 | with open(os.path.join(raw_dir, "test_2016_flickr.de"), "r") as f: 99 | test_de = [text.rstrip() for text in f] 100 | self.test = [(en, de) for en, de in zip(test_en, test_de)] 101 | save_pkl(self.test, test_file) 102 | 103 | 104 | def build_vocab(self, cache_dir=".data"): 105 | assert self.train is not None 106 | def yield_tokens(is_src=True): 107 | for text_pair in self.train: 108 | if is_src: 109 | yield [str(token) for token in self.tokenizer_src(text_pair[0])] 110 | else: 111 | yield [str(token) for token in self.tokenizer_tgt(text_pair[1])] 112 | 113 | cache_dir = os.path.join(cache_dir, self.dataset_name) 114 | os.makedirs(cache_dir, exist_ok=True) 115 | 116 | vocab_src_file = os.path.join(cache_dir, f"vocab_{self.lang_src}.pkl") 117 | if os.path.exists(vocab_src_file): 118 | vocab_src = load_pkl(vocab_src_file) 119 | else: 120 | vocab_src = build_vocab_from_iterator(yield_tokens(is_src=True), min_freq=self.vocab_min_freq, specials=self.specials.keys()) 121 | vocab_src.set_default_index(self.unk_idx) 122 | save_pkl(vocab_src, vocab_src_file) 123 | 124 | vocab_tgt_file = os.path.join(cache_dir, f"vocab_{self.lang_tgt}.pkl") 125 | if os.path.exists(vocab_tgt_file): 126 | vocab_tgt = load_pkl(vocab_tgt_file) 127 | else: 128 | vocab_tgt = build_vocab_from_iterator(yield_tokens(is_src=False), min_freq=self.vocab_min_freq, specials=self.specials.keys()) 129 | vocab_tgt.set_default_index(self.unk_idx) 130 | save_pkl(vocab_tgt, vocab_tgt_file) 131 | 132 | self.vocab_src = vocab_src 133 | self.vocab_tgt = vocab_tgt 134 | 135 | 136 | def build_tokenizer(self, lang): 137 | from torchtext.data.utils import get_tokenizer 138 | spacy_lang_dict = { 139 | 'en': "en_core_web_sm", 140 | 'de': "de_core_news_sm" 141 | } 142 | assert lang in spacy_lang_dict.keys() 143 | return get_tokenizer("spacy", spacy_lang_dict[lang]) 144 | 145 | 146 | def build_transform(self): 147 | def get_transform(self, vocab): 148 | return T.Sequential( 149 | T.VocabTransform(vocab), 150 | T.Truncate(self.max_seq_len-2), 151 | T.AddToken(token=self.sos_idx, begin=True), 152 | T.AddToken(token=self.eos_idx, begin=False), 153 | T.ToTensor(padding_value=self.pad_idx)) 154 | 155 | self.transform_src = get_transform(self, self.vocab_src) 156 | self.transform_tgt = get_transform(self, self.vocab_tgt) 157 | 158 | 159 | def collate_fn(self, pairs): 160 | src = [self.tokenizer_src(pair[0]) for pair in pairs] 161 | tgt = [self.tokenizer_tgt(pair[1]) for pair in pairs] 162 | batch_src = self.transform_src(src) 163 | batch_tgt = self.transform_tgt(tgt) 164 | return (batch_src, batch_tgt) 165 | 166 | 167 | def get_iter(self, **kwargs): 168 | if self.transform_src is None: 169 | self.build_transform() 170 | train_iter = DataLoader(self.train, collate_fn=self.collate_fn, **kwargs) 171 | valid_iter = DataLoader(self.valid, collate_fn=self.collate_fn, **kwargs) 172 | test_iter = DataLoader(self.test, collate_fn=self.collate_fn, **kwargs) 173 | return train_iter, valid_iter, test_iter 174 | 175 | 176 | def translate(self, model, src_sentence: str, decode_func): 177 | model.eval() 178 | src = self.transform_src([self.tokenizer_src(src_sentence)]).view(1, -1) 179 | num_tokens = src.shape[1] 180 | tgt_tokens = decode_func(model, 181 | src, 182 | max_len=num_tokens+5, 183 | start_symbol=self.sos_idx, 184 | end_symbol=self.eos_idx).flatten().cpu().numpy() 185 | tgt_sentence = " ".join(self.vocab_tgt.lookup_tokens(tgt_tokens)) 186 | return tgt_sentence 187 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author : Hansu Kim(@cpm0722) 3 | @when : 2022-08-21 4 | @github : https://github.com/cpm0722 5 | @homepage : https://cpm0722.github.io 6 | """ 7 | 8 | import os, sys, time 9 | import logging 10 | 11 | import torch 12 | from torch import nn, optim 13 | 14 | from config import * 15 | from models.build_model import build_model 16 | from data import Multi30k 17 | from utils import get_bleu_score, greedy_decode 18 | 19 | 20 | DATASET = Multi30k() 21 | 22 | 23 | def train(model, data_loader, optimizer, criterion, epoch, checkpoint_dir): 24 | model.train() 25 | epoch_loss = 0 26 | 27 | for idx, (src, tgt) in enumerate(data_loader): 28 | src = src.to(model.device) 29 | tgt = tgt.to(model.device) 30 | tgt_x = tgt[:, :-1] 31 | tgt_y = tgt[:, 1:] 32 | 33 | optimizer.zero_grad() 34 | 35 | output, _ = model(src, tgt_x) 36 | 37 | y_hat = output.contiguous().view(-1, output.shape[-1]) 38 | y_gt = tgt_y.contiguous().view(-1) 39 | loss = criterion(y_hat, y_gt) 40 | loss.backward() 41 | nn.utils.clip_grad_norm_(model.parameters(), 1.0) 42 | optimizer.step() 43 | 44 | epoch_loss += loss.item() 45 | num_samples = idx + 1 46 | 47 | if checkpoint_dir: 48 | os.makedirs(checkpoint_dir, exist_ok=True) 49 | checkpoint_file = os.path.join(checkpoint_dir, f"{epoch:04d}.pt") 50 | torch.save({ 51 | 'epoch': epoch, 52 | 'model_state_dict': model.state_dict(), 53 | 'optimizer_state_dict': optimizer.state_dict(), 54 | 'loss': loss 55 | }, checkpoint_file) 56 | 57 | return epoch_loss / num_samples 58 | 59 | 60 | def evaluate(model, data_loader, criterion): 61 | model.eval() 62 | epoch_loss = 0 63 | 64 | total_bleu = [] 65 | with torch.no_grad(): 66 | for idx, (src, tgt) in enumerate(data_loader): 67 | src = src.to(model.device) 68 | tgt = tgt.to(model.device) 69 | tgt_x = tgt[:, :-1] 70 | tgt_y = tgt[:, 1:] 71 | 72 | output, _ = model(src, tgt_x) 73 | 74 | y_hat = output.contiguous().view(-1, output.shape[-1]) 75 | y_gt = tgt_y.contiguous().view(-1) 76 | loss = criterion(y_hat, y_gt) 77 | 78 | epoch_loss += loss.item() 79 | score = get_bleu_score(output, tgt_y, DATASET.vocab_tgt, DATASET.specials) 80 | total_bleu.append(score) 81 | num_samples = idx + 1 82 | 83 | loss_avr = epoch_loss / num_samples 84 | bleu_score = sum(total_bleu) / len(total_bleu) 85 | return loss_avr, bleu_score 86 | 87 | 88 | def main(): 89 | model = build_model(len(DATASET.vocab_src), len(DATASET.vocab_tgt), device=DEVICE, dr_rate=DROPOUT_RATE) 90 | 91 | def initialize_weights(model): 92 | if hasattr(model, 'weight') and model.weight.dim() > 1: 93 | nn.init.kaiming_uniform_(model.weight.data) 94 | 95 | model.apply(initialize_weights) 96 | 97 | optimizer = optim.Adam(params=model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY, eps=ADAM_EPS) 98 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, verbose=True, factor=SCHEDULER_FACTOR, patience=SCHEDULER_PATIENCE) 99 | 100 | criterion = nn.CrossEntropyLoss(ignore_index=DATASET.pad_idx) 101 | 102 | train_iter, valid_iter, test_iter = DATASET.get_iter(batch_size=BATCH_SIZE, num_workers=NUM_WORKERS) 103 | 104 | for epoch in range(N_EPOCH): 105 | logging.info(f"*****epoch: {epoch:02}*****") 106 | train_loss = train(model, train_iter, optimizer, criterion, epoch, CHECKPOINT_DIR) 107 | logging.info(f"train_loss: {train_loss:.5f}") 108 | valid_loss, bleu_score = evaluate(model, valid_iter, criterion) 109 | if epoch > WARM_UP_STEP: 110 | scheduler.step(valid_loss) 111 | logging.info(f"valid_loss: {valid_loss:.5f}, bleu_score: {bleu_score:.5f}") 112 | 113 | logging.info(DATASET.translate(model, "A little girl climbing into a wooden playhouse .", greedy_decode)) 114 | # expected output: "Ein kleines Mädchen klettert in ein Spielhaus aus Holz ." 115 | 116 | test_loss, bleu_score = evaluate(model, test_iter, criterion) 117 | logging.info(f"test_loss: {test_loss:.5f}, bleu_score: {bleu_score:.5f}") 118 | 119 | 120 | if __name__ == "__main__": 121 | torch.manual_seed(0) 122 | logging.basicConfig(level=logging.INFO) 123 | main() 124 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author : Hansu Kim(@cpm0722) 3 | @when : 2022-08-21 4 | @github : https://github.com/cpm0722 5 | @homepage : https://cpm0722.github.io 6 | """ 7 | -------------------------------------------------------------------------------- /models/block/decoder_block.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author : Hansu Kim(@cpm0722) 3 | @when : 2022-08-21 4 | @github : https://github.com/cpm0722 5 | @homepage : https://cpm0722.github.io 6 | """ 7 | 8 | import copy 9 | import torch.nn as nn 10 | 11 | from models.layer.residual_connection_layer import ResidualConnectionLayer 12 | 13 | 14 | class DecoderBlock(nn.Module): 15 | 16 | def __init__(self, self_attention, cross_attention, position_ff, norm, dr_rate=0): 17 | super(DecoderBlock, self).__init__() 18 | self.self_attention = self_attention 19 | self.residual1 = ResidualConnectionLayer(copy.deepcopy(norm), dr_rate) 20 | self.cross_attention = cross_attention 21 | self.residual2 = ResidualConnectionLayer(copy.deepcopy(norm), dr_rate) 22 | self.position_ff = position_ff 23 | self.residual3 = ResidualConnectionLayer(copy.deepcopy(norm), dr_rate) 24 | 25 | 26 | def forward(self, tgt, encoder_out, tgt_mask, src_tgt_mask): 27 | out = tgt 28 | out = self.residual1(out, lambda out: self.self_attention(query=out, key=out, value=out, mask=tgt_mask)) 29 | out = self.residual2(out, lambda out: self.cross_attention(query=out, key=encoder_out, value=encoder_out, mask=src_tgt_mask)) 30 | out = self.residual3(out, self.position_ff) 31 | return out 32 | -------------------------------------------------------------------------------- /models/block/encoder_block.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author : Hansu Kim(@cpm0722) 3 | @when : 2022-08-21 4 | @github : https://github.com/cpm0722 5 | @homepage : https://cpm0722.github.io 6 | """ 7 | 8 | import copy 9 | import torch.nn as nn 10 | 11 | from models.layer.residual_connection_layer import ResidualConnectionLayer 12 | 13 | 14 | class EncoderBlock(nn.Module): 15 | 16 | def __init__(self, self_attention, position_ff, norm, dr_rate=0): 17 | super(EncoderBlock, self).__init__() 18 | self.self_attention = self_attention 19 | self.residual1 = ResidualConnectionLayer(copy.deepcopy(norm), dr_rate) 20 | self.position_ff = position_ff 21 | self.residual2 = ResidualConnectionLayer(copy.deepcopy(norm), dr_rate) 22 | 23 | 24 | def forward(self, src, src_mask): 25 | out = src 26 | out = self.residual1(out, lambda out: self.self_attention(query=out, key=out, value=out, mask=src_mask)) 27 | out = self.residual2(out, self.position_ff) 28 | return out 29 | -------------------------------------------------------------------------------- /models/build_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author : Hansu Kim(@cpm0722) 3 | @when : 2022-08-21 4 | @github : https://github.com/cpm0722 5 | @homepage : https://cpm0722.github.io 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from models.model.transformer import Transformer 12 | from models.model.encoder import Encoder 13 | from models.model.decoder import Decoder 14 | from models.block.encoder_block import EncoderBlock 15 | from models.block.decoder_block import DecoderBlock 16 | from models.layer.multi_head_attention_layer import MultiHeadAttentionLayer 17 | from models.layer.position_wise_feed_forward_layer import PositionWiseFeedForwardLayer 18 | from models.embedding.transformer_embedding import TransformerEmbedding 19 | from models.embedding.token_embedding import TokenEmbedding 20 | from models.embedding.positional_encoding import PositionalEncoding 21 | 22 | 23 | def build_model(src_vocab_size, 24 | tgt_vocab_size, 25 | device=torch.device("cpu"), 26 | max_len = 256, 27 | d_embed = 512, 28 | n_layer = 6, 29 | d_model = 512, 30 | h = 8, 31 | d_ff = 2048, 32 | dr_rate = 0.1, 33 | norm_eps = 1e-5): 34 | import copy 35 | copy = copy.deepcopy 36 | 37 | src_token_embed = TokenEmbedding( 38 | d_embed = d_embed, 39 | vocab_size = src_vocab_size) 40 | tgt_token_embed = TokenEmbedding( 41 | d_embed = d_embed, 42 | vocab_size = tgt_vocab_size) 43 | pos_embed = PositionalEncoding( 44 | d_embed = d_embed, 45 | max_len = max_len, 46 | device = device) 47 | 48 | src_embed = TransformerEmbedding( 49 | token_embed = src_token_embed, 50 | pos_embed = copy(pos_embed), 51 | dr_rate = dr_rate) 52 | tgt_embed = TransformerEmbedding( 53 | token_embed = tgt_token_embed, 54 | pos_embed = copy(pos_embed), 55 | dr_rate = dr_rate) 56 | 57 | attention = MultiHeadAttentionLayer( 58 | d_model = d_model, 59 | h = h, 60 | qkv_fc = nn.Linear(d_embed, d_model), 61 | out_fc = nn.Linear(d_model, d_embed), 62 | dr_rate = dr_rate) 63 | position_ff = PositionWiseFeedForwardLayer( 64 | fc1 = nn.Linear(d_embed, d_ff), 65 | fc2 = nn.Linear(d_ff, d_embed), 66 | dr_rate = dr_rate) 67 | norm = nn.LayerNorm(d_embed, eps = norm_eps) 68 | 69 | encoder_block = EncoderBlock( 70 | self_attention = copy(attention), 71 | position_ff = copy(position_ff), 72 | norm = copy(norm), 73 | dr_rate = dr_rate) 74 | decoder_block = DecoderBlock( 75 | self_attention = copy(attention), 76 | cross_attention = copy(attention), 77 | position_ff = copy(position_ff), 78 | norm = copy(norm), 79 | dr_rate = dr_rate) 80 | 81 | encoder = Encoder( 82 | encoder_block = encoder_block, 83 | n_layer = n_layer, 84 | norm = copy(norm)) 85 | decoder = Decoder( 86 | decoder_block = decoder_block, 87 | n_layer = n_layer, 88 | norm = copy(norm)) 89 | generator = nn.Linear(d_model, tgt_vocab_size) 90 | 91 | model = Transformer( 92 | src_embed = src_embed, 93 | tgt_embed = tgt_embed, 94 | encoder = encoder, 95 | decoder = decoder, 96 | generator = generator).to(device) 97 | model.device = device 98 | 99 | return model 100 | -------------------------------------------------------------------------------- /models/embedding/positional_encoding.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author : Hansu Kim(@cpm0722) 3 | @when : 2022-08-21 4 | @github : https://github.com/cpm0722 5 | @homepage : https://cpm0722.github.io 6 | """ 7 | 8 | import math 9 | import torch 10 | import torch.nn as nn 11 | 12 | 13 | class PositionalEncoding(nn.Module): 14 | 15 | def __init__(self, d_embed, max_len=256, device=torch.device("cpu")): 16 | super(PositionalEncoding, self).__init__() 17 | encoding = torch.zeros(max_len, d_embed) 18 | encoding.requires_grad = False 19 | position = torch.arange(0, max_len).float().unsqueeze(1) 20 | div_term = torch.exp(torch.arange(0, d_embed, 2) * -(math.log(10000.0) / d_embed)) 21 | encoding[:, 0::2] = torch.sin(position * div_term) 22 | encoding[:, 1::2] = torch.cos(position * div_term) 23 | self.encoding = encoding.unsqueeze(0).to(device) 24 | 25 | 26 | def forward(self, x): 27 | _, seq_len, _ = x.size() 28 | pos_embed = self.encoding[:, :seq_len, :] 29 | out = x + pos_embed 30 | return out 31 | -------------------------------------------------------------------------------- /models/embedding/token_embedding.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author : Hansu Kim(@cpm0722) 3 | @when : 2022-08-21 4 | @github : https://github.com/cpm0722 5 | @homepage : https://cpm0722.github.io 6 | """ 7 | 8 | import math 9 | import torch.nn as nn 10 | 11 | 12 | class TokenEmbedding(nn.Module): 13 | 14 | def __init__(self, d_embed, vocab_size): 15 | super(TokenEmbedding, self).__init__() 16 | self.embedding = nn.Embedding(vocab_size, d_embed) 17 | self.d_embed = d_embed 18 | 19 | 20 | def forward(self, x): 21 | out = self.embedding(x) * math.sqrt(self.d_embed) 22 | return out 23 | -------------------------------------------------------------------------------- /models/embedding/transformer_embedding.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author : Hansu Kim(@cpm0722) 3 | @when : 2022-08-21 4 | @github : https://github.com/cpm0722 5 | @homepage : https://cpm0722.github.io 6 | """ 7 | 8 | import torch.nn as nn 9 | 10 | 11 | class TransformerEmbedding(nn.Module): 12 | 13 | def __init__(self, token_embed, pos_embed, dr_rate=0): 14 | super(TransformerEmbedding, self).__init__() 15 | self.embedding = nn.Sequential(token_embed, pos_embed) 16 | self.dropout = nn.Dropout(p=dr_rate) 17 | 18 | 19 | def forward(self, x): 20 | out = x 21 | out = self.embedding(out) 22 | out = self.dropout(out) 23 | return out 24 | -------------------------------------------------------------------------------- /models/layer/multi_head_attention_layer.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author : Hansu Kim(@cpm0722) 3 | @when : 2022-08-21 4 | @github : https://github.com/cpm0722 5 | @homepage : https://cpm0722.github.io 6 | """ 7 | 8 | import copy 9 | import math 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | class MultiHeadAttentionLayer(nn.Module): 16 | 17 | def __init__(self, d_model, h, qkv_fc, out_fc, dr_rate=0): 18 | super(MultiHeadAttentionLayer, self).__init__() 19 | self.d_model = d_model 20 | self.h = h 21 | self.q_fc = copy.deepcopy(qkv_fc) # (d_embed, d_model) 22 | self.k_fc = copy.deepcopy(qkv_fc) # (d_embed, d_model) 23 | self.v_fc = copy.deepcopy(qkv_fc) # (d_embed, d_model) 24 | self.out_fc = out_fc # (d_model, d_embed) 25 | self.dropout = nn.Dropout(p=dr_rate) 26 | 27 | 28 | def calculate_attention(self, query, key, value, mask): 29 | # query, key, value: (n_batch, h, seq_len, d_k) 30 | # mask: (n_batch, seq_len, seq_len) 31 | d_k = key.shape[-1] 32 | attention_score = torch.matmul(query, key.transpose(-2, -1)) # Q x K^T, (n_batch, h, seq_len, seq_len) 33 | attention_score = attention_score / math.sqrt(d_k) 34 | if mask is not None: 35 | attention_score = attention_score.masked_fill(mask==0, -1e9) 36 | attention_prob = F.softmax(attention_score, dim=-1) # (n_batch, h, seq_len, seq_len) 37 | attention_prob = self.dropout(attention_prob) 38 | out = torch.matmul(attention_prob, value) # (n_batch, h, seq_len, d_k) 39 | return out 40 | 41 | 42 | def forward(self, *args, query, key, value, mask=None): 43 | # query, key, value: (n_batch, seq_len, d_embed) 44 | # mask: (n_batch, seq_len, seq_len) 45 | # return value: (n_batch, h, seq_len, d_k) 46 | n_batch = query.size(0) 47 | 48 | def transform(x, fc): # (n_batch, seq_len, d_embed) 49 | out = fc(x) # (n_batch, seq_len, d_model) 50 | out = out.view(n_batch, -1, self.h, self.d_model//self.h) # (n_batch, seq_len, h, d_k) 51 | out = out.transpose(1, 2) # (n_batch, h, seq_len, d_k) 52 | return out 53 | 54 | query = transform(query, self.q_fc) # (n_batch, h, seq_len, d_k) 55 | key = transform(key, self.k_fc) # (n_batch, h, seq_len, d_k) 56 | value = transform(value, self.v_fc) # (n_batch, h, seq_len, d_k) 57 | 58 | out = self.calculate_attention(query, key, value, mask) # (n_batch, h, seq_len, d_k) 59 | out = out.transpose(1, 2) # (n_batch, seq_len, h, d_k) 60 | out = out.contiguous().view(n_batch, -1, self.d_model) # (n_batch, seq_len, d_model) 61 | out = self.out_fc(out) # (n_batch, seq_len, d_embed) 62 | return out 63 | -------------------------------------------------------------------------------- /models/layer/position_wise_feed_forward_layer.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author : Hansu Kim(@cpm0722) 3 | @when : 2022-08-21 4 | @github : https://github.com/cpm0722 5 | @homepage : https://cpm0722.github.io 6 | """ 7 | 8 | import torch.nn as nn 9 | 10 | 11 | class PositionWiseFeedForwardLayer(nn.Module): 12 | 13 | def __init__(self, fc1, fc2, dr_rate=0): 14 | super(PositionWiseFeedForwardLayer, self).__init__() 15 | self.fc1 = fc1 # (d_embed, d_ff) 16 | self.relu = nn.ReLU() 17 | self.dropout = nn.Dropout(p=dr_rate) 18 | self.fc2 = fc2 # (d_ff, d_embed) 19 | 20 | 21 | def forward(self, x): 22 | out = x 23 | out = self.fc1(out) 24 | out = self.relu(out) 25 | out = self.dropout(out) 26 | out = self.fc2(out) 27 | return out 28 | -------------------------------------------------------------------------------- /models/layer/residual_connection_layer.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author : Hansu Kim(@cpm0722) 3 | @when : 2022-08-21 4 | @github : https://github.com/cpm0722 5 | @homepage : https://cpm0722.github.io 6 | """ 7 | 8 | import torch.nn as nn 9 | 10 | 11 | class ResidualConnectionLayer(nn.Module): 12 | 13 | def __init__(self, norm, dr_rate=0): 14 | super(ResidualConnectionLayer, self).__init__() 15 | self.norm = norm 16 | self.dropout = nn.Dropout(p=dr_rate) 17 | 18 | 19 | def forward(self, x, sub_layer): 20 | out = x 21 | out = self.norm(out) 22 | out = sub_layer(out) 23 | out = self.dropout(out) 24 | out = out + x 25 | return out 26 | -------------------------------------------------------------------------------- /models/model/decoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author : Hansu Kim(@cpm0722) 3 | @when : 2022-08-21 4 | @github : https://github.com/cpm0722 5 | @homepage : https://cpm0722.github.io 6 | """ 7 | 8 | import copy 9 | import torch.nn as nn 10 | 11 | 12 | class Decoder(nn.Module): 13 | 14 | def __init__(self, decoder_block, n_layer, norm): 15 | super(Decoder, self).__init__() 16 | self.n_layer = n_layer 17 | self.layers = nn.ModuleList([copy.deepcopy(decoder_block) for _ in range(self.n_layer)]) 18 | self.norm = norm 19 | 20 | 21 | def forward(self, tgt, encoder_out, tgt_mask, src_tgt_mask): 22 | out = tgt 23 | for layer in self.layers: 24 | out = layer(out, encoder_out, tgt_mask, src_tgt_mask) 25 | out = self.norm(out) 26 | return out 27 | -------------------------------------------------------------------------------- /models/model/encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author : Hansu Kim(@cpm0722) 3 | @when : 2022-08-21 4 | @github : https://github.com/cpm0722 5 | @homepage : https://cpm0722.github.io 6 | """ 7 | 8 | import copy 9 | import torch.nn as nn 10 | 11 | 12 | class Encoder(nn.Module): 13 | 14 | def __init__(self, encoder_block, n_layer, norm): 15 | super(Encoder, self).__init__() 16 | self.n_layer = n_layer 17 | self.layers = nn.ModuleList([copy.deepcopy(encoder_block) for _ in range(self.n_layer)]) 18 | self.norm = norm 19 | 20 | 21 | def forward(self, src, src_mask): 22 | out = src 23 | for layer in self.layers: 24 | out = layer(out, src_mask) 25 | out = self.norm(out) 26 | return out 27 | -------------------------------------------------------------------------------- /models/model/transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author : Hansu Kim(@cpm0722) 3 | @when : 2022-08-21 4 | @github : https://github.com/cpm0722 5 | @homepage : https://cpm0722.github.io 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class Transformer(nn.Module): 15 | 16 | def __init__(self, src_embed, tgt_embed, encoder, decoder, generator): 17 | super(Transformer, self).__init__() 18 | self.src_embed = src_embed 19 | self.tgt_embed = tgt_embed 20 | self.encoder = encoder 21 | self.decoder = decoder 22 | self.generator = generator 23 | 24 | 25 | def encode(self, src, src_mask): 26 | return self.encoder(self.src_embed(src), src_mask) 27 | 28 | 29 | def decode(self, tgt, encoder_out, tgt_mask, src_tgt_mask): 30 | return self.decoder(self.tgt_embed(tgt), encoder_out, tgt_mask, src_tgt_mask) 31 | 32 | 33 | def forward(self, src, tgt): 34 | src_mask = self.make_src_mask(src) 35 | tgt_mask = self.make_tgt_mask(tgt) 36 | src_tgt_mask = self.make_src_tgt_mask(src, tgt) 37 | encoder_out = self.encode(src, src_mask) 38 | decoder_out = self.decode(tgt, encoder_out, tgt_mask, src_tgt_mask) 39 | out = self.generator(decoder_out) 40 | out = F.log_softmax(out, dim=-1) 41 | return out, decoder_out 42 | 43 | 44 | def make_src_mask(self, src): 45 | pad_mask = self.make_pad_mask(src, src) 46 | return pad_mask 47 | 48 | 49 | def make_tgt_mask(self, tgt): 50 | pad_mask = self.make_pad_mask(tgt, tgt) 51 | seq_mask = self.make_subsequent_mask(tgt, tgt) 52 | mask = pad_mask & seq_mask 53 | return pad_mask & seq_mask 54 | 55 | 56 | def make_src_tgt_mask(self, src, tgt): 57 | pad_mask = self.make_pad_mask(tgt, src) 58 | return pad_mask 59 | 60 | 61 | def make_pad_mask(self, query, key, pad_idx=1): 62 | # query: (n_batch, query_seq_len) 63 | # key: (n_batch, key_seq_len) 64 | query_seq_len, key_seq_len = query.size(1), key.size(1) 65 | 66 | key_mask = key.ne(pad_idx).unsqueeze(1).unsqueeze(2) # (n_batch, 1, 1, key_seq_len) 67 | key_mask = key_mask.repeat(1, 1, query_seq_len, 1) # (n_batch, 1, query_seq_len, key_seq_len) 68 | 69 | query_mask = query.ne(pad_idx).unsqueeze(1).unsqueeze(3) # (n_batch, 1, query_seq_len, 1) 70 | query_mask = query_mask.repeat(1, 1, 1, key_seq_len) # (n_batch, 1, query_seq_len, key_seq_len) 71 | 72 | mask = key_mask & query_mask 73 | mask.requires_grad = False 74 | return mask 75 | 76 | 77 | def make_subsequent_mask(self, query, key): 78 | query_seq_len, key_seq_len = query.size(1), key.size(1) 79 | 80 | tril = np.tril(np.ones((query_seq_len, key_seq_len)), k=0).astype('uint8') # lower triangle without diagonal 81 | mask = torch.tensor(tril, dtype=torch.bool, requires_grad=False, device=query.device) 82 | return mask 83 | -------------------------------------------------------------------------------- /prepare.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | pip install -r requirements.txt 4 | 5 | mkdir -p ./.data/multi30k/raw 6 | wget https://github.com/multi30k/dataset/raw/master/data/task1/raw/train.en.gz && mv train.en.gz ./.data/multi30k/raw && gzip -d ./.data/multi30k/raw/train.en.gz 7 | wget https://github.com/multi30k/dataset/raw/master/data/task1/raw/train.de.gz && mv train.de.gz ./.data/multi30k/raw && gzip -d ./.data/multi30k/raw/train.de.gz 8 | wget https://github.com/multi30k/dataset/raw/master/data/task1/raw/val.en.gz && mv val.en.gz ./.data/multi30k/raw && gzip -d ./.data/multi30k/raw/val.en.gz 9 | wget https://github.com/multi30k/dataset/raw/master/data/task1/raw/val.de.gz && mv val.de.gz ./.data/multi30k/raw && gzip -d ./.data/multi30k/raw/val.de.gz 10 | wget https://github.com/multi30k/dataset/raw/master/data/task1/raw/test_2016_flickr.en.gz && mv test_2016_flickr.en.gz ./.data/multi30k/raw && gzip -d ./.data/multi30k/raw/test_2016_flickr.en.gz 11 | wget https://github.com/multi30k/dataset/raw/master/data/task1/raw/test_2016_flickr.de.gz && mv test_2016_flickr.de.gz ./.data/multi30k/raw && gzip -d ./.data/multi30k/raw/test_2016_flickr.de.gz 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.21.2 2 | markupsafe==2.0.1 3 | transformers==4.21.1 4 | huggingface-hub==0.8.1 5 | datasets==2.4.0 6 | spacy==3.0.8 7 | https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0.tar.gz#egg=en_core_web_sm 8 | https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.0.0/de_core_news_sm-3.0.0.tar.gz#egg=de_core_news_sm 9 | 10 | --extra-index-url https://download.pytorch.org/whl/cu113 11 | torch==1.11.0 12 | torchtext==0.12.0 13 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author : Hansu Kim(@cpm0722) 3 | @when : 2022-08-21 4 | @github : https://github.com/cpm0722 5 | @homepage : https://cpm0722.github.io 6 | """ 7 | 8 | import pickle 9 | import torch 10 | from torchtext.data.metrics import bleu_score 11 | 12 | 13 | def save_pkl(data, fname): 14 | with open(fname, "wb") as f: 15 | pickle.dump(data, f) 16 | 17 | 18 | def load_pkl(fname): 19 | with open(fname, "rb") as f: 20 | data = pickle.load(f) 21 | return data 22 | 23 | 24 | def get_bleu_score(output, gt, vocab, specials, max_n=4): 25 | 26 | def itos(x): 27 | x = list(x.cpu().numpy()) 28 | tokens = vocab.lookup_tokens(x) 29 | tokens = list(filter(lambda x: x not in {"", " ", "."} and x not in list(specials.keys()), tokens)) 30 | return tokens 31 | 32 | pred = [out.max(dim=1)[1] for out in output] 33 | pred_str = list(map(itos, pred)) 34 | gt_str = list(map(lambda x: [itos(x)], gt)) 35 | 36 | score = bleu_score(pred_str, gt_str, max_n=max_n) * 100 37 | return score 38 | 39 | 40 | def greedy_decode(model, src, max_len, start_symbol, end_symbol): 41 | src = src.to(model.device) 42 | src_mask = model.make_src_mask(src).to(model.device) 43 | memory = model.encode(src, src_mask) 44 | 45 | ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(model.device) 46 | for i in range(max_len-1): 47 | memory = memory.to(model.device) 48 | tgt_mask = model.make_tgt_mask(ys).to(model.device) 49 | src_tgt_mask = model.make_src_tgt_mask(src, ys).to(model.device) 50 | out = model.decode(ys, memory, tgt_mask, src_tgt_mask) 51 | prob = model.generator(out[:, -1]) 52 | _, next_word = torch.max(prob, dim=1) 53 | next_word = next_word.item() 54 | 55 | ys = torch.cat([ys, 56 | torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1) 57 | if next_word == end_symbol: 58 | break 59 | return ys 60 | --------------------------------------------------------------------------------