├── README.md ├── data.py ├── model.py ├── module.py ├── test.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | [Deconvolutional Latent-Variable Model for Text Sequence Matching](https://arxiv.org/abs/1709.07109) 2 | 3 | Unsupervised sequence learning and semi-supervised learning for SNLI dataset are implemented. 4 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from torchtext import data 2 | from torchtext import datasets 3 | from torchtext.vocab import GloVe 4 | 5 | from nltk import word_tokenize 6 | import numpy as np 7 | 8 | 9 | class SNLI(): 10 | def __init__(self, args): 11 | self.TEXT = data.Field(batch_first=True, tokenize=word_tokenize, lower=True, fix_length=29) 12 | self.LABEL = data.Field(sequential=False, unk_token=None) 13 | 14 | self.train, self.dev, self.test = datasets.SNLI.splits(self.TEXT, self.LABEL) 15 | 16 | self.TEXT.build_vocab(self.train, self.dev, self.test, vectors=GloVe(name='840B', dim=300)) 17 | self.LABEL.build_vocab(self.train) 18 | 19 | self.train_iter, self.dev_iter, self.test_iter = \ 20 | data.BucketIterator.splits((self.train, self.dev, self.test), 21 | batch_size=args.batch_size, 22 | device=args.gpu) 23 | 24 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from module import * 6 | 7 | 8 | class DeConvVAE(nn.Module): 9 | 10 | def __init__(self, args, data): 11 | super(DeConvVAE, self).__init__() 12 | self.args = args 13 | 14 | self.encoder = ConvolutionEncoder(args) 15 | 16 | self.fc_mu = nn.Linear(args.feature_maps[2], args.latent_size) 17 | self.fc_logvar = nn.Linear(args.feature_maps[2], args.latent_size) 18 | 19 | self.decoder = DeconvolutionDecoder(args) 20 | 21 | self.dropout = nn.Dropout(args.dropout) 22 | 23 | 24 | def reparameterize(self, mu, logvar): 25 | if self.training: 26 | std = torch.exp(0.5 * logvar) 27 | eps = torch.randn_like(std) 28 | return eps.mul(std).add_(mu) 29 | else: 30 | return mu 31 | 32 | 33 | def forward(self, x, word_emb): 34 | # Encode 35 | h = self.encoder(self.dropout(x)) 36 | mu = self.fc_mu(self.dropout(h)) 37 | logvar = self.fc_logvar(self.dropout(h)) 38 | 39 | # Sample 40 | z = self.reparameterize(mu, logvar) 41 | 42 | # Decode 43 | x_hat = self.decoder(z) 44 | 45 | # normalize 46 | norm_x_hat = torch.norm(x_hat, 2, dim=2, keepdim=True) 47 | rec_x_hat = x_hat / norm_x_hat 48 | norm_w = torch.norm(word_emb.weight.data, 2, dim=1, keepdim=True) 49 | rec_w = (word_emb.weight.data / (norm_w + 1e-20)).t() 50 | 51 | # compute probability 52 | prob_logits = torch.bmm(rec_x_hat, rec_w.unsqueeze(0) 53 | .expand(rec_x_hat.size(0), *rec_w.size())) / self.args.tau 54 | log_prob = F.log_softmax(prob_logits, dim=2) 55 | 56 | return log_prob, mu, logvar, z 57 | 58 | 59 | def generate(self, sample_num, word_emb): 60 | latent_size = self.args.latent_size 61 | device = torch.device(self.args.device) 62 | 63 | # Sample 64 | z = torch.cat([torch.randn(latent_size).unsqueeze_(0) for i in range(sample_num)], dim=0) 65 | z = z.to(device) 66 | 67 | # Decode 68 | x_hat = self.decoder(z) 69 | 70 | # normalize 71 | norm_x_hat = torch.norm(x_hat, 2, dim=2, keepdim=True) 72 | rec_x_hat = x_hat / norm_x_hat 73 | norm_w = torch.norm(word_emb.weight.data, 2, dim=1, keepdim=True) 74 | rec_w = (word_emb.weight.data / (norm_w + 1e-20)).t() 75 | 76 | # compute probability 77 | prob_logits = torch.bmm(rec_x_hat, rec_w.unsqueeze(0) 78 | .expand(rec_x_hat.size(0), *rec_w.size())) / self.args.tau 79 | log_prob = F.log_softmax(prob_logits, dim=2) 80 | 81 | return log_prob 82 | 83 | 84 | class NN4VAE(nn.Module): 85 | 86 | def __init__(self, args, data): 87 | super(NN4VAE, self).__init__() 88 | 89 | self.args = args 90 | 91 | self.word_emb = nn.Embedding(args.word_vocab_size, args.word_dim) 92 | # initialize word embedding with GloVe 93 | self.word_emb.weight.data.copy_(data.TEXT.vocab.vectors) 94 | # fine-tune the word embedding 95 | self.word_emb.weight.requires_grad = True 96 | # vectors is randomly initialized 97 | nn.init.uniform_(self.word_emb.weight.data[0], -0.05, 0.05) 98 | 99 | self.vae = DeConvVAE(args, data) 100 | 101 | 102 | def forward(self, x): 103 | # word embedding 104 | x = self.word_emb(x) 105 | 106 | log_prob, mu, logvar, z = self.vae(x, self.word_emb) 107 | 108 | return log_prob, mu, logvar, z 109 | 110 | 111 | def generate(self, sample_num): 112 | return self.vae.generate(sample_num, self.word_emb) 113 | 114 | 115 | class NN4SNLI(nn.Module): 116 | 117 | def __init__(self, args, data): 118 | super(NN4SNLI, self).__init__() 119 | 120 | self.args = args 121 | 122 | self.word_emb = nn.Embedding(args.word_vocab_size, args.word_dim) 123 | # initialize word embedding with GloVe 124 | self.word_emb.weight.data.copy_(data.TEXT.vocab.vectors) 125 | # fine-tune the word embedding 126 | self.word_emb.weight.requires_grad = True 127 | # vectors is randomly initialized 128 | nn.init.uniform_(self.word_emb.weight.data[0], -0.05, 0.05) 129 | 130 | self.vae = DeConvVAE(args, data) 131 | 132 | self.fc_1 = nn.Linear(4*args.latent_size, args.hidden_size) 133 | self.fc_2 = nn.Linear(args.hidden_size, args.hidden_size) 134 | self.fc_out = nn.Linear(args.hidden_size, args.class_size) 135 | 136 | self.relu = nn.ReLU() 137 | 138 | 139 | def forward(self, batch): 140 | p = batch.premise 141 | h = batch.hypothesis 142 | 143 | # (batch, seq_len, word_dim) 144 | p_x = self.word_emb(p) 145 | h_x = self.word_emb(h) 146 | 147 | # VAE 148 | p_log_prob, p_mu, p_logvar, z_p = self.vae(p_x, self.word_emb) 149 | h_log_prob, h_mu, h_logvar, z_h = self.vae(h_x, self.word_emb) 150 | 151 | # matching layer 152 | m = torch.cat([z_p, z_h, z_p - z_h, z_p * z_h], dim=-1) 153 | 154 | # fully-connected layers 155 | out = self.relu(self.fc_1(m)) 156 | out = self.relu(self.fc_2(out)) 157 | out = self.fc_out(out) 158 | 159 | return out, p_log_prob, p_mu, p_logvar, h_log_prob, h_mu, h_logvar -------------------------------------------------------------------------------- /module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ConvolutionEncoder(nn.Module): 6 | 7 | def __init__(self, args): 8 | super(ConvolutionEncoder, self).__init__() 9 | 10 | self.conv1 = nn.Conv2d(1, args.feature_maps[0], (args.filter_size, args.word_dim), stride=args.stride) 11 | self.conv2 = nn.Conv2d(args.feature_maps[0], args.feature_maps[1], (args.filter_size, 1), stride=args.stride) 12 | self.conv3 = nn.Conv2d(args.feature_maps[1], args.feature_maps[2], (args.filter_size, 1), stride=args.stride) 13 | 14 | self.relu = nn.ReLU() 15 | 16 | 17 | def forward(self, x): 18 | # reshape for convolution layer 19 | x.unsqueeze_(1) 20 | 21 | h1 = self.relu(self.conv1(x)) 22 | h2 = self.relu(self.conv2(h1)) 23 | h3 = self.relu(self.conv3(h2)) 24 | 25 | # (batch, feature_maps[2]) 26 | h3.squeeze_() 27 | if len(h3.size()) < 2: 28 | h3.unsqueeze_(0) 29 | return h3 30 | 31 | 32 | class DeconvolutionDecoder(nn.Module): 33 | 34 | def __init__(self, args): 35 | super(DeconvolutionDecoder, self).__init__() 36 | 37 | self.deconv1 = nn.ConvTranspose2d(args.latent_size, args.feature_maps[1], (args.filter_size, 1), stride=args.stride) 38 | self.deconv2 = nn.ConvTranspose2d(args.feature_maps[1], args.feature_maps[0], (args.filter_size, 1), stride=args.stride) 39 | self.deconv3 = nn.ConvTranspose2d(args.feature_maps[0], 1, (args.filter_size, args.word_dim), stride=args.stride) 40 | 41 | self.relu = nn.ReLU() 42 | 43 | 44 | def forward(self, z): 45 | # reshape for deconvolution layer 46 | z = z.unsqueeze(-1).unsqueeze(-1) 47 | 48 | h2 = self.relu(self.deconv1(z)) 49 | h1 = self.relu(self.deconv2(h2)) 50 | x_hat = self.relu(self.deconv3(h1)) 51 | 52 | # (batch, seq_len, word_dim) 53 | x_hat.squeeze_() 54 | if len(x_hat.size()) < 3: 55 | x_hat.unsqueeze_(0) 56 | return x_hat 57 | 58 | 59 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | from random import randint 6 | 7 | 8 | def compute_cross_entropy(log_prob, target): 9 | # compute reconstruction loss using cross entropy 10 | loss = [F.nll_loss(sentence_emb_matrix, word_ids, size_average=False) for sentence_emb_matrix, word_ids in zip(log_prob, target)] 11 | average_loss = sum([torch.sum(l) for l in loss]) / log_prob.size()[0] 12 | return average_loss 13 | 14 | 15 | def loss_function(log_prob, target, mu, logvar): 16 | reconst = compute_cross_entropy(log_prob, target) 17 | KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1).mean() 18 | 19 | return reconst, KLD 20 | 21 | 22 | def test(model, data, mode='test'): 23 | with torch.no_grad(): 24 | if mode == 'dev': 25 | iterator = iter(data.dev_iter) 26 | else: 27 | iterator = iter(data.test_iter) 28 | 29 | model.eval() 30 | test_reconst, test_KLD = 0, 0 31 | loss, size = 0, 0 32 | 33 | for batch in iterator: 34 | batch_text = torch.cat([batch.premise, batch.hypothesis], dim=0) 35 | log_prob, mu, logvar, _ = model(batch_text) 36 | 37 | reconst, KLD = loss_function(log_prob, batch_text, mu, logvar) 38 | batch_loss = reconst + KLD 39 | loss += batch_loss.item() 40 | 41 | test_reconst += reconst.item() 42 | test_KLD += KLD.item() 43 | size += 1 44 | 45 | test_reconst /= size 46 | test_KLD /= size 47 | loss /= size 48 | return loss, test_reconst, test_KLD 49 | 50 | 51 | def snli_test(model, data, mode='test'): 52 | with torch.no_grad(): 53 | if mode == 'dev': 54 | iterator = iter(data.dev_iter) 55 | else: 56 | iterator = iter(data.test_iter) 57 | 58 | criterion = nn.CrossEntropyLoss() 59 | model.eval() 60 | acc, loss, size = 0, 0, 0 61 | 62 | for batch in iterator: 63 | pred, _, _, _, _, _, _ = model(batch) 64 | 65 | batch_loss = criterion(pred, batch.label) 66 | loss += batch_loss.item() 67 | 68 | _, pred = pred.max(dim=1) 69 | acc += (pred == batch.label).sum().float() 70 | size += len(pred) 71 | 72 | acc /= size 73 | acc = acc.cpu().item() 74 | return loss, acc 75 | 76 | # reconstuct an example from test set 77 | def example(model, args, data): 78 | i = randint(0, len(data.test.examples)) 79 | 80 | e = data.test.examples[i] 81 | 82 | print(e.premise) 83 | p = torch.ones(29, dtype=torch.long).to(torch.device(args.device)) 84 | for i in range(len(e.premise)): 85 | if i < 29: 86 | p[i] = data.TEXT.vocab.stoi[e.premise[i]] 87 | 88 | model.eval() 89 | log_prob, mu, logvar, _ = model(p.unsqueeze(0)) 90 | 91 | _, predict_index = torch.max(log_prob, 2) 92 | p_predict = [data.TEXT.vocab.itos[word] for word in predict_index[0]] 93 | 94 | print(p_predict) 95 | 96 | 97 | def snli_example(model, args, data): 98 | i = randint(0, len(data.test.examples)) 99 | 100 | e = data.test.examples[i] 101 | 102 | print(e.premise) 103 | print(e.hypothesis) 104 | p = torch.ones(29, dtype=torch.long).to(torch.device(args.device)) 105 | h = torch.ones(29, dtype=torch.long).to(torch.device(args.device)) 106 | for i in range(len(e.premise)): 107 | if i < 29: 108 | p[i] = data.TEXT.vocab.stoi[e.premise[i]] 109 | for i in range(len(e.hypothesis)): 110 | if i < 29: 111 | h[i] = data.TEXT.vocab.stoi[e.hypothesis[i]] 112 | 113 | example = object 114 | setattr(example, 'premise', p.unsqueeze(0)) 115 | setattr(example, 'hypothesis', h.unsqueeze(0)) 116 | 117 | model.eval() 118 | pred, p_log_prob, p_mu, p_logvar, z_p, h_log_prob, h_mu, h_logvar, z_h = model(example) 119 | 120 | _, p_predict_index = torch.max(p_log_prob, 2) 121 | _, h_predict_index = torch.max(h_log_prob, 2) 122 | p_predict = [data.TEXT.vocab.itos[word] for word in p_predict_index[0]] 123 | h_predict = [data.TEXT.vocab.itos[word] for word in h_predict_index[0]] 124 | 125 | print(p_predict) 126 | print(h_predict) 127 | 128 | 129 | # generate 10 sentences 130 | def generate(model, args, data, sample_num): 131 | log_prob = model.generate(sample_num) 132 | _, predict_index = torch.max(log_prob, 2) 133 | 134 | for sentence in predict_index: 135 | predict = [data.TEXT.vocab.itos[word] for word in sentence] 136 | print(predict) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import os 4 | import torch 5 | 6 | from torch import nn, optim 7 | import torch.nn.functional as F 8 | from tensorboardX import SummaryWriter 9 | from time import gmtime, strftime 10 | 11 | from model import NN4VAE, NN4SNLI 12 | from data import SNLI 13 | from test import test, snli_test, example, generate 14 | 15 | 16 | def count_parameters(model): 17 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 18 | 19 | 20 | def compute_cross_entropy(log_prob, target): 21 | # compute reconstruction loss using cross entropy 22 | loss = [F.nll_loss(sentence_emb_matrix, word_ids, size_average=False) for sentence_emb_matrix, word_ids in zip(log_prob, target)] 23 | average_loss = sum([torch.sum(l) for l in loss]) / log_prob.size()[0] 24 | return average_loss 25 | 26 | 27 | def loss_function(log_prob, target, mu, logvar): 28 | reconst = compute_cross_entropy(log_prob, target) 29 | KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1).mean() 30 | 31 | return reconst, KLD 32 | 33 | 34 | def train(args, data): 35 | model = NN4VAE(args, data) 36 | model.to(torch.device(args.device)) 37 | 38 | parameters = filter(lambda p: p.requires_grad, model.parameters()) 39 | optimizer = optim.Adam(parameters, lr=args.learning_rate) 40 | print("number of all parameters: " + str(count_parameters(model))) 41 | 42 | writer = SummaryWriter(log_dir='runs/' + args.model_time) 43 | 44 | model.train() 45 | train_reconst, train_KLD = 0, 0 46 | loss, size, last_epoch = 0, 0, -1 47 | 48 | iterator = data.train_iter 49 | for i, batch in enumerate(iterator): 50 | present_epoch = int(iterator.epoch) 51 | if present_epoch == args.epoch: 52 | break 53 | if present_epoch > last_epoch: 54 | print('epoch:', present_epoch + 1) 55 | generate(model, args, data, sample_num=10) 56 | last_epoch = present_epoch 57 | 58 | batch_text = torch.cat([batch.premise, batch.hypothesis], dim=0) 59 | log_prob, mu, logvar, _ = model(batch_text) 60 | 61 | optimizer.zero_grad() 62 | reconst, KLD = loss_function(log_prob, batch_text, mu, logvar) 63 | batch_loss = reconst + KLD 64 | loss += batch_loss.item() 65 | batch_loss.backward() 66 | optimizer.step() 67 | 68 | train_reconst += reconst.item() 69 | train_KLD += KLD.item() 70 | size += 1 71 | 72 | writer.add_scalar('KL_divergence/train', KLD.item(), size) 73 | if (i + 1) % args.print_freq == 0: 74 | train_reconst /= size 75 | train_KLD /= size 76 | loss /= size 77 | 78 | dev_loss, dev_reconst, dev_KLD = test(model, data, mode='dev') 79 | test_loss, test_reconst, test_KLD = test(model, data) 80 | 81 | c = (i + 1) // args.print_freq 82 | 83 | writer.add_scalar('loss/train', loss, c) 84 | writer.add_scalar('reconstruction loss/train', train_reconst, c) 85 | 86 | writer.add_scalar('loss/dev', dev_loss, c) 87 | writer.add_scalar('reconstruction loss/dev', dev_reconst, c) 88 | writer.add_scalar('KL_divergence/dev', dev_KLD, c) 89 | writer.add_scalar('loss/test', test_loss, c) 90 | writer.add_scalar('reconstruction loss/test', test_reconst, c) 91 | writer.add_scalar('KL_divergence/test', test_KLD, c) 92 | 93 | print(f'train loss: {loss:.5f} / train reconstruction loss: {train_reconst:.5f} / train KL divergence: {train_KLD:.5f}') 94 | print(f'dev loss: {dev_loss:.5f} / dev reconstruction loss: {dev_reconst:.5f} / dev KL divergence: {dev_KLD:.5f}') 95 | print(f'test loss: {test_loss:.5f} / test reconstruction loss: {test_reconst:.5f} / test KL divergence: {test_KLD:.5f}') 96 | 97 | example(model, args, data) 98 | 99 | train_reconst, train_KLD, loss, size = 0, 0, 0, 0 100 | model.train() 101 | 102 | writer.close() 103 | 104 | return model 105 | 106 | 107 | def snli_train(args, data): 108 | model = NN4SNLI(args, data) 109 | model.to(torch.device(args.device)) 110 | 111 | parameters = filter(lambda p: p.requires_grad, model.parameters()) 112 | optimizer = optim.Adam(parameters, lr=args.learning_rate) 113 | criterion = nn.CrossEntropyLoss() 114 | print("number of all parameters: " + str(count_parameters(model))) 115 | 116 | writer = SummaryWriter(log_dir='runs/' + args.model_time) 117 | 118 | model.train() 119 | acc, loss, size, last_epoch = 0, 0, 0, -1 120 | train_reconst, train_KLD, vae_size = 0, 0, 0 121 | max_dev_acc, max_test_acc = 0, 0 122 | alpha = torch.Tensor([-0.1]).to(torch.device(args.device)) 123 | 124 | iterator = data.train_iter 125 | for i, batch in enumerate(iterator): 126 | present_epoch = int(iterator.epoch) 127 | if present_epoch == args.epoch: 128 | break 129 | if present_epoch > last_epoch: 130 | print('epoch:', present_epoch + 1) 131 | if alpha < 1: 132 | alpha += 0.1 133 | last_epoch = present_epoch 134 | 135 | pred, p_log_prob, p_mu, p_logvar, h_log_prob, h_mu, h_logvar = model(batch) 136 | 137 | optimizer.zero_grad() 138 | batch_loss = alpha * criterion(pred, batch.label) 139 | p_reconst, p_KLD = loss_function(p_log_prob, batch.premise, p_mu, p_logvar) 140 | h_reconst, h_KLD = loss_function(h_log_prob, batch.hypothesis, h_mu, h_logvar) 141 | batch_loss += p_reconst + h_reconst + p_KLD + h_KLD 142 | loss += batch_loss.item() 143 | batch_loss.backward() 144 | optimizer.step() 145 | 146 | train_reconst += p_reconst.item() + h_reconst.item() 147 | train_KLD += p_KLD.item() + h_KLD.item() 148 | vae_size += 2 149 | 150 | _, pred = pred.max(dim=1) 151 | acc += (pred == batch.label).sum().float() 152 | size += len(pred) 153 | 154 | if (i + 1) % args.print_freq == 0: 155 | acc /= size 156 | acc = acc.cpu().item() 157 | loss /= vae_size 158 | train_reconst /= vae_size 159 | train_KLD /= vae_size 160 | dev_loss, dev_acc = snli_test(model, data, mode='dev') 161 | test_loss, test_acc = snli_test(model, data) 162 | c = (i + 1) // args.print_freq 163 | 164 | writer.add_scalar('loss/train', loss, c) 165 | writer.add_scalar('acc/train', acc, c) 166 | writer.add_scalar('reconstruction loss/train', train_reconst, c) 167 | writer.add_scalar('KL_divergence/train', train_KLD, c) 168 | writer.add_scalar('loss/dev', dev_loss, c) 169 | writer.add_scalar('acc/dev', dev_acc, c) 170 | writer.add_scalar('loss/test', test_loss, c) 171 | writer.add_scalar('acc/test', test_acc, c) 172 | 173 | print(f'train loss: {loss:.5f} / train reconstruction loss: {train_reconst:.5f} / train KL divergence: {train_KLD:.5f}') 174 | print(f'dev loss: {dev_loss:.3f} / test loss: {test_loss:.3f}' 175 | f' / train acc: {acc:.3f} / dev acc: {dev_acc:.3f} / test acc: {test_acc:.3f}') 176 | 177 | if dev_acc > max_dev_acc: 178 | max_dev_acc = dev_acc 179 | max_test_acc = test_acc 180 | best_model = copy.deepcopy(model) 181 | 182 | acc, loss, size = 0, 0, 0 183 | vae_size = 0 184 | model.train() 185 | 186 | writer.close() 187 | print(f'max dev acc: {max_dev_acc:.3f} / max test acc: {max_test_acc:.3f}') 188 | 189 | return best_model 190 | 191 | 192 | def main(): 193 | parser = argparse.ArgumentParser() 194 | parser.add_argument('--batch-size', default=16, type=int) 195 | parser.add_argument('--data-type', default='SNLI') 196 | parser.add_argument('--dropout', default=0.3, type=float) 197 | parser.add_argument('--epoch', default=20, type=int) 198 | parser.add_argument('--gpu', default=0, type=int) 199 | parser.add_argument('--learning-rate', default=3e-4, type=float) 200 | parser.add_argument('--print-freq', default=3000, type=int) 201 | parser.add_argument('--word-dim', default=300, type=int) 202 | parser.add_argument('--filter-size', default=5, type=int) 203 | parser.add_argument('--stride', default=2, type=int) 204 | parser.add_argument('--latent-size', default=500, type=int) 205 | parser.add_argument('--tau', default=0.01, type=float) 206 | parser.add_argument('--hidden-size', default=500, type=int) 207 | parser.add_argument('--mode', default='VAE', help="available mode: VAE, SNLI") 208 | 209 | args = parser.parse_args() 210 | 211 | print('loading SNLI data...') 212 | data = SNLI(args) 213 | 214 | setattr(args, 'word_vocab_size', len(data.TEXT.vocab)) 215 | setattr(args, 'class_size', len(data.LABEL.vocab)) 216 | setattr(args, 'model_time', strftime('%H:%M:%S', gmtime())) 217 | setattr(args, 'feature_maps', [300, 600, 500]) 218 | if args.gpu > -1: 219 | setattr(args, 'device', "cuda:0") 220 | else: 221 | setattr(args, 'device', "cpu") 222 | 223 | print('training start!') 224 | if args.mode == 'VAE': 225 | best_model = train(args, data) 226 | else: 227 | best_model = snli_train(args, data) 228 | 229 | if not os.path.exists('saved_models'): 230 | os.makedirs('saved_models') 231 | if args.mode == 'VAE': 232 | torch.save(best_model.state_dict(), f'saved_models/DeConv_VAE_{args.data_type}_{args.model_time}.pt') 233 | else: 234 | torch.save(best_model.state_dict(), f'saved_models/DeConv_VAE_SNLI_{args.data_type}_{args.model_time}.pt') 235 | 236 | print('training finished!') 237 | 238 | 239 | if __name__ == '__main__': 240 | main() 241 | --------------------------------------------------------------------------------