├── tools ├── __init__.py ├── const.py ├── tokenizer.py └── data_loader.py ├── images ├── transformer-decoder.png ├── transformer-architecture.png ├── transformer-residual-dropout.png ├── transformer-scaled-dot-product.png └── transformer-multi-head-attention.png ├── requirements.txt ├── transformer ├── debug.py ├── modules.py ├── embed.py ├── mask.py ├── optimizer.py ├── __init__.py ├── layers.py ├── sublayers.py ├── translator.py └── models.py ├── .gitignore ├── README.md ├── preprocess_chatbot.py ├── translate.py ├── preprocess.py └── train.py /tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/const.py: -------------------------------------------------------------------------------- 1 | PAD = '' 2 | UNK = '' 3 | SOS = '' 4 | EOS = '' 5 | -------------------------------------------------------------------------------- /images/transformer-decoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AndersonJo/transformer-anderson/master/images/transformer-decoder.png -------------------------------------------------------------------------------- /images/transformer-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AndersonJo/transformer-anderson/master/images/transformer-architecture.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | tqdm 4 | 5 | torch==1.4.0 6 | torchtext==0.5.0 7 | torchvision==0.5.0 8 | spacy==2.2.3 9 | nltk==3.4.5 -------------------------------------------------------------------------------- /images/transformer-residual-dropout.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AndersonJo/transformer-anderson/master/images/transformer-residual-dropout.png -------------------------------------------------------------------------------- /images/transformer-scaled-dot-product.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AndersonJo/transformer-anderson/master/images/transformer-scaled-dot-product.png -------------------------------------------------------------------------------- /images/transformer-multi-head-attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AndersonJo/transformer-anderson/master/images/transformer-multi-head-attention.png -------------------------------------------------------------------------------- /transformer/debug.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def to_sentence(index_sentence, vocab): 5 | n = len(vocab.itos) 6 | if len(index_sentence.size()) == 1: 7 | return ' '.join([vocab.itos[w] for w in index_sentence]) 8 | elif index_sentence.size(1) == n: 9 | _, sentence = torch.softmax(index_sentence, dim=1).max(1) 10 | return ' '.join([vocab.itos[w] for w in sentence]) 11 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # IDE 2 | .idea 3 | _site 4 | 5 | 6 | # Python 7 | *.pyc 8 | .ipynb_checkpoints 9 | 10 | # Data 11 | _data 12 | *.gz 13 | *.zip 14 | *.alz 15 | *.data 16 | *.json 17 | *.npz 18 | *.pkl 19 | *.span 20 | *.context 21 | *.question 22 | *.answer 23 | cifar-10-batches-py 24 | *.h5 25 | joblib 26 | glove.840B.300d.txt 27 | *.log 28 | 29 | # Specific Data 30 | imdb_full.pkl 31 | 32 | # TensorFlow Model 33 | *.model 34 | *.tfmodel 35 | _network 36 | _network_backup 37 | _tfmodel 38 | checkpoints 39 | 40 | # R 41 | .Rhistory 42 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Transformer 2 | 3 | ## Installation 4 | 5 | ```bash 6 | pip install -r requirements.txt 7 | python -m spacy download en 8 | python -m spacy download de 9 | ``` 10 | 11 | python 을 실행후 다음을 설치합니다. 12 | 13 | ``` 14 | import nltk 15 | nltk.download('wordnet') 16 | ``` 17 | 18 | 19 | ## Preprocessing 20 | 21 | 학습 데이터를 다운로드 받고, 전처리를 합니다. 22 | 23 | ```bash 24 | python preprocess.py 25 | ``` 26 | 27 | ## Training 28 | 29 | default 세팅 값으로 학습시키려면 간단히 다음과 같이 합니다. 30 | 31 | ```bash 32 | python train.py 33 | ``` 34 | 35 | 그외 hyper-parameter 변경의 예제는 다음과 같습니다. 36 | 37 | ``` 38 | python3.6 train.py --batch_size=64 39 | ``` 40 | 41 | Cloud에서 학습시 nohup사용은 유용합니다. 42 | 43 | ```bash 44 | nohup python train.py & 45 | tail -f .train.log 46 | ``` -------------------------------------------------------------------------------- /tools/tokenizer.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import spacy 4 | from konlpy.tag import Mecab 5 | 6 | 7 | class Tokenizer(object): 8 | def __init__(self, lang: str): 9 | self.nlp = spacy.load(lang) 10 | 11 | def tokenizer(self, sentence): 12 | sentence = re.sub(r"[\*\"“”\n\\…\+\-\/\=\(\)‘•:\[\]\|’\!;]", " ", str(sentence)) 13 | sentence = re.sub(r"[ ]+", " ", sentence) 14 | sentence = re.sub(r"\!+", "!", sentence) 15 | sentence = re.sub(r"\,+", ",", sentence) 16 | sentence = re.sub(r"\?+", "?", sentence) 17 | sentence = sentence.lower() 18 | return [tok.text for tok in self.nlp.tokenizer(sentence) if tok.text != " "] 19 | 20 | 21 | class TokenizerKorea(object): 22 | def __init__(self): 23 | self.nlp = Mecab() 24 | 25 | def tokenizer(self, sentence): 26 | pass 27 | -------------------------------------------------------------------------------- /transformer/modules.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import torch.nn as nn 3 | import torch 4 | 5 | 6 | class NormLayer(nn.Module): 7 | __constants__ = ['norm_shape', 'weight', 'bias', 'eps'] 8 | 9 | def __init__(self, norm_shape, eps=1e-6): 10 | super(NormLayer, self).__init__() 11 | if isinstance(norm_shape, numbers.Integral): 12 | norm_shape = (norm_shape,) 13 | self.norm_shape = norm_shape 14 | 15 | # create two trainable parameters to do affine tuning 16 | self.weight = nn.Parameter(torch.ones(*self.norm_shape), requires_grad=True) 17 | self.bias = nn.Parameter(torch.zeros(*self.norm_shape), requires_grad=True) 18 | self.eps = eps 19 | 20 | def forward(self, x): 21 | norm = self.weight * (x - x.mean(dim=-1, keepdim=True)) 22 | norm /= (x.std(dim=-1, keepdim=True) + self.eps) + self.bias 23 | return norm 24 | -------------------------------------------------------------------------------- /transformer/embed.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class PositionalEncoding(nn.Module): 8 | 9 | def __init__(self, embed_dim, max_seq_len=400): 10 | super(PositionalEncoding, self).__init__() 11 | 12 | pe = torch.zeros(max_seq_len, embed_dim) # (400, 512) shape 의 matrix 13 | position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1) 14 | div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim)) 15 | pe[:, 0::2] = torch.sin(position * div_term) 16 | pe[:, 1::2] = torch.cos(position * div_term) 17 | pe = pe.unsqueeze(0) # (1, 400, 512) shape 으로 만든다 18 | self.register_buffer('pe', pe) # 논문에서 positional emcodding은 constant matrix 임으로 register_buffer 사용 19 | 20 | def forward(self, x): 21 | return x + self.pe[:, :x.size(1)].detach() # constant matrix 이기 때문에 detach 시킨다 22 | -------------------------------------------------------------------------------- /tools/data_loader.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from typing import Tuple 3 | 4 | from torchtext.vocab import Vocab 5 | 6 | from tools import const 7 | from torchtext.data import Dataset, BucketIterator 8 | 9 | 10 | def load_preprocessed_data(opt) -> Tuple[BucketIterator, BucketIterator, Vocab, Vocab]: 11 | batch_size = opt.batch_size 12 | device = opt.device 13 | data = pickle.load(open(opt.data_pkl, 'rb')) 14 | 15 | opt.max_seq_len = data['opt'].max_seq_len 16 | opt.src_pad_idx = data['src'].vocab.stoi[const.PAD] 17 | opt.trg_pad_idx = data['trg'].vocab.stoi[const.PAD] 18 | opt.src_vocab_size = len(data['src'].vocab) 19 | opt.trg_vocab_size = len(data['trg'].vocab) 20 | 21 | if opt.share_embed_weights: 22 | assert data['src'].vocab.stoi == data['trg'].vocab.stoi 23 | 24 | fields = {'src': data['src'], 'trg': data['trg']} 25 | train = Dataset(examples=data['train'], fields=fields) 26 | val = Dataset(examples=data['val'], fields=fields) 27 | 28 | train_iterator = BucketIterator(train, batch_size=batch_size, device=device, train=True) 29 | val_iterator = BucketIterator(val, batch_size=batch_size, device=device) 30 | 31 | return train_iterator, val_iterator, data['src'].vocab, data['trg'].vocab 32 | -------------------------------------------------------------------------------- /transformer/mask.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | 5 | 6 | def create_mask(src: torch.Tensor, 7 | trg: torch.Tensor, 8 | src_pad_idx: int, 9 | trg_pad_idx: int) -> Tuple[torch.Tensor, torch.Tensor]: 10 | src_mask = create_padding_mask(src, src_pad_idx) 11 | trg_mask = None 12 | if trg is not None: 13 | trg_mask = create_padding_mask(trg, trg_pad_idx) # (256, 1, 33) 14 | nopeak_mask = create_nopeak_mask(trg) # (1, 33, 33) 15 | trg_mask = trg_mask & nopeak_mask # (256, 33, 33) 16 | 17 | return src_mask, trg_mask 18 | 19 | 20 | def create_padding_mask(seq: torch.Tensor, pad_idx: int) -> torch.Tensor: 21 | """ 22 | seq 형태를 (256, 33) -> (256, 1, 31) 이렇게 변경합니다. 23 | 24 | 아래와 같이 padding index부분을 False로 변경합니다. (리턴 tensor) 25 | 아래의 vector 하나당 sentence라고 보면 되고, True로 되어 있는건 단어가 있다는 뜻. 26 | tensor([[[ True, True, True, True, False, False, False]], 27 | [[ True, True, False, False, False, False, False]], 28 | [[ True, True, True, True, True, True, False]]]) 29 | """ 30 | return (seq != pad_idx).unsqueeze(-2) 31 | 32 | 33 | def create_nopeak_mask(seq: torch.Tensor) -> torch.Tensor: 34 | """ 35 | NO PEAK MASK 36 | Target의 경우 그 다음 단어를 못보게 가린다 37 | """ 38 | batch_size, seq_len = seq.size() 39 | nopeak_mask = (1 - torch.triu(torch.ones(1, seq_len, seq_len, device=seq.device), diagonal=1)).bool() 40 | return nopeak_mask 41 | -------------------------------------------------------------------------------- /transformer/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Adam 3 | 4 | 5 | class ScheduledAdam(Adam): 6 | """ 7 | 논문에서는 Adam Optimizer를 사용하였으며, hyper-parameters 는 다음과 같다 8 | B_1=0.9 9 | B_2-0.98 10 | e=10**-9 11 | 12 | Learning Rate는 training도중에 변경하도록 만들었음 13 | """ 14 | 15 | def __init__(self, parameters, embed_dim: int, 16 | betas=(0.9, 0.98), eps=1e-09, init_lr: float = 2.0, 17 | warmup_steps: int = 4000, **kwargs): 18 | """ 19 | Warm-up Steps 의 학습시 중요한 부분은 충분히 많은 데이터가 warm-up steps 동안 학습이 되어야 한다 20 | 따라서 batch_size 가 2048 보다 작으면서 warmup_steps 도 4000 이하가 된다면 적은 데이터만 21 | warm-up steps만 적용이 된다. 따라서 batch_size를 크게 늘리던 warmup steps을 좀더 크게 가져가던 해야 한다 22 | 23 | :param parameters: transformer.parameters() <- Transformer Model weights 24 | :param embed_dim: it is used as "d_model" in paper and the default value is 512 25 | :param warmup_steps: warm-up steps. the lr will linearly increase during warm-up steps 26 | """ 27 | super().__init__(parameters, betas=betas, eps=eps, **kwargs) 28 | 29 | self.embed_dim = embed_dim 30 | self.init_lr = init_lr 31 | self.warmup_step = warmup_steps 32 | self.n_step = 0 33 | self.lr = 0 34 | 35 | def step(self, **kwargs): 36 | self._update_learning_rate() 37 | super().step() 38 | 39 | def _update_learning_rate(self): 40 | self.n_step += 1 41 | 42 | step, warmup_step = self.n_step, self.warmup_step 43 | init_lr = self.init_lr 44 | d_model = self.embed_dim 45 | 46 | self.lr = init_lr * (d_model ** -0.5) * min(step ** -0.5, step * warmup_step ** -1.5) 47 | 48 | for param_group in self.param_groups: 49 | param_group['lr'] = self.lr 50 | -------------------------------------------------------------------------------- /transformer/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import re 4 | 5 | import torch 6 | 7 | from transformer.models import Transformer 8 | 9 | logger = logging.getLogger('transformer') 10 | 11 | 12 | def load_transformer_to_train(opt) -> Transformer: 13 | model = Transformer(embed_dim=opt.embed_dim, 14 | src_vocab_size=opt.src_vocab_size, 15 | trg_vocab_size=opt.trg_vocab_size, 16 | src_pad_idx=opt.src_pad_idx, 17 | trg_pad_idx=opt.trg_pad_idx, 18 | n_head=opt.n_head) 19 | model = model.to(opt.device) 20 | checkpoint_file_path = get_best_checkpoint(opt) 21 | if checkpoint_file_path is not None: 22 | logger.info(f'Checkpoint loaded - {checkpoint_file_path}') 23 | checkpoint = torch.load(checkpoint_file_path, map_location=opt.device) 24 | model.load_state_dict(checkpoint['weights']) 25 | return model 26 | 27 | 28 | def load_transformer(opt) -> Transformer: 29 | checkpoint_file_path = get_best_checkpoint(opt) 30 | checkpoint = torch.load(checkpoint_file_path, map_location=opt.device) 31 | 32 | assert checkpoint is not None 33 | assert checkpoint['opt'] is not None 34 | assert checkpoint['weights'] is not None 35 | 36 | model_opt = checkpoint['opt'] 37 | model = Transformer(embed_dim=model_opt.embed_dim, 38 | src_vocab_size=model_opt.src_vocab_size, 39 | trg_vocab_size=model_opt.trg_vocab_size, 40 | src_pad_idx=model_opt.src_pad_idx, 41 | trg_pad_idx=model_opt.trg_pad_idx, 42 | n_head=model_opt.n_head) 43 | 44 | model.load_state_dict(checkpoint['weights']) 45 | print('model loaded:', checkpoint_file_path) 46 | return model.to(opt.device) 47 | 48 | 49 | def get_best_checkpoint(opt): 50 | regex = re.compile('checkpoint_(\d+)_(\d+\.\d+)\.chkpt') 51 | checkpoints = [] 52 | if os.path.exists(opt.checkpoint_path): 53 | for name in os.listdir(opt.checkpoint_path): 54 | if regex.match(name): 55 | checkpoints.append((name, float(regex.match(name).group(1)))) 56 | if not checkpoints: 57 | return None 58 | 59 | checkpoints = sorted(checkpoints, key=lambda x: -x[1]) 60 | return os.path.join(opt.checkpoint_path, checkpoints[0][0]) 61 | -------------------------------------------------------------------------------- /transformer/layers.py: -------------------------------------------------------------------------------- 1 | from transformer.sublayers import MultiHeadAttention, PositionWiseFeedForward 2 | from transformer.modules import NormLayer 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class EncoderLayer(nn.Module): 9 | def __init__(self, embed_dim: int = 512, n_head: int = 8, d_ff: int = 2048, dropout: float = 0.1): 10 | super(EncoderLayer, self).__init__() 11 | self.norm1 = NormLayer(embed_dim) 12 | self.norm2 = NormLayer(embed_dim) 13 | 14 | self.self_attn = MultiHeadAttention(embed_dim, n_head, dropout=dropout) 15 | self.pos_ffn = PositionWiseFeedForward(embed_dim, d_ff=d_ff, dropout=dropout) 16 | 17 | self.dropout1 = nn.Dropout(dropout) 18 | self.dropout2 = nn.Dropout(dropout) 19 | 20 | def forward(self, x: torch.Tensor, mask: torch.Tensor): 21 | """ 22 | For sublayer in [MultiHeadAttetion, PostionWiseFeedForward]: 23 | 1. Normalize(x) 24 | 2. Do sublayer (like MultiHeadAttention) 25 | 3. Dropout(0.1) 26 | 27 | 논문에서는 다음과 같이 씌여져 있음 28 | We apply dropout to the output of each sub-layer, 29 | before it is added to the sub-layer input and normalized 30 | """ 31 | x2 = self.norm1(x) 32 | h = x + self.dropout1(self.self_attn(x2, x2, x2, mask)) 33 | h = x + self.dropout2(self.pos_ffn(self.norm2(h))) 34 | return h 35 | 36 | 37 | class DecoderLayer(nn.Module): 38 | def __init__(self, embed_dim: int, n_head: int = 8, d_ff: int = 2048, dropout: float = 0.1): 39 | super(DecoderLayer, self).__init__() 40 | self.norm1 = NormLayer(embed_dim) 41 | self.norm2 = NormLayer(embed_dim) 42 | 43 | self.self_attn1 = MultiHeadAttention(embed_dim, n_head, dropout=dropout) 44 | self.self_attn2 = MultiHeadAttention(embed_dim, n_head, dropout=dropout) 45 | self.pos_ffn = PositionWiseFeedForward(embed_dim, d_ff=d_ff, dropout=dropout) 46 | 47 | self.dropout1 = nn.Dropout(dropout) 48 | self.dropout2 = nn.Dropout(dropout) 49 | self.dropout3 = nn.Dropout(dropout) 50 | 51 | def forward(self, trg: torch.Tensor, trg_mask: torch.Tensor, enc_output: torch.Tensor, enc_mask: torch.Tensor): 52 | # First Multi Head Attention : target_input 53 | x2 = self.norm1(trg) 54 | dec_output = trg + self.dropout1(self.self_attn1(x2, x2, x2, trg_mask)) 55 | 56 | # Second Multi Head Attention : 1st multi-head attetion output + encoder output 57 | # TODO: 현재 dec_output에만 normalization이 들어갔는데, enc_output도 normalization도 실험 필요함 58 | dec_output2 = self.norm2(dec_output) 59 | dec_output = dec_output + self.dropout2(self.self_attn2(dec_output2, enc_output, enc_output, enc_mask)) 60 | 61 | # Postion-Wise Feed Forward 62 | dec_output = trg + self.dropout3(self.pos_ffn(dec_output)) 63 | return dec_output 64 | -------------------------------------------------------------------------------- /preprocess_chatbot.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | from argparse import Namespace 4 | from typing import Tuple 5 | 6 | import numpy as np 7 | import torchtext 8 | from torchtext.data import Field 9 | 10 | from tools import const 11 | from tools.tokenizer import Tokenizer, TokenizerKorea 12 | 13 | 14 | def init() -> Namespace: 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument('--max_seq_len', type=int, default=100) 18 | parser.add_argument('--min_word_freq', type=int, default=3, help='minimum word count') 19 | parser.add_argument('--share_vocab', action='store_true', default=True, help='Merge source vocab with target vocab') 20 | 21 | opt = parser.parse_args() 22 | return opt 23 | 24 | 25 | def create_torch_fields(opt: Namespace) -> Tuple[Field, Field]: 26 | tokenizer = TokenizerKorea() 27 | 28 | src = Field(tokenize=tokenizer.tokenizer, lower=True, 29 | pad_token=const.PAD, init_token=const.SOS, eos_token=const.EOS) 30 | trg = Field(tokenize=tokenizer.tokenizer, lower=True, 31 | pad_token=const.PAD, init_token=const.SOS, eos_token=const.EOS) 32 | 33 | return src, trg 34 | 35 | 36 | def merge_source_and_target(src, trg): 37 | for word in set(trg.vocab.stoi) - set(src.vocab.stoi): 38 | l = len(src.vocab.stoi) 39 | src.vocab.stoi[word] = l 40 | src.vocab.itos.append(word) 41 | src.vocab.freqs[word] = trg.vocab.freqs[word] 42 | trg.vocab.stoi = src.vocab.stoi 43 | trg.vocab.itos = src.vocab.itos 44 | trg.vocab.freqs = src.vocab.freqs 45 | 46 | print(f'Merged source vocabulary: {len(src.vocab)}') 47 | print(f'Merged target vocabulary: {len(trg.vocab)}') 48 | return src, trg 49 | 50 | 51 | def main(): 52 | opt = init() 53 | max_seq_len = opt.max_seq_len 54 | min_word_freq = opt.min_word_freq 55 | 56 | # Create Fields 57 | src, trg = create_torch_fields(opt) 58 | 59 | # Data - max_seq_len 값 이상 넘어가는 단어로 이루어진 문장을 제외 시킨다 60 | def filter_with_length(x): 61 | return len(x.src) <= max_seq_len and len(x.trg) <= max_seq_len 62 | 63 | train, val, test = torchtext.datasets.Multi30k.splits(exts=('.' + opt.lang_src, '.' + opt.lang_trg), 64 | fields=(src, trg), 65 | filter_pred=filter_with_length) 66 | src.build_vocab(train.src, min_freq=min_word_freq) # src.vocab.stoi, src.vocab.itos 생성 67 | print(f'Source vocabulary: {len(src.vocab)}') 68 | 69 | trg.build_vocab(train.trg, min_freq=min_word_freq) 70 | print(f'Target vocabulary: {len(trg.vocab)}') 71 | 72 | # Merge source vocabulary and target vocabulary 73 | if opt.share_vocab: 74 | src, trg = merge_source_and_target(src, trg) 75 | 76 | # Save data as a pickle 77 | data = { 78 | 'opt': opt, 79 | 'src': src, 80 | 'trg': trg, 81 | 'train': train.examples, 82 | 'val': val.examples, 83 | 'test': test.examples 84 | } 85 | 86 | pickle.dump(data, open(opt.save_data, 'wb')) 87 | print(f'The data saved at: {opt.save_data}') 88 | print('Preprocessing Completed Successfully') 89 | 90 | sentences = np.random.choice(train.examples, 5) 91 | print('\n[Train Data Examples]') 92 | for i, sentence in enumerate(sentences): 93 | print(f'[{i + 1}] Source:', sentence.src) 94 | print(f'[{i + 1}] Target:', sentence.trg) 95 | print() 96 | 97 | 98 | if __name__ == '__main__': 99 | main() 100 | -------------------------------------------------------------------------------- /translate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | from argparse import Namespace 4 | from typing import Generator, List, Tuple 5 | 6 | import torch 7 | from nltk.corpus import wordnet 8 | from torchtext.vocab import Vocab 9 | from tqdm import tqdm 10 | 11 | from tools import const 12 | from transformer import load_transformer 13 | from torchtext.data import Dataset, Field, Example 14 | 15 | from transformer.debug import to_sentence 16 | from transformer.translator import Translator 17 | 18 | 19 | def init() -> Namespace: 20 | parser = argparse.ArgumentParser(description='translate') 21 | parser.add_argument('--output', default='pred.txt') 22 | parser.add_argument('--beam', default=5, type=int) 23 | parser.add_argument('--cuda', default='cuda') 24 | parser.add_argument('--data', default='.data/data.pkl') 25 | parser.add_argument('--max_seq_len', default=100, type=int) 26 | parser.add_argument('--checkpoint_path', default='checkpoints') 27 | 28 | opt = parser.parse_args() 29 | opt.device = torch.device('cuda' if opt.cuda else 'cpu') 30 | return opt 31 | 32 | 33 | def load_data(opt) -> Tuple[Dataset, Dataset, Vocab, Vocab]: 34 | data = pickle.load(open(opt.data, 'rb')) 35 | src: Field = data['src'] 36 | trg: Field = data['trg'] 37 | 38 | opt.src_pad_idx = src.vocab.stoi[const.PAD] 39 | opt.trg_pad_idx = trg.vocab.stoi[const.PAD] 40 | opt.trg_sos_idx = trg.vocab.stoi[const.SOS] 41 | opt.trg_eos_idx = trg.vocab.stoi[const.EOS] 42 | 43 | train_loader = Dataset(examples=data['train'], fields={'src': src, 'trg': trg}) 44 | test_loader = Dataset(examples=data['test'], fields={'src': src, 'trg': trg}) 45 | return train_loader, test_loader, src.vocab, trg.vocab 46 | 47 | 48 | def get_word_or_synonym(vocab: Vocab, word: str, unk_idx: int): 49 | if word in vocab.stoi: 50 | return vocab.stoi[word] 51 | 52 | syns = wordnet.synsets(word) 53 | for s in syns: 54 | for lemma in s.lemmas(): 55 | if lemma.name() in vocab.stoi: 56 | print('Synonym 사용:', lemma.name()) 57 | return vocab.stoi[lemma.name()] 58 | return unk_idx 59 | 60 | 61 | def iterate_test_data(data_loader: Dataset, 62 | device: torch.device) -> Generator[Tuple[Example, torch.LongTensor], 63 | Tuple[Example, torch.LongTensor], 64 | None]: 65 | src_vocab = data_loader.fields['src'].vocab 66 | unk_idx = src_vocab.stoi[const.UNK] 67 | 68 | for example in data_loader: # tqdm(data_loader, mininterval=1, desc='Evaluation', leave=False): 69 | src_seq = [get_word_or_synonym(src_vocab, word, unk_idx) for word in example.src] 70 | yield example, torch.LongTensor([src_seq]).to(device) 71 | 72 | 73 | def main(): 74 | opt = init() 75 | 76 | # Load test dataset 77 | train_loader, test_loader, src_vocab, trg_vocab = load_data(opt) 78 | 79 | # Load Translator 80 | translator = Translator(model=load_transformer(opt), 81 | beam_size=opt.beam, 82 | device=opt.device, 83 | max_seq_len=opt.max_seq_len, 84 | src_pad_idx=opt.src_pad_idx, 85 | trg_pad_idx=opt.trg_pad_idx, 86 | trg_sos_idx=opt.trg_sos_idx, 87 | trg_eos_idx=opt.trg_eos_idx, 88 | src_vocab=src_vocab, 89 | trg_vocab=trg_vocab) 90 | 91 | for i, (example, src_seq) in enumerate(iterate_test_data(test_loader, device=opt.device)): 92 | pred_seq = translator.translate(src_seq) 93 | pred_sentence = [] 94 | for idx in pred_seq: 95 | word = trg_vocab.itos[idx] 96 | if word not in {const.SOS, const.EOS}: 97 | pred_sentence.append(word) 98 | pred_sentence = ' '.join(pred_sentence) 99 | print(i, pred_sentence) 100 | 101 | 102 | if __name__ == '__main__': 103 | main() 104 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | from argparse import Namespace 4 | from typing import Tuple 5 | 6 | import numpy as np 7 | import torchtext 8 | from torchtext.data import Field 9 | 10 | from tools import const 11 | from tools.tokenizer import Tokenizer 12 | 13 | 14 | def init() -> Namespace: 15 | spacy_support_langs = ['de', 'el', 'en', 'es', 'fr', 'it', 'lt', 'nb', 'nl', 'pt'] 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--lang_src', default='de', choices=spacy_support_langs) 19 | parser.add_argument('--lang_trg', default='en', choices=spacy_support_langs) 20 | parser.add_argument('--data_src', type=str) 21 | parser.add_argument('--save_data', type=str, default='.data/data.pkl') 22 | 23 | parser.add_argument('--max_seq_len', type=int, default=100) 24 | parser.add_argument('--min_word_freq', type=int, default=3, help='minimum word count') 25 | parser.add_argument('--share_vocab', action='store_true', default=True, help='Merge source vocab with target vocab') 26 | 27 | opt = parser.parse_args() 28 | return opt 29 | 30 | 31 | def create_torch_fields(opt: Namespace) -> Tuple[Field, Field]: 32 | tokenizer_src = Tokenizer(opt.lang_src) 33 | tokenizer_trg = Tokenizer(opt.lang_trg) 34 | 35 | src = Field(tokenize=tokenizer_src.tokenizer, lower=True, 36 | pad_token=const.PAD, init_token=const.SOS, eos_token=const.EOS) 37 | trg = Field(tokenize=tokenizer_trg.tokenizer, lower=True, 38 | pad_token=const.PAD, init_token=const.SOS, eos_token=const.EOS) 39 | 40 | return src, trg 41 | 42 | 43 | def merge_source_and_target(src, trg): 44 | for word in set(trg.vocab.stoi) - set(src.vocab.stoi): 45 | l = len(src.vocab.stoi) 46 | src.vocab.stoi[word] = l 47 | src.vocab.itos.append(word) 48 | src.vocab.freqs[word] = trg.vocab.freqs[word] 49 | trg.vocab.stoi = src.vocab.stoi 50 | trg.vocab.itos = src.vocab.itos 51 | trg.vocab.freqs = src.vocab.freqs 52 | 53 | assert trg.vocab.stoi == src.vocab.stoi 54 | assert trg.vocab.itos == src.vocab.itos 55 | assert trg.vocab.freqs == src.vocab.freqs 56 | print(f'Merged source vocabulary: {len(src.vocab)}') 57 | print(f'Merged target vocabulary: {len(trg.vocab)}') 58 | 59 | return src, trg 60 | 61 | 62 | def main(): 63 | opt = init() 64 | max_seq_len = opt.max_seq_len 65 | min_word_freq = opt.min_word_freq 66 | 67 | # Create Fields 68 | src, trg = create_torch_fields(opt) 69 | 70 | # Data - max_seq_len 값 이상 넘어가는 단어로 이루어진 문장을 제외 시킨다 71 | def filter_with_length(x): 72 | return len(x.src) <= max_seq_len and len(x.trg) <= max_seq_len 73 | 74 | train, val, test = torchtext.datasets.Multi30k.splits(exts=('.' + opt.lang_src, '.' + opt.lang_trg), 75 | fields=(src, trg), 76 | filter_pred=filter_with_length) 77 | src.build_vocab(train.src, min_freq=min_word_freq) # src.vocab.stoi, src.vocab.itos 생성 78 | print(f'Source vocabulary: {len(src.vocab)}') 79 | 80 | trg.build_vocab(train.trg, min_freq=min_word_freq) 81 | print(f'Target vocabulary: {len(trg.vocab)}') 82 | 83 | # Merge source vocabulary and target vocabulary 84 | if opt.share_vocab: 85 | src, trg = merge_source_and_target(src, trg) 86 | 87 | # Save data as a pickle 88 | data = { 89 | 'opt': opt, 90 | 'src': src, 91 | 'trg': trg, 92 | 'train': train.examples, 93 | 'val': val.examples, 94 | 'test': test.examples 95 | } 96 | 97 | pickle.dump(data, open(opt.save_data, 'wb')) 98 | print(f'The data saved at: {opt.save_data}') 99 | print('Preprocessing Completed Successfully') 100 | 101 | sentences = np.random.choice(train.examples, 5) 102 | print('\n[Train Data Examples]') 103 | for i, sentence in enumerate(sentences): 104 | print(f'[{i + 1}] Source:', sentence.src) 105 | print(f'[{i + 1}] Target:', sentence.trg) 106 | print() 107 | 108 | 109 | if __name__ == '__main__': 110 | main() 111 | -------------------------------------------------------------------------------- /transformer/sublayers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of Multi-Head Self Attention 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class MultiHeadAttention(nn.Module): 10 | """ 11 | 512인 embedding vector를 -> n_head에 따라서 나눈다. 12 | 예를 들어, n_head=8일 경우 512 vector를 -> 8 * 64 vector 로 변형한다 13 | 14 | """ 15 | 16 | def __init__(self, embed_dim: int, n_head: int, dropout: float = 0.1): 17 | super(MultiHeadAttention, self).__init__() 18 | 19 | self.embed_dim = embed_dim 20 | self.n_head = n_head 21 | self.dk = embed_dim // n_head 22 | self.dv = embed_dim // n_head 23 | 24 | self.linear_q = nn.Linear(embed_dim, embed_dim, bias=False) 25 | self.linear_v = nn.Linear(embed_dim, embed_dim, bias=False) 26 | self.linear_k = nn.Linear(embed_dim, embed_dim, bias=False) 27 | self.linear_f = nn.Linear(embed_dim, embed_dim, bias=False) # Final linear layer 28 | 29 | self.attention = ScaleDotProductAttention(self.dk, dropout=dropout) 30 | 31 | self.dropout = nn.Dropout(dropout) 32 | 33 | def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: 34 | """ 35 | * 마지막 skip connection 은 layer 부분에서 구현함 36 | """ 37 | batch_size, n_head, dk, dv = q.size(0), self.n_head, self.dk, self.dv 38 | 39 | # Linear Transformation (256, 33, 512) 40 | # Multi Head : d_model(512) vector부분을 h개 (8개) 로 나눈다 41 | q = self.linear_q(q).view(batch_size, -1, n_head, dk) 42 | k = self.linear_k(k).view(batch_size, -1, n_head, dk) 43 | v = self.linear_v(v).view(batch_size, -1, n_head, dv) 44 | 45 | q = q.transpose(1, 2) 46 | k = k.transpose(1, 2) 47 | v = v.transpose(1, 2) 48 | 49 | scores = self.attention(q, k, v, mask) 50 | 51 | # multi head dimension 을 원래의 형태로 되돌린다 52 | # (batch, n_head, seq_len, d_v) (256, 8, 33, 64) --> (batch, seq_len, n_head, d_v) (256, 33, 8, 64) 53 | scores = scores.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim) 54 | 55 | # Final linear Layer 56 | scores = self.linear_f(scores) 57 | 58 | return scores 59 | 60 | 61 | class ScaleDotProductAttention(nn.Module): 62 | """ 63 | Attention(Q, K, V) = softmax( (QK^T)/sqrt(d_k) ) 64 | """ 65 | 66 | def __init__(self, d_k: int, dropout: float): 67 | """ 68 | :param d_k: the number of heads 69 | """ 70 | super(ScaleDotProductAttention, self).__init__() 71 | self.sqrt_dk = d_k ** 0.5 # 8 = 64**0.5 72 | self.dropout = nn.Dropout(dropout) 73 | 74 | def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: 75 | """ 76 | :param q: Queries (256 batch, 8 d_k, 33 sequence, 64) 77 | :param k: Keys (256, 8, 33, 64) 78 | :param v: Values (256, 8, 33, 64) 79 | :param mask: mask (256, 1, 28) Source Mask 80 | :return: scaled dot attention: (256, 8, 33, 64) 81 | """ 82 | attn = (q @ k.transpose(-2, -1)) / self.sqrt_dk 83 | if mask is not None: # 논문에는 안나온 내용. 하지만 masking을 해주어야 한다 84 | mask = mask.unsqueeze(1) 85 | attn = attn.masked_fill(~mask, -1e9) 86 | 87 | attn = self.dropout(F.softmax(attn, dim=-1)) # softmax 이후 dropout도 논문에는 없으나 해야 한다 88 | output = attn @ v # (256, 8, 33, 64) 89 | 90 | return output 91 | 92 | 93 | class PositionWiseFeedForward(nn.Module): 94 | 95 | def __init__(self, embed_dim: int, d_ff: int = 2048, dropout: float = 0.1): 96 | super(PositionWiseFeedForward, self).__init__() 97 | 98 | self.w_1 = nn.Linear(embed_dim, d_ff) 99 | self.w_2 = nn.Linear(d_ff, embed_dim) 100 | self.dropout = nn.Dropout(dropout) 101 | 102 | def forward(self, x): 103 | """ 104 | 논문에서는 Position-wise Feed Forward를 할때 skip connection에 대한 이야기는 없습니다. 105 | 다만 MultiHead 부분에서도 106 | """ 107 | residual = x 108 | x = F.relu(self.w_1(x)) 109 | x = self.dropout(x) 110 | x = self.w_2(x) 111 | return x + residual 112 | -------------------------------------------------------------------------------- /transformer/translator.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torchtext.vocab import Vocab 7 | 8 | from transformer import Transformer 9 | from transformer.debug import to_sentence 10 | from transformer.mask import create_nopeak_mask, create_padding_mask 11 | 12 | 13 | class Translator(nn.Module): 14 | def __init__(self, model: Transformer, beam_size: int, device: torch.device, max_seq_len: int, 15 | src_pad_idx: int, trg_pad_idx: int, trg_sos_idx: int, trg_eos_idx: int, 16 | src_vocab: Vocab, trg_vocab: Vocab): 17 | super(Translator, self).__init__() 18 | 19 | self.model = model 20 | self.model.eval() 21 | 22 | self.beam_size = beam_size 23 | self.device = device 24 | self.max_seq_len = max_seq_len 25 | self.src_pad_idx = src_pad_idx 26 | self.trg_pad_idx = trg_pad_idx 27 | self.trg_sos_idx = trg_sos_idx 28 | self.trg_eos_idx = trg_eos_idx 29 | 30 | self.src_vocab = src_vocab 31 | self.trg_vocab = trg_vocab 32 | 33 | # init_trg_seq: [[""]] 로 시작하는 matrix 이며, output 의 초기값으로 사용됨 34 | # beam_output: beam search 를 하기 위해서 decoder에서 나온 output 값들을 저장한다 35 | init_trg_seq = torch.LongTensor([[self.trg_sos_idx]]).to(self.device) 36 | seq_arange = torch.arange(1, self.max_seq_len + 1, dtype=torch.long).to(self.device) 37 | beam_output = torch.full((self.beam_size, self.max_seq_len), self.trg_pad_idx, dtype=torch.long) 38 | beam_output[:, 0] = self.trg_sos_idx 39 | beam_output = beam_output.to(self.device) 40 | 41 | self.register_buffer('init_trg_seq', init_trg_seq) 42 | self.register_buffer('beam_output', beam_output) 43 | self.register_buffer('seq_arange', seq_arange) 44 | 45 | def _create_init_sequence(self, 46 | src_seq: torch.Tensor, 47 | src_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 48 | """ 49 | :param src_seq: (1, seq_size) 50 | :param src_mask: (1, 1, seq_size) 51 | :return: 52 | - enc_output: (beam_size, seq_len, embed_dim) 53 | - beam_output: (beam_size, max_seq_len) 54 | - scores: (beam_size,) 55 | """ 56 | # Encoder 57 | # 먼저 source sentence tensor 를 encoder 에 집어넣고 encoder output 을 냅니다 58 | enc_output = self.model.encoder(src_seq, src_mask) # (1, seq_len, embed_size) 59 | 60 | # Decoder 61 | dec_output = self._decoder_softmax(self.init_trg_seq, enc_output, src_mask) 62 | k_probs, k_indices = dec_output[:, -1, :].topk(self.beam_size) 63 | scores = torch.log1p(k_probs).view(self.beam_size) 64 | 65 | # Generate beam sequences 66 | beam_output = self.beam_output.clone().detach() # (beam_size, max_seq_len) 67 | beam_output[:, 1] = k_indices[0] 68 | 69 | # Reshape encoder output 70 | enc_output = enc_output.repeat(self.beam_size, 1, 1) # (beam_size, seq_len, embed_dim) 71 | return enc_output, beam_output, scores 72 | 73 | def _decoder_softmax(self, 74 | trg_seq: torch.Tensor, 75 | enc_output: torch.Tensor, 76 | src_mask: torch.Tensor) -> torch.Tensor: 77 | trg_mask = create_nopeak_mask(trg_seq) # (1, 1, 1) 78 | dec_output = self.model.decoder(trg_seq, trg_mask, enc_output, src_mask) # (1, 1, embed_size) 79 | dec_output = self.model.out_linear(dec_output) # (1, 1, 9473) 80 | dec_output = F.softmax(dec_output, dim=-1) # (1, 1, 9473) everything is zero except one element 81 | return dec_output 82 | 83 | def _calculate_scores(self, 84 | step: int, 85 | beam_output: torch.Tensor, 86 | dec_output: torch.Tensor, 87 | scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 88 | """ 89 | :param step: start from 2 to max_sequence or until reaching to the 90 | :param beam_output: (beam_size, max_seq_len) 91 | :param dec_output: (beam_size, 2~step, word_size) ex.(5, 2, 9473) -> next -> (5, 3, 9473) 92 | :param scores: (beam_size, beam_size) ex. (5, 5) 93 | :return: 94 | """ 95 | # k_probs: (beam_size, beam_size == topk) -> (5, 5) -> 가장 마지막으로 예측한 단어 top k 의 단어 확률 96 | # k_indices: (beam_size, beam_size) -> (5, 5) -> 가장 마지막으로 예측한 단어 top k 의 index 97 | k_probs, k_indices = dec_output[:, -1, :].topk(self.beam_size) 98 | 99 | # Calculate scores added by previous scores 100 | # 즉 단어 하나하나 생성하며서 문장을 만들어 나가는데.. 101 | # 누적 점수를 만들어 나가면서 나중에 가장 점수가 높은 문장을 찾아내겠다는 의미 102 | # scores: (beam_size, beam_size) ex. (5, 5) 103 | scores = torch.log1p(k_probs) + scores 104 | 105 | # (5, 5) 에서 만들어진 전체 단어중에서 best k 단어를 찾아낸다 106 | # beam_scores: (5*5,) 107 | # beam_indices: (5*5,) index는 0에서 25사이의 값이다. (즉 단어의 index가 아니다) 108 | beam_scores, beam_indices = scores.view(-1).topk(self.beam_size) 109 | _row_idx = beam_indices // self.beam_size 110 | _col_idx = beam_indices % self.beam_size 111 | best_indices = k_indices[_row_idx, _col_idx] # (beam_size,) k_indices안에 단어의 index가 들어있다 112 | 113 | # best_indices 와 row index와 동일하게 맞쳐준후, best indices 값을 추가한다 114 | beam_output[:, :step] = beam_output[_row_idx, :step] 115 | beam_output[:, step] = best_indices 116 | 117 | return beam_scores, beam_output 118 | 119 | def beam_search(self, src_seq): 120 | """ 121 | beam_output 설명 122 | 기본적으로 beam_output 은 다음과 같이 생겼다 123 | tensor([[ 2, 1615, 1, 1, 1], 124 | [ 2, 538, 1, 1, 1], 125 | [ 2, 2, 1, 1, 1], 126 | [ 2, 1, 1, 1, 1], 127 | [ 2, 0, 1, 1, 1]] 128 | 여기서 2="", 1="" 이며, 5개의 beam_size에서 forloop 을 돌면서, 그 다음 단어를 예측하며, 129 | "" 가 나올때까지 padding 부분을 단어 index 로 채워 나가면서 계속 진행한다 130 | 131 | :param src_seq: 132 | :return: 133 | """ 134 | # Create initial source padding mask 135 | src_mask = create_padding_mask(src_seq, pad_idx=self.src_pad_idx) 136 | src_mask = src_mask.to(self.device) 137 | enc_output, beam_output, scores = self._create_init_sequence(src_seq, src_mask) 138 | ans_row_idx = 0 139 | for step in range(2, self.max_seq_len): 140 | dec_output = self._decoder_softmax(beam_output[:, :step], enc_output, src_mask) 141 | scores, beam_output = self._calculate_scores(step, beam_output, dec_output, scores) 142 | 143 | # Find complete setences and end this loop 144 | eos_loc = beam_output == self.trg_eos_idx # (beam_size, max_seq_size) ex. (5, 100) 145 | 146 | # (beam_size, max_seq_size) 에서 대부분 max_seq_len 값을 갖고, trg_eos 만 실제 index값을 갖는다 147 | eos_indices, _ = self.seq_arange.masked_fill(~eos_loc, self.max_seq_len).min(1) 148 | 149 | n_complete_sentences = (eos_loc.sum(1) > 0).sum().item() 150 | 151 | # DEBUG 152 | # print(to_sentence(src_seq[0], self.src_vocab)) 153 | # for i in range(5): 154 | # print(to_sentence(beam_output[i], self.src_vocab)[:150]) 155 | 156 | if n_complete_sentences == self.beam_size: 157 | ans_row_idx = scores.max(0)[1].item() 158 | break 159 | 160 | return beam_output[ans_row_idx][:eos_indices[ans_row_idx]].tolist() 161 | 162 | def translate(self, src_seq: torch.LongTensor): 163 | assert src_seq.size(0) == 1 # Batch Size should be 1 164 | 165 | with torch.no_grad(): 166 | pred_seq = self.beam_search(src_seq) 167 | return pred_seq 168 | -------------------------------------------------------------------------------- /transformer/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.modules.container import ModuleList 4 | 5 | from transformer.embed import PositionalEncoding 6 | from transformer.layers import EncoderLayer, DecoderLayer 7 | from transformer.mask import create_mask 8 | from transformer.modules import NormLayer 9 | 10 | 11 | class Transformer(nn.Module): 12 | 13 | def __init__(self, embed_dim: int, src_vocab_size: int, trg_vocab_size: int, 14 | src_pad_idx: int, trg_pad_idx: int, 15 | n_layers: int = 6, n_head: int = 8, d_ff: int = 2048, dropout: float = 0.1, 16 | share_embed_weights: bool = True, 17 | share_embed_weights_to_output_linear: bool = True): 18 | """ 19 | # Vocab 20 | :param src_vocab_size: the size of the source vocabulary 21 | :param trg_vocab_size: the size of the target vocabulary 22 | :param src_pad_idx: source padding index 23 | :param trg_pad_idx: target padding index 24 | 25 | # Embedding 26 | :param embed_dim: embedding dimension. 512 is used in the paper 27 | :param share_embed_weights: share target embedding weights with source embedding 28 | :param share_embed_weights_to_output_linear: share target embedding weights with output linear weights 29 | 30 | # Transformer 31 | :param n_layers: the number of sub-layers 32 | :param n_head: the number of heads in Multi Head Attention 33 | :param d_ff: inner dimension of position-wise feed-forward 34 | 35 | super(Transformer, self).__init__() 36 | 37 | self.src_pad_idx = src_pad_idx 38 | self.trg_pad_idx = trg_pad_idx 39 | 40 | self.encoder = Encoder(src_vocab_size, embed_dim=embed_dim, n_head=n_head, d_ff=d_ff, 41 | pad_idx=src_pad_idx, n_layers=n_layers, dropout=dropout) 42 | self.decoder = Decoder(trg_vocab_size, embed_dim=embed_dim, n_head=n_head, d_ff=d_ff, 43 | pad_idx=trg_pad_idx, n_layers=n_layers, dropout=dropout) 44 | 45 | self.out_linear = nn.Linear(embed_dim, trg_vocab_size, bias=False) 46 | 47 | self.logit_scale = 1 48 | if share_embed_weights_to_output_linear: 49 | # 다음은 논문에 나온 내용 50 | # 3.4 Embeddings and Softmax 51 | # We also use the usual learned linear transformation and so 52 | 53 | """ 54 | super(Transformer, self).__init__() 55 | 56 | self.src_pad_idx = src_pad_idx 57 | self.trg_pad_idx = trg_pad_idx 58 | 59 | self.encoder = Encoder(src_vocab_size, embed_dim=embed_dim, n_head=n_head, d_ff=d_ff, 60 | pad_idx=src_pad_idx, n_layers=n_layers, dropout=dropout) 61 | self.decoder = Decoder(trg_vocab_size, embed_dim=embed_dim, n_head=n_head, d_ff=d_ff, 62 | pad_idx=trg_pad_idx, n_layers=n_layers, dropout=dropout) 63 | 64 | self.out_linear = nn.Linear(embed_dim, trg_vocab_size, bias=False) 65 | 66 | self.logit_scale = 1 67 | if share_embed_weights_to_output_linear: 68 | # 다음은 논문에 나온 내용 69 | # 3.4 Embeddings and Softmax 70 | # We also use the usual learned linear transformation and softmax function to convert the decoder output 71 | # to predicted next-token probabilities. In our model, we share the same weight matrix between 72 | # the two embedding layers and the pre-softmax linear transformation, similar to [30]. 73 | # In the embedding layers, we multiply those weights by sqrt(d_model). 74 | self.out_linear.weight = self.decoder.embed.weight 75 | self.logit_scale = (embed_dim ** -0.5) 76 | 77 | if share_embed_weights: 78 | self.encoder.embed.weight = self.decoder.embed.weight 79 | 80 | def forward(self, src: torch.Tensor, trg: torch.Tensor): 81 | """ 82 | :param src: (batch_size, maximum_sequence_length) 83 | :param trg: (batch_size, maximum_sequence_length) 84 | :return (batch, seq_len, trg_vocab_size) ex.(256, 33, 9473) 85 | """ 86 | src_mask, trg_mask = create_mask(src, trg, src_pad_idx=self.src_pad_idx, trg_pad_idx=self.trg_pad_idx) 87 | 88 | enc_output = self.encoder(src, src_mask) # (batch, seq_len, embed_dim) like (256, 33, 512) 89 | dec_output = self.decoder(trg, trg_mask, enc_output, src_mask) # (batch, seq_len, embed_dim) 90 | output = self.out_linear(dec_output) * self.logit_scale # (batch, seq_len, trg_vocab_size) like (256, 33, 9473) 91 | return output 92 | 93 | 94 | class Encoder(nn.Module): 95 | def __init__(self, vocab_size: int, embed_dim: int, n_head: int, d_ff: int, 96 | pad_idx: int, n_layers: int, dropout: float = 0.1): 97 | """ 98 | Embedding Parameters 99 | :param vocab_size: the size of the source vocabulary 100 | :param embed_dim: embedding dimension. 512 is used in the paper 101 | :param n_head: the number of multi head. (split the embed_dim to 8.. such that 8 * 64 = 512) 102 | :param d_ff: inner dimension of position-wise feed-forward 103 | :param pad_idx: padding index 104 | :param n_layers: the number of sub-layers 105 | 106 | Flow 107 | 1. embedding layer 108 | 2. positional encoding 109 | 3. residual dropout(0.1) 110 | 4. iterate sub-layers (6 layers are used in paper) 111 | """ 112 | super(Encoder, self).__init__() 113 | self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx) 114 | self.position_enc = PositionalEncoding(embed_dim) 115 | self.dropout = nn.Dropout(dropout) # Residual Dropout(0.1) in paper 116 | 117 | self.layer_stack: ModuleList[EncoderLayer] = nn.ModuleList( 118 | [EncoderLayer(embed_dim, n_head, d_ff=d_ff, dropout=dropout) for _ in range(n_layers)]) 119 | 120 | self.layer_norm = NormLayer(embed_dim) 121 | 122 | def forward(self, src: torch.Tensor, mask: torch.Tensor): 123 | """ 124 | Sublayers 에 들어가기 전에 논문에서는 다음 3가지를 해야 함 125 | Word Tensor -> Embedding -> Positional Embedding -> Dropout(0.1) 126 | 127 | 아래는 논문 내용 128 | we apply dropout to the sums of the embeddings and the 129 | positional encodings in both the encoder and decoder stacks 130 | 131 | :return encoder output : (batch, seq_len, embed_dim) like (256, 33, 512) 132 | """ 133 | x = self.embed(src) # (256 batch, 33 seqence, 512 embedding) 134 | x = self.position_enc(x) # (256, 33, 512) 135 | x = self.dropout(x) # Layer Stack 사용전에 Dropout 을 해야 함 (Decoder 에도 해야 됨) 136 | 137 | for enc_layer in self.layer_stack: 138 | x = enc_layer(x, mask) 139 | 140 | enc_output = self.layer_norm(x) # (256, 33, 512) 141 | return enc_output 142 | 143 | 144 | class Decoder(nn.Module): 145 | 146 | def __init__(self, vocab_size: int, embed_dim: int, n_head: int, d_ff: int, 147 | pad_idx: int, n_layers: int, dropout: float = 0.1): 148 | """ 149 | :param vocab_size: the size of the target vocabulary 150 | :param embed_dim: embedding dimension. 512 is used in the paper 151 | :param n_head: the number of multi head. (split the embed_dim to 8.. such that 8 * 64 = 512) 152 | :param d_ff: inner dimension of position-wise feed-forward 153 | :param pad_idx: target padding index 154 | :param n_layers: the number of sub-layers 155 | """ 156 | super(Decoder, self).__init__() 157 | self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx) 158 | self.position_enc = PositionalEncoding(embed_dim) 159 | self.dropout = nn.Dropout(dropout) # Residual Dropout(0.1) in paper 160 | 161 | self.layer_stack: ModuleList[DecoderLayer] = nn.ModuleList( 162 | [DecoderLayer(embed_dim, n_head, d_ff=d_ff, dropout=dropout) for _ in range(n_layers)]) 163 | 164 | self.layer_norm = NormLayer(embed_dim) 165 | 166 | def forward(self, trg: torch.Tensor, trg_mask: torch.Tensor, 167 | enc_output: torch.Tensor, enc_mask: torch.Tensor): 168 | dec_output = self.embed(trg) # (256 batch, 33 seqence, 512 embedding) 169 | dec_output = self.position_enc(dec_output) # (256, 33, 512) 170 | dec_output = self.dropout(dec_output) # Layer Stack 사용전에 Dropout 을 해야 함 171 | 172 | for dec_layer in self.layer_stack: 173 | dec_output = dec_layer(dec_output, trg_mask, enc_output, enc_mask) 174 | 175 | dec_output = self.layer_norm(dec_output) 176 | 177 | return dec_output 178 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | from argparse import Namespace 5 | from datetime import datetime 6 | from typing import Tuple 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from tqdm import tqdm 11 | 12 | from tools.data_loader import load_preprocessed_data 13 | from transformer import load_transformer_to_train 14 | from transformer.debug import to_sentence 15 | from transformer.models import Transformer 16 | from transformer.optimizer import ScheduledAdam 17 | 18 | # Logging Configuration 19 | logFormatter = logging.Formatter('%(message)s') 20 | logger = logging.getLogger('transformer') 21 | logger.setLevel(logging.DEBUG) 22 | 23 | fileHandler = logging.FileHandler('.train.log', 'w+') 24 | fileHandler.setLevel(logging.DEBUG) 25 | fileHandler.setFormatter(logFormatter) 26 | logger.addHandler(fileHandler) 27 | 28 | streamHandler = logging.StreamHandler() 29 | streamHandler.setLevel(logging.DEBUG) 30 | streamHandler.setFormatter(logFormatter) 31 | logger.addHandler(streamHandler) 32 | 33 | 34 | def init() -> Namespace: 35 | parser = argparse.ArgumentParser() 36 | # Data 37 | parser.add_argument('--data_pkl', default='.data/data.pkl', type=str) 38 | 39 | # System 40 | parser.add_argument('--cuda', action='store_true', default=True) 41 | parser.add_argument('--checkpoint_path', default='checkpoints', type=str) 42 | 43 | # Hyper Parameters 44 | parser.add_argument('--epoch', default=500, type=int) 45 | parser.add_argument('--batch_size', default=256, type=int) 46 | parser.add_argument('--embed_dim', default=512, type=int) 47 | parser.add_argument('--n_head', default=8, type=int, help='the number of multi heads') 48 | parser.add_argument('--warmup_steps', default=10000, type=int, help='the number of warmup steps') 49 | 50 | # Parse 51 | parser.set_defaults(share_embed_weights=True) 52 | opt = parser.parse_args() 53 | 54 | assert opt.embed_dim % opt.n_head == 0, 'the number of heads should be the multiple of embed_dim' 55 | 56 | opt.device = torch.device('cuda' if opt.cuda else 'cpu') 57 | logger.debug(f'device: {opt.device}') 58 | return opt 59 | 60 | 61 | def train(opt: Namespace, model: Transformer, optimizer: ScheduledAdam): 62 | if not os.path.exists(opt.checkpoint_path): 63 | os.makedirs(opt.checkpoint_path) 64 | 65 | train_data, val_data, src_vocab, trg_vocab = load_preprocessed_data(opt) 66 | min_loss = float('inf') 67 | 68 | for epoch in range(opt.epoch): 69 | # Training and Evaluation 70 | _t = train_per_epoch(opt, model, optimizer, train_data, src_vocab, trg_vocab) 71 | _v = evaluate_epoch(opt, model, val_data, src_vocab, trg_vocab) 72 | 73 | # Save checkpoint 74 | min_loss = _v['loss_per_word'] 75 | checkpoint = {'epoch': epoch, 76 | 'opt': opt, 77 | 'weights': model.state_dict(), 78 | 'loss': min_loss, 79 | '_t': _t, 80 | '_v': _v} 81 | model_name = os.path.join(opt.checkpoint_path, f'checkpoint_{epoch:04}_{min_loss:.4f}.chkpt') 82 | torch.save(checkpoint, model_name) 83 | is_checkpointed = True 84 | 85 | # Print performance 86 | _show_performance(epoch=epoch, step=optimizer.n_step, lr=optimizer.lr, t=_t, v=_v, 87 | checkpoint=is_checkpointed) 88 | 89 | 90 | def train_per_epoch(opt: Namespace, 91 | model: Transformer, 92 | optimizer: ScheduledAdam, 93 | train_data, 94 | src_vocab, 95 | trg_vocab) -> dict: 96 | model.train() 97 | start_time = datetime.now() 98 | total_loss = total_word = total_corrected_word = 0 99 | 100 | for i, batch in tqdm(enumerate(train_data), total=len(train_data), leave=False): 101 | src_input, trg_input, y_true = _prepare_batch_data(batch, opt.device) 102 | 103 | # Forward 104 | optimizer.zero_grad() 105 | y_pred = model(src_input, trg_input) 106 | 107 | # DEBUG 108 | pred_sentence = to_sentence(y_pred[0], trg_vocab) 109 | true_sentence = to_sentence(batch.trg[:, 0], trg_vocab) 110 | print(pred_sentence) 111 | print(true_sentence) 112 | import ipdb 113 | ipdb.set_trace() 114 | 115 | # Backward and update parameters 116 | loss = calculate_loss(y_pred, y_true, opt.trg_pad_idx, trg_vocab) 117 | n_word, n_corrected = calculate_performance(y_pred, y_true, opt.trg_pad_idx) 118 | loss.backward() 119 | optimizer.step() 120 | 121 | # Training Logs 122 | total_loss += loss.item() 123 | total_word += n_word 124 | total_corrected_word += n_corrected 125 | 126 | loss_per_word = total_loss / total_word 127 | accuracy = total_corrected_word / total_word 128 | 129 | return {'total_seconds': (datetime.now() - start_time).total_seconds(), 130 | 'total_loss': total_loss, 131 | 'total_word': total_word, 132 | 'total_corrected_word': total_corrected_word, 133 | 'loss_per_word': loss_per_word, 134 | 'accuracy': accuracy} 135 | 136 | 137 | def evaluate_epoch(opt: Namespace, model: Transformer, val_data, src_vocab, trg_vocab): 138 | model.eval() 139 | start_time = datetime.now() 140 | total_loss = total_word = total_corrected_word = 0 141 | 142 | with torch.no_grad(): 143 | for i, batch in tqdm(enumerate(val_data), total=len(val_data), leave=False): 144 | # Prepare validation data 145 | src_input, trg_input, y_true = _prepare_batch_data(batch, opt.device) 146 | 147 | # Forward 148 | y_pred = model(src_input, trg_input) 149 | loss = calculate_loss(y_pred, y_true, opt.trg_pad_idx, trg_vocab) 150 | n_word, n_corrected = calculate_performance(y_pred, y_true, opt.trg_pad_idx) 151 | 152 | # Validation Logs 153 | total_loss += loss.item() 154 | total_word += n_word 155 | total_corrected_word += n_corrected 156 | 157 | loss_per_word = total_loss / total_word 158 | accuracy = total_corrected_word / total_word 159 | 160 | return {'total_seconds': (datetime.now() - start_time).total_seconds(), 161 | 'total_loss': total_loss, 162 | 'total_word': total_word, 163 | 'total_corrected_word': total_corrected_word, 164 | 'loss_per_word': loss_per_word, 165 | 'accuracy': accuracy} 166 | 167 | 168 | def _prepare_batch_data(batch, device): 169 | """ 170 | Prepare data 171 | - src_input: , 외국단어_1, 외국단어_2, ..., 외국단어_n, , pad_1, ..., pad_n 172 | - trg_inprint_performancesput: (256, 33) -> , 영어_1, 영어_2, ..., 영어_n, , pad_1, ..., pad_n-1 173 | - y_true : (256 * 33) -> 영어_1, 영어_2, ... 영어_n, , pad_1, ..., pad_n 174 | """ 175 | src_input = batch.src.transpose(0, 1).to(device) # (seq_length, batch) -> (batch, seq_length) 176 | trg_input = batch.trg.transpose(0, 1).to(device) # (seq_length, batch) -> (batch, seq_length) 177 | trg_input, y_true = trg_input[:, :-1], trg_input[:, 1:].contiguous().view(-1) 178 | return src_input, trg_input, y_true 179 | 180 | 181 | def calculate_performance(y_pred: torch.Tensor, y_true: torch.Tensor, trg_pad_idx: int) -> Tuple[int, int]: 182 | y_pred = y_pred.view(-1, y_pred.size(-1)) 183 | y_argmax = y_pred.argmax(dim=1) 184 | y_true = y_true.contiguous().view(-1) 185 | 186 | non_pad_mask = y_true != trg_pad_idx 187 | n_corrected = (y_argmax == y_true).masked_select(non_pad_mask).sum().item() 188 | n_word = non_pad_mask.sum().item() 189 | 190 | return n_word, n_corrected 191 | 192 | 193 | def calculate_loss(y_pred, y_true, trg_pad_idx, trg_vocab): 194 | """ 195 | y_pred는 trg_vocab_size인 vector 형태로 들어오고, 196 | y_true값은 index값으로 들어옴. 197 | F.cross_entropy에 그대로 집어 넣으면 vector에서 가장 큰 값과, 198 | 199 | :param y_pred: (batch * seq_len, trg_vocab_size) ex. (256*33, 9473) 200 | :param y_true: (batch * seq_len) 201 | :param trg_pad_idx: 202 | :return: 203 | """ 204 | # DEBUG 205 | # true_sentence = to_sentence(y_true.reshape(64, -1)[0], trg_vocab) 206 | # pred_sentence = to_sentence(y_pred[0], trg_vocab) 207 | # print(true_sentence) 208 | # print(pred_sentence) 209 | 210 | y_pred = y_pred.view(-1, y_pred.size(-1)) 211 | y_true = y_true.contiguous().view(-1) 212 | 213 | return F.cross_entropy(y_pred, y_true, ignore_index=trg_pad_idx, reduction='sum') 214 | 215 | 216 | def _show_performance(epoch, step, lr, t, v, checkpoint): 217 | mins = int(t['total_seconds'] / 60) 218 | secs = int(t['total_seconds'] % 60) 219 | 220 | t_loss = t['total_loss'] 221 | t_accuracy = t['accuracy'] 222 | t_loss_per_word = t['loss_per_word'] 223 | 224 | v_loss = v['total_loss'] 225 | v_accuracy = v['accuracy'] 226 | v_loss_per_word = v['loss_per_word'] 227 | 228 | msg = f'[{epoch + 1:02}] {mins:02}:{secs:02} | loss:{t_loss:10.2f}/{v_loss:10.2f} | ' \ 229 | f'acc:{t_accuracy:7.4f}/{v_accuracy:7.4f} | ' \ 230 | f'loss_per_word:{t_loss_per_word:5.2f}/{v_loss_per_word:5.2f} | step:{step:5} | lr:{lr:6.4f}' \ 231 | f'{" | checkpoint" if checkpoint else ""}' 232 | logger.info(msg) 233 | 234 | 235 | def main(): 236 | opt = init() 237 | train_data, val_data, src_vocab, trg_vocab = load_preprocessed_data(opt) 238 | 239 | transformer = load_transformer_to_train(opt) 240 | optimizer = ScheduledAdam(transformer.parameters(), opt.embed_dim, warmup_steps=opt.warmup_steps) 241 | 242 | train(opt, transformer, optimizer) 243 | 244 | 245 | if __name__ == '__main__': 246 | main() 247 | --------------------------------------------------------------------------------