├── Constants.py ├── README.md ├── .gitignore ├── Optim.py ├── dataset.py ├── Model.py ├── preprocess.py └── train.py /Constants.py: -------------------------------------------------------------------------------- 1 | 2 | PAD = 0 3 | UNK = 1 4 | BOS = 2 5 | EOS = 3 6 | 7 | PAD_WORD = '' 8 | UNK_WORD = '' 9 | BOS_WORD = '' 10 | EOS_WORD = '' -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # cnn-seq2seq 2 | 3 | # PyTorch Implementation of [Convlutional Sequence to Sequence learning](https://arxiv.org/abs/1705.03122) 4 | 5 | 6 | ## 1. Data 7 | 8 | ### **_[the open parallel corpus](http://opus.lingfil.uu.se/)_** 9 | 10 | 1.1 [EUROPARL v7 - European Parliament Proceedings](http://opus.lingfil.uu.se/Europarl.php) ([Europarlv7.tar.gz](http://opus.lingfil.uu.se/download.php?f=Europarl/Europarlv7.tar.gz) - 8.4 GB) 11 | 12 | 13 | ## 2. Preprocess 14 | 15 | ### test files 16 | \# -files: it is a directory, and contains train.src, train.tgt, valid.src, valid.tgt, test.src, test.tgt 17 | 18 | \# -save_data: it save the .train.pt file. 19 | python preprocess.py -files /home/zeng/conversation/OpenNMT-py/data/test/ -save_data /home/zeng/data/test/test -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | 91 | *.pyc 92 | /data 93 | *.pt 94 | -------------------------------------------------------------------------------- /Optim.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.optim as optim 3 | import torch.nn as nn 4 | from torch.nn.utils import clip_grad_norm 5 | 6 | class Optim(object): 7 | 8 | def set_parameters(self, params, momentum=0.9): 9 | self.params = list(params) # careful: params may be a generator 10 | if self.method == 'sgd': 11 | self.optimizer = optim.SGD(self.params, lr=self.lr) 12 | #self.optimizer = optim.SGD(self.params, lr=self.lr, momentum=momentum) 13 | 14 | elif self.method == 'adagrad': 15 | self.optimizer = optim.Adagrad(self.params, lr=self.lr) 16 | elif self.method == 'adadelta': 17 | self.optimizer = optim.Adadelta(self.params, lr=self.lr) 18 | elif self.method == 'adam': 19 | self.optimizer = optim.Adam(self.params, lr=self.lr) 20 | else: 21 | raise RuntimeError("Invalid optim method: " + self.method) 22 | 23 | def __init__(self, method, lr, max_grad_norm, lr_decay=1, start_decay_at=None): 24 | self.last_ppl = None 25 | self.lr = lr 26 | self.max_grad_norm = max_grad_norm 27 | self.method = method 28 | self.lr_decay = lr_decay 29 | self.start_decay_at = start_decay_at 30 | self.start_decay = False 31 | 32 | def step(self): 33 | # Compute gradients norm. 34 | # if self.max_grad_norm: 35 | # clip_grad_norm(self.params, self.max_grad_norm) 36 | self.optimizer.step() 37 | 38 | # decay learning rate if val perf does not improve or we hit the start_decay_at limit 39 | def updateLearningRate(self, ppl, epoch): 40 | if self.start_decay_at is not None and epoch >= self.start_decay_at: 41 | self.start_decay = True 42 | if self.last_ppl is not None and ppl > self.last_ppl: 43 | self.start_decay = True 44 | 45 | if self.start_decay: 46 | self.lr = self.lr * self.lr_decay 47 | print("Decaying learning rate to %g" % self.lr) 48 | 49 | self.last_ppl = ppl 50 | self.optimizer.param_groups[0]['lr'] = self.lr 51 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import math 4 | import random 5 | 6 | import torch 7 | from torch.autograd import Variable 8 | 9 | 10 | class Dataset(object): 11 | 12 | def __init__(self, xs, ys, batch_size, cuda, volatile=False): 13 | 14 | self.xs = xs 15 | self.ys = ys 16 | assert (len(self.xs) == len(self.ys)) 17 | self.batch_size = batch_size 18 | self.numBatches = math.ceil(len(self.xs)/batch_size) 19 | self.volatile = volatile 20 | self.cuda = cuda 21 | 22 | def _batchify(self, data, align_right=False, include_lengths=False, PADDING_TOKEN=0): 23 | lengths = [x.size(0) for x in data] 24 | max_length = max(lengths) 25 | out = data[0].new(len(data), max_length).fill_(PADDING_TOKEN) 26 | for i in range(len(data)): 27 | data_length = data[i].size(0) 28 | offset = max_length - data_length if align_right else 0 29 | out[i].narrow(0, offset, data_length).copy_(data[i]) 30 | if include_lengths: 31 | return out, lengths 32 | else: 33 | return out 34 | 35 | def __getitem__(self, index): 36 | assert index < self.numBatches, "%d > %d" % (index, self.numBatches) 37 | xs, lengths = self._batchify( 38 | self.xs[index*self.batch_size:(index+1)*self.batch_size], align_right=False, include_lengths=True) 39 | 40 | ys = self._batchify( 41 | self.ys[index*self.batch_size:(index+1)*self.batch_size]) 42 | 43 | # within batch sorting by decreasing length for variable length rnns 44 | indices = range(len(xs)) 45 | batch = zip(indices, xs, ys) 46 | batch, lengths = zip(*sorted(zip(batch, lengths), key=lambda x: -x[1])) 47 | indices, xs, ys = zip(*batch) 48 | 49 | def wrap(b): 50 | if b is None: 51 | return b 52 | b = torch.stack(b, 0).t().contiguous() 53 | if self.cuda: 54 | b = b.cuda() 55 | b = Variable(b, volatile=self.volatile) 56 | return b 57 | 58 | return (wrap(xs), lengths), wrap(ys) 59 | 60 | def __len__(self): 61 | return self.numBatches 62 | 63 | 64 | def shuffle(self): 65 | data = list(zip(self.xs, self.ys)) 66 | self.xs, self.ys = zip(*[data[i] for i in torch.randperm(len(data))]) -------------------------------------------------------------------------------- /Model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 7 | from torch.nn.utils.rnn import pack_padded_sequence as pack 8 | 9 | import Constants 10 | 11 | class Encoder(nn.Module): 12 | """ 13 | Args: 14 | input: seq_len, batch 15 | Returns: 16 | attn: batch, seq_len, hidden_size 17 | outputs: batch, seq_len, hidden_size 18 | 19 | """ 20 | def __init__(self, opt, vocab_size): 21 | super(Encoder, self).__init__() 22 | self.vocab_size = vocab_size 23 | self.embedding_size = opt.embedding_size 24 | self.hidden_size = opt.hidden_size 25 | 26 | self.in_channels = 1 27 | self.out_channels = opt.hidden_size * 2 28 | self.kernel_size = opt.kernel_size 29 | self.kernel = (opt.kernel_size, opt.hidden_size * 2) 30 | self.stride = 1 31 | self.padding = ((opt.kernel_size -1) / 2, 0) 32 | self.layers = opt.enc_layers 33 | 34 | self.embedding = nn.Embedding(self.vocab_size, self.embedding_size) 35 | self.affine = nn.Linear(self.embedding_size, 2*self.hidden_size) 36 | self.softmax = nn.Softmax() 37 | 38 | self.conv = nn.Conv2d(self.in_channels, self.out_channels, self.kernel, self.stride,self.padding) 39 | 40 | self.mapping = nn.Linear(self.hidden_size, 2 * self.hidden_size) 41 | # self.attn = nn.Linear(2 * self.hidden_size, self.hidden_size) 42 | self.bn1 = nn.BatchNorm1d(self.hidden_size) 43 | self.bn2 = nn.BatchNorm1d(self.hidden_size * 2) 44 | 45 | def forward(self, input): 46 | inputs = self.embedding(input[0]) 47 | _inputs = inputs.view(-1, inputs.size(2)) 48 | _outputs = self.affine(_inputs) 49 | _outputs = _outputs.view(inputs.size(0), inputs.size(1), -1).t() 50 | outputs = _outputs 51 | for i in range(self.layers): 52 | outputs = outputs.unsqueeze(1) # batch, 1, seq_len, 2*hidden 53 | outputs = self.conv(outputs) # batch, out_channels, seq_len, 1 54 | outputs = F.relu(outputs) 55 | outputs = outputs.squeeze(3).transpose(1,2) # batch, seq_len, 2*hidden 56 | A, B = outputs.split(self.hidden_size, 2) # A, B: batch, seq_len, hidden 57 | A2 = A.contiguous().view(-1, A.size(2)) # A2: batch * seq_len, hidden 58 | B2 = B.contiguous().view(-1, B.size(2)) # B2: batch * seq_len, hidden 59 | attn = torch.mul(A2, self.softmax(B2)) # attn: batch * seq_len, hidden 60 | attn2 = self.mapping(attn) # attn2: batch * seq_len, 2 * hidden 61 | outputs = attn2.view(A.size(0), A.size(1), -1) # outputs: batch, seq_len, 2 * hidden 62 | # outputs = torch.sum(outputs, 2).squeeze(2) 63 | out = attn2.view(A.size(0), A.size(1), -1) + _outputs # batch, seq_len, 2 * hidden_size 64 | # print "_outputs", _outputs 65 | # print "out", out 66 | 67 | return attn, out 68 | 69 | def load_pretrained_vectors(self, opt): 70 | if opt.pre_word_vecs_enc is not None: 71 | pretrained = torch.load(opt.pre_word_vecs_enc) 72 | self.word_lut.weight.data.copy_(pretrained) 73 | 74 | 75 | 76 | class Decoder(nn.Module): 77 | """ 78 | Decoder 79 | Args: 80 | Input: seq_len, batch_size 81 | return: 82 | out: 83 | """ 84 | 85 | def __init__(self, opt, vocab_size): 86 | super(Decoder, self).__init__() 87 | 88 | self.vocab_size = vocab_size 89 | self.embedding_size = opt.embedding_size 90 | self.hidden_size = opt.hidden_size 91 | 92 | 93 | self.in_channels = 1 94 | self.out_channels = opt.hidden_size * 2 95 | self.kernel_size = opt.kernel_size 96 | self.kernel = (opt.kernel_size, opt.hidden_size * 2) 97 | self.stride = 1 98 | self.padding = (opt.kernel_size - 1, 0) 99 | self.layers = 1 #opt.dec_layers 100 | 101 | self.embedding = nn.Embedding(self.vocab_size, self.embedding_size) 102 | self.affine = nn.Linear(self.embedding_size, 2 * self.hidden_size) 103 | self.softmax = nn.Softmax() 104 | 105 | self.conv = nn.Conv2d(self.in_channels, self.out_channels, self.kernel, self.stride, self.padding) 106 | 107 | self.mapping = nn.Linear(self.hidden_size, 2*self.hidden_size) 108 | 109 | self.softmax = nn.Softmax() 110 | # attn_src: src_seq_len, hidden_size 111 | def forward(self, source, target, enc_attn, source_seq_out): 112 | inputs = self.embedding(target) 113 | _inputs = inputs.view(-1, inputs.size(2)) 114 | outputs = self.affine(_inputs) 115 | outputs = outputs.view(inputs.size(0), inputs.size(1), -1).t() 116 | for i in range(self.layers): 117 | outputs = outputs.unsqueeze(1) # batch, 1, seq_len, 2*hidden 118 | outputs = self.conv(outputs) # batch, out_channels, seq_len + self.kernel_size - 1, 1 119 | outputs = outputs.narrow(2, 0, outputs.size(2)-self.kernel_size) # remove the last k elements 120 | 121 | # This is the residual connection, 122 | # for the output of the conv will add kernel_size/2 elements 123 | # before and after the origin input 124 | if i > 0: 125 | conv_out = conv_out + outputs 126 | 127 | outputs = F.relu(outputs) 128 | outputs = outputs.squeeze(3).transpose(1,2) # batch, seq_len, 2*hidden 129 | A, B = outputs.split(self.hidden_size, 2) # A, B: batch, seq_len, hidden 130 | A2 = A.contiguous().view(-1, A.size(2)) # A2: batch * seq_len, hidden 131 | B2 = B.contiguous().view(-1, B.size(2)) # B2: batch * seq_len, hidden 132 | dec_attn = torch.mul(A2, self.softmax(B2)) # attn: batch * seq_len, hidden 133 | 134 | dec_attn2 = self.mapping(dec_attn) 135 | dec_attn2 = dec_attn2.view(A.size(0), A.size(1), -1) 136 | 137 | enc_attn = enc_attn.view(A.size(0), -1, A.size(2)) # enc_attn1: batch, seq_len_src, hidden_size 138 | dec_attn = dec_attn.view(A.size(0), -1, A.size(2)) # dec_attn1: batch, seq_len_tgt, hidden_size 139 | 140 | 141 | 142 | _attn_matrix = torch.bmm(dec_attn, enc_attn.transpose(1,2)) # attn_matrix: batch, seq_len_tgt, seq_len_src 143 | attn_matrix = self.softmax(_attn_matrix.view(-1, _attn_matrix.size(2))) 144 | attn_matrix = attn_matrix.view(_attn_matrix.size(0), _attn_matrix.size(1), -1) # normalized attn_matrix: batch, seq_len_tgt, seq_len_src 145 | 146 | attns = torch.bmm(attn_matrix, source_seq_out) # attns: batch, seq_len_tgt, 2 * hidden_size 147 | outputs = dec_attn2 + attns # outpus: batch, seq_len_tgt - 1, 2 * hidden_size 148 | return outputs 149 | 150 | def load_pretrained_vectors(self, opt): 151 | if opt.pre_word_vecs_enc is not None: 152 | pretrained = torch.load(opt.pre_word_vecs_enc) 153 | self.word_lut.weight.data.copy_(pretrained) 154 | 155 | 156 | class NMTModel(nn.Module): 157 | """ 158 | NMTModel: 159 | Input: 160 | encoder: 161 | decoder: 162 | attention: 163 | generator: 164 | return: 165 | """ 166 | def __init__(self, encoder, decocer): 167 | super(NMTModel, self).__init__() 168 | self.encoder = encoder 169 | self.decocer = decocer 170 | 171 | def forward(self, source, target): 172 | # attn: batch, seq_len, hidden 173 | # out: batch, seq_len, 2 * hidden_size 174 | attn, source_seq_out = self.encoder(source) 175 | 176 | # batch, seq_len_tgt, hidden_size 177 | out = self.decocer(source, target, attn, source_seq_out) 178 | 179 | return out 180 | 181 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import re 4 | import itertools 5 | from collections import Counter 6 | import numpy as np 7 | 8 | import Constants 9 | 10 | import os 11 | import sys 12 | 13 | parser = argparse.ArgumentParser(description='preprocess.py') 14 | 15 | ## 16 | ## **Preprocess Options** 17 | ## 18 | 19 | parser.add_argument('-config', help="Read options from this file") 20 | 21 | 22 | parser.add_argument('-files', type=str, default="/home/zeng/conversation/OpenNMT-py/data/test/", 23 | help="Path to the training source data") 24 | parser.add_argument('-source_train_file', type=str, default="/home/zeng/data/OpenSubData/train.src", 25 | help="Path to the training source data") 26 | parser.add_argument('-target_train_file', type=str, default="/home/zeng/data/OpenSubData/train.tgt", 27 | help="Path to the training target data") 28 | parser.add_argument('-source_valid_file', type=str, default="/home/zeng/data/OpenSubData/valid.src", 29 | help="Path to the training source data") 30 | parser.add_argument('-target_valid_file', type=str, default="/home/zeng/data/OpenSubData/valid.tgt", 31 | help="Path to the training target data") 32 | parser.add_argument('-source_test_file', type=str, default="/home/zeng/data/OpenSubData/test.src", 33 | help="Path to the training source data") 34 | parser.add_argument('-target_test_file', type=str, default="/home/zeng/data/OpenSubData/test.tgt", 35 | help="Path to the training target data") 36 | 37 | parser.add_argument('-save_data', type=str, default="/home/zeng/data/OpenSubData/5m", 38 | help="Output file for the prepared data") 39 | 40 | parser.add_argument('-maximum_vocab_size', type=int, default=50000, 41 | help="Size of the source vocabulary") 42 | 43 | parser.add_argument('-vocab', 44 | help="Path to an existing vocabulary") 45 | 46 | parser.add_argument('-seq_length', type=int, default=50, 47 | help="Maximum sequence length") 48 | parser.add_argument('-shuffle', type=int, default=1, 49 | help="Shuffle data") 50 | parser.add_argument('-seed', type=int, default=3435, 51 | help="Random seed") 52 | 53 | parser.add_argument('-lower', action='store_true', help='lowercase data') 54 | 55 | parser.add_argument('-report_every', type=int, default=1000, 56 | help="Report status every this many sentences") 57 | 58 | opt = parser.parse_args() 59 | 60 | torch.manual_seed(opt.seed) 61 | 62 | 63 | def clean_str(string): 64 | """ 65 | Tokenization/string cleaning for all datasets except for SST. 66 | Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py 67 | """ 68 | string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string) 69 | string = re.sub(r"\'s", " \'s", string) 70 | string = re.sub(r"\'ve", " \'ve", string) 71 | string = re.sub(r"n\'t", " n\'t", string) 72 | string = re.sub(r"\'re", " \'re", string) 73 | string = re.sub(r"\'d", " \'d", string) 74 | string = re.sub(r"\'ll", " \'ll", string) 75 | string = re.sub(r",", " , ", string) 76 | string = re.sub(r"!", " ! ", string) 77 | string = re.sub(r"\(", " \( ", string) 78 | string = re.sub(r"\)", " \) ", string) 79 | string = re.sub(r"\?", " \? ", string) 80 | string = re.sub(r"\s{2,}", " ", string) 81 | return string.strip().lower() 82 | 83 | 84 | def build_vocab(sequence, maximum_vocab_size=50000): 85 | word_count = Counter(itertools.chain(*sequence)).most_common(maximum_vocab_size) 86 | word2count = dict([(word[0], word[1]) for word in word_count]) 87 | 88 | word2index = dict([(word, index + 4) for index, word in enumerate(word2count) if word != "UNknown"]) 89 | word2index[Constants.PAD_WORD], word2index[Constants.BOS_WORD], word2index[Constants.EOS_WORD], word2index[ 90 | Constants.UNK_WORD] = \ 91 | Constants.PAD, Constants.BOS, Constants.EOS, Constants.UNK 92 | 93 | index2word = dict([(index + 4, word) for index, word in enumerate(word2count) if word != "UNknown"]) 94 | index2word[Constants.PAD], index2word[Constants.BOS], index2word[Constants.EOS], index2word[ 95 | Constants.UNK] = Constants.PAD_WORD, \ 96 | Constants.BOS_WORD, Constants.EOS_WORD, Constants.UNK_WORD 97 | 98 | # word2index[Constants.PAD_WORD], word2index[Constants.BOS_WORD], word2index[Constants.EOS_WORD], word2index[Constants.UNK_WORD] = \ 99 | # Constants.PAD, Constants.BOS, Constants.EOS, Constants.UNK 100 | 101 | index2word[Constants.PAD], index2word[Constants.BOS], index2word[Constants.EOS], index2word[Constants.UNK] = \ 102 | Constants.PAD_WORD, Constants.BOS_WORD, Constants.EOS_WORD, Constants.UNK_WORD 103 | return word2count, word2index, index2word 104 | 105 | 106 | def makeData(sources, targets, src_word2index, tgt_word2index, shuffle=opt.shuffle): 107 | assert len(sources) == len(targets) 108 | sizes = [] 109 | for idx in range(len(sources)): 110 | # Insert `eosWord` at the end 111 | src_words = [src_word2index[word] if word in src_word2index else Constants.UNK for word in sources[idx]] + [ 112 | Constants.EOS] 113 | sources[idx] = torch.LongTensor(src_words) 114 | 115 | sizes += [len(sources)] 116 | 117 | tgt_words = [Constants.BOS] + [tgt_word2index[word] if word in tgt_word2index else Constants.UNK for word in targets[idx]] + [ 118 | Constants.EOS] 119 | targets[idx] = torch.LongTensor(tgt_words) 120 | 121 | if shuffle == 1: 122 | print "... shuffling sentences" 123 | perm = torch.randperm(len(sources)) 124 | sources = [sources[idx] for idx in perm] 125 | targets = [targets[idx] for idx in perm] 126 | sizes = [sizes[idx] for idx in perm] 127 | 128 | print "... sorting sentences" 129 | _, perm = torch.sort(torch.Tensor(sizes)) 130 | sources = [sources[idx] for idx in perm] 131 | targets = [targets[idx] for idx in perm] 132 | 133 | return sources, targets 134 | 135 | 136 | def load_source_and_target(source_file, target_file): 137 | """ 138 | Source_file 139 | Target_file 140 | """ 141 | 142 | src_lines = open(source_file, "r").readlines() 143 | tgt_lines = open(target_file, "r").readlines() 144 | 145 | sources = [] 146 | targets = [] 147 | 148 | for src, tgt in zip(src_lines, tgt_lines): 149 | src = src.strip().split() 150 | tgt = tgt.strip().split() 151 | 152 | sources.append(src) 153 | targets.append(tgt) 154 | 155 | return sources, targets 156 | 157 | 158 | def main(): 159 | 160 | # train 161 | source_train_file = os.path.join(opt.files, "train.src") 162 | target_train_file = os.path.join(opt.files, "train.tgt") 163 | 164 | # valid 165 | source_valid_file = os.path.join(opt.files, "valid.src") 166 | target_valid_file = os.path.join(opt.files, "valid.tgt") 167 | 168 | # test 169 | source_test_file = os.path.join(opt.files, "test.src") 170 | target_test_file = os.path.join(opt.files, "test.tgt") 171 | 172 | source_train, target_train = load_source_and_target(source_train_file, target_train_file) 173 | source_valid, target_valid = load_source_and_target(source_valid_file, target_valid_file) 174 | source_test, target_test = load_source_and_target(source_test_file, target_test_file) 175 | 176 | source_texts = source_train + source_valid + source_test 177 | target_texts = target_train + target_valid + target_test 178 | 179 | src_word2count, src_word2index, src_index2word = build_vocab(source_texts, opt.maximum_vocab_size) 180 | tgt_word2count, tgt_word2index, tgt_index2word = build_vocab(target_texts, opt.maximum_vocab_size) 181 | 182 | dicts = {} 183 | word2index = {} 184 | word2index["src"] = src_word2index 185 | word2index["tgt"] = tgt_word2index 186 | index2word = {} 187 | index2word["src"] = src_index2word 188 | index2word["tgt"] = tgt_index2word 189 | dicts["word2index"] = word2index 190 | dicts["index2word"] = index2word 191 | 192 | 193 | 194 | print('Preparing training ...') 195 | train = {} 196 | train['src'], train['tgt'] = makeData(source_train, target_train, src_word2index, tgt_word2index) 197 | 198 | print('Preparing validation ...') 199 | valid = {} 200 | valid['src'], valid['tgt'] = makeData(source_valid, target_valid, src_word2index, tgt_word2index) 201 | 202 | print('Preparing testing ...') 203 | valid = {} 204 | valid['src'], valid['tgt'] = makeData(source_test, target_test, src_word2index, tgt_word2index) 205 | 206 | print "saving data to \'" + opt.save_data + ".train.pt\'..." 207 | save_data = { 208 | "train": train, 209 | "valid": valid, 210 | "test": valid, 211 | "dicts": dicts 212 | } 213 | torch.save(save_data, opt.save_data + ".train.pt") 214 | 215 | 216 | if __name__ == "__main__": 217 | main() 218 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import argparse 4 | import torch 5 | import torch.nn as nn 6 | from torch import cuda 7 | from torch.autograd import Variable 8 | import math 9 | import time 10 | 11 | import Model 12 | from dataset import Dataset 13 | from Optim import Optim 14 | import Constants 15 | 16 | parser = argparse.ArgumentParser(description='train.py') 17 | 18 | ## Data options 19 | 20 | parser.add_argument('-data', required=True, 21 | help='Path to the *-train.pt file from preprocess.py') 22 | parser.add_argument('-save_model', default='model', 23 | help="""Model filename (the model will be saved as 24 | _epochN_PPL.pt where PPL is the 25 | validation perplexity""") 26 | parser.add_argument('-train_from_state_dict', default='', type=str, 27 | help="""If training from a checkpoint then this is the 28 | path to the pretrained model's state_dict.""") 29 | parser.add_argument('-train_from', default='', type=str, 30 | help="""If training from a checkpoint then this is the 31 | path to the pretrained model.""") 32 | 33 | ## Model options 34 | 35 | parser.add_argument('-layers', type=int, default=2, 36 | help='Number of layers in the LSTM encoder/decoder') 37 | parser.add_argument('-rnn_size', type=int, default=512, 38 | help='Size of LSTM hidden states') 39 | parser.add_argument('-embedding_size', type=int, default=512, 40 | help='Word embedding sizes') 41 | parser.add_argument('-input_feed', type=int, default=1, 42 | help="""Feed the context vector at each time step as 43 | additional input (via concatenation with the word 44 | embeddings) to the decoder.""") 45 | # parser.add_argument('-residual', action="store_true", 46 | # help="Add residual connections between RNN layers.") 47 | parser.add_argument('-brnn', action='store_true', 48 | help='Use a bidirectional encoder') 49 | parser.add_argument('-brnn_merge', default='concat', 50 | help="""Merge action for the bidirectional hidden states: 51 | [concat|sum]""") 52 | 53 | 54 | # CNN parameters 55 | ## Encoder or Decoder 56 | parser.add_argument("-hidden_size", type=int, default=512, 57 | help="CNN hidden size") 58 | parser.add_argument("-kernel_size", type=int, default=5, 59 | help="") 60 | parser.add_argument("-enc_layers", type=int, default=2, 61 | help="Numbers of encoder hidden layer") 62 | 63 | # Decoder 64 | parser.add_argument("-dec_layers", type=int, default=2, 65 | help="Numbers of decoder hidden layer") 66 | 67 | 68 | ## Optimization options 69 | 70 | parser.add_argument('-batch_size', type=int, default=64, 71 | help='Maximum batch size') 72 | parser.add_argument('-max_generator_batches', type=int, default=32, 73 | help="""Maximum batches of words in a sequence to run 74 | the generator on in parallel. Higher is faster, but uses 75 | more memory.""") 76 | parser.add_argument('-epochs', type=int, default=13, 77 | help='Number of training epochs') 78 | parser.add_argument('-start_epoch', type=int, default=1, 79 | help='The epoch from which to start') 80 | parser.add_argument('-param_init', type=float, default=0.1, 81 | help="""Parameters are initialized over uniform distribution 82 | with support (-param_init, param_init)""") 83 | parser.add_argument('-optim', default='adam', 84 | help="Optimization method. [sgd|adagrad|adadelta|adam]") 85 | parser.add_argument('-max_grad_norm', type=float, default=5, 86 | help="""If the norm of the gradient vector exceeds this, 87 | renormalize it to have the norm equal to max_grad_norm""") 88 | parser.add_argument('-dropout', type=float, default=0.3, 89 | help='Dropout probability; applied between LSTM stacks.') 90 | parser.add_argument('-curriculum', action="store_true", 91 | help="""For this many epochs, order the minibatches based 92 | on source sequence length. Sometimes setting this to 1 will 93 | increase convergence speed.""") 94 | parser.add_argument('-extra_shuffle', action="store_true", 95 | help="""By default only shuffle mini-batch order; when true, 96 | shuffle and re-assign mini-batches""") 97 | 98 | #learning rate 99 | parser.add_argument('-learning_rate', type=float, default=0.001, 100 | help="""Starting learning rate. If adagrad/adadelta/adam is 101 | used, then this is the global learning rate. Recommended 102 | settings: sgd = 1, adagrad = 0.1, adadelta = 1, adam = 0.001""") 103 | parser.add_argument('-learning_rate_decay', type=float, default=0.5, 104 | help="""If update_learning_rate, decay learning rate by 105 | this much if (i) perplexity does not decrease on the 106 | validation set or (ii) epoch has gone past 107 | start_decay_at""") 108 | parser.add_argument('-start_decay_at', type=int, default=8, 109 | help="""Start decaying every epoch after and including this 110 | epoch""") 111 | 112 | #pretrained word vectors 113 | 114 | parser.add_argument('-pre_word_vecs_enc', 115 | help="""If a valid path is specified, then this will load 116 | pretrained word embeddings on the encoder side. 117 | See README for specific formatting instructions.""") 118 | parser.add_argument('-pre_word_vecs_dec', 119 | help="""If a valid path is specified, then this will load 120 | pretrained word embeddings on the decoder side. 121 | See README for specific formatting instructions.""") 122 | 123 | # GPU 124 | parser.add_argument('-gpus', default=[], nargs='+', type=int, 125 | help="Use CUDA on the listed devices.") 126 | 127 | parser.add_argument('-log_interval', type=int, default=1, 128 | help="Print stats at this interval.") 129 | 130 | opt = parser.parse_args() 131 | 132 | print(opt) 133 | 134 | if torch.cuda.is_available() and not opt.gpus: 135 | print("WARNING: You have a CUDA device, so you should probably run with -gpus 0") 136 | 137 | if opt.gpus: 138 | cuda.set_device(opt.gpus[0]) 139 | 140 | def NMTCriterion(vocabSize): 141 | weight = torch.ones(vocabSize) 142 | weight[Constants.PAD] = 0 143 | crit = nn.NLLLoss(weight, size_average=False) 144 | if opt.gpus: 145 | crit.cuda() 146 | return crit 147 | 148 | 149 | def memoryEfficientLoss(outputs, targets, generator, crit, eval=False): 150 | # compute generations one piece at a time 151 | num_correct, loss = 0, 0 152 | outputs = Variable(outputs.data, requires_grad=(not eval), volatile=eval) 153 | 154 | targets = targets[1:] # exclude the from the begin 155 | 156 | batch_size = outputs.size(1) 157 | 158 | # print "outputs", outputs 159 | # print "targets", targets 160 | 161 | outputs_split = torch.split(outputs.t().contiguous(), opt.max_generator_batches) 162 | targets_split = torch.split(targets, opt.max_generator_batches) 163 | 164 | for i, (out_t, targ_t) in enumerate(zip(outputs_split, targets_split)): 165 | # print out_t.size(0), out_t.size(1) 166 | out_t = out_t.view(-1, out_t.size(2)) 167 | # print out_t.size(0), out_t.size(1) 168 | scores_t = generator(out_t) 169 | # print scores_t.size(0), targ_t.size(0), targ_t.size(1) 170 | targ_t = targ_t.view(-1) 171 | # print targ_t.size(0) 172 | # print targ_t 173 | loss_t = crit(scores_t, targ_t) 174 | pred_t = scores_t.max(1)[1] 175 | num_correct_t = pred_t.data.eq(targ_t.data).masked_select(targ_t.ne(Constants.PAD).data).sum() 176 | num_correct += num_correct_t 177 | loss += loss_t.data[0] 178 | if not eval: 179 | loss_t.div(batch_size).backward() 180 | 181 | grad_output = None if outputs.grad is None else outputs.grad.data 182 | # print "loss", loss 183 | return loss, grad_output, num_correct 184 | 185 | 186 | def eval(model, criterion, data): 187 | total_loss = 0 188 | total_words = 0 189 | total_num_correct = 0 190 | 191 | model.eval() 192 | for i in range(len(data)): 193 | batch = data[i] 194 | targets = batch[1] 195 | outputs = model(batch[0], targets) 196 | loss, _, num_correct = memoryEfficientLoss( 197 | outputs, targets, model.generator, criterion, eval=True) 198 | total_loss += loss 199 | total_num_correct += num_correct 200 | total_words += targets.data.ne(Constants.PAD).sum() 201 | 202 | model.train() 203 | return total_loss / total_words, total_num_correct / total_words 204 | 205 | 206 | def trainModel(model, trainData, validData, dataset, optim, criterion): 207 | print(model) 208 | model.train() 209 | 210 | # define criterion of each GPU 211 | 212 | start_time = time.time() 213 | def trainEpoch(epoch): 214 | 215 | if opt.extra_shuffle and epoch > opt.curriculum: 216 | trainData.shuffle() 217 | 218 | # shuffle mini batch order 219 | batchOrder = torch.randperm(len(trainData)) 220 | 221 | total_loss, total_words, total_num_correct = 0, 0, 0 222 | report_loss, report_tgt_words, report_src_words, report_num_correct = 0, 0, 0, 0 223 | start = time.time() 224 | for i in range(len(trainData)): 225 | 226 | batchIdx = batchOrder[i] if epoch > opt.curriculum else i 227 | batch = trainData[batchIdx]#[:-1] # exclude original indices 228 | model.zero_grad() 229 | 230 | targets = batch[1] 231 | outputs = model(batch[0], targets) 232 | loss, gradOutput, num_correct = memoryEfficientLoss( 233 | outputs, targets, model.generator, criterion) 234 | 235 | outputs.backward(gradOutput) 236 | 237 | # update the parameters 238 | optim.step() 239 | 240 | num_words = targets.data.ne(Constants.PAD).sum() 241 | report_loss += loss 242 | report_num_correct += num_correct 243 | report_tgt_words += num_words 244 | report_src_words += sum(batch[0][1]) 245 | total_loss += loss 246 | total_num_correct += num_correct 247 | total_words += num_words 248 | if i % opt.log_interval == -1 % opt.log_interval: 249 | print("Epoch %2d, %5d/%5d; acc: %6.2f; ppl: %6.2f; %3.0f src tok/s; %3.0f tgt tok/s; %6.0f s elapsed" % 250 | (epoch, i+1, len(trainData), 251 | report_num_correct / report_tgt_words * 100, 252 | math.exp(report_loss / report_tgt_words), 253 | report_src_words/(time.time()-start), 254 | report_tgt_words/(time.time()-start), 255 | time.time()-start_time)) 256 | 257 | report_loss = report_tgt_words = report_src_words = report_num_correct = 0 258 | start = time.time() 259 | 260 | return total_loss / total_words, total_num_correct / total_words 261 | 262 | for epoch in range(opt.start_epoch, opt.epochs + 1): 263 | print('') 264 | 265 | # (1) train for one epoch on the training set 266 | train_loss, train_acc = trainEpoch(epoch) 267 | train_ppl = math.exp(min(train_loss, 100)) 268 | print('Train perplexity: %g' % train_ppl) 269 | print('Train accuracy: %g' % (train_acc*100)) 270 | 271 | # (2) evaluate on the validation set 272 | valid_loss, valid_acc = eval(model, criterion, validData) 273 | valid_ppl = math.exp(min(valid_loss, 100)) 274 | print('Validation perplexity: %g' % valid_ppl) 275 | print('Validation accuracy: %g' % (valid_acc*100)) 276 | 277 | # (3) update the learning rate 278 | optim.updateLearningRate(valid_loss, epoch) 279 | 280 | model_state_dict = model.module.state_dict() if len(opt.gpus) > 1 else model.state_dict() 281 | model_state_dict = {k: v for k, v in model_state_dict.items() if 'generator' not in k} 282 | generator_state_dict = model.generator.module.state_dict() if len(opt.gpus) > 1 else model.generator.state_dict() 283 | # (4) drop a checkpoint 284 | checkpoint = { 285 | 'model': model_state_dict, 286 | 'generator': generator_state_dict, 287 | 'dicts': dataset['dicts'], 288 | 'opt': opt, 289 | 'epoch': epoch, 290 | 'optim': optim 291 | } 292 | torch.save(checkpoint, 293 | '%s_acc_%.2f_ppl_%.2f_e%d.pt' % (opt.save_model, 100*valid_acc, valid_ppl, epoch)) 294 | 295 | def main(): 296 | 297 | print("Loading data from '%s'" % opt.data) 298 | 299 | dataset = torch.load(opt.data) 300 | 301 | dict_checkpoint = opt.train_from if opt.train_from else opt.train_from_state_dict 302 | if dict_checkpoint: 303 | print('Loading dicts from checkpoint at %s' % dict_checkpoint) 304 | checkpoint = torch.load(dict_checkpoint) 305 | dataset['dicts'] = checkpoint['dicts'] 306 | 307 | trainData = Dataset(dataset['train']['src'], 308 | dataset['train']['tgt'], opt.batch_size, opt.gpus) 309 | validData = Dataset(dataset['valid']['src'], 310 | dataset['valid']['tgt'], opt.batch_size, opt.gpus, 311 | volatile=True) 312 | 313 | dicts = dataset['dicts'] 314 | print(' * vocabulary size. source = %d; target = %d' % 315 | (len(dicts["word2index"]['src']), len(dicts["word2index"]['tgt']))) 316 | print(' * number of training sentences. %d' % 317 | len(dataset['train']['src'])) 318 | print(' * maximum batch size. %d' % opt.batch_size) 319 | 320 | print('Building model...') 321 | 322 | encoder = Model.Encoder(opt, len(dicts["word2index"]['src'])) 323 | decoder = Model.Decoder(opt, len(dicts["word2index"]['tgt'])) 324 | 325 | generator = nn.Sequential( 326 | nn.Linear(opt.hidden_size * 2, len(dicts["word2index"]['tgt'])), 327 | nn.LogSoftmax()) 328 | 329 | model = Model.NMTModel(encoder, decoder) 330 | 331 | if opt.train_from: 332 | print('Loading model from checkpoint at %s' % opt.train_from) 333 | chk_model = checkpoint['model'] 334 | generator_state_dict = chk_model.generator.state_dict() 335 | model_state_dict = {k: v for k, v in chk_model.state_dict().items() if 'generator' not in k} 336 | model.load_state_dict(model_state_dict) 337 | generator.load_state_dict(generator_state_dict) 338 | opt.start_epoch = checkpoint['epoch'] + 1 339 | 340 | if opt.train_from_state_dict: 341 | print('Loading model from checkpoint at %s' % opt.train_from_state_dict) 342 | model.load_state_dict(checkpoint['model']) 343 | generator.load_state_dict(checkpoint['generator']) 344 | opt.start_epoch = checkpoint['epoch'] + 1 345 | 346 | if len(opt.gpus) >= 1: 347 | model.cuda() 348 | generator.cuda() 349 | else: 350 | model.cpu() 351 | generator.cpu() 352 | 353 | if len(opt.gpus) > 1: 354 | model = nn.DataParallel(model, device_ids=opt.gpus, dim=1) 355 | generator = nn.DataParallel(generator, device_ids=opt.gpus, dim=0) 356 | 357 | model.generator = generator 358 | 359 | if not opt.train_from_state_dict and not opt.train_from: 360 | for p in model.parameters(): 361 | p.data.uniform_(-opt.param_init, opt.param_init) 362 | 363 | encoder.load_pretrained_vectors(opt) 364 | decoder.load_pretrained_vectors(opt) 365 | 366 | optim = Optim( 367 | opt.optim, opt.learning_rate, opt.max_grad_norm, 368 | lr_decay=opt.learning_rate_decay, 369 | start_decay_at=opt.start_decay_at 370 | ) 371 | else: 372 | print('Loading optimizer from checkpoint:') 373 | optim = checkpoint['optim'] 374 | print(optim) 375 | 376 | optim.set_parameters(model.parameters()) 377 | 378 | if opt.train_from or opt.train_from_state_dict: 379 | optim.optimizer.load_state_dict(checkpoint['optim'].optimizer.state_dict()) 380 | 381 | nParams = sum([p.nelement() for p in model.parameters()]) 382 | print('* number of parameters: %d' % nParams) 383 | 384 | criterion = NMTCriterion(len(dicts["word2index"]['tgt'])) 385 | 386 | trainModel(model, trainData, validData, dataset, optim, criterion) 387 | 388 | 389 | if __name__ == "__main__": 390 | main() --------------------------------------------------------------------------------